From f71de230a6cf9f15c54a1efd237d062c039a0bde Mon Sep 17 00:00:00 2001 From: JmPotato Date: Wed, 13 Dec 2023 13:15:18 +0800 Subject: [PATCH 01/21] audit, client/http: use X-Caller-ID to replace the component signature key (#7536) ref tikv/pd#7300 Use `X-Caller-ID` to replace the component signature key. Signed-off-by: JmPotato --- client/http/client.go | 4 +- client/http/client_test.go | 2 +- pkg/audit/audit.go | 2 +- pkg/audit/audit_test.go | 12 +++--- pkg/utils/apiutil/apiutil.go | 55 ++++++++++++++++----------- pkg/utils/apiutil/apiutil_test.go | 8 ++-- pkg/utils/requestutil/context_test.go | 4 +- pkg/utils/requestutil/request_info.go | 14 ++++--- server/api/middleware.go | 2 +- server/metrics.go | 2 +- tests/pdctl/global_test.go | 12 +++--- tests/server/api/api_test.go | 10 ++--- tools/pd-ctl/pdctl/command/global.go | 17 +++++---- 13 files changed, 80 insertions(+), 64 deletions(-) diff --git a/client/http/client.go b/client/http/client.go index b79aa9ca002..613ebf33294 100644 --- a/client/http/client.go +++ b/client/http/client.go @@ -221,7 +221,7 @@ func (c *client) execDuration(name string, duration time.Duration) { // Header key definition constants. const ( pdAllowFollowerHandleKey = "PD-Allow-Follower-Handle" - componentSignatureKey = "component" + xCallerIDKey = "X-Caller-ID" ) // HeaderOption configures the HTTP header. @@ -279,7 +279,7 @@ func (c *client) request( for _, opt := range headerOpts { opt(req.Header) } - req.Header.Set(componentSignatureKey, c.callerID) + req.Header.Set(xCallerIDKey, c.callerID) start := time.Now() resp, err := c.inner.cli.Do(req) diff --git a/client/http/client_test.go b/client/http/client_test.go index 621910e29ea..70c2ddee08b 100644 --- a/client/http/client_test.go +++ b/client/http/client_test.go @@ -59,7 +59,7 @@ func TestCallerID(t *testing.T) { re := require.New(t) expectedVal := defaultCallerID httpClient := newHTTPClientWithRequestChecker(func(req *http.Request) error { - val := req.Header.Get(componentSignatureKey) + val := req.Header.Get(xCallerIDKey) if val != expectedVal { re.Failf("Caller ID header check failed", "should be %s, but got %s", expectedVal, val) diff --git a/pkg/audit/audit.go b/pkg/audit/audit.go index 3553e5f8377..b971b09ed7e 100644 --- a/pkg/audit/audit.go +++ b/pkg/audit/audit.go @@ -98,7 +98,7 @@ func (b *PrometheusHistogramBackend) ProcessHTTPRequest(req *http.Request) bool if !ok { return false } - b.histogramVec.WithLabelValues(requestInfo.ServiceLabel, "HTTP", requestInfo.Component, requestInfo.IP).Observe(float64(endTime - requestInfo.StartTimeStamp)) + b.histogramVec.WithLabelValues(requestInfo.ServiceLabel, "HTTP", requestInfo.CallerID, requestInfo.IP).Observe(float64(endTime - requestInfo.StartTimeStamp)) return true } diff --git a/pkg/audit/audit_test.go b/pkg/audit/audit_test.go index d59c9627115..8098b36975e 100644 --- a/pkg/audit/audit_test.go +++ b/pkg/audit/audit_test.go @@ -51,7 +51,7 @@ func TestPrometheusHistogramBackend(t *testing.T) { Name: "audit_handling_seconds_test", Help: "PD server service handling audit", Buckets: prometheus.DefBuckets, - }, []string{"service", "method", "component", "ip"}) + }, []string{"service", "method", "caller_id", "ip"}) prometheus.MustRegister(serviceAuditHistogramTest) @@ -62,7 +62,7 @@ func TestPrometheusHistogramBackend(t *testing.T) { req, _ := http.NewRequest(http.MethodGet, "http://127.0.0.1:2379/test?test=test", http.NoBody) info := requestutil.GetRequestInfo(req) info.ServiceLabel = "test" - info.Component = "user1" + info.CallerID = "user1" info.IP = "localhost" req = req.WithContext(requestutil.WithRequestInfo(req.Context(), info)) re.False(backend.ProcessHTTPRequest(req)) @@ -73,7 +73,7 @@ func TestPrometheusHistogramBackend(t *testing.T) { re.True(backend.ProcessHTTPRequest(req)) re.True(backend.ProcessHTTPRequest(req)) - info.Component = "user2" + info.CallerID = "user2" req = req.WithContext(requestutil.WithRequestInfo(req.Context(), info)) re.True(backend.ProcessHTTPRequest(req)) @@ -85,8 +85,8 @@ func TestPrometheusHistogramBackend(t *testing.T) { defer resp.Body.Close() content, _ := io.ReadAll(resp.Body) output := string(content) - re.Contains(output, "pd_service_audit_handling_seconds_test_count{component=\"user1\",ip=\"localhost\",method=\"HTTP\",service=\"test\"} 2") - re.Contains(output, "pd_service_audit_handling_seconds_test_count{component=\"user2\",ip=\"localhost\",method=\"HTTP\",service=\"test\"} 1") + re.Contains(output, "pd_service_audit_handling_seconds_test_count{caller_id=\"user1\",ip=\"localhost\",method=\"HTTP\",service=\"test\"} 2") + re.Contains(output, "pd_service_audit_handling_seconds_test_count{caller_id=\"user2\",ip=\"localhost\",method=\"HTTP\",service=\"test\"} 1") } func TestLocalLogBackendUsingFile(t *testing.T) { @@ -103,7 +103,7 @@ func TestLocalLogBackendUsingFile(t *testing.T) { b, _ := os.ReadFile(fname) output := strings.SplitN(string(b), "]", 4) re.Equal( - fmt.Sprintf(" [\"audit log\"] [service-info=\"{ServiceLabel:, Method:HTTP/1.1/GET:/test, Component:anonymous, IP:, Port:, "+ + fmt.Sprintf(" [\"audit log\"] [service-info=\"{ServiceLabel:, Method:HTTP/1.1/GET:/test, CallerID:anonymous, IP:, Port:, "+ "StartTime:%s, URLParam:{\\\"test\\\":[\\\"test\\\"]}, BodyParam:testBody}\"]\n", time.Unix(info.StartTimeStamp, 0).String()), output[3], diff --git a/pkg/utils/apiutil/apiutil.go b/pkg/utils/apiutil/apiutil.go index 164b3c0783d..53fab682fcb 100644 --- a/pkg/utils/apiutil/apiutil.go +++ b/pkg/utils/apiutil/apiutil.go @@ -39,15 +39,14 @@ import ( "go.uber.org/zap" ) -var ( - // componentSignatureKey is used for http request header key - // to identify component signature +const ( + // componentSignatureKey is used for http request header key to identify component signature. + // Deprecated: please use `XCallerIDHeader` below to obtain a more granular source identification. + // This is kept for backward compatibility. componentSignatureKey = "component" - // componentAnonymousValue identifies anonymous request source - componentAnonymousValue = "anonymous" -) + // anonymousValue identifies anonymous request source + anonymousValue = "anonymous" -const ( // PDRedirectorHeader is used to mark which PD redirected this request. PDRedirectorHeader = "PD-Redirector" // PDAllowFollowerHandleHeader is used to mark whether this request is allowed to be handled by the follower PD. @@ -58,6 +57,8 @@ const ( XForwardedPortHeader = "X-Forwarded-Port" // XRealIPHeader is used to mark the real client IP. XRealIPHeader = "X-Real-Ip" + // XCallerIDHeader is used to mark the caller ID. + XCallerIDHeader = "X-Caller-ID" // ForwardToMicroServiceHeader is used to mark the request is forwarded to micro service. ForwardToMicroServiceHeader = "Forward-To-Micro-Service" @@ -112,7 +113,7 @@ func ErrorResp(rd *render.Render, w http.ResponseWriter, err error) { // GetIPPortFromHTTPRequest returns http client host IP and port from context. // Because `X-Forwarded-For ` header has been written into RFC 7239(Forwarded HTTP Extension), -// so `X-Forwarded-For` has the higher priority than `X-Real-IP`. +// so `X-Forwarded-For` has the higher priority than `X-Real-Ip`. // And both of them have the higher priority than `RemoteAddr` func GetIPPortFromHTTPRequest(r *http.Request) (ip, port string) { forwardedIPs := strings.Split(r.Header.Get(XForwardedForHeader), ",") @@ -136,32 +137,42 @@ func GetIPPortFromHTTPRequest(r *http.Request) (ip, port string) { return splitIP, splitPort } -// GetComponentNameOnHTTP returns component name from Request Header -func GetComponentNameOnHTTP(r *http.Request) string { +// getComponentNameOnHTTP returns component name from the request header. +func getComponentNameOnHTTP(r *http.Request) string { componentName := r.Header.Get(componentSignatureKey) if len(componentName) == 0 { - componentName = componentAnonymousValue + componentName = anonymousValue } return componentName } -// ComponentSignatureRoundTripper is used to add component signature in HTTP header -type ComponentSignatureRoundTripper struct { - proxied http.RoundTripper - component string +// GetCallerIDOnHTTP returns caller ID from the request header. +func GetCallerIDOnHTTP(r *http.Request) string { + callerID := r.Header.Get(XCallerIDHeader) + if len(callerID) == 0 { + // Fall back to get the component name to keep backward compatibility. + callerID = getComponentNameOnHTTP(r) + } + return callerID +} + +// CallerIDRoundTripper is used to add caller ID in the HTTP header. +type CallerIDRoundTripper struct { + proxied http.RoundTripper + callerID string } -// NewComponentSignatureRoundTripper returns a new ComponentSignatureRoundTripper. -func NewComponentSignatureRoundTripper(roundTripper http.RoundTripper, componentName string) *ComponentSignatureRoundTripper { - return &ComponentSignatureRoundTripper{ - proxied: roundTripper, - component: componentName, +// NewCallerIDRoundTripper returns a new `CallerIDRoundTripper`. +func NewCallerIDRoundTripper(roundTripper http.RoundTripper, callerID string) *CallerIDRoundTripper { + return &CallerIDRoundTripper{ + proxied: roundTripper, + callerID: callerID, } } // RoundTrip is used to implement RoundTripper -func (rt *ComponentSignatureRoundTripper) RoundTrip(req *http.Request) (resp *http.Response, err error) { - req.Header.Add(componentSignatureKey, rt.component) +func (rt *CallerIDRoundTripper) RoundTrip(req *http.Request) (resp *http.Response, err error) { + req.Header.Add(XCallerIDHeader, rt.callerID) // Send the request, get the response and the error resp, err = rt.proxied.RoundTrip(req) return diff --git a/pkg/utils/apiutil/apiutil_test.go b/pkg/utils/apiutil/apiutil_test.go index a4e7b97aa4d..106d3fb21cb 100644 --- a/pkg/utils/apiutil/apiutil_test.go +++ b/pkg/utils/apiutil/apiutil_test.go @@ -101,7 +101,7 @@ func TestGetIPPortFromHTTPRequest(t *testing.T) { ip: "127.0.0.1", port: "5299", }, - // IPv4 "X-Real-IP" with port + // IPv4 "X-Real-Ip" with port { r: &http.Request{ Header: map[string][]string{ @@ -111,7 +111,7 @@ func TestGetIPPortFromHTTPRequest(t *testing.T) { ip: "127.0.0.1", port: "5299", }, - // IPv4 "X-Real-IP" without port + // IPv4 "X-Real-Ip" without port { r: &http.Request{ Header: map[string][]string{ @@ -158,7 +158,7 @@ func TestGetIPPortFromHTTPRequest(t *testing.T) { ip: "::1", port: "", }, - // IPv6 "X-Real-IP" with port + // IPv6 "X-Real-Ip" with port { r: &http.Request{ Header: map[string][]string{ @@ -168,7 +168,7 @@ func TestGetIPPortFromHTTPRequest(t *testing.T) { ip: "::1", port: "5299", }, - // IPv6 "X-Real-IP" without port + // IPv6 "X-Real-Ip" without port { r: &http.Request{ Header: map[string][]string{ diff --git a/pkg/utils/requestutil/context_test.go b/pkg/utils/requestutil/context_test.go index 475b109e410..298fc1ff8a3 100644 --- a/pkg/utils/requestutil/context_test.go +++ b/pkg/utils/requestutil/context_test.go @@ -34,7 +34,7 @@ func TestRequestInfo(t *testing.T) { RequestInfo{ ServiceLabel: "test label", Method: http.MethodPost, - Component: "pdctl", + CallerID: "pdctl", IP: "localhost", URLParam: "{\"id\"=1}", BodyParam: "{\"state\"=\"Up\"}", @@ -45,7 +45,7 @@ func TestRequestInfo(t *testing.T) { re.True(ok) re.Equal("test label", result.ServiceLabel) re.Equal(http.MethodPost, result.Method) - re.Equal("pdctl", result.Component) + re.Equal("pdctl", result.CallerID) re.Equal("localhost", result.IP) re.Equal("{\"id\"=1}", result.URLParam) re.Equal("{\"state\"=\"Up\"}", result.BodyParam) diff --git a/pkg/utils/requestutil/request_info.go b/pkg/utils/requestutil/request_info.go index 40724bb790f..cc5403f7232 100644 --- a/pkg/utils/requestutil/request_info.go +++ b/pkg/utils/requestutil/request_info.go @@ -27,9 +27,11 @@ import ( // RequestInfo holds service information from http.Request type RequestInfo struct { - ServiceLabel string - Method string - Component string + ServiceLabel string + Method string + // CallerID is used to identify the specific source of a HTTP request, it will be marked in + // the PD HTTP client, with granularity that can be refined to a specific functionality within a component. + CallerID string IP string Port string URLParam string @@ -38,8 +40,8 @@ type RequestInfo struct { } func (info *RequestInfo) String() string { - s := fmt.Sprintf("{ServiceLabel:%s, Method:%s, Component:%s, IP:%s, Port:%s, StartTime:%s, URLParam:%s, BodyParam:%s}", - info.ServiceLabel, info.Method, info.Component, info.IP, info.Port, time.Unix(info.StartTimeStamp, 0), info.URLParam, info.BodyParam) + s := fmt.Sprintf("{ServiceLabel:%s, Method:%s, CallerID:%s, IP:%s, Port:%s, StartTime:%s, URLParam:%s, BodyParam:%s}", + info.ServiceLabel, info.Method, info.CallerID, info.IP, info.Port, time.Unix(info.StartTimeStamp, 0), info.URLParam, info.BodyParam) return s } @@ -49,7 +51,7 @@ func GetRequestInfo(r *http.Request) RequestInfo { return RequestInfo{ ServiceLabel: apiutil.GetRouteName(r), Method: fmt.Sprintf("%s/%s:%s", r.Proto, r.Method, r.URL.Path), - Component: apiutil.GetComponentNameOnHTTP(r), + CallerID: apiutil.GetCallerIDOnHTTP(r), IP: ip, Port: port, URLParam: getURLParam(r), diff --git a/server/api/middleware.go b/server/api/middleware.go index 627d7fecc92..4173c37b396 100644 --- a/server/api/middleware.go +++ b/server/api/middleware.go @@ -69,7 +69,7 @@ func (rm *requestInfoMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Reques w.Header().Add("body-param", requestInfo.BodyParam) w.Header().Add("url-param", requestInfo.URLParam) w.Header().Add("method", requestInfo.Method) - w.Header().Add("component", requestInfo.Component) + w.Header().Add("caller-id", requestInfo.CallerID) w.Header().Add("ip", requestInfo.IP) }) diff --git a/server/metrics.go b/server/metrics.go index 2d13d02d564..54c5830dc52 100644 --- a/server/metrics.go +++ b/server/metrics.go @@ -151,7 +151,7 @@ var ( Name: "audit_handling_seconds", Help: "PD server service handling audit", Buckets: prometheus.DefBuckets, - }, []string{"service", "method", "component", "ip"}) + }, []string{"service", "method", "caller_id", "ip"}) serverMaxProcs = prometheus.NewGauge( prometheus.GaugeOpts{ Namespace: "pd", diff --git a/tests/pdctl/global_test.go b/tests/pdctl/global_test.go index 7e57f589249..00d31a384d5 100644 --- a/tests/pdctl/global_test.go +++ b/tests/pdctl/global_test.go @@ -30,18 +30,20 @@ import ( "go.uber.org/zap" ) +const pdControlCallerID = "pd-ctl" + func TestSendAndGetComponent(t *testing.T) { re := require.New(t) handler := func(ctx context.Context, s *server.Server) (http.Handler, apiutil.APIServiceGroup, error) { mux := http.NewServeMux() mux.HandleFunc("/pd/api/v1/health", func(w http.ResponseWriter, r *http.Request) { - component := apiutil.GetComponentNameOnHTTP(r) + callerID := apiutil.GetCallerIDOnHTTP(r) for k := range r.Header { log.Info("header", zap.String("key", k)) } - log.Info("component", zap.String("component", component)) - re.Equal("pdctl", component) - fmt.Fprint(w, component) + log.Info("caller id", zap.String("caller-id", callerID)) + re.Equal(pdControlCallerID, callerID) + fmt.Fprint(w, callerID) }) info := apiutil.APIServiceGroup{ IsCore: true, @@ -65,5 +67,5 @@ func TestSendAndGetComponent(t *testing.T) { args := []string{"-u", pdAddr, "health"} output, err := ExecuteCommand(cmd, args...) re.NoError(err) - re.Equal("pdctl\n", string(output)) + re.Equal(fmt.Sprintf("%s\n", pdControlCallerID), string(output)) } diff --git a/tests/server/api/api_test.go b/tests/server/api/api_test.go index 905fb8ec096..f5db6bb2513 100644 --- a/tests/server/api/api_test.go +++ b/tests/server/api/api_test.go @@ -164,7 +164,7 @@ func (suite *middlewareTestSuite) TestRequestInfoMiddleware() { suite.Equal("{\"force\":[\"true\"]}", resp.Header.Get("url-param")) suite.Equal("{\"testkey\":\"testvalue\"}", resp.Header.Get("body-param")) suite.Equal("HTTP/1.1/POST:/pd/api/v1/debug/pprof/profile", resp.Header.Get("method")) - suite.Equal("anonymous", resp.Header.Get("component")) + suite.Equal("anonymous", resp.Header.Get("caller-id")) suite.Equal("127.0.0.1", resp.Header.Get("ip")) input = map[string]interface{}{ @@ -408,7 +408,7 @@ func (suite *middlewareTestSuite) TestAuditPrometheusBackend() { defer resp.Body.Close() content, _ := io.ReadAll(resp.Body) output := string(content) - suite.Contains(output, "pd_service_audit_handling_seconds_count{component=\"anonymous\",ip=\"127.0.0.1\",method=\"HTTP\",service=\"GetTrend\"} 1") + suite.Contains(output, "pd_service_audit_handling_seconds_count{caller_id=\"anonymous\",ip=\"127.0.0.1\",method=\"HTTP\",service=\"GetTrend\"} 1") // resign to test persist config oldLeaderName := leader.GetServer().Name() @@ -434,7 +434,7 @@ func (suite *middlewareTestSuite) TestAuditPrometheusBackend() { defer resp.Body.Close() content, _ = io.ReadAll(resp.Body) output = string(content) - suite.Contains(output, "pd_service_audit_handling_seconds_count{component=\"anonymous\",ip=\"127.0.0.1\",method=\"HTTP\",service=\"GetTrend\"} 2") + suite.Contains(output, "pd_service_audit_handling_seconds_count{caller_id=\"anonymous\",ip=\"127.0.0.1\",method=\"HTTP\",service=\"GetTrend\"} 2") input = map[string]interface{}{ "enable-audit": "false", @@ -543,7 +543,7 @@ func BenchmarkDoRequestWithoutServiceMiddleware(b *testing.B) { func doTestRequestWithLogAudit(srv *tests.TestServer) { req, _ := http.NewRequest(http.MethodDelete, fmt.Sprintf("%s/pd/api/v1/admin/cache/regions", srv.GetAddr()), http.NoBody) - req.Header.Set("component", "test") + req.Header.Set(apiutil.XCallerIDHeader, "test") resp, _ := dialClient.Do(req) resp.Body.Close() } @@ -551,7 +551,7 @@ func doTestRequestWithLogAudit(srv *tests.TestServer) { func doTestRequestWithPrometheus(srv *tests.TestServer) { timeUnix := time.Now().Unix() - 20 req, _ := http.NewRequest(http.MethodGet, fmt.Sprintf("%s/pd/api/v1/trend?from=%d", srv.GetAddr(), timeUnix), http.NoBody) - req.Header.Set("component", "test") + req.Header.Set(apiutil.XCallerIDHeader, "test") resp, _ := dialClient.Do(req) resp.Body.Close() } diff --git a/tools/pd-ctl/pdctl/command/global.go b/tools/pd-ctl/pdctl/command/global.go index 5d8552da51a..0b1f4b4409a 100644 --- a/tools/pd-ctl/pdctl/command/global.go +++ b/tools/pd-ctl/pdctl/command/global.go @@ -29,14 +29,15 @@ import ( "go.etcd.io/etcd/pkg/transport" ) -var ( - pdControllerComponentName = "pdctl" - dialClient = &http.Client{ - Transport: apiutil.NewComponentSignatureRoundTripper(http.DefaultTransport, pdControllerComponentName), - } - pingPrefix = "pd/api/v1/ping" +const ( + pdControlCallerID = "pd-ctl" + pingPrefix = "pd/api/v1/ping" ) +var dialClient = &http.Client{ + Transport: apiutil.NewCallerIDRoundTripper(http.DefaultTransport, pdControlCallerID), +} + // InitHTTPSClient creates https client with ca file func InitHTTPSClient(caPath, certPath, keyPath string) error { tlsInfo := transport.TLSInfo{ @@ -50,8 +51,8 @@ func InitHTTPSClient(caPath, certPath, keyPath string) error { } dialClient = &http.Client{ - Transport: apiutil.NewComponentSignatureRoundTripper( - &http.Transport{TLSClientConfig: tlsConfig}, pdControllerComponentName), + Transport: apiutil.NewCallerIDRoundTripper( + &http.Transport{TLSClientConfig: tlsConfig}, pdControlCallerID), } return nil From 1eed4948168ac555d07f15574096e54b927fa802 Mon Sep 17 00:00:00 2001 From: Hu# Date: Wed, 13 Dec 2023 14:40:49 +0800 Subject: [PATCH 02/21] client/http: add func for jenkins test (#7516) ref tikv/pd#7300 Signed-off-by: husharp Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- client/http/api.go | 13 ++++ client/http/client.go | 71 +++++++++++++++++++ client/http/types.go | 10 +++ tests/integrations/client/http_client_test.go | 62 +++++++++++++++- 4 files changed, 155 insertions(+), 1 deletion(-) diff --git a/client/http/api.go b/client/http/api.go index f744fd0c395..2153cd286e8 100644 --- a/client/http/api.go +++ b/client/http/api.go @@ -38,6 +38,9 @@ const ( store = "/pd/api/v1/store" Stores = "/pd/api/v1/stores" StatsRegion = "/pd/api/v1/stats/region" + membersPrefix = "/pd/api/v1/members" + leaderPrefix = "/pd/api/v1/leader" + transferLeader = "/pd/api/v1/leader/transfer" // Config Config = "/pd/api/v1/config" ClusterVersion = "/pd/api/v1/config/cluster-version" @@ -124,6 +127,16 @@ func StoreLabelByID(id uint64) string { return fmt.Sprintf("%s/%d/label", store, id) } +// LabelByStoreID returns the path of PD HTTP API to set store label. +func LabelByStoreID(storeID int64) string { + return fmt.Sprintf("%s/%d/label", store, storeID) +} + +// TransferLeaderByID returns the path of PD HTTP API to transfer leader by ID. +func TransferLeaderByID(leaderID string) string { + return fmt.Sprintf("%s/%s", transferLeader, leaderID) +} + // ConfigWithTTLSeconds returns the config API with the TTL seconds parameter. func ConfigWithTTLSeconds(ttlSeconds float64) string { return fmt.Sprintf("%s?ttlSecond=%.0f", Config, ttlSeconds) diff --git a/client/http/client.go b/client/http/client.go index 613ebf33294..958c52489fb 100644 --- a/client/http/client.go +++ b/client/http/client.go @@ -26,6 +26,7 @@ import ( "time" "github.com/pingcap/errors" + "github.com/pingcap/kvproto/pkg/pdpb" "github.com/pingcap/log" "github.com/prometheus/client_golang/prometheus" "go.uber.org/zap" @@ -54,9 +55,16 @@ type Client interface { GetHistoryHotRegions(context.Context, *HistoryHotRegionsRequest) (*HistoryHotRegions, error) GetRegionStatusByKeyRange(context.Context, *KeyRange, bool) (*RegionStats, error) GetStores(context.Context) (*StoresInfo, error) + SetStoreLabels(context.Context, int64, map[string]string) error + GetMembers(context.Context) (*MembersInfo, error) + GetLeader(context.Context) (*pdpb.Member, error) + TransferLeader(context.Context, string) error /* Config-related interfaces */ GetScheduleConfig(context.Context) (map[string]interface{}, error) SetScheduleConfig(context.Context, map[string]interface{}) error + /* Scheduler-related interfaces */ + GetSchedulers(context.Context) ([]string, error) + CreateScheduler(ctx context.Context, name string, storeID uint64) error /* Rule-related interfaces */ GetAllPlacementRuleBundles(context.Context) ([]*GroupBundle, error) GetPlacementRuleBundleByGroup(context.Context, string) (*GroupBundle, error) @@ -458,6 +466,44 @@ func (c *client) GetRegionStatusByKeyRange(ctx context.Context, keyRange *KeyRan return ®ionStats, nil } +// SetStoreLabels sets the labels of a store. +func (c *client) SetStoreLabels(ctx context.Context, storeID int64, storeLabels map[string]string) error { + jsonInput, err := json.Marshal(storeLabels) + if err != nil { + return errors.Trace(err) + } + return c.requestWithRetry(ctx, "SetStoreLabel", LabelByStoreID(storeID), + http.MethodPost, bytes.NewBuffer(jsonInput), nil) +} + +func (c *client) GetMembers(ctx context.Context) (*MembersInfo, error) { + var members MembersInfo + err := c.requestWithRetry(ctx, + "GetMembers", membersPrefix, + http.MethodGet, http.NoBody, &members) + if err != nil { + return nil, err + } + return &members, nil +} + +// GetLeader gets the leader of PD cluster. +func (c *client) GetLeader(ctx context.Context) (*pdpb.Member, error) { + var leader pdpb.Member + err := c.requestWithRetry(ctx, "GetLeader", leaderPrefix, + http.MethodGet, http.NoBody, &leader) + if err != nil { + return nil, err + } + return &leader, nil +} + +// TransferLeader transfers the PD leader. +func (c *client) TransferLeader(ctx context.Context, newLeader string) error { + return c.requestWithRetry(ctx, "TransferLeader", TransferLeaderByID(newLeader), + http.MethodPost, http.NoBody, nil) +} + // GetScheduleConfig gets the schedule configurations. func (c *client) GetScheduleConfig(ctx context.Context) (map[string]interface{}, error) { var config map[string]interface{} @@ -662,6 +708,31 @@ func (c *client) PatchRegionLabelRules(ctx context.Context, labelRulePatch *Labe http.MethodPatch, bytes.NewBuffer(labelRulePatchJSON), nil) } +// GetSchedulers gets the schedulers from PD cluster. +func (c *client) GetSchedulers(ctx context.Context) ([]string, error) { + var schedulers []string + err := c.requestWithRetry(ctx, "GetSchedulers", Schedulers, + http.MethodGet, http.NoBody, &schedulers) + if err != nil { + return nil, err + } + return schedulers, nil +} + +// CreateScheduler creates a scheduler to PD cluster. +func (c *client) CreateScheduler(ctx context.Context, name string, storeID uint64) error { + inputJSON, err := json.Marshal(map[string]interface{}{ + "name": name, + "store_id": storeID, + }) + if err != nil { + return errors.Trace(err) + } + return c.requestWithRetry(ctx, + "CreateScheduler", Schedulers, + http.MethodPost, bytes.NewBuffer(inputJSON), nil) +} + // AccelerateSchedule accelerates the scheduling of the regions within the given key range. // The keys in the key range should be encoded in the hex bytes format (without encoding to the UTF-8 bytes). func (c *client) AccelerateSchedule(ctx context.Context, keyRange *KeyRange) error { diff --git a/client/http/types.go b/client/http/types.go index 1d8db36d100..b05e8e0efba 100644 --- a/client/http/types.go +++ b/client/http/types.go @@ -21,6 +21,7 @@ import ( "time" "github.com/pingcap/kvproto/pkg/encryptionpb" + "github.com/pingcap/kvproto/pkg/pdpb" ) // KeyRange defines a range of keys in bytes. @@ -574,3 +575,12 @@ type LabelRulePatch struct { SetRules []*LabelRule `json:"sets"` DeleteRules []string `json:"deletes"` } + +// MembersInfo is PD members info returned from PD RESTful interface +// type Members map[string][]*pdpb.Member +type MembersInfo struct { + Header *pdpb.ResponseHeader `json:"header,omitempty"` + Members []*pdpb.Member `json:"members,omitempty"` + Leader *pdpb.Member `json:"leader,omitempty"` + EtcdLeader *pdpb.Member `json:"etcd_leader,omitempty"` +} diff --git a/tests/integrations/client/http_client_test.go b/tests/integrations/client/http_client_test.go index 6c636d2a2a1..476b4d2f541 100644 --- a/tests/integrations/client/http_client_test.go +++ b/tests/integrations/client/http_client_test.go @@ -49,7 +49,7 @@ func (suite *httpClientTestSuite) SetupSuite() { re := suite.Require() var err error suite.ctx, suite.cancelFunc = context.WithCancel(context.Background()) - suite.cluster, err = tests.NewTestCluster(suite.ctx, 1) + suite.cluster, err = tests.NewTestCluster(suite.ctx, 2) re.NoError(err) err = suite.cluster.RunInitialServers() re.NoError(err) @@ -384,3 +384,63 @@ func (suite *httpClientTestSuite) TestScheduleConfig() { re.Equal(float64(8), config["leader-schedule-limit"]) re.Equal(float64(2048), config["region-schedule-limit"]) } + +func (suite *httpClientTestSuite) TestSchedulers() { + re := suite.Require() + schedulers, err := suite.client.GetSchedulers(suite.ctx) + re.NoError(err) + re.Len(schedulers, 0) + + err = suite.client.CreateScheduler(suite.ctx, "evict-leader-scheduler", 1) + re.NoError(err) + schedulers, err = suite.client.GetSchedulers(suite.ctx) + re.NoError(err) + re.Len(schedulers, 1) +} + +func (suite *httpClientTestSuite) TestSetStoreLabels() { + re := suite.Require() + resp, err := suite.client.GetStores(suite.ctx) + re.NoError(err) + setStore := resp.Stores[0] + re.Empty(setStore.Store.Labels, nil) + storeLabels := map[string]string{ + "zone": "zone1", + } + err = suite.client.SetStoreLabels(suite.ctx, 1, storeLabels) + re.NoError(err) + + resp, err = suite.client.GetStores(suite.ctx) + re.NoError(err) + for _, store := range resp.Stores { + if store.Store.ID == setStore.Store.ID { + for _, label := range store.Store.Labels { + re.Equal(label.Value, storeLabels[label.Key]) + } + } + } +} + +func (suite *httpClientTestSuite) TestTransferLeader() { + re := suite.Require() + members, err := suite.client.GetMembers(suite.ctx) + re.NoError(err) + re.Len(members.Members, 2) + + oldLeader, err := suite.client.GetLeader(suite.ctx) + re.NoError(err) + + // Transfer leader to another pd + for _, member := range members.Members { + if member.Name != oldLeader.Name { + err = suite.client.TransferLeader(suite.ctx, member.Name) + re.NoError(err) + break + } + } + + newLeader := suite.cluster.WaitLeader() + re.NotEmpty(newLeader) + re.NoError(err) + re.NotEqual(oldLeader.Name, newLeader) +} From 859502b957be228b1be2e2cd6bebeb516c8f38bb Mon Sep 17 00:00:00 2001 From: Ryan Leung Date: Wed, 13 Dec 2023 16:16:49 +0800 Subject: [PATCH 03/21] mcs: fix panic when getting the member which is not started (#7540) close tikv/pd#7539 Signed-off-by: Ryan Leung Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- pkg/mcs/scheduling/server/server.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pkg/mcs/scheduling/server/server.go b/pkg/mcs/scheduling/server/server.go index a43cbbebd86..8ee8b81ae47 100644 --- a/pkg/mcs/scheduling/server/server.go +++ b/pkg/mcs/scheduling/server/server.go @@ -192,6 +192,10 @@ func (s *Server) updateAPIServerMemberLoop() { continue } for _, ep := range members.Members { + if len(ep.GetClientURLs()) == 0 { // This member is not started yet. + log.Info("member is not started yet", zap.String("member-id", fmt.Sprintf("%x", ep.GetID())), errs.ZapError(err)) + continue + } status, err := s.GetClient().Status(ctx, ep.ClientURLs[0]) if err != nil { log.Info("failed to get status of member", zap.String("member-id", fmt.Sprintf("%x", ep.ID)), zap.String("endpoint", ep.ClientURLs[0]), errs.ZapError(err)) From cf718a799fcddda5413e6e8c8df0a14270b3d20e Mon Sep 17 00:00:00 2001 From: Yongbo Jiang Date: Wed, 13 Dec 2023 17:30:19 +0800 Subject: [PATCH 04/21] pkg/ratelimit: refactor for BBR (#7239) ref tikv/pd#7167 Signed-off-by: Cabinfever_B Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- pkg/ratelimit/controller.go | 79 ++++++ pkg/ratelimit/controller_test.go | 426 +++++++++++++++++++++++++++++++ pkg/ratelimit/limiter.go | 170 +++++++----- pkg/ratelimit/limiter_test.go | 164 ++++++------ pkg/ratelimit/option.go | 55 +--- server/api/middleware.go | 4 +- server/grpc_service.go | 38 +-- server/server.go | 12 +- 8 files changed, 723 insertions(+), 225 deletions(-) create mode 100644 pkg/ratelimit/controller.go create mode 100644 pkg/ratelimit/controller_test.go diff --git a/pkg/ratelimit/controller.go b/pkg/ratelimit/controller.go new file mode 100644 index 00000000000..0c95be9b11b --- /dev/null +++ b/pkg/ratelimit/controller.go @@ -0,0 +1,79 @@ +// Copyright 2023 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ratelimit + +import ( + "sync" + + "golang.org/x/time/rate" +) + +var emptyFunc = func() {} + +// Controller is a controller which holds multiple limiters to manage the request rate of different objects. +type Controller struct { + limiters sync.Map + // the label which is in labelAllowList won't be limited, and only inited by hard code. + labelAllowList map[string]struct{} +} + +// NewController returns a global limiter which can be updated in the later. +func NewController() *Controller { + return &Controller{ + labelAllowList: make(map[string]struct{}), + } +} + +// Allow is used to check whether it has enough token. +func (l *Controller) Allow(label string) (DoneFunc, error) { + var ok bool + lim, ok := l.limiters.Load(label) + if ok { + return lim.(*limiter).allow() + } + return emptyFunc, nil +} + +// Update is used to update Ratelimiter with Options +func (l *Controller) Update(label string, opts ...Option) UpdateStatus { + var status UpdateStatus + for _, opt := range opts { + status |= opt(label, l) + } + return status +} + +// GetQPSLimiterStatus returns the status of a given label's QPS limiter. +func (l *Controller) GetQPSLimiterStatus(label string) (limit rate.Limit, burst int) { + if limit, exist := l.limiters.Load(label); exist { + return limit.(*limiter).getQPSLimiterStatus() + } + return 0, 0 +} + +// GetConcurrencyLimiterStatus returns the status of a given label's concurrency limiter. +func (l *Controller) GetConcurrencyLimiterStatus(label string) (limit uint64, current uint64) { + if limit, exist := l.limiters.Load(label); exist { + return limit.(*limiter).getConcurrencyLimiterStatus() + } + return 0, 0 +} + +// IsInAllowList returns whether this label is in allow list. +// If returns true, the given label won't be limited +func (l *Controller) IsInAllowList(label string) bool { + _, allow := l.labelAllowList[label] + return allow +} diff --git a/pkg/ratelimit/controller_test.go b/pkg/ratelimit/controller_test.go new file mode 100644 index 00000000000..a830217cb9f --- /dev/null +++ b/pkg/ratelimit/controller_test.go @@ -0,0 +1,426 @@ +// Copyright 2023 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ratelimit + +import ( + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/tikv/pd/pkg/utils/syncutil" + "golang.org/x/time/rate" +) + +type changeAndResult struct { + opt Option + checkOptionStatus func(string, Option) + totalRequest int + success int + fail int + release int + waitDuration time.Duration + checkStatusFunc func(string) +} + +type labelCase struct { + label string + round []changeAndResult +} + +func runMulitLabelLimiter(t *testing.T, limiter *Controller, testCase []labelCase) { + re := require.New(t) + var caseWG sync.WaitGroup + for _, tempCas := range testCase { + caseWG.Add(1) + cas := tempCas + go func() { + var lock syncutil.Mutex + successCount, failedCount := 0, 0 + var wg sync.WaitGroup + r := &releaseUtil{} + for _, rd := range cas.round { + rd.checkOptionStatus(cas.label, rd.opt) + time.Sleep(rd.waitDuration) + for i := 0; i < rd.totalRequest; i++ { + wg.Add(1) + go func() { + countRateLimiterHandleResult(limiter, cas.label, &successCount, &failedCount, &lock, &wg, r) + }() + } + wg.Wait() + re.Equal(rd.fail, failedCount) + re.Equal(rd.success, successCount) + for i := 0; i < rd.release; i++ { + r.release() + } + rd.checkStatusFunc(cas.label) + failedCount -= rd.fail + successCount -= rd.success + } + caseWG.Done() + }() + } + caseWG.Wait() +} + +func TestControllerWithConcurrencyLimiter(t *testing.T) { + t.Parallel() + re := require.New(t) + limiter := NewController() + testCase := []labelCase{ + { + label: "test1", + round: []changeAndResult{ + { + opt: UpdateConcurrencyLimiter(10), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&ConcurrencyChanged != 0) + }, + totalRequest: 15, + fail: 5, + success: 10, + release: 10, + waitDuration: 0, + checkStatusFunc: func(label string) { + limit, current := limiter.GetConcurrencyLimiterStatus(label) + re.Equal(uint64(10), limit) + re.Equal(uint64(0), current) + }, + }, + { + opt: UpdateConcurrencyLimiter(10), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&ConcurrencyNoChange != 0) + }, + checkStatusFunc: func(label string) {}, + }, + { + opt: UpdateConcurrencyLimiter(5), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&ConcurrencyChanged != 0) + }, + totalRequest: 15, + fail: 10, + success: 5, + release: 5, + waitDuration: 0, + checkStatusFunc: func(label string) { + limit, current := limiter.GetConcurrencyLimiterStatus(label) + re.Equal(uint64(5), limit) + re.Equal(uint64(0), current) + }, + }, + { + opt: UpdateConcurrencyLimiter(0), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&ConcurrencyDeleted != 0) + }, + totalRequest: 15, + fail: 0, + success: 15, + release: 5, + waitDuration: 0, + checkStatusFunc: func(label string) { + limit, current := limiter.GetConcurrencyLimiterStatus(label) + re.Equal(uint64(0), limit) + re.Equal(uint64(0), current) + }, + }, + }, + }, + { + label: "test2", + round: []changeAndResult{ + { + opt: UpdateConcurrencyLimiter(15), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&ConcurrencyChanged != 0) + }, + totalRequest: 10, + fail: 0, + success: 10, + release: 0, + waitDuration: 0, + checkStatusFunc: func(label string) { + limit, current := limiter.GetConcurrencyLimiterStatus(label) + re.Equal(uint64(15), limit) + re.Equal(uint64(10), current) + }, + }, + { + opt: UpdateConcurrencyLimiter(10), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&ConcurrencyChanged != 0) + }, + totalRequest: 10, + fail: 10, + success: 0, + release: 10, + waitDuration: 0, + checkStatusFunc: func(label string) { + limit, current := limiter.GetConcurrencyLimiterStatus(label) + re.Equal(uint64(10), limit) + re.Equal(uint64(0), current) + }, + }, + }, + }, + } + runMulitLabelLimiter(t, limiter, testCase) +} + +func TestBlockList(t *testing.T) { + t.Parallel() + re := require.New(t) + opts := []Option{AddLabelAllowList()} + limiter := NewController() + label := "test" + + re.False(limiter.IsInAllowList(label)) + for _, opt := range opts { + opt(label, limiter) + } + re.True(limiter.IsInAllowList(label)) + + status := UpdateQPSLimiter(float64(rate.Every(time.Second)), 1)(label, limiter) + re.True(status&InAllowList != 0) + for i := 0; i < 10; i++ { + _, err := limiter.Allow(label) + re.NoError(err) + } +} + +func TestControllerWithQPSLimiter(t *testing.T) { + t.Parallel() + re := require.New(t) + limiter := NewController() + testCase := []labelCase{ + { + label: "test1", + round: []changeAndResult{ + { + opt: UpdateQPSLimiter(float64(rate.Every(time.Second)), 1), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&QPSChanged != 0) + }, + totalRequest: 3, + fail: 2, + success: 1, + waitDuration: 0, + checkStatusFunc: func(label string) { + limit, burst := limiter.GetQPSLimiterStatus(label) + re.Equal(rate.Limit(1), limit) + re.Equal(1, burst) + }, + }, + { + opt: UpdateQPSLimiter(float64(rate.Every(time.Second)), 1), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&QPSNoChange != 0) + }, + checkStatusFunc: func(label string) {}, + }, + { + opt: UpdateQPSLimiter(5, 5), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&QPSChanged != 0) + }, + totalRequest: 10, + fail: 5, + success: 5, + waitDuration: time.Second, + checkStatusFunc: func(label string) { + limit, burst := limiter.GetQPSLimiterStatus(label) + re.Equal(rate.Limit(5), limit) + re.Equal(5, burst) + }, + }, + { + opt: UpdateQPSLimiter(0, 0), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&QPSDeleted != 0) + }, + totalRequest: 10, + fail: 0, + success: 10, + release: 0, + waitDuration: 0, + checkStatusFunc: func(label string) { + limit, burst := limiter.GetQPSLimiterStatus(label) + re.Equal(rate.Limit(0), limit) + re.Equal(0, burst) + }, + }, + }, + }, + { + label: "test2", + round: []changeAndResult{ + { + opt: UpdateQPSLimiter(50, 5), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&QPSChanged != 0) + }, + totalRequest: 10, + fail: 5, + success: 5, + waitDuration: time.Second, + checkStatusFunc: func(label string) { + limit, burst := limiter.GetQPSLimiterStatus(label) + re.Equal(rate.Limit(50), limit) + re.Equal(5, burst) + }, + }, + { + opt: UpdateQPSLimiter(0, 0), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&QPSDeleted != 0) + }, + totalRequest: 10, + fail: 0, + success: 10, + release: 0, + waitDuration: 0, + checkStatusFunc: func(label string) { + limit, burst := limiter.GetQPSLimiterStatus(label) + re.Equal(rate.Limit(0), limit) + re.Equal(0, burst) + }, + }, + }, + }, + } + runMulitLabelLimiter(t, limiter, testCase) +} + +func TestControllerWithTwoLimiters(t *testing.T) { + t.Parallel() + re := require.New(t) + limiter := NewController() + testCase := []labelCase{ + { + label: "test1", + round: []changeAndResult{ + { + opt: UpdateDimensionConfig(&DimensionConfig{ + QPS: 100, + QPSBurst: 100, + ConcurrencyLimit: 100, + }), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&QPSChanged != 0) + }, + totalRequest: 200, + fail: 100, + success: 100, + release: 100, + waitDuration: time.Second, + checkStatusFunc: func(label string) { + limit, burst := limiter.GetQPSLimiterStatus(label) + re.Equal(rate.Limit(100), limit) + re.Equal(100, burst) + climit, current := limiter.GetConcurrencyLimiterStatus(label) + re.Equal(uint64(100), climit) + re.Equal(uint64(0), current) + }, + }, + { + opt: UpdateQPSLimiter(float64(rate.Every(time.Second)), 1), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&QPSChanged != 0) + }, + totalRequest: 200, + fail: 199, + success: 1, + release: 0, + waitDuration: time.Second, + checkStatusFunc: func(label string) { + limit, current := limiter.GetConcurrencyLimiterStatus(label) + re.Equal(uint64(100), limit) + re.Equal(uint64(1), current) + }, + }, + }, + }, + { + label: "test2", + round: []changeAndResult{ + { + opt: UpdateQPSLimiter(50, 5), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&QPSChanged != 0) + }, + totalRequest: 10, + fail: 5, + success: 5, + waitDuration: time.Second, + checkStatusFunc: func(label string) { + limit, burst := limiter.GetQPSLimiterStatus(label) + re.Equal(rate.Limit(50), limit) + re.Equal(5, burst) + }, + }, + { + opt: UpdateQPSLimiter(0, 0), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&QPSDeleted != 0) + }, + totalRequest: 10, + fail: 0, + success: 10, + release: 0, + waitDuration: 0, + checkStatusFunc: func(label string) { + limit, burst := limiter.GetQPSLimiterStatus(label) + re.Equal(rate.Limit(0), limit) + re.Equal(0, burst) + }, + }, + }, + }, + } + runMulitLabelLimiter(t, limiter, testCase) +} + +func countRateLimiterHandleResult(limiter *Controller, label string, successCount *int, + failedCount *int, lock *syncutil.Mutex, wg *sync.WaitGroup, r *releaseUtil) { + doneFucn, err := limiter.Allow(label) + lock.Lock() + defer lock.Unlock() + if err == nil { + *successCount++ + r.append(doneFucn) + } else { + *failedCount++ + } + wg.Done() +} diff --git a/pkg/ratelimit/limiter.go b/pkg/ratelimit/limiter.go index 4bf930ed6c5..444b5aa2481 100644 --- a/pkg/ratelimit/limiter.go +++ b/pkg/ratelimit/limiter.go @@ -1,4 +1,4 @@ -// Copyright 2022 TiKV Project Authors. +// Copyright 2023 TiKV Project Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,11 +15,16 @@ package ratelimit import ( - "sync" + "math" + "github.com/tikv/pd/pkg/errs" + "github.com/tikv/pd/pkg/utils/syncutil" "golang.org/x/time/rate" ) +// DoneFunc is done function. +type DoneFunc func() + // DimensionConfig is the limit dimension config of one label type DimensionConfig struct { // qps conifg @@ -29,92 +34,125 @@ type DimensionConfig struct { ConcurrencyLimit uint64 } -// Limiter is a controller for the request rate. -type Limiter struct { - qpsLimiter sync.Map - concurrencyLimiter sync.Map - // the label which is in labelAllowList won't be limited - labelAllowList map[string]struct{} +type limiter struct { + mu syncutil.RWMutex + concurrency *concurrencyLimiter + rate *RateLimiter } -// NewLimiter returns a global limiter which can be updated in the later. -func NewLimiter() *Limiter { - return &Limiter{ - labelAllowList: make(map[string]struct{}), - } +func newLimiter() *limiter { + lim := &limiter{} + return lim } -// Allow is used to check whether it has enough token. -func (l *Limiter) Allow(label string) bool { - var cl *concurrencyLimiter - var ok bool - if limiter, exist := l.concurrencyLimiter.Load(label); exist { - if cl, ok = limiter.(*concurrencyLimiter); ok && !cl.allow() { - return false - } - } +func (l *limiter) getConcurrencyLimiter() *concurrencyLimiter { + l.mu.RLock() + defer l.mu.RUnlock() + return l.concurrency +} - if limiter, exist := l.qpsLimiter.Load(label); exist { - if ql, ok := limiter.(*RateLimiter); ok && !ql.Allow() { - if cl != nil { - cl.release() - } - return false - } - } +func (l *limiter) getRateLimiter() *RateLimiter { + l.mu.RLock() + defer l.mu.RUnlock() + return l.rate +} - return true +func (l *limiter) deleteRateLimiter() bool { + l.mu.Lock() + defer l.mu.Unlock() + l.rate = nil + return l.isEmpty() } -// Release is used to refill token. It may be not uesful for some limiters because they will refill automatically -func (l *Limiter) Release(label string) { - if limiter, exist := l.concurrencyLimiter.Load(label); exist { - if cl, ok := limiter.(*concurrencyLimiter); ok { - cl.release() - } - } +func (l *limiter) deleteConcurrency() bool { + l.mu.Lock() + defer l.mu.Unlock() + l.concurrency = nil + return l.isEmpty() } -// Update is used to update Ratelimiter with Options -func (l *Limiter) Update(label string, opts ...Option) UpdateStatus { - var status UpdateStatus - for _, opt := range opts { - status |= opt(label, l) - } - return status +func (l *limiter) isEmpty() bool { + return l.concurrency == nil && l.rate == nil } -// GetQPSLimiterStatus returns the status of a given label's QPS limiter. -func (l *Limiter) GetQPSLimiterStatus(label string) (limit rate.Limit, burst int) { - if limiter, exist := l.qpsLimiter.Load(label); exist { - return limiter.(*RateLimiter).Limit(), limiter.(*RateLimiter).Burst() +func (l *limiter) getQPSLimiterStatus() (limit rate.Limit, burst int) { + baseLimiter := l.getRateLimiter() + if baseLimiter != nil { + return baseLimiter.Limit(), baseLimiter.Burst() } - return 0, 0 } -// QPSUnlimit deletes QPS limiter of the given label -func (l *Limiter) QPSUnlimit(label string) { - l.qpsLimiter.Delete(label) +func (l *limiter) getConcurrencyLimiterStatus() (limit uint64, current uint64) { + baseLimiter := l.getConcurrencyLimiter() + if baseLimiter != nil { + return baseLimiter.getLimit(), baseLimiter.getCurrent() + } + return 0, 0 } -// GetConcurrencyLimiterStatus returns the status of a given label's concurrency limiter. -func (l *Limiter) GetConcurrencyLimiterStatus(label string) (limit uint64, current uint64) { - if limiter, exist := l.concurrencyLimiter.Load(label); exist { - return limiter.(*concurrencyLimiter).getLimit(), limiter.(*concurrencyLimiter).getCurrent() +func (l *limiter) updateConcurrencyConfig(limit uint64) UpdateStatus { + oldConcurrencyLimit, _ := l.getConcurrencyLimiterStatus() + if oldConcurrencyLimit == limit { + return ConcurrencyNoChange + } + if limit < 1 { + l.deleteConcurrency() + return ConcurrencyDeleted } - return 0, 0 + l.mu.Lock() + defer l.mu.Unlock() + if l.concurrency != nil { + l.concurrency.setLimit(limit) + } else { + l.concurrency = newConcurrencyLimiter(limit) + } + return ConcurrencyChanged +} + +func (l *limiter) updateQPSConfig(limit float64, burst int) UpdateStatus { + oldQPSLimit, oldBurst := l.getQPSLimiterStatus() + if math.Abs(float64(oldQPSLimit)-limit) < eps && oldBurst == burst { + return QPSNoChange + } + if limit <= eps || burst < 1 { + l.deleteRateLimiter() + return QPSDeleted + } + l.mu.Lock() + defer l.mu.Unlock() + if l.rate != nil { + l.rate.SetLimit(rate.Limit(limit)) + l.rate.SetBurst(burst) + } else { + l.rate = NewRateLimiter(limit, burst) + } + return QPSChanged } -// ConcurrencyUnlimit deletes concurrency limiter of the given label -func (l *Limiter) ConcurrencyUnlimit(label string) { - l.concurrencyLimiter.Delete(label) +func (l *limiter) updateDimensionConfig(cfg *DimensionConfig) UpdateStatus { + status := l.updateQPSConfig(cfg.QPS, cfg.QPSBurst) + status |= l.updateConcurrencyConfig(cfg.ConcurrencyLimit) + return status } -// IsInAllowList returns whether this label is in allow list. -// If returns true, the given label won't be limited -func (l *Limiter) IsInAllowList(label string) bool { - _, allow := l.labelAllowList[label] - return allow +func (l *limiter) allow() (DoneFunc, error) { + concurrency := l.getConcurrencyLimiter() + if concurrency != nil && !concurrency.allow() { + return nil, errs.ErrRateLimitExceeded + } + + rate := l.getRateLimiter() + if rate != nil && !rate.Allow() { + if concurrency != nil { + concurrency.release() + } + return nil, errs.ErrRateLimitExceeded + } + return func() { + if concurrency != nil { + concurrency.release() + } + }, nil } diff --git a/pkg/ratelimit/limiter_test.go b/pkg/ratelimit/limiter_test.go index d5d9829816a..8834495f3e9 100644 --- a/pkg/ratelimit/limiter_test.go +++ b/pkg/ratelimit/limiter_test.go @@ -1,4 +1,4 @@ -// Copyright 2022 TiKV Project Authors. +// Copyright 2023 TiKV Project Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -24,162 +24,145 @@ import ( "golang.org/x/time/rate" ) -func TestUpdateConcurrencyLimiter(t *testing.T) { +type releaseUtil struct { + dones []DoneFunc +} + +func (r *releaseUtil) release() { + if len(r.dones) > 0 { + r.dones[0]() + r.dones = r.dones[1:] + } +} + +func (r *releaseUtil) append(d DoneFunc) { + r.dones = append(r.dones, d) +} + +func TestWithConcurrencyLimiter(t *testing.T) { t.Parallel() re := require.New(t) - opts := []Option{UpdateConcurrencyLimiter(10)} - limiter := NewLimiter() - - label := "test" - status := limiter.Update(label, opts...) + limiter := newLimiter() + status := limiter.updateConcurrencyConfig(10) re.True(status&ConcurrencyChanged != 0) var lock syncutil.Mutex successCount, failedCount := 0, 0 var wg sync.WaitGroup + r := &releaseUtil{} for i := 0; i < 15; i++ { wg.Add(1) go func() { - countRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) + countSingleLimiterHandleResult(limiter, &successCount, &failedCount, &lock, &wg, r) }() } wg.Wait() re.Equal(5, failedCount) re.Equal(10, successCount) for i := 0; i < 10; i++ { - limiter.Release(label) + r.release() } - limit, current := limiter.GetConcurrencyLimiterStatus(label) + limit, current := limiter.getConcurrencyLimiterStatus() re.Equal(uint64(10), limit) re.Equal(uint64(0), current) - status = limiter.Update(label, UpdateConcurrencyLimiter(10)) + status = limiter.updateConcurrencyConfig(10) re.True(status&ConcurrencyNoChange != 0) - status = limiter.Update(label, UpdateConcurrencyLimiter(5)) + status = limiter.updateConcurrencyConfig(5) re.True(status&ConcurrencyChanged != 0) failedCount = 0 successCount = 0 for i := 0; i < 15; i++ { wg.Add(1) - go countRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) + go countSingleLimiterHandleResult(limiter, &successCount, &failedCount, &lock, &wg, r) } wg.Wait() re.Equal(10, failedCount) re.Equal(5, successCount) for i := 0; i < 5; i++ { - limiter.Release(label) + r.release() } - status = limiter.Update(label, UpdateConcurrencyLimiter(0)) + status = limiter.updateConcurrencyConfig(0) re.True(status&ConcurrencyDeleted != 0) failedCount = 0 successCount = 0 for i := 0; i < 15; i++ { wg.Add(1) - go countRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) + go countSingleLimiterHandleResult(limiter, &successCount, &failedCount, &lock, &wg, r) } wg.Wait() re.Equal(0, failedCount) re.Equal(15, successCount) - limit, current = limiter.GetConcurrencyLimiterStatus(label) + limit, current = limiter.getConcurrencyLimiterStatus() re.Equal(uint64(0), limit) re.Equal(uint64(0), current) } -func TestBlockList(t *testing.T) { +func TestWithQPSLimiter(t *testing.T) { t.Parallel() re := require.New(t) - opts := []Option{AddLabelAllowList()} - limiter := NewLimiter() - label := "test" - - re.False(limiter.IsInAllowList(label)) - for _, opt := range opts { - opt(label, limiter) - } - re.True(limiter.IsInAllowList(label)) - - status := UpdateQPSLimiter(float64(rate.Every(time.Second)), 1)(label, limiter) - re.True(status&InAllowList != 0) - for i := 0; i < 10; i++ { - re.True(limiter.Allow(label)) - } -} - -func TestUpdateQPSLimiter(t *testing.T) { - t.Parallel() - re := require.New(t) - opts := []Option{UpdateQPSLimiter(float64(rate.Every(time.Second)), 1)} - limiter := NewLimiter() - - label := "test" - status := limiter.Update(label, opts...) + limiter := newLimiter() + status := limiter.updateQPSConfig(float64(rate.Every(time.Second)), 1) re.True(status&QPSChanged != 0) var lock syncutil.Mutex successCount, failedCount := 0, 0 var wg sync.WaitGroup + r := &releaseUtil{} wg.Add(3) for i := 0; i < 3; i++ { - go countRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) + go countSingleLimiterHandleResult(limiter, &successCount, &failedCount, &lock, &wg, r) } wg.Wait() re.Equal(2, failedCount) re.Equal(1, successCount) - limit, burst := limiter.GetQPSLimiterStatus(label) + limit, burst := limiter.getQPSLimiterStatus() re.Equal(rate.Limit(1), limit) re.Equal(1, burst) - status = limiter.Update(label, UpdateQPSLimiter(float64(rate.Every(time.Second)), 1)) + status = limiter.updateQPSConfig(float64(rate.Every(time.Second)), 1) re.True(status&QPSNoChange != 0) - status = limiter.Update(label, UpdateQPSLimiter(5, 5)) + status = limiter.updateQPSConfig(5, 5) re.True(status&QPSChanged != 0) - limit, burst = limiter.GetQPSLimiterStatus(label) + limit, burst = limiter.getQPSLimiterStatus() re.Equal(rate.Limit(5), limit) re.Equal(5, burst) time.Sleep(time.Second) for i := 0; i < 10; i++ { if i < 5 { - re.True(limiter.Allow(label)) + _, err := limiter.allow() + re.NoError(err) } else { - re.False(limiter.Allow(label)) + _, err := limiter.allow() + re.Error(err) } } time.Sleep(time.Second) - status = limiter.Update(label, UpdateQPSLimiter(0, 0)) + status = limiter.updateQPSConfig(0, 0) re.True(status&QPSDeleted != 0) for i := 0; i < 10; i++ { - re.True(limiter.Allow(label)) + _, err := limiter.allow() + re.NoError(err) } - qLimit, qCurrent := limiter.GetQPSLimiterStatus(label) + qLimit, qCurrent := limiter.getQPSLimiterStatus() re.Equal(rate.Limit(0), qLimit) re.Equal(0, qCurrent) -} -func TestQPSLimiter(t *testing.T) { - t.Parallel() - re := require.New(t) - opts := []Option{UpdateQPSLimiter(float64(rate.Every(3*time.Second)), 100)} - limiter := NewLimiter() - - label := "test" - for _, opt := range opts { - opt(label, limiter) - } - - var lock syncutil.Mutex - successCount, failedCount := 0, 0 - var wg sync.WaitGroup + successCount = 0 + failedCount = 0 + status = limiter.updateQPSConfig(float64(rate.Every(3*time.Second)), 100) + re.True(status&QPSChanged != 0) wg.Add(200) for i := 0; i < 200; i++ { - go countRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) + go countSingleLimiterHandleResult(limiter, &successCount, &failedCount, &lock, &wg, r) } wg.Wait() re.Equal(200, failedCount+successCount) @@ -188,12 +171,12 @@ func TestQPSLimiter(t *testing.T) { time.Sleep(4 * time.Second) // 3+1 wg.Add(1) - countRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) + countSingleLimiterHandleResult(limiter, &successCount, &failedCount, &lock, &wg, r) wg.Wait() re.Equal(101, successCount) } -func TestTwoLimiters(t *testing.T) { +func TestWithTwoLimiters(t *testing.T) { t.Parallel() re := require.New(t) cfg := &DimensionConfig{ @@ -201,20 +184,18 @@ func TestTwoLimiters(t *testing.T) { QPSBurst: 100, ConcurrencyLimit: 100, } - opts := []Option{UpdateDimensionConfig(cfg)} - limiter := NewLimiter() - - label := "test" - for _, opt := range opts { - opt(label, limiter) - } + limiter := newLimiter() + status := limiter.updateDimensionConfig(cfg) + re.True(status&QPSChanged != 0) + re.True(status&ConcurrencyChanged != 0) var lock syncutil.Mutex successCount, failedCount := 0, 0 var wg sync.WaitGroup + r := &releaseUtil{} wg.Add(200) for i := 0; i < 200; i++ { - go countRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) + go countSingleLimiterHandleResult(limiter, &successCount, &failedCount, &lock, &wg, r) } wg.Wait() re.Equal(100, failedCount) @@ -223,35 +204,42 @@ func TestTwoLimiters(t *testing.T) { wg.Add(100) for i := 0; i < 100; i++ { - go countRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) + go countSingleLimiterHandleResult(limiter, &successCount, &failedCount, &lock, &wg, r) } wg.Wait() re.Equal(200, failedCount) re.Equal(100, successCount) for i := 0; i < 100; i++ { - limiter.Release(label) + r.release() } - limiter.Update(label, UpdateQPSLimiter(float64(rate.Every(10*time.Second)), 1)) + status = limiter.updateQPSConfig(float64(rate.Every(10*time.Second)), 1) + re.True(status&QPSChanged != 0) wg.Add(100) for i := 0; i < 100; i++ { - go countRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) + go countSingleLimiterHandleResult(limiter, &successCount, &failedCount, &lock, &wg, r) } wg.Wait() re.Equal(101, successCount) re.Equal(299, failedCount) - limit, current := limiter.GetConcurrencyLimiterStatus(label) + limit, current := limiter.getConcurrencyLimiterStatus() re.Equal(uint64(100), limit) re.Equal(uint64(1), current) + + cfg = &DimensionConfig{} + status = limiter.updateDimensionConfig(cfg) + re.True(status&ConcurrencyDeleted != 0) + re.True(status&QPSDeleted != 0) } -func countRateLimiterHandleResult(limiter *Limiter, label string, successCount *int, - failedCount *int, lock *syncutil.Mutex, wg *sync.WaitGroup) { - result := limiter.Allow(label) +func countSingleLimiterHandleResult(limiter *limiter, successCount *int, + failedCount *int, lock *syncutil.Mutex, wg *sync.WaitGroup, r *releaseUtil) { + doneFucn, err := limiter.allow() lock.Lock() defer lock.Unlock() - if result { + if err == nil { *successCount++ + r.append(doneFucn) } else { *failedCount++ } diff --git a/pkg/ratelimit/option.go b/pkg/ratelimit/option.go index 53afb9926d4..b1cc459d786 100644 --- a/pkg/ratelimit/option.go +++ b/pkg/ratelimit/option.go @@ -14,8 +14,6 @@ package ratelimit -import "golang.org/x/time/rate" - // UpdateStatus is flags for updating limiter config. type UpdateStatus uint32 @@ -40,77 +38,46 @@ const ( // Option is used to create a limiter with the optional settings. // these setting is used to add a kind of limiter for a service -type Option func(string, *Limiter) UpdateStatus +type Option func(string, *Controller) UpdateStatus // AddLabelAllowList adds a label into allow list. // It means the given label will not be limited func AddLabelAllowList() Option { - return func(label string, l *Limiter) UpdateStatus { + return func(label string, l *Controller) UpdateStatus { l.labelAllowList[label] = struct{}{} return 0 } } -func updateConcurrencyConfig(l *Limiter, label string, limit uint64) UpdateStatus { - oldConcurrencyLimit, _ := l.GetConcurrencyLimiterStatus(label) - if oldConcurrencyLimit == limit { - return ConcurrencyNoChange - } - if limit < 1 { - l.ConcurrencyUnlimit(label) - return ConcurrencyDeleted - } - if limiter, exist := l.concurrencyLimiter.LoadOrStore(label, newConcurrencyLimiter(limit)); exist { - limiter.(*concurrencyLimiter).setLimit(limit) - } - return ConcurrencyChanged -} - -func updateQPSConfig(l *Limiter, label string, limit float64, burst int) UpdateStatus { - oldQPSLimit, oldBurst := l.GetQPSLimiterStatus(label) - - if (float64(oldQPSLimit)-limit < eps && float64(oldQPSLimit)-limit > -eps) && oldBurst == burst { - return QPSNoChange - } - if limit <= eps || burst < 1 { - l.QPSUnlimit(label) - return QPSDeleted - } - if limiter, exist := l.qpsLimiter.LoadOrStore(label, NewRateLimiter(limit, burst)); exist { - limiter.(*RateLimiter).SetLimit(rate.Limit(limit)) - limiter.(*RateLimiter).SetBurst(burst) - } - return QPSChanged -} - // UpdateConcurrencyLimiter creates a concurrency limiter for a given label if it doesn't exist. func UpdateConcurrencyLimiter(limit uint64) Option { - return func(label string, l *Limiter) UpdateStatus { + return func(label string, l *Controller) UpdateStatus { if _, allow := l.labelAllowList[label]; allow { return InAllowList } - return updateConcurrencyConfig(l, label, limit) + lim, _ := l.limiters.LoadOrStore(label, newLimiter()) + return lim.(*limiter).updateConcurrencyConfig(limit) } } // UpdateQPSLimiter creates a QPS limiter for a given label if it doesn't exist. func UpdateQPSLimiter(limit float64, burst int) Option { - return func(label string, l *Limiter) UpdateStatus { + return func(label string, l *Controller) UpdateStatus { if _, allow := l.labelAllowList[label]; allow { return InAllowList } - return updateQPSConfig(l, label, limit, burst) + lim, _ := l.limiters.LoadOrStore(label, newLimiter()) + return lim.(*limiter).updateQPSConfig(limit, burst) } } // UpdateDimensionConfig creates QPS limiter and concurrency limiter for a given label by config if it doesn't exist. func UpdateDimensionConfig(cfg *DimensionConfig) Option { - return func(label string, l *Limiter) UpdateStatus { + return func(label string, l *Controller) UpdateStatus { if _, allow := l.labelAllowList[label]; allow { return InAllowList } - status := updateQPSConfig(l, label, cfg.QPS, cfg.QPSBurst) - status |= updateConcurrencyConfig(l, label, cfg.ConcurrencyLimit) - return status + lim, _ := l.limiters.LoadOrStore(label, newLimiter()) + return lim.(*limiter).updateDimensionConfig(cfg) } } diff --git a/server/api/middleware.go b/server/api/middleware.go index 4173c37b396..6536935592f 100644 --- a/server/api/middleware.go +++ b/server/api/middleware.go @@ -177,8 +177,8 @@ func (s *rateLimitMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request, // There is no need to check whether rateLimiter is nil. CreateServer ensures that it is created rateLimiter := s.svr.GetServiceRateLimiter() - if rateLimiter.Allow(requestInfo.ServiceLabel) { - defer rateLimiter.Release(requestInfo.ServiceLabel) + if done, err := rateLimiter.Allow(requestInfo.ServiceLabel); err == nil { + defer done() next(w, r) } else { http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests) diff --git a/server/grpc_service.go b/server/grpc_service.go index fa74f1ea8b6..24280f46437 100644 --- a/server/grpc_service.go +++ b/server/grpc_service.go @@ -431,11 +431,11 @@ func (s *GrpcServer) GetMembers(context.Context, *pdpb.GetMembersRequest) (*pdpb if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { fName := currentFunction() limiter := s.GetGRPCRateLimiter() - if limiter.Allow(fName) { - defer limiter.Release(fName) + if done, err := limiter.Allow(fName); err == nil { + defer done() } else { return &pdpb.GetMembersResponse{ - Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, errs.ErrRateLimitExceeded.FastGenByArgs().Error()), + Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), }, nil } } @@ -662,11 +662,11 @@ func (s *GrpcServer) GetStore(ctx context.Context, request *pdpb.GetStoreRequest if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { fName := currentFunction() limiter := s.GetGRPCRateLimiter() - if limiter.Allow(fName) { - defer limiter.Release(fName) + if done, err := limiter.Allow(fName); err == nil { + defer done() } else { return &pdpb.GetStoreResponse{ - Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, errs.ErrRateLimitExceeded.FastGenByArgs().Error()), + Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), }, nil } } @@ -765,11 +765,11 @@ func (s *GrpcServer) GetAllStores(ctx context.Context, request *pdpb.GetAllStore if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { fName := currentFunction() limiter := s.GetGRPCRateLimiter() - if limiter.Allow(fName) { - defer limiter.Release(fName) + if done, err := limiter.Allow(fName); err == nil { + defer done() } else { return &pdpb.GetAllStoresResponse{ - Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, errs.ErrRateLimitExceeded.FastGenByArgs().Error()), + Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), }, nil } } @@ -810,8 +810,8 @@ func (s *GrpcServer) StoreHeartbeat(ctx context.Context, request *pdpb.StoreHear if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { fName := currentFunction() limiter := s.GetGRPCRateLimiter() - if limiter.Allow(fName) { - defer limiter.Release(fName) + if done, err := limiter.Allow(fName); err == nil { + defer done() } else { return &pdpb.StoreHeartbeatResponse{ Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, errs.ErrRateLimitExceeded.FastGenByArgs().Error()), @@ -1286,8 +1286,8 @@ func (s *GrpcServer) GetRegion(ctx context.Context, request *pdpb.GetRegionReque if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { fName := currentFunction() limiter := s.GetGRPCRateLimiter() - if limiter.Allow(fName) { - defer limiter.Release(fName) + if done, err := limiter.Allow(fName); err == nil { + defer done() } else { return &pdpb.GetRegionResponse{ Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, errs.ErrRateLimitExceeded.FastGenByArgs().Error()), @@ -1330,8 +1330,8 @@ func (s *GrpcServer) GetPrevRegion(ctx context.Context, request *pdpb.GetRegionR if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { fName := currentFunction() limiter := s.GetGRPCRateLimiter() - if limiter.Allow(fName) { - defer limiter.Release(fName) + if done, err := limiter.Allow(fName); err == nil { + defer done() } else { return &pdpb.GetRegionResponse{ Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, errs.ErrRateLimitExceeded.FastGenByArgs().Error()), @@ -1375,8 +1375,8 @@ func (s *GrpcServer) GetRegionByID(ctx context.Context, request *pdpb.GetRegionB if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { fName := currentFunction() limiter := s.GetGRPCRateLimiter() - if limiter.Allow(fName) { - defer limiter.Release(fName) + if done, err := limiter.Allow(fName); err == nil { + defer done() } else { return &pdpb.GetRegionResponse{ Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, errs.ErrRateLimitExceeded.FastGenByArgs().Error()), @@ -1419,8 +1419,8 @@ func (s *GrpcServer) ScanRegions(ctx context.Context, request *pdpb.ScanRegionsR if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { fName := currentFunction() limiter := s.GetGRPCRateLimiter() - if limiter.Allow(fName) { - defer limiter.Release(fName) + if done, err := limiter.Allow(fName); err == nil { + defer done() } else { return &pdpb.ScanRegionsResponse{ Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, errs.ErrRateLimitExceeded.FastGenByArgs().Error()), diff --git a/server/server.go b/server/server.go index c815e7d50c6..187c30dbf7a 100644 --- a/server/server.go +++ b/server/server.go @@ -216,11 +216,11 @@ type Server struct { // related data structures defined in the PD grpc service pdProtoFactory *tsoutil.PDProtoFactory - serviceRateLimiter *ratelimit.Limiter + serviceRateLimiter *ratelimit.Controller serviceLabels map[string][]apiutil.AccessPath apiServiceLabelMap map[apiutil.AccessPath]string - grpcServiceRateLimiter *ratelimit.Limiter + grpcServiceRateLimiter *ratelimit.Controller grpcServiceLabels map[string]struct{} grpcServer *grpc.Server @@ -273,8 +273,8 @@ func CreateServer(ctx context.Context, cfg *config.Config, services []string, le audit.NewLocalLogBackend(true), audit.NewPrometheusHistogramBackend(serviceAuditHistogram, false), } - s.serviceRateLimiter = ratelimit.NewLimiter() - s.grpcServiceRateLimiter = ratelimit.NewLimiter() + s.serviceRateLimiter = ratelimit.NewController() + s.grpcServiceRateLimiter = ratelimit.NewController() s.serviceAuditBackendLabels = make(map[string]*audit.BackendLabels) s.serviceLabels = make(map[string][]apiutil.AccessPath) s.grpcServiceLabels = make(map[string]struct{}) @@ -1467,7 +1467,7 @@ func (s *Server) SetServiceAuditBackendLabels(serviceLabel string, labels []stri } // GetServiceRateLimiter is used to get rate limiter -func (s *Server) GetServiceRateLimiter() *ratelimit.Limiter { +func (s *Server) GetServiceRateLimiter() *ratelimit.Controller { return s.serviceRateLimiter } @@ -1482,7 +1482,7 @@ func (s *Server) UpdateServiceRateLimiter(serviceLabel string, opts ...ratelimit } // GetGRPCRateLimiter is used to get rate limiter -func (s *Server) GetGRPCRateLimiter() *ratelimit.Limiter { +func (s *Server) GetGRPCRateLimiter() *ratelimit.Controller { return s.grpcServiceRateLimiter } From e26a4f7292280345d5813d7a1990a77a8785c0d6 Mon Sep 17 00:00:00 2001 From: lance6716 Date: Wed, 13 Dec 2023 17:46:48 +0800 Subject: [PATCH 05/21] client/http: add more API for lightning's usage, and don't use body io.Reader (#7534) ref tikv/pd#7300 Signed-off-by: lance6716 Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- client/http/client.go | 124 ++++++++++++------ tests/integrations/client/http_client_test.go | 11 ++ 2 files changed, 95 insertions(+), 40 deletions(-) diff --git a/client/http/client.go b/client/http/client.go index 958c52489fb..d74c77571d6 100644 --- a/client/http/client.go +++ b/client/http/client.go @@ -55,6 +55,7 @@ type Client interface { GetHistoryHotRegions(context.Context, *HistoryHotRegionsRequest) (*HistoryHotRegions, error) GetRegionStatusByKeyRange(context.Context, *KeyRange, bool) (*RegionStats, error) GetStores(context.Context) (*StoresInfo, error) + GetStore(context.Context, uint64) (*StoreInfo, error) SetStoreLabels(context.Context, int64, map[string]string) error GetMembers(context.Context) (*MembersInfo, error) GetLeader(context.Context) (*pdpb.Member, error) @@ -62,9 +63,11 @@ type Client interface { /* Config-related interfaces */ GetScheduleConfig(context.Context) (map[string]interface{}, error) SetScheduleConfig(context.Context, map[string]interface{}) error + GetClusterVersion(context.Context) (string, error) /* Scheduler-related interfaces */ GetSchedulers(context.Context) ([]string, error) CreateScheduler(ctx context.Context, name string, storeID uint64) error + SetSchedulerDelay(context.Context, string, int64) error /* Rule-related interfaces */ GetAllPlacementRuleBundles(context.Context) ([]*GroupBundle, error) GetPlacementRuleBundleByGroup(context.Context, string) (*GroupBundle, error) @@ -247,7 +250,7 @@ func WithAllowFollowerHandle() HeaderOption { func (c *client) requestWithRetry( ctx context.Context, name, uri, method string, - body io.Reader, res interface{}, + body []byte, res interface{}, headerOpts ...HeaderOption, ) error { var ( @@ -269,7 +272,7 @@ func (c *client) requestWithRetry( func (c *client) request( ctx context.Context, name, url, method string, - body io.Reader, res interface{}, + body []byte, res interface{}, headerOpts ...HeaderOption, ) error { logFields := []zap.Field{ @@ -279,7 +282,7 @@ func (c *client) request( zap.String("caller-id", c.callerID), } log.Debug("[pd] request the http url", logFields...) - req, err := http.NewRequestWithContext(ctx, method, url, body) + req, err := http.NewRequestWithContext(ctx, method, url, bytes.NewBuffer(body)) if err != nil { log.Error("[pd] create http request failed", append(logFields, zap.Error(err))...) return errors.Trace(err) @@ -341,7 +344,7 @@ func (c *client) GetRegionByID(ctx context.Context, regionID uint64) (*RegionInf var region RegionInfo err := c.requestWithRetry(ctx, "GetRegionByID", RegionByID(regionID), - http.MethodGet, http.NoBody, ®ion) + http.MethodGet, nil, ®ion) if err != nil { return nil, err } @@ -353,7 +356,7 @@ func (c *client) GetRegionByKey(ctx context.Context, key []byte) (*RegionInfo, e var region RegionInfo err := c.requestWithRetry(ctx, "GetRegionByKey", RegionByKey(key), - http.MethodGet, http.NoBody, ®ion) + http.MethodGet, nil, ®ion) if err != nil { return nil, err } @@ -365,7 +368,7 @@ func (c *client) GetRegions(ctx context.Context) (*RegionsInfo, error) { var regions RegionsInfo err := c.requestWithRetry(ctx, "GetRegions", Regions, - http.MethodGet, http.NoBody, ®ions) + http.MethodGet, nil, ®ions) if err != nil { return nil, err } @@ -378,7 +381,7 @@ func (c *client) GetRegionsByKeyRange(ctx context.Context, keyRange *KeyRange, l var regions RegionsInfo err := c.requestWithRetry(ctx, "GetRegionsByKeyRange", RegionsByKeyRange(keyRange, limit), - http.MethodGet, http.NoBody, ®ions) + http.MethodGet, nil, ®ions) if err != nil { return nil, err } @@ -390,7 +393,7 @@ func (c *client) GetRegionsByStoreID(ctx context.Context, storeID uint64) (*Regi var regions RegionsInfo err := c.requestWithRetry(ctx, "GetRegionsByStoreID", RegionsByStoreID(storeID), - http.MethodGet, http.NoBody, ®ions) + http.MethodGet, nil, ®ions) if err != nil { return nil, err } @@ -403,7 +406,7 @@ func (c *client) GetRegionsReplicatedStateByKeyRange(ctx context.Context, keyRan var state string err := c.requestWithRetry(ctx, "GetRegionsReplicatedStateByKeyRange", RegionsReplicatedByKeyRange(keyRange), - http.MethodGet, http.NoBody, &state) + http.MethodGet, nil, &state) if err != nil { return "", err } @@ -415,7 +418,7 @@ func (c *client) GetHotReadRegions(ctx context.Context) (*StoreHotPeersInfos, er var hotReadRegions StoreHotPeersInfos err := c.requestWithRetry(ctx, "GetHotReadRegions", HotRead, - http.MethodGet, http.NoBody, &hotReadRegions) + http.MethodGet, nil, &hotReadRegions) if err != nil { return nil, err } @@ -427,7 +430,7 @@ func (c *client) GetHotWriteRegions(ctx context.Context) (*StoreHotPeersInfos, e var hotWriteRegions StoreHotPeersInfos err := c.requestWithRetry(ctx, "GetHotWriteRegions", HotWrite, - http.MethodGet, http.NoBody, &hotWriteRegions) + http.MethodGet, nil, &hotWriteRegions) if err != nil { return nil, err } @@ -443,7 +446,7 @@ func (c *client) GetHistoryHotRegions(ctx context.Context, req *HistoryHotRegion var historyHotRegions HistoryHotRegions err = c.requestWithRetry(ctx, "GetHistoryHotRegions", HotHistory, - http.MethodGet, bytes.NewBuffer(reqJSON), &historyHotRegions, + http.MethodGet, reqJSON, &historyHotRegions, WithAllowFollowerHandle()) if err != nil { return nil, err @@ -458,7 +461,7 @@ func (c *client) GetRegionStatusByKeyRange(ctx context.Context, keyRange *KeyRan var regionStats RegionStats err := c.requestWithRetry(ctx, "GetRegionStatusByKeyRange", RegionStatsByKeyRange(keyRange, onlyCount), - http.MethodGet, http.NoBody, ®ionStats, + http.MethodGet, nil, ®ionStats, ) if err != nil { return nil, err @@ -473,14 +476,14 @@ func (c *client) SetStoreLabels(ctx context.Context, storeID int64, storeLabels return errors.Trace(err) } return c.requestWithRetry(ctx, "SetStoreLabel", LabelByStoreID(storeID), - http.MethodPost, bytes.NewBuffer(jsonInput), nil) + http.MethodPost, jsonInput, nil) } func (c *client) GetMembers(ctx context.Context) (*MembersInfo, error) { var members MembersInfo err := c.requestWithRetry(ctx, "GetMembers", membersPrefix, - http.MethodGet, http.NoBody, &members) + http.MethodGet, nil, &members) if err != nil { return nil, err } @@ -491,7 +494,7 @@ func (c *client) GetMembers(ctx context.Context) (*MembersInfo, error) { func (c *client) GetLeader(ctx context.Context) (*pdpb.Member, error) { var leader pdpb.Member err := c.requestWithRetry(ctx, "GetLeader", leaderPrefix, - http.MethodGet, http.NoBody, &leader) + http.MethodGet, nil, &leader) if err != nil { return nil, err } @@ -501,7 +504,7 @@ func (c *client) GetLeader(ctx context.Context) (*pdpb.Member, error) { // TransferLeader transfers the PD leader. func (c *client) TransferLeader(ctx context.Context, newLeader string) error { return c.requestWithRetry(ctx, "TransferLeader", TransferLeaderByID(newLeader), - http.MethodPost, http.NoBody, nil) + http.MethodPost, nil, nil) } // GetScheduleConfig gets the schedule configurations. @@ -509,7 +512,7 @@ func (c *client) GetScheduleConfig(ctx context.Context) (map[string]interface{}, var config map[string]interface{} err := c.requestWithRetry(ctx, "GetScheduleConfig", ScheduleConfig, - http.MethodGet, http.NoBody, &config) + http.MethodGet, nil, &config) if err != nil { return nil, err } @@ -524,7 +527,7 @@ func (c *client) SetScheduleConfig(ctx context.Context, config map[string]interf } return c.requestWithRetry(ctx, "SetScheduleConfig", ScheduleConfig, - http.MethodPost, bytes.NewBuffer(configJSON), nil) + http.MethodPost, configJSON, nil) } // GetStores gets the stores info. @@ -532,19 +535,43 @@ func (c *client) GetStores(ctx context.Context) (*StoresInfo, error) { var stores StoresInfo err := c.requestWithRetry(ctx, "GetStores", Stores, - http.MethodGet, http.NoBody, &stores) + http.MethodGet, nil, &stores) if err != nil { return nil, err } return &stores, nil } +// GetStore gets the store info by ID. +func (c *client) GetStore(ctx context.Context, storeID uint64) (*StoreInfo, error) { + var store StoreInfo + err := c.requestWithRetry(ctx, + "GetStore", StoreByID(storeID), + http.MethodGet, nil, &store) + if err != nil { + return nil, err + } + return &store, nil +} + +// GetClusterVersion gets the cluster version. +func (c *client) GetClusterVersion(ctx context.Context) (string, error) { + var version string + err := c.requestWithRetry(ctx, + "GetClusterVersion", ClusterVersion, + http.MethodGet, nil, &version) + if err != nil { + return "", err + } + return version, nil +} + // GetAllPlacementRuleBundles gets all placement rules bundles. func (c *client) GetAllPlacementRuleBundles(ctx context.Context) ([]*GroupBundle, error) { var bundles []*GroupBundle err := c.requestWithRetry(ctx, "GetPlacementRuleBundle", PlacementRuleBundle, - http.MethodGet, http.NoBody, &bundles) + http.MethodGet, nil, &bundles) if err != nil { return nil, err } @@ -556,7 +583,7 @@ func (c *client) GetPlacementRuleBundleByGroup(ctx context.Context, group string var bundle GroupBundle err := c.requestWithRetry(ctx, "GetPlacementRuleBundleByGroup", PlacementRuleBundleByGroup(group), - http.MethodGet, http.NoBody, &bundle) + http.MethodGet, nil, &bundle) if err != nil { return nil, err } @@ -568,7 +595,7 @@ func (c *client) GetPlacementRulesByGroup(ctx context.Context, group string) ([] var rules []*Rule err := c.requestWithRetry(ctx, "GetPlacementRulesByGroup", PlacementRulesByGroup(group), - http.MethodGet, http.NoBody, &rules) + http.MethodGet, nil, &rules) if err != nil { return nil, err } @@ -583,7 +610,7 @@ func (c *client) SetPlacementRule(ctx context.Context, rule *Rule) error { } return c.requestWithRetry(ctx, "SetPlacementRule", PlacementRule, - http.MethodPost, bytes.NewBuffer(ruleJSON), nil) + http.MethodPost, ruleJSON, nil) } // SetPlacementRuleInBatch sets the placement rules in batch. @@ -594,7 +621,7 @@ func (c *client) SetPlacementRuleInBatch(ctx context.Context, ruleOps []*RuleOp) } return c.requestWithRetry(ctx, "SetPlacementRuleInBatch", PlacementRulesInBatch, - http.MethodPost, bytes.NewBuffer(ruleOpsJSON), nil) + http.MethodPost, ruleOpsJSON, nil) } // SetPlacementRuleBundles sets the placement rule bundles. @@ -606,14 +633,14 @@ func (c *client) SetPlacementRuleBundles(ctx context.Context, bundles []*GroupBu } return c.requestWithRetry(ctx, "SetPlacementRuleBundles", PlacementRuleBundleWithPartialParameter(partial), - http.MethodPost, bytes.NewBuffer(bundlesJSON), nil) + http.MethodPost, bundlesJSON, nil) } // DeletePlacementRule deletes the placement rule. func (c *client) DeletePlacementRule(ctx context.Context, group, id string) error { return c.requestWithRetry(ctx, "DeletePlacementRule", PlacementRuleByGroupAndID(group, id), - http.MethodDelete, http.NoBody, nil) + http.MethodDelete, nil, nil) } // GetAllPlacementRuleGroups gets all placement rule groups. @@ -621,7 +648,7 @@ func (c *client) GetAllPlacementRuleGroups(ctx context.Context) ([]*RuleGroup, e var ruleGroups []*RuleGroup err := c.requestWithRetry(ctx, "GetAllPlacementRuleGroups", placementRuleGroups, - http.MethodGet, http.NoBody, &ruleGroups) + http.MethodGet, nil, &ruleGroups) if err != nil { return nil, err } @@ -633,7 +660,7 @@ func (c *client) GetPlacementRuleGroupByID(ctx context.Context, id string) (*Rul var ruleGroup RuleGroup err := c.requestWithRetry(ctx, "GetPlacementRuleGroupByID", PlacementRuleGroupByID(id), - http.MethodGet, http.NoBody, &ruleGroup) + http.MethodGet, nil, &ruleGroup) if err != nil { return nil, err } @@ -648,14 +675,14 @@ func (c *client) SetPlacementRuleGroup(ctx context.Context, ruleGroup *RuleGroup } return c.requestWithRetry(ctx, "SetPlacementRuleGroup", placementRuleGroup, - http.MethodPost, bytes.NewBuffer(ruleGroupJSON), nil) + http.MethodPost, ruleGroupJSON, nil) } // DeletePlacementRuleGroupByID deletes the placement rule group by ID. func (c *client) DeletePlacementRuleGroupByID(ctx context.Context, id string) error { return c.requestWithRetry(ctx, "DeletePlacementRuleGroupByID", PlacementRuleGroupByID(id), - http.MethodDelete, http.NoBody, nil) + http.MethodDelete, nil, nil) } // GetAllRegionLabelRules gets all region label rules. @@ -663,7 +690,7 @@ func (c *client) GetAllRegionLabelRules(ctx context.Context) ([]*LabelRule, erro var labelRules []*LabelRule err := c.requestWithRetry(ctx, "GetAllRegionLabelRules", RegionLabelRules, - http.MethodGet, http.NoBody, &labelRules) + http.MethodGet, nil, &labelRules) if err != nil { return nil, err } @@ -679,7 +706,7 @@ func (c *client) GetRegionLabelRulesByIDs(ctx context.Context, ruleIDs []string) var labelRules []*LabelRule err = c.requestWithRetry(ctx, "GetRegionLabelRulesByIDs", RegionLabelRulesByIDs, - http.MethodGet, bytes.NewBuffer(idsJSON), &labelRules) + http.MethodGet, idsJSON, &labelRules) if err != nil { return nil, err } @@ -694,7 +721,7 @@ func (c *client) SetRegionLabelRule(ctx context.Context, labelRule *LabelRule) e } return c.requestWithRetry(ctx, "SetRegionLabelRule", RegionLabelRule, - http.MethodPost, bytes.NewBuffer(labelRuleJSON), nil) + http.MethodPost, labelRuleJSON, nil) } // PatchRegionLabelRules patches the region label rules. @@ -705,14 +732,14 @@ func (c *client) PatchRegionLabelRules(ctx context.Context, labelRulePatch *Labe } return c.requestWithRetry(ctx, "PatchRegionLabelRules", RegionLabelRules, - http.MethodPatch, bytes.NewBuffer(labelRulePatchJSON), nil) + http.MethodPatch, labelRulePatchJSON, nil) } // GetSchedulers gets the schedulers from PD cluster. func (c *client) GetSchedulers(ctx context.Context) ([]string, error) { var schedulers []string err := c.requestWithRetry(ctx, "GetSchedulers", Schedulers, - http.MethodGet, http.NoBody, &schedulers) + http.MethodGet, nil, &schedulers) if err != nil { return nil, err } @@ -730,7 +757,7 @@ func (c *client) CreateScheduler(ctx context.Context, name string, storeID uint6 } return c.requestWithRetry(ctx, "CreateScheduler", Schedulers, - http.MethodPost, bytes.NewBuffer(inputJSON), nil) + http.MethodPost, inputJSON, nil) } // AccelerateSchedule accelerates the scheduling of the regions within the given key range. @@ -746,7 +773,7 @@ func (c *client) AccelerateSchedule(ctx context.Context, keyRange *KeyRange) err } return c.requestWithRetry(ctx, "AccelerateSchedule", AccelerateSchedule, - http.MethodPost, bytes.NewBuffer(inputJSON), nil) + http.MethodPost, inputJSON, nil) } // AccelerateScheduleInBatch accelerates the scheduling of the regions within the given key ranges in batch. @@ -766,10 +793,27 @@ func (c *client) AccelerateScheduleInBatch(ctx context.Context, keyRanges []*Key } return c.requestWithRetry(ctx, "AccelerateScheduleInBatch", AccelerateScheduleInBatch, - http.MethodPost, bytes.NewBuffer(inputJSON), nil) + http.MethodPost, inputJSON, nil) +} + +// SetSchedulerDelay sets the delay of given scheduler. +func (c *client) SetSchedulerDelay(ctx context.Context, scheduler string, delaySec int64) error { + m := map[string]int64{ + "delay": delaySec, + } + inputJSON, err := json.Marshal(m) + if err != nil { + return errors.Trace(err) + } + return c.requestWithRetry(ctx, + "SetSchedulerDelay", SchedulerByName(scheduler), + http.MethodPost, inputJSON, nil) } // GetMinResolvedTSByStoresIDs get min-resolved-ts by stores IDs. +// - When storeIDs has zero length, it will return (cluster-level's min_resolved_ts, nil, nil) when no error. +// - When storeIDs is {"cluster"}, it will return (cluster-level's min_resolved_ts, stores_min_resolved_ts, nil) when no error. +// - When storeID is specified to ID lists, it will return (min_resolved_ts of given stores, stores_min_resolved_ts, nil) when no error. func (c *client) GetMinResolvedTSByStoresIDs(ctx context.Context, storeIDs []uint64) (uint64, map[uint64]uint64, error) { uri := MinResolvedTSPrefix // scope is an optional parameter, it can be `cluster` or specified store IDs. @@ -791,7 +835,7 @@ func (c *client) GetMinResolvedTSByStoresIDs(ctx context.Context, storeIDs []uin }{} err := c.requestWithRetry(ctx, "GetMinResolvedTSByStoresIDs", uri, - http.MethodGet, http.NoBody, &resp) + http.MethodGet, nil, &resp) if err != nil { return 0, nil, err } diff --git a/tests/integrations/client/http_client_test.go b/tests/integrations/client/http_client_test.go index 476b4d2f541..7c8f66f4826 100644 --- a/tests/integrations/client/http_client_test.go +++ b/tests/integrations/client/http_client_test.go @@ -134,6 +134,13 @@ func (suite *httpClientTestSuite) TestMeta() { re.NoError(err) re.Equal(1, store.Count) re.Len(store.Stores, 1) + storeID := uint64(store.Stores[0].Store.ID) // TODO: why type is different? + store2, err := suite.client.GetStore(suite.ctx, storeID) + re.NoError(err) + re.EqualValues(storeID, store2.Store.ID) + version, err := suite.client.GetClusterVersion(suite.ctx) + re.NoError(err) + re.Equal("0.0.0", version) } func (suite *httpClientTestSuite) TestGetMinResolvedTSByStoresIDs() { @@ -396,6 +403,10 @@ func (suite *httpClientTestSuite) TestSchedulers() { schedulers, err = suite.client.GetSchedulers(suite.ctx) re.NoError(err) re.Len(schedulers, 1) + err = suite.client.SetSchedulerDelay(suite.ctx, "evict-leader-scheduler", 100) + re.NoError(err) + err = suite.client.SetSchedulerDelay(suite.ctx, "not-exist", 100) + re.ErrorContains(err, "500 Internal Server Error") // TODO: should return friendly error message } func (suite *httpClientTestSuite) TestSetStoreLabels() { From 48fabb79e8b0197bd6752fb84b0b704b782d3b48 Mon Sep 17 00:00:00 2001 From: lhy1024 Date: Wed, 13 Dec 2023 18:37:49 +0800 Subject: [PATCH 06/21] mcs: fix unnecessary PDRedirectorHeader (#7538) close tikv/pd#7533 Signed-off-by: lhy1024 Co-authored-by: Ryan Leung --- pkg/utils/apiutil/serverapi/middleware.go | 20 +++---- tests/integrations/mcs/scheduling/api_test.go | 55 +++++++++++++++++-- tests/pdctl/scheduler/scheduler_test.go | 45 --------------- 3 files changed, 60 insertions(+), 60 deletions(-) mode change 100644 => 100755 pkg/utils/apiutil/serverapi/middleware.go diff --git a/pkg/utils/apiutil/serverapi/middleware.go b/pkg/utils/apiutil/serverapi/middleware.go old mode 100644 new mode 100755 index eb0f8a5f8eb..c7979dcc038 --- a/pkg/utils/apiutil/serverapi/middleware.go +++ b/pkg/utils/apiutil/serverapi/middleware.go @@ -182,14 +182,6 @@ func (h *redirector) ServeHTTP(w http.ResponseWriter, r *http.Request, next http return } - // Prevent more than one redirection. - if name := r.Header.Get(apiutil.PDRedirectorHeader); len(name) != 0 { - log.Error("redirect but server is not leader", zap.String("from", name), zap.String("server", h.s.Name()), errs.ZapError(errs.ErrRedirect)) - http.Error(w, errs.ErrRedirectToNotLeader.FastGenByArgs().Error(), http.StatusInternalServerError) - return - } - - r.Header.Set(apiutil.PDRedirectorHeader, h.s.Name()) forwardedIP, forwardedPort := apiutil.GetIPPortFromHTTPRequest(r) if len(forwardedIP) > 0 { r.Header.Add(apiutil.XForwardedForHeader, forwardedIP) @@ -208,9 +200,9 @@ func (h *redirector) ServeHTTP(w http.ResponseWriter, r *http.Request, next http return } clientUrls = append(clientUrls, targetAddr) + // Add a header to the response, this is not a failure injection + // it is used for testing, to check whether the request is forwarded to the micro service failpoint.Inject("checkHeader", func() { - // add a header to the response, this is not a failure injection - // it is used for testing, to check whether the request is forwarded to the micro service w.Header().Set(apiutil.ForwardToMicroServiceHeader, "true") }) } else { @@ -220,7 +212,15 @@ func (h *redirector) ServeHTTP(w http.ResponseWriter, r *http.Request, next http return } clientUrls = leader.GetClientUrls() + // Prevent more than one redirection among PD/API servers. + if name := r.Header.Get(apiutil.PDRedirectorHeader); len(name) != 0 { + log.Error("redirect but server is not leader", zap.String("from", name), zap.String("server", h.s.Name()), errs.ZapError(errs.ErrRedirect)) + http.Error(w, errs.ErrRedirectToNotLeader.FastGenByArgs().Error(), http.StatusInternalServerError) + return + } + r.Header.Set(apiutil.PDRedirectorHeader, h.s.Name()) } + urls := make([]url.URL, 0, len(clientUrls)) for _, item := range clientUrls { u, err := url.Parse(item) diff --git a/tests/integrations/mcs/scheduling/api_test.go b/tests/integrations/mcs/scheduling/api_test.go index 8f5d37ee1bb..4c71f8f14a3 100644 --- a/tests/integrations/mcs/scheduling/api_test.go +++ b/tests/integrations/mcs/scheduling/api_test.go @@ -1,6 +1,7 @@ package scheduling_test import ( + "context" "encoding/hex" "encoding/json" "fmt" @@ -40,10 +41,12 @@ func TestAPI(t *testing.T) { } func (suite *apiTestSuite) SetupSuite() { + suite.NoError(failpoint.Enable("github.com/tikv/pd/pkg/utils/apiutil/serverapi/checkHeader", "return(true)")) suite.env = tests.NewSchedulingTestEnvironment(suite.T()) } func (suite *apiTestSuite) TearDownSuite() { + suite.NoError(failpoint.Disable("github.com/tikv/pd/pkg/utils/apiutil/serverapi/checkHeader")) suite.env.Cleanup() } @@ -99,10 +102,6 @@ func (suite *apiTestSuite) TestAPIForward() { func (suite *apiTestSuite) checkAPIForward(cluster *tests.TestCluster) { re := suite.Require() - re.NoError(failpoint.Enable("github.com/tikv/pd/pkg/utils/apiutil/serverapi/checkHeader", "return(true)")) - defer func() { - re.NoError(failpoint.Disable("github.com/tikv/pd/pkg/utils/apiutil/serverapi/checkHeader")) - }() leader := cluster.GetLeaderServer().GetServer() urlPrefix := fmt.Sprintf("%s/pd/api/v1", leader.GetAddr()) @@ -300,7 +299,7 @@ func (suite *apiTestSuite) checkAPIForward(cluster *tests.TestCluster) { rulesArgs, err := json.Marshal(rules) suite.NoError(err) - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "/config/rules"), &rules, + err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules"), &rules, testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true")) re.NoError(err) err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules"), rulesArgs, @@ -499,3 +498,49 @@ func (suite *apiTestSuite) checkAdminRegionCacheForward(cluster *tests.TestClust re.Equal(0, schedulingServer.GetCluster().GetRegionCount([]byte{}, []byte{})) re.Equal(0, apiServer.GetRaftCluster().GetRegionCount([]byte{}, []byte{}).Count) } + +func (suite *apiTestSuite) TestFollowerForward() { + suite.env.RunTestInTwoModes(suite.checkFollowerForward) +} + +func (suite *apiTestSuite) checkFollowerForward(cluster *tests.TestCluster) { + re := suite.Require() + leaderAddr := cluster.GetLeaderServer().GetAddr() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + follower, err := cluster.JoinAPIServer(ctx) + re.NoError(err) + re.NoError(follower.Run()) + re.NotEmpty(cluster.WaitLeader()) + + followerAddr := follower.GetAddr() + if cluster.GetLeaderServer().GetAddr() != leaderAddr { + followerAddr = leaderAddr + } + + urlPrefix := fmt.Sprintf("%s/pd/api/v1", followerAddr) + rules := []*placement.Rule{} + if sche := cluster.GetSchedulingPrimaryServer(); sche != nil { + // follower will forward to scheduling server directly + re.NotEqual(cluster.GetLeaderServer().GetAddr(), followerAddr) + err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules"), &rules, + testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true"), + ) + re.NoError(err) + } else { + // follower will forward to leader server + re.NotEqual(cluster.GetLeaderServer().GetAddr(), followerAddr) + err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules"), &rules, + testutil.WithoutHeader(re, apiutil.ForwardToMicroServiceHeader), + ) + re.NoError(err) + } + + // follower will forward to leader server + re.NotEqual(cluster.GetLeaderServer().GetAddr(), followerAddr) + results := make(map[string]interface{}) + err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config"), &results, + testutil.WithoutHeader(re, apiutil.ForwardToMicroServiceHeader), + ) + re.NoError(err) +} diff --git a/tests/pdctl/scheduler/scheduler_test.go b/tests/pdctl/scheduler/scheduler_test.go index d8d54a79d13..fb7c239b431 100644 --- a/tests/pdctl/scheduler/scheduler_test.go +++ b/tests/pdctl/scheduler/scheduler_test.go @@ -15,7 +15,6 @@ package scheduler_test import ( - "context" "encoding/json" "fmt" "reflect" @@ -691,47 +690,3 @@ func mightExec(re *require.Assertions, cmd *cobra.Command, args []string, v inte } json.Unmarshal(output, v) } - -func TestForwardSchedulerRequest(t *testing.T) { - re := require.New(t) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - cluster, err := tests.NewTestAPICluster(ctx, 1) - re.NoError(err) - re.NoError(cluster.RunInitialServers()) - re.NotEmpty(cluster.WaitLeader()) - server := cluster.GetLeaderServer() - re.NoError(server.BootstrapCluster()) - backendEndpoints := server.GetAddr() - tc, err := tests.NewTestSchedulingCluster(ctx, 1, backendEndpoints) - re.NoError(err) - defer tc.Destroy() - tc.WaitForPrimaryServing(re) - - cmd := pdctlCmd.GetRootCmd() - args := []string{"-u", backendEndpoints, "scheduler", "show"} - var sches []string - testutil.Eventually(re, func() bool { - output, err := pdctl.ExecuteCommand(cmd, args...) - re.NoError(err) - re.NoError(json.Unmarshal(output, &sches)) - return slice.Contains(sches, "balance-leader-scheduler") - }) - - mustUsage := func(args []string) { - output, err := pdctl.ExecuteCommand(cmd, args...) - re.NoError(err) - re.Contains(string(output), "Usage") - } - mustUsage([]string{"-u", backendEndpoints, "scheduler", "pause", "balance-leader-scheduler"}) - echo := mustExec(re, cmd, []string{"-u", backendEndpoints, "scheduler", "pause", "balance-leader-scheduler", "60"}, nil) - re.Contains(echo, "Success!") - checkSchedulerWithStatusCommand := func(status string, expected []string) { - var schedulers []string - mustExec(re, cmd, []string{"-u", backendEndpoints, "scheduler", "show", "--status", status}, &schedulers) - re.Equal(expected, schedulers) - } - checkSchedulerWithStatusCommand("paused", []string{ - "balance-leader-scheduler", - }) -} From 0e220b0b39a765762c48a7fa620b15bc78bd0a38 Mon Sep 17 00:00:00 2001 From: Ryan Leung Date: Wed, 13 Dec 2023 18:55:49 +0800 Subject: [PATCH 07/21] api: fix the output of some APIs (#7542) ref tikv/pd#4399 Signed-off-by: Ryan Leung Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- server/api/store.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/server/api/store.go b/server/api/store.go index 8537cd45c5b..44e178c23fd 100644 --- a/server/api/store.go +++ b/server/api/store.go @@ -334,7 +334,7 @@ func (h *storeHandler) SetStoreLabel(w http.ResponseWriter, r *http.Request) { // @Param id path integer true "Store Id" // @Param body body object true "Labels in json format" // @Produce json -// @Success 200 {string} string "The store's label is updated." +// @Success 200 {string} string "The label is deleted for store." // @Failure 400 {string} string "The input is invalid." // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /store/{id}/label [delete] @@ -369,7 +369,7 @@ func (h *storeHandler) DeleteStoreLabel(w http.ResponseWriter, r *http.Request) // @Param id path integer true "Store Id" // @Param body body object true "json params" // @Produce json -// @Success 200 {string} string "The store's label is updated." +// @Success 200 {string} string "The store's weight is updated." // @Failure 400 {string} string "The input is invalid." // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /store/{id}/weight [post] @@ -413,7 +413,7 @@ func (h *storeHandler) SetStoreWeight(w http.ResponseWriter, r *http.Request) { return } - h.rd.JSON(w, http.StatusOK, "The store's label is updated.") + h.rd.JSON(w, http.StatusOK, "The store's weight is updated.") } // FIXME: details of input json body params @@ -423,7 +423,7 @@ func (h *storeHandler) SetStoreWeight(w http.ResponseWriter, r *http.Request) { // @Param id path integer true "Store Id" // @Param body body object true "json params" // @Produce json -// @Success 200 {string} string "The store's label is updated." +// @Success 200 {string} string "The store's limit is updated." // @Failure 400 {string} string "The input is invalid." // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /store/{id}/limit [post] @@ -486,7 +486,7 @@ func (h *storeHandler) SetStoreLimit(w http.ResponseWriter, r *http.Request) { return } } - h.rd.JSON(w, http.StatusOK, "The store's label is updated.") + h.rd.JSON(w, http.StatusOK, "The store's limit is updated.") } type storesHandler struct { From f51f9134558e793e7ea08fc32cef672c7afa37b7 Mon Sep 17 00:00:00 2001 From: Yongbo Jiang Date: Wed, 13 Dec 2023 19:27:19 +0800 Subject: [PATCH 08/21] errs: remove redundant `FastGenWithCause` in `ZapError` (#7497) close tikv/pd#7496 Signed-off-by: Cabinfever_B Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- client/errs/errno.go | 2 +- client/errs/errs.go | 2 +- client/tso_dispatcher.go | 6 ++-- errors.toml | 2 +- pkg/autoscaling/prometheus.go | 2 +- pkg/errs/errno.go | 2 +- pkg/errs/errs.go | 2 +- pkg/errs/errs_test.go | 35 +++++++++++++++------- pkg/schedule/hbstream/heartbeat_streams.go | 2 +- pkg/schedule/plugin_interface.go | 6 ++-- pkg/schedule/schedulers/evict_leader.go | 2 +- pkg/schedule/schedulers/grant_leader.go | 2 +- pkg/schedule/schedulers/init.go | 10 +++---- pkg/schedule/schedulers/scheduler.go | 4 +-- pkg/schedule/schedulers/utils.go | 4 +-- pkg/tso/keyspace_group_manager.go | 2 +- pkg/utils/logutil/log.go | 2 +- pkg/utils/tempurl/check_env_linux.go | 2 +- server/handler.go | 2 +- 19 files changed, 53 insertions(+), 38 deletions(-) diff --git a/client/errs/errno.go b/client/errs/errno.go index 646af81929d..0f93ebf1472 100644 --- a/client/errs/errno.go +++ b/client/errs/errno.go @@ -54,7 +54,7 @@ var ( ErrClientGetMultiResponse = errors.Normalize("get invalid value response %v, must only one", errors.RFCCodeText("PD:client:ErrClientGetMultiResponse")) ErrClientGetServingEndpoint = errors.Normalize("get serving endpoint failed", errors.RFCCodeText("PD:client:ErrClientGetServingEndpoint")) ErrClientFindGroupByKeyspaceID = errors.Normalize("can't find keyspace group by keyspace id", errors.RFCCodeText("PD:client:ErrClientFindGroupByKeyspaceID")) - ErrClientWatchGCSafePointV2Stream = errors.Normalize("watch gc safe point v2 stream failed, %s", errors.RFCCodeText("PD:client:ErrClientWatchGCSafePointV2Stream")) + ErrClientWatchGCSafePointV2Stream = errors.Normalize("watch gc safe point v2 stream failed", errors.RFCCodeText("PD:client:ErrClientWatchGCSafePointV2Stream")) ) // grpcutil errors diff --git a/client/errs/errs.go b/client/errs/errs.go index e715056b055..47f7c29a467 100644 --- a/client/errs/errs.go +++ b/client/errs/errs.go @@ -27,7 +27,7 @@ func ZapError(err error, causeError ...error) zap.Field { } if e, ok := err.(*errors.Error); ok { if len(causeError) >= 1 { - err = e.Wrap(causeError[0]).FastGenWithCause() + err = e.Wrap(causeError[0]) } else { err = e.FastGenByArgs() } diff --git a/client/tso_dispatcher.go b/client/tso_dispatcher.go index 0de4dc3a49e..6b2c33ca58d 100644 --- a/client/tso_dispatcher.go +++ b/client/tso_dispatcher.go @@ -412,7 +412,7 @@ tsoBatchLoop: } else { log.Error("[tso] fetch pending tso requests error", zap.String("dc-location", dc), - errs.ZapError(errs.ErrClientGetTSO.FastGenByArgs("when fetch pending tso requests"), err)) + errs.ZapError(errs.ErrClientGetTSO, err)) } return } @@ -495,10 +495,10 @@ tsoBatchLoop: default: } c.svcDiscovery.ScheduleCheckMemberChanged() - log.Error("[tso] getTS error", + log.Error("[tso] getTS error after processing requests", zap.String("dc-location", dc), zap.String("stream-addr", streamAddr), - errs.ZapError(errs.ErrClientGetTSO.FastGenByArgs("after processing requests"), err)) + errs.ZapError(errs.ErrClientGetTSO, err)) // Set `stream` to nil and remove this stream from the `connectionCtxs` due to error. connectionCtxs.Delete(streamAddr) cancel() diff --git a/errors.toml b/errors.toml index a318fc32492..69aa29d2dda 100644 --- a/errors.toml +++ b/errors.toml @@ -83,7 +83,7 @@ get min TSO failed, %v ["PD:client:ErrClientGetTSO"] error = ''' -get TSO failed, %v +get TSO failed ''' ["PD:client:ErrClientGetTSOTimeout"] diff --git a/pkg/autoscaling/prometheus.go b/pkg/autoscaling/prometheus.go index 43ba768a585..91b813b6ef2 100644 --- a/pkg/autoscaling/prometheus.go +++ b/pkg/autoscaling/prometheus.go @@ -94,7 +94,7 @@ func (prom *PrometheusQuerier) queryMetricsFromPrometheus(query string, timestam resp, warnings, err := prom.api.Query(ctx, query, timestamp) if err != nil { - return nil, errs.ErrPrometheusQuery.Wrap(err).FastGenWithCause() + return nil, errs.ErrPrometheusQuery.Wrap(err) } if len(warnings) > 0 { diff --git a/pkg/errs/errno.go b/pkg/errs/errno.go index a4320238374..03fa0f61158 100644 --- a/pkg/errs/errno.go +++ b/pkg/errs/errno.go @@ -86,7 +86,7 @@ var ( var ( ErrClientCreateTSOStream = errors.Normalize("create TSO stream failed, %s", errors.RFCCodeText("PD:client:ErrClientCreateTSOStream")) ErrClientGetTSOTimeout = errors.Normalize("get TSO timeout", errors.RFCCodeText("PD:client:ErrClientGetTSOTimeout")) - ErrClientGetTSO = errors.Normalize("get TSO failed, %v", errors.RFCCodeText("PD:client:ErrClientGetTSO")) + ErrClientGetTSO = errors.Normalize("get TSO failed", errors.RFCCodeText("PD:client:ErrClientGetTSO")) ErrClientGetLeader = errors.Normalize("get leader failed, %v", errors.RFCCodeText("PD:client:ErrClientGetLeader")) ErrClientGetMember = errors.Normalize("get member failed", errors.RFCCodeText("PD:client:ErrClientGetMember")) ErrClientGetMinTSO = errors.Normalize("get min TSO failed, %v", errors.RFCCodeText("PD:client:ErrClientGetMinTSO")) diff --git a/pkg/errs/errs.go b/pkg/errs/errs.go index acc42637733..5746b282f10 100644 --- a/pkg/errs/errs.go +++ b/pkg/errs/errs.go @@ -27,7 +27,7 @@ func ZapError(err error, causeError ...error) zap.Field { } if e, ok := err.(*errors.Error); ok { if len(causeError) >= 1 { - err = e.Wrap(causeError[0]).FastGenWithCause() + err = e.Wrap(causeError[0]) } else { err = e.FastGenByArgs() } diff --git a/pkg/errs/errs_test.go b/pkg/errs/errs_test.go index c242dd994f5..d76c02dc110 100644 --- a/pkg/errs/errs_test.go +++ b/pkg/errs/errs_test.go @@ -81,9 +81,19 @@ func TestError(t *testing.T) { log.Error("test", zap.Error(ErrEtcdLeaderNotFound.FastGenByArgs())) re.Contains(lg.Message(), rfc) err := errors.New("test error") - log.Error("test", ZapError(ErrEtcdLeaderNotFound, err)) - rfc = `[error="[PD:member:ErrEtcdLeaderNotFound]test error` - re.Contains(lg.Message(), rfc) + // use Info() because of no stack for comparing. + log.Info("test", ZapError(ErrEtcdLeaderNotFound, err)) + rfc = `[error="[PD:member:ErrEtcdLeaderNotFound]etcd leader not found: test error` + m1 := lg.Message() + re.Contains(m1, rfc) + log.Info("test", zap.Error(ErrEtcdLeaderNotFound.Wrap(err))) + m2 := lg.Message() + idx1 := strings.Index(m1, "[error") + idx2 := strings.Index(m2, "[error") + re.Equal(m1[idx1:], m2[idx2:]) + log.Info("test", zap.Error(ErrEtcdLeaderNotFound.Wrap(err).FastGenWithCause())) + m3 := lg.Message() + re.NotContains(m3, rfc) } func TestErrorEqual(t *testing.T) { @@ -94,24 +104,24 @@ func TestErrorEqual(t *testing.T) { re.True(errors.ErrorEqual(err1, err2)) err := errors.New("test") - err1 = ErrSchedulerNotFound.Wrap(err).FastGenWithCause() - err2 = ErrSchedulerNotFound.Wrap(err).FastGenWithCause() + err1 = ErrSchedulerNotFound.Wrap(err) + err2 = ErrSchedulerNotFound.Wrap(err) re.True(errors.ErrorEqual(err1, err2)) err1 = ErrSchedulerNotFound.FastGenByArgs() - err2 = ErrSchedulerNotFound.Wrap(err).FastGenWithCause() + err2 = ErrSchedulerNotFound.Wrap(err) re.False(errors.ErrorEqual(err1, err2)) err3 := errors.New("test") err4 := errors.New("test") - err1 = ErrSchedulerNotFound.Wrap(err3).FastGenWithCause() - err2 = ErrSchedulerNotFound.Wrap(err4).FastGenWithCause() + err1 = ErrSchedulerNotFound.Wrap(err3) + err2 = ErrSchedulerNotFound.Wrap(err4) re.True(errors.ErrorEqual(err1, err2)) err3 = errors.New("test1") err4 = errors.New("test") - err1 = ErrSchedulerNotFound.Wrap(err3).FastGenWithCause() - err2 = ErrSchedulerNotFound.Wrap(err4).FastGenWithCause() + err1 = ErrSchedulerNotFound.Wrap(err3) + err2 = ErrSchedulerNotFound.Wrap(err4) re.False(errors.ErrorEqual(err1, err2)) } @@ -135,11 +145,16 @@ func TestErrorWithStack(t *testing.T) { m1 := lg.Message() log.Error("test", zap.Error(errors.WithStack(err))) m2 := lg.Message() + log.Error("test", ZapError(ErrStrconvParseInt.GenWithStackByCause(), err)) + m3 := lg.Message() // This test is based on line number and the first log is in line 141, the second is in line 142. // So they have the same length stack. Move this test to another place need to change the corresponding length. idx1 := strings.Index(m1, "[stack=") re.GreaterOrEqual(idx1, -1) idx2 := strings.Index(m2, "[stack=") re.GreaterOrEqual(idx2, -1) + idx3 := strings.Index(m3, "[stack=") + re.GreaterOrEqual(idx3, -1) re.Len(m2[idx2:], len(m1[idx1:])) + re.Len(m3[idx3:], len(m1[idx1:])) } diff --git a/pkg/schedule/hbstream/heartbeat_streams.go b/pkg/schedule/hbstream/heartbeat_streams.go index e7d7f688035..57a7521c0a7 100644 --- a/pkg/schedule/hbstream/heartbeat_streams.go +++ b/pkg/schedule/hbstream/heartbeat_streams.go @@ -139,7 +139,7 @@ func (s *HeartbeatStreams) run() { if stream, ok := s.streams[storeID]; ok { if err := stream.Send(msg); err != nil { log.Error("send heartbeat message fail", - zap.Uint64("region-id", msg.GetRegionId()), errs.ZapError(errs.ErrGRPCSend.Wrap(err).GenWithStackByArgs())) + zap.Uint64("region-id", msg.GetRegionId()), errs.ZapError(errs.ErrGRPCSend, err)) delete(s.streams, storeID) heartbeatStreamCounter.WithLabelValues(storeAddress, storeLabel, "push", "err").Inc() } else { diff --git a/pkg/schedule/plugin_interface.go b/pkg/schedule/plugin_interface.go index dd1ce3471e6..62ffe2eb900 100644 --- a/pkg/schedule/plugin_interface.go +++ b/pkg/schedule/plugin_interface.go @@ -46,19 +46,19 @@ func (p *PluginInterface) GetFunction(path string, funcName string) (plugin.Symb // open plugin filePath, err := filepath.Abs(path) if err != nil { - return nil, errs.ErrFilePathAbs.Wrap(err).FastGenWithCause() + return nil, errs.ErrFilePathAbs.Wrap(err) } log.Info("open plugin file", zap.String("file-path", filePath)) plugin, err := plugin.Open(filePath) if err != nil { - return nil, errs.ErrLoadPlugin.Wrap(err).FastGenWithCause() + return nil, errs.ErrLoadPlugin.Wrap(err) } p.pluginMap[path] = plugin } // get func from plugin f, err := p.pluginMap[path].Lookup(funcName) if err != nil { - return nil, errs.ErrLookupPluginFunc.Wrap(err).FastGenWithCause() + return nil, errs.ErrLookupPluginFunc.Wrap(err) } return f, nil } diff --git a/pkg/schedule/schedulers/evict_leader.go b/pkg/schedule/schedulers/evict_leader.go index 879aa9869b3..ae8b1ecf1ea 100644 --- a/pkg/schedule/schedulers/evict_leader.go +++ b/pkg/schedule/schedulers/evict_leader.go @@ -80,7 +80,7 @@ func (conf *evictLeaderSchedulerConfig) BuildWithArgs(args []string) error { id, err := strconv.ParseUint(args[0], 10, 64) if err != nil { - return errs.ErrStrconvParseUint.Wrap(err).FastGenWithCause() + return errs.ErrStrconvParseUint.Wrap(err) } ranges, err := getKeyRanges(args[1:]) if err != nil { diff --git a/pkg/schedule/schedulers/grant_leader.go b/pkg/schedule/schedulers/grant_leader.go index 885f81e2442..027350536aa 100644 --- a/pkg/schedule/schedulers/grant_leader.go +++ b/pkg/schedule/schedulers/grant_leader.go @@ -63,7 +63,7 @@ func (conf *grantLeaderSchedulerConfig) BuildWithArgs(args []string) error { id, err := strconv.ParseUint(args[0], 10, 64) if err != nil { - return errs.ErrStrconvParseUint.Wrap(err).FastGenWithCause() + return errs.ErrStrconvParseUint.Wrap(err) } ranges, err := getKeyRanges(args[1:]) if err != nil { diff --git a/pkg/schedule/schedulers/init.go b/pkg/schedule/schedulers/init.go index f60be1e5b06..57eb4b90985 100644 --- a/pkg/schedule/schedulers/init.go +++ b/pkg/schedule/schedulers/init.go @@ -129,7 +129,7 @@ func schedulersRegister() { id, err := strconv.ParseUint(args[0], 10, 64) if err != nil { - return errs.ErrStrconvParseUint.Wrap(err).FastGenWithCause() + return errs.ErrStrconvParseUint.Wrap(err) } ranges, err := getKeyRanges(args[1:]) @@ -180,14 +180,14 @@ func schedulersRegister() { } leaderID, err := strconv.ParseUint(args[0], 10, 64) if err != nil { - return errs.ErrStrconvParseUint.Wrap(err).FastGenWithCause() + return errs.ErrStrconvParseUint.Wrap(err) } storeIDs := make([]uint64, 0) for _, id := range strings.Split(args[1], ",") { storeID, err := strconv.ParseUint(id, 10, 64) if err != nil { - return errs.ErrStrconvParseUint.Wrap(err).FastGenWithCause() + return errs.ErrStrconvParseUint.Wrap(err) } storeIDs = append(storeIDs, storeID) } @@ -248,7 +248,7 @@ func schedulersRegister() { id, err := strconv.ParseUint(args[0], 10, 64) if err != nil { - return errs.ErrStrconvParseUint.Wrap(err).FastGenWithCause() + return errs.ErrStrconvParseUint.Wrap(err) } ranges, err := getKeyRanges(args[1:]) if err != nil { @@ -365,7 +365,7 @@ func schedulersRegister() { if len(args) == 1 { limit, err := strconv.ParseUint(args[0], 10, 64) if err != nil { - return errs.ErrStrconvParseUint.Wrap(err).FastGenWithCause() + return errs.ErrStrconvParseUint.Wrap(err) } conf.Limit = limit } diff --git a/pkg/schedule/schedulers/scheduler.go b/pkg/schedule/schedulers/scheduler.go index 1c788989454..38fc8f5607d 100644 --- a/pkg/schedule/schedulers/scheduler.go +++ b/pkg/schedule/schedulers/scheduler.go @@ -52,7 +52,7 @@ type Scheduler interface { func EncodeConfig(v interface{}) ([]byte, error) { marshaled, err := json.Marshal(v) if err != nil { - return nil, errs.ErrJSONMarshal.Wrap(err).FastGenWithCause() + return nil, errs.ErrJSONMarshal.Wrap(err) } return marshaled, nil } @@ -61,7 +61,7 @@ func EncodeConfig(v interface{}) ([]byte, error) { func DecodeConfig(data []byte, v interface{}) error { err := json.Unmarshal(data, v) if err != nil { - return errs.ErrJSONUnmarshal.Wrap(err).FastGenWithCause() + return errs.ErrJSONUnmarshal.Wrap(err) } return nil } diff --git a/pkg/schedule/schedulers/utils.go b/pkg/schedule/schedulers/utils.go index fea51798d1c..a22f992bda1 100644 --- a/pkg/schedule/schedulers/utils.go +++ b/pkg/schedule/schedulers/utils.go @@ -218,11 +218,11 @@ func getKeyRanges(args []string) ([]core.KeyRange, error) { for len(args) > 1 { startKey, err := url.QueryUnescape(args[0]) if err != nil { - return nil, errs.ErrQueryUnescape.Wrap(err).FastGenWithCause() + return nil, errs.ErrQueryUnescape.Wrap(err) } endKey, err := url.QueryUnescape(args[1]) if err != nil { - return nil, errs.ErrQueryUnescape.Wrap(err).FastGenWithCause() + return nil, errs.ErrQueryUnescape.Wrap(err) } args = args[2:] ranges = append(ranges, core.NewKeyRange(startKey, endKey)) diff --git a/pkg/tso/keyspace_group_manager.go b/pkg/tso/keyspace_group_manager.go index badcb18d5d8..58534de1642 100644 --- a/pkg/tso/keyspace_group_manager.go +++ b/pkg/tso/keyspace_group_manager.go @@ -542,7 +542,7 @@ func (kgm *KeyspaceGroupManager) InitializeGroupWatchLoop() error { putFn := func(kv *mvccpb.KeyValue) error { group := &endpoint.KeyspaceGroup{} if err := json.Unmarshal(kv.Value, group); err != nil { - return errs.ErrJSONUnmarshal.Wrap(err).FastGenWithCause() + return errs.ErrJSONUnmarshal.Wrap(err) } kgm.updateKeyspaceGroup(group) if group.ID == mcsutils.DefaultKeyspaceGroupID { diff --git a/pkg/utils/logutil/log.go b/pkg/utils/logutil/log.go index 3dc4430b066..8c0977818fa 100644 --- a/pkg/utils/logutil/log.go +++ b/pkg/utils/logutil/log.go @@ -70,7 +70,7 @@ func StringToZapLogLevel(level string) zapcore.Level { func SetupLogger(logConfig log.Config, logger **zap.Logger, logProps **log.ZapProperties, enabled ...bool) error { lg, p, err := log.InitLogger(&logConfig, zap.AddStacktrace(zapcore.FatalLevel)) if err != nil { - return errs.ErrInitLogger.Wrap(err).FastGenWithCause() + return errs.ErrInitLogger.Wrap(err) } *logger = lg *logProps = p diff --git a/pkg/utils/tempurl/check_env_linux.go b/pkg/utils/tempurl/check_env_linux.go index cf0e686cada..58f902f4bb7 100644 --- a/pkg/utils/tempurl/check_env_linux.go +++ b/pkg/utils/tempurl/check_env_linux.go @@ -36,7 +36,7 @@ func checkAddr(addr string) (bool, error) { return s.RemoteAddr.String() == addr || s.LocalAddr.String() == addr }) if err != nil { - return false, errs.ErrNetstatTCPSocks.Wrap(err).FastGenWithCause() + return false, errs.ErrNetstatTCPSocks.Wrap(err) } return len(tabs) < 1, nil } diff --git a/server/handler.go b/server/handler.go index 6c0679bd9f9..b91c8e368f9 100644 --- a/server/handler.go +++ b/server/handler.go @@ -470,7 +470,7 @@ func (h *Handler) PluginLoad(pluginPath string) error { // make sure path is in data dir filePath, err := filepath.Abs(pluginPath) if err != nil || !isPathInDirectory(filePath, h.s.GetConfig().DataDir) { - return errs.ErrFilePathAbs.Wrap(err).FastGenWithCause() + return errs.ErrFilePathAbs.Wrap(err) } c.LoadPlugin(pluginPath, ch) From 7b60e0928d35bbdb1914850040825684853d63f0 Mon Sep 17 00:00:00 2001 From: Ryan Leung Date: Thu, 14 Dec 2023 18:21:20 +0800 Subject: [PATCH 09/21] mcs: fix sequence of callback functions (#7548) close tikv/pd#7543 Signed-off-by: Ryan Leung --- pkg/mcs/resourcemanager/server/server.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pkg/mcs/resourcemanager/server/server.go b/pkg/mcs/resourcemanager/server/server.go index 2a1be3e0ca5..43d426bfc40 100644 --- a/pkg/mcs/resourcemanager/server/server.go +++ b/pkg/mcs/resourcemanager/server/server.go @@ -321,13 +321,15 @@ func (s *Server) startServer() (err error) { s.serverLoopWg.Add(1) go utils.StartGRPCAndHTTPServers(s, serverReadyChan, s.GetListener()) <-serverReadyChan - s.startServerLoop() // Run callbacks log.Info("triggering the start callback functions") for _, cb := range s.GetStartCallbacks() { cb() } + // The start callback function will initialize storage, which will be used in service ready callback. + // We should make sure the calling sequence is right. + s.startServerLoop() // Server has started. entry := &discovery.ServiceRegistryEntry{ServiceAddr: s.cfg.AdvertiseListenAddr} From 8196d84c04fb4fcd0d99a57304aa242655381056 Mon Sep 17 00:00:00 2001 From: Hu# Date: Fri, 15 Dec 2023 17:27:20 +0800 Subject: [PATCH 10/21] makefile: update golangci (#7556) close tikv/pd#7551 Signed-off-by: husharp --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 946493cd7ce..67cdac99b02 100644 --- a/Makefile +++ b/Makefile @@ -160,7 +160,7 @@ SHELL := env PATH='$(PATH)' GOBIN='$(GO_TOOLS_BIN_PATH)' $(shell which bash) install-tools: @mkdir -p $(GO_TOOLS_BIN_PATH) - @which golangci-lint >/dev/null 2>&1 || curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(GO_TOOLS_BIN_PATH) v1.51.2 + @which golangci-lint >/dev/null 2>&1 || curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(GO_TOOLS_BIN_PATH) v1.55.2 @grep '_' tools.go | sed 's/"//g' | awk '{print $$2}' | xargs go install .PHONY: install-tools From a4ab7d31da67acb2dcc5945b973e18d2a0f10d1b Mon Sep 17 00:00:00 2001 From: Hu# Date: Mon, 18 Dec 2023 14:06:22 +0800 Subject: [PATCH 11/21] ci: support real cluster test in jenkins (#7493) ref tikv/pd#7298 Signed-off-by: husharp --- Makefile | 7 +- go.mod | 4 +- go.sum | 8 +- pkg/member/member.go | 4 +- tests/integrations/client/go.mod | 2 +- tests/integrations/mcs/go.mod | 2 +- tests/integrations/realtiup/Makefile | 58 ++++ tests/integrations/realtiup/deploy.sh | 23 ++ tests/integrations/realtiup/go.mod | 47 ++++ tests/integrations/realtiup/go.sum | 252 ++++++++++++++++++ tests/integrations/realtiup/mock_db.go | 91 +++++++ tests/integrations/realtiup/reboot_pd_test.go | 71 +++++ .../realtiup/transfer_leader_test.go | 73 +++++ tests/integrations/realtiup/ts_test.go | 45 ++++ tests/integrations/realtiup/util.go | 39 +++ tests/integrations/realtiup/wait_tiup.sh | 22 ++ tests/integrations/tso/go.mod | 2 +- tools/pd-api-bench/go.mod | 2 +- tools/pd-simulator/main.go | 2 +- 19 files changed, 740 insertions(+), 14 deletions(-) create mode 100644 tests/integrations/realtiup/Makefile create mode 100755 tests/integrations/realtiup/deploy.sh create mode 100644 tests/integrations/realtiup/go.mod create mode 100644 tests/integrations/realtiup/go.sum create mode 100644 tests/integrations/realtiup/mock_db.go create mode 100644 tests/integrations/realtiup/reboot_pd_test.go create mode 100644 tests/integrations/realtiup/transfer_leader_test.go create mode 100644 tests/integrations/realtiup/ts_test.go create mode 100644 tests/integrations/realtiup/util.go create mode 100755 tests/integrations/realtiup/wait_tiup.sh diff --git a/Makefile b/Makefile index 67cdac99b02..2a506eb576f 100644 --- a/Makefile +++ b/Makefile @@ -55,7 +55,12 @@ ifeq ($(ENABLE_FIPS), 1) BUILD_TOOL_CGO_ENABLED := 1 endif -LDFLAGS += -X "$(PD_PKG)/pkg/versioninfo.PDReleaseVersion=$(shell git describe --tags --dirty --always)" +RELEASE_VERSION ?= $(shell git describe --tags --dirty --always) +ifeq ($(RUN_CI), 1) + RELEASE_VERSION := None +endif + +LDFLAGS += -X "$(PD_PKG)/pkg/versioninfo.PDReleaseVersion=$(RELEASE_VERSION)" LDFLAGS += -X "$(PD_PKG)/pkg/versioninfo.PDBuildTS=$(shell date -u '+%Y-%m-%d %I:%M:%S')" LDFLAGS += -X "$(PD_PKG)/pkg/versioninfo.PDGitHash=$(shell git rev-parse HEAD)" LDFLAGS += -X "$(PD_PKG)/pkg/versioninfo.PDGitBranch=$(shell git rev-parse --abbrev-ref HEAD)" diff --git a/go.mod b/go.mod index 676d350d22d..d5cbc41f654 100644 --- a/go.mod +++ b/go.mod @@ -55,7 +55,7 @@ require ( go.uber.org/atomic v1.10.0 go.uber.org/goleak v1.1.12 go.uber.org/zap v1.24.0 - golang.org/x/exp v0.0.0-20230108222341-4b8118a2686a + golang.org/x/exp v0.0.0-20230711005742-c3f37128e5a4 golang.org/x/text v0.13.0 golang.org/x/time v0.1.0 golang.org/x/tools v0.6.0 @@ -185,7 +185,7 @@ require ( golang.org/x/arch v0.3.0 // indirect golang.org/x/crypto v0.14.0 // indirect golang.org/x/image v0.5.0 // indirect - golang.org/x/mod v0.8.0 // indirect + golang.org/x/mod v0.11.0 // indirect golang.org/x/net v0.17.0 // indirect golang.org/x/oauth2 v0.4.0 // indirect golang.org/x/sync v0.1.0 // indirect diff --git a/go.sum b/go.sum index c7ceeee028c..bf35be7eb8c 100644 --- a/go.sum +++ b/go.sum @@ -683,8 +683,8 @@ golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4/go.mod h1:IxCIyHEi3zRg3s0 golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/exp v0.0.0-20230108222341-4b8118a2686a h1:tlXy25amD5A7gOfbXdqCGN5k8ESEed/Ee1E5RcrYnqU= -golang.org/x/exp v0.0.0-20230108222341-4b8118a2686a/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= +golang.org/x/exp v0.0.0-20230711005742-c3f37128e5a4 h1:QLureRX3moex6NVu/Lr4MGakp9FdA7sBHGBmvRW7NaM= +golang.org/x/exp v0.0.0-20230711005742-c3f37128e5a4/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= golang.org/x/image v0.0.0-20200119044424-58c23975cae1/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/image v0.5.0 h1:5JMiNunQeQw++mMOz48/ISeNu3Iweh/JaZU8ZLqHRrI= golang.org/x/image v0.5.0/go.mod h1:FVC7BI/5Ym8R25iw5OLsgshdUBbT1h5jZTpA+mvAdZ4= @@ -700,8 +700,8 @@ golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8= -golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU= +golang.org/x/mod v0.11.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= diff --git a/pkg/member/member.go b/pkg/member/member.go index b411d0c957b..1b901a1d04a 100644 --- a/pkg/member/member.go +++ b/pkg/member/member.go @@ -57,7 +57,7 @@ type EmbeddedEtcdMember struct { id uint64 // etcd server id. member *pdpb.Member // current PD's info. rootPath string - // memberValue is the serialized string of `member`. It will be save in + // memberValue is the serialized string of `member`. It will be saved in // etcd leader key when the PD node is successfully elected as the PD leader // of the cluster. Every write will use it to check PD leadership. memberValue string @@ -199,7 +199,7 @@ func (m *EmbeddedEtcdMember) KeepLeader(ctx context.Context) { m.leadership.Keep(ctx) } -// PreCheckLeader does some pre-check before checking whether or not it's the leader. +// PreCheckLeader does some pre-check before checking whether it's the leader. func (m *EmbeddedEtcdMember) PreCheckLeader() error { if m.GetEtcdLeader() == 0 { return errs.ErrEtcdLeaderNotFound diff --git a/tests/integrations/client/go.mod b/tests/integrations/client/go.mod index 799901ff2e3..da130278ae0 100644 --- a/tests/integrations/client/go.mod +++ b/tests/integrations/client/go.mod @@ -17,7 +17,7 @@ require ( github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 github.com/stretchr/testify v1.8.3 github.com/tikv/pd v0.0.0-00010101000000-000000000000 - github.com/tikv/pd/client v0.0.0-00010101000000-000000000000 + github.com/tikv/pd/client v0.0.0-20231101084237-a1a1eea8dafd go.etcd.io/etcd v0.5.0-alpha.5.0.20220915004622-85b640cee793 go.uber.org/goleak v1.1.12 go.uber.org/zap v1.24.0 diff --git a/tests/integrations/mcs/go.mod b/tests/integrations/mcs/go.mod index 75d70e3cf06..1823a224fa1 100644 --- a/tests/integrations/mcs/go.mod +++ b/tests/integrations/mcs/go.mod @@ -17,7 +17,7 @@ require ( github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 github.com/stretchr/testify v1.8.3 github.com/tikv/pd v0.0.0-00010101000000-000000000000 - github.com/tikv/pd/client v0.0.0-00010101000000-000000000000 + github.com/tikv/pd/client v0.0.0-20231101084237-a1a1eea8dafd go.etcd.io/etcd v0.5.0-alpha.5.0.20220915004622-85b640cee793 go.uber.org/goleak v1.1.12 go.uber.org/zap v1.24.0 diff --git a/tests/integrations/realtiup/Makefile b/tests/integrations/realtiup/Makefile new file mode 100644 index 00000000000..c9ffd3d6599 --- /dev/null +++ b/tests/integrations/realtiup/Makefile @@ -0,0 +1,58 @@ +# Copyright 2023 TiKV Project Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +ROOT_PATH := ../../.. +GO_TOOLS_BIN_PATH := $(ROOT_PATH)/.tools/bin +PATH := $(GO_TOOLS_BIN_PATH):$(PATH) +SHELL := env PATH='$(PATH)' GOBIN='$(GO_TOOLS_BIN_PATH)' $(shell which bash) + +static: install-tools + @ echo "gofmt ..." + @ gofmt -s -l -d . 2>&1 | awk '{ print } END { if (NR > 0) { exit 1 } }' + @ echo "golangci-lint ..." + @ golangci-lint run -c $(ROOT_PATH)/.golangci.yml --verbose ./... --allow-parallel-runners + @ echo "revive ..." + @ revive -formatter friendly -config $(ROOT_PATH)/revive.toml ./... + +tidy: + @ go mod tidy + git diff go.mod go.sum | cat + git diff --quiet go.mod go.sum + +check: deploy test kill_tiup + +deploy: kill_tiup + @ echo "deploying..." + ./deploy.sh + @ echo "wait tiup cluster ready..." + ./wait_tiup.sh 15 20 + @ echo "check cluster status..." + @ pid=$$(ps -ef | grep 'tiup' | grep -v grep | awk '{print $$2}' | head -n 1); \ + echo $$pid; + +kill_tiup: + @ echo "kill tiup..." + @ pid=$$(ps -ef | grep 'tiup' | grep -v grep | awk '{print $$2}' | head -n 1); \ + if [ ! -z "$$pid" ]; then \ + echo $$pid; \ + kill $$pid; \ + echo "waiting for tiup to exit..."; \ + sleep 10; \ + fi + +test: + CGO_ENABLED=1 go test ./... -v -tags deadlock -race -cover || { exit 1; } + +install-tools: + cd $(ROOT_PATH) && $(MAKE) install-tools diff --git a/tests/integrations/realtiup/deploy.sh b/tests/integrations/realtiup/deploy.sh new file mode 100755 index 00000000000..18e6de7f0b9 --- /dev/null +++ b/tests/integrations/realtiup/deploy.sh @@ -0,0 +1,23 @@ +#!/bin/bash +# deploy `tiup playground` + +TIUP_BIN_DIR=$HOME/.tiup/bin/tiup +CUR_PATH=$(pwd) + +# See https://misc.flogisoft.com/bash/tip_colors_and_formatting. +color-green() { # Green + echo -e "\x1B[1;32m${*}\x1B[0m" +} + +# Install TiUP +color-green "install TiUP..." +curl --proto '=https' --tlsv1.2 -sSf https://tiup-mirrors.pingcap.com/install.sh | sh +$TIUP_BIN_DIR update playground + +cd ../../.. +# Run TiUP +$TIUP_BIN_DIR playground nightly --kv 3 --tiflash 1 --db 1 --pd 3 --without-monitor \ + --pd.binpath ./bin/pd-server --kv.binpath ./bin/tikv-server --db.binpath ./bin/tidb-server --tiflash.binpath ./bin/tiflash --tag pd_test \ + > $CUR_PATH/playground.log 2>&1 & + +cd $CUR_PATH diff --git a/tests/integrations/realtiup/go.mod b/tests/integrations/realtiup/go.mod new file mode 100644 index 00000000000..ccb23548f3e --- /dev/null +++ b/tests/integrations/realtiup/go.mod @@ -0,0 +1,47 @@ +module github.com/tikv/pd/tests/integrations/realtiup + +go 1.21 + +replace github.com/tikv/pd/client => ../../../client + +require ( + github.com/DATA-DOG/go-sqlmock v1.5.0 + github.com/go-sql-driver/mysql v1.7.1 + github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 + github.com/stretchr/testify v1.8.4 + github.com/tikv/pd/client v0.0.0-00010101000000-000000000000 + gorm.io/driver/mysql v1.5.2 + gorm.io/gorm v1.25.5 + moul.io/zapgorm2 v1.3.0 +) + +require ( + github.com/benbjohnson/clock v1.3.0 // indirect + github.com/beorn7/perks v1.0.1 // indirect + github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/gogo/protobuf v1.3.2 // indirect + github.com/golang/protobuf v1.5.3 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect + github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect + github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c // indirect + github.com/pingcap/kvproto v0.0.0-20230727073445-53e1f8730c30 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/prometheus/client_golang v1.11.1 // indirect + github.com/prometheus/client_model v0.2.0 // indirect + github.com/prometheus/common v0.26.0 // indirect + github.com/prometheus/procfs v0.6.0 // indirect + go.uber.org/atomic v1.10.0 // indirect + go.uber.org/goleak v1.1.12 // indirect + go.uber.org/multierr v1.11.0 // indirect + go.uber.org/zap v1.24.0 // indirect + golang.org/x/net v0.17.0 // indirect + golang.org/x/sys v0.13.0 // indirect + golang.org/x/text v0.13.0 // indirect + google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1 // indirect + google.golang.org/grpc v1.54.0 // indirect + google.golang.org/protobuf v1.30.0 // indirect + gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/tests/integrations/realtiup/go.sum b/tests/integrations/realtiup/go.sum new file mode 100644 index 00000000000..fde38211174 --- /dev/null +++ b/tests/integrations/realtiup/go.sum @@ -0,0 +1,252 @@ +cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60= +github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= +github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= +github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= +github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= +github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= +github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho= +github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= +github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A= +github.com/benbjohnson/clock v1.3.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= +github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= +github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= +github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= +github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= +github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= +github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= +github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= +github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= +github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= +github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= +github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= +github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= +github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= +github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= +github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= +github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= +github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= +github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= +github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= +github.com/pingcap/errors v0.11.0/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8= +github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c h1:xpW9bvK+HuuTmyFqUwr+jcCvpVkK7sumiz+ko5H9eq4= +github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c/go.mod h1:X2r9ueLEUZgtx2cIogM0v4Zj5uvvzhuuiu7Pn8HzMPg= +github.com/pingcap/kvproto v0.0.0-20230727073445-53e1f8730c30 h1:EvqKcDT7ceGLW0mXqM8Cp5Z8DfgQRnwj2YTnlCLj2QI= +github.com/pingcap/kvproto v0.0.0-20230727073445-53e1f8730c30/go.mod h1:r0q/CFcwvyeRhKtoqzmWMBebrtpIziQQ9vR+JKh1knc= +github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 h1:HR/ylkkLmGdSSDaD8IDP+SZrdhV1Kibl9KrHxJ9eciw= +github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3/go.mod h1:DWQW5jICDR7UJh4HtxXSM20Churx4CQL0fwL/SoOSA4= +github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= +github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= +github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= +github.com/prometheus/client_golang v1.11.1 h1:+4eQaD7vAZ6DsfsxB15hbE0odUjGI5ARs9yskGu1v4s= +github.com/prometheus/client_golang v1.11.1/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= +github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= +github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/prometheus/client_model v0.2.0 h1:uq5h0d+GuxiXLJLNABMgp2qUWDPiLvgCzz2dUR+/W/M= +github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= +github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo= +github.com/prometheus/common v0.26.0 h1:iMAkS2TDoNWnKM+Kopnx/8tnEStIfpYA0ur0xQzzhMQ= +github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc= +github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= +github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= +github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= +github.com/prometheus/procfs v0.6.0 h1:mxy4L2jP6qMonqmq+aTtOx1ifVWUgG/TAmntgbh3xv4= +github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= +github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= +github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= +github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ= +go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= +go.uber.org/goleak v1.1.10/go.mod h1:8a7PlsEVH3e/a/GLqe5IIrQx6GzcnRmZEufDUTk4A7A= +go.uber.org/goleak v1.1.11/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= +go.uber.org/goleak v1.1.12 h1:gZAh5/EyT/HQwlpkCy6wTpqfH9H8Lz8zbm3dZh+OyzA= +go.uber.org/goleak v1.1.12/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= +go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= +go.uber.org/multierr v1.7.0/go.mod h1:7EAYxJLBy9rStEaz58O2t4Uvip6FSURkq8/ppBp95ak= +go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= +go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +go.uber.org/zap v1.19.0/go.mod h1:xg/QME4nWcxGxrpdeYfq7UvYrLh66cuVKdrbD1XF/NI= +go.uber.org/zap v1.21.0/go.mod h1:wjWOCqI0f2ZZrJF/UufIOkiC8ii6tm1iqIsLo76RfJw= +go.uber.org/zap v1.24.0 h1:FiJd5l1UOLj0wCgbSE0rwwXHzEdAZS6hiiSnxJN/D60= +go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg= +golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= +golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= +golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= +golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200615200032-f1bc736245b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= +golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= +golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191108193012-7d206e10da11/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1 h1:KpwkzHKEF7B9Zxg18WzOa7djJ+Ha5DzthMyZYQfEn2A= +google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1/go.mod h1:nKE/iIaLqn2bQwXBg8f1g2Ylh6r5MN5CmZvuzZCgsCU= +google.golang.org/grpc v1.54.0 h1:EhTqbhiYeixwWQtAEZAxmV9MGqcjEU2mFx52xCzNyag= +google.golang.org/grpc v1.54.0/go.mod h1:PUSEXI6iWghWaB6lXM4knEgpJNu2qUcKfDtNci3EC2g= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= +google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/natefinch/lumberjack.v2 v2.0.0/go.mod h1:l0ndWWf7gzL7RNwBG7wST/UCcT4T24xpD6X8LsfU/+k= +gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= +gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/mysql v1.5.2 h1:QC2HRskSE75wBuOxe0+iCkyJZ+RqpudsQtqkp+IMuXs= +gorm.io/driver/mysql v1.5.2/go.mod h1:pQLhh1Ut/WUAySdTHwBpBv6+JKcj+ua4ZFx1QQTBzb8= +gorm.io/gorm v1.23.6/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= +gorm.io/gorm v1.25.2-0.20230530020048-26663ab9bf55/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= +gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls= +gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= +moul.io/zapgorm2 v1.3.0 h1:+CzUTMIcnafd0d/BvBce8T4uPn6DQnpIrz64cyixlkk= +moul.io/zapgorm2 v1.3.0/go.mod h1:nPVy6U9goFKHR4s+zfSo1xVFaoU7Qgd5DoCdOfzoCqs= diff --git a/tests/integrations/realtiup/mock_db.go b/tests/integrations/realtiup/mock_db.go new file mode 100644 index 00000000000..95f3af8a06c --- /dev/null +++ b/tests/integrations/realtiup/mock_db.go @@ -0,0 +1,91 @@ +// Copyright 2023 TiKV Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package realtiup + +import ( + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + mysqldriver "github.com/go-sql-driver/mysql" + "github.com/pingcap/log" + "github.com/stretchr/testify/require" + "gorm.io/driver/mysql" + "gorm.io/gorm" + "moul.io/zapgorm2" +) + +// TestDB is a test database +type TestDB struct { + inner *gorm.DB + require *require.Assertions + + isUnderlyingMocked bool + mock sqlmock.Sqlmock +} + +// OpenTestDB opens a test database +func OpenTestDB(t *testing.T, configModifier ...func(*mysqldriver.Config, *gorm.Config)) *TestDB { + r := require.New(t) + + dsn := mysqldriver.NewConfig() + dsn.Net = "tcp" + dsn.Addr = "127.0.0.1:4000" + dsn.Params = map[string]string{"time_zone": "'+00:00'"} + dsn.ParseTime = true + dsn.Loc = time.UTC + dsn.User = "root" + dsn.DBName = "test" + + config := &gorm.Config{ + Logger: zapgorm2.New(log.L()), + } + + for _, m := range configModifier { + m(dsn, config) + } + + db, err := gorm.Open(mysql.Open(dsn.FormatDSN()), config) + r.NoError(err) + + return &TestDB{ + inner: db.Debug(), + require: r, + } +} + +// MustClose closes the test database +func (db *TestDB) MustClose() { + if db.isUnderlyingMocked { + db.mock.ExpectClose() + } + + d, err := db.inner.DB() + db.require.NoError(err) + + err = d.Close() + db.require.NoError(err) +} + +// Gorm returns the underlying gorm.DB +func (db *TestDB) Gorm() *gorm.DB { + return db.inner +} + +// MustExec executes a query +func (db *TestDB) MustExec(sql string, values ...interface{}) { + err := db.inner.Exec(sql, values...).Error + db.require.NoError(err) +} diff --git a/tests/integrations/realtiup/reboot_pd_test.go b/tests/integrations/realtiup/reboot_pd_test.go new file mode 100644 index 00000000000..bccf465bde0 --- /dev/null +++ b/tests/integrations/realtiup/reboot_pd_test.go @@ -0,0 +1,71 @@ +// Copyright 2023 TiKV Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package realtiup + +import ( + "context" + "os/exec" + "testing" + + "github.com/pingcap/log" + "github.com/stretchr/testify/require" +) + +func restartTiUP() { + log.Info("start to restart TiUP") + cmd := exec.Command("make", "deploy") + err := cmd.Run() + if err != nil { + panic(err) + } + log.Info("TiUP restart success") +} + +// https://github.com/tikv/pd/issues/6467 +func TestReloadLabel(t *testing.T) { + re := require.New(t) + ctx := context.Background() + + resp, _ := pdHTTPCli.GetStores(ctx) + setStore := resp.Stores[0] + re.Empty(setStore.Store.Labels, nil) + storeLabel := map[string]string{ + "zone": "zone1", + } + err := pdHTTPCli.SetStoreLabels(ctx, setStore.Store.ID, storeLabel) + re.NoError(err) + + resp, err = pdHTTPCli.GetStores(ctx) + re.NoError(err) + for _, store := range resp.Stores { + if store.Store.ID == setStore.Store.ID { + for _, label := range store.Store.Labels { + re.Equal(label.Value, storeLabel[label.Key]) + } + } + } + + restartTiUP() + + resp, err = pdHTTPCli.GetStores(ctx) + re.NoError(err) + for _, store := range resp.Stores { + if store.Store.ID == setStore.Store.ID { + for _, label := range store.Store.Labels { + re.Equal(label.Value, storeLabel[label.Key]) + } + } + } +} diff --git a/tests/integrations/realtiup/transfer_leader_test.go b/tests/integrations/realtiup/transfer_leader_test.go new file mode 100644 index 00000000000..51142be03f9 --- /dev/null +++ b/tests/integrations/realtiup/transfer_leader_test.go @@ -0,0 +1,73 @@ +// Copyright 2023 TiKV Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package realtiup + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// https://github.com/tikv/pd/issues/6988#issuecomment-1694924611 +// https://github.com/tikv/pd/issues/6897 +func TestTransferLeader(t *testing.T) { + re := require.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + resp, err := pdHTTPCli.GetLeader(ctx) + re.NoError(err) + oldLeader := resp.Name + + var newLeader string + for i := 0; i < 2; i++ { + if resp.Name != fmt.Sprintf("pd-%d", i) { + newLeader = fmt.Sprintf("pd-%d", i) + } + } + + // record scheduler + err = pdHTTPCli.CreateScheduler(ctx, "evict-leader-scheduler", 1) + re.NoError(err) + res, err := pdHTTPCli.GetSchedulers(ctx) + re.NoError(err) + oldSchedulersLen := len(res) + + re.NoError(pdHTTPCli.TransferLeader(ctx, newLeader)) + // wait for transfer leader to new leader + time.Sleep(1 * time.Second) + resp, err = pdHTTPCli.GetLeader(ctx) + re.NoError(err) + re.Equal(newLeader, resp.Name) + + res, err = pdHTTPCli.GetSchedulers(ctx) + re.NoError(err) + re.Equal(oldSchedulersLen, len(res)) + + // transfer leader to old leader + re.NoError(pdHTTPCli.TransferLeader(ctx, oldLeader)) + // wait for transfer leader + time.Sleep(1 * time.Second) + resp, err = pdHTTPCli.GetLeader(ctx) + re.NoError(err) + re.Equal(oldLeader, resp.Name) + + res, err = pdHTTPCli.GetSchedulers(ctx) + re.NoError(err) + re.Equal(oldSchedulersLen, len(res)) +} diff --git a/tests/integrations/realtiup/ts_test.go b/tests/integrations/realtiup/ts_test.go new file mode 100644 index 00000000000..9bf8aee2d49 --- /dev/null +++ b/tests/integrations/realtiup/ts_test.go @@ -0,0 +1,45 @@ +// Copyright 2023 TiKV Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package realtiup + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestTS(t *testing.T) { + re := require.New(t) + + db := OpenTestDB(t) + db.MustExec("use test") + db.MustExec("drop table if exists t") + db.MustExec("create table t(a int, index i(a))") + db.MustExec("insert t values (1), (2), (3)") + var rows int + err := db.inner.Raw("select count(*) from t").Row().Scan(&rows) + re.NoError(err) + re.Equal(3, rows) + + re.NoError(err) + re.Equal(3, rows) + + var ts uint64 + err = db.inner.Begin().Raw("select @@tidb_current_ts").Scan(&ts).Rollback().Error + re.NoError(err) + re.NotEqual(0, GetTimeFromTS(ts)) + + db.MustClose() +} diff --git a/tests/integrations/realtiup/util.go b/tests/integrations/realtiup/util.go new file mode 100644 index 00000000000..66d6127b5c4 --- /dev/null +++ b/tests/integrations/realtiup/util.go @@ -0,0 +1,39 @@ +// Copyright 2023 TiKV Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package realtiup + +import ( + "time" + + "github.com/tikv/pd/client/http" +) + +const physicalShiftBits = 18 + +var ( + pdAddrs = []string{"127.0.0.1:2379"} + pdHTTPCli = http.NewClient(pdAddrs) +) + +// GetTimeFromTS extracts time.Time from a timestamp. +func GetTimeFromTS(ts uint64) time.Time { + ms := ExtractPhysical(ts) + return time.Unix(ms/1e3, (ms%1e3)*1e6) +} + +// ExtractPhysical returns a ts's physical part. +func ExtractPhysical(ts uint64) int64 { + return int64(ts >> physicalShiftBits) +} diff --git a/tests/integrations/realtiup/wait_tiup.sh b/tests/integrations/realtiup/wait_tiup.sh new file mode 100755 index 00000000000..497774f9e96 --- /dev/null +++ b/tests/integrations/realtiup/wait_tiup.sh @@ -0,0 +1,22 @@ +#!/bin/bash +# Wait until `tiup playground` command runs success + +TIUP_BIN_DIR=$HOME/.tiup/bin/tiup +INTERVAL=$1 +MAX_TIMES=$2 + +if ([ -z "${INTERVAL}" ] || [ -z "${MAX_TIMES}" ]); then + echo "Usage: command " + exit 1 +fi + +for ((i=0; i<${MAX_TIMES}; i++)); do + sleep ${INTERVAL} + $TIUP_BIN_DIR playground display --tag pd_test + if [ $? -eq 0 ]; then + exit 0 + fi + cat ./playground.log +done + +exit 1 \ No newline at end of file diff --git a/tests/integrations/tso/go.mod b/tests/integrations/tso/go.mod index 309ea9dbc4d..0d734716ed5 100644 --- a/tests/integrations/tso/go.mod +++ b/tests/integrations/tso/go.mod @@ -16,7 +16,7 @@ require ( github.com/pingcap/kvproto v0.0.0-20231018065736-c0689aded40c github.com/stretchr/testify v1.8.4 github.com/tikv/pd v0.0.0-00010101000000-000000000000 - github.com/tikv/pd/client v0.0.0-00010101000000-000000000000 + github.com/tikv/pd/client v0.0.0-20231101084237-a1a1eea8dafd github.com/tikv/pd/tests/integrations/mcs v0.0.0-00010101000000-000000000000 google.golang.org/grpc v1.54.0 ) diff --git a/tools/pd-api-bench/go.mod b/tools/pd-api-bench/go.mod index 8050f433e8b..77b4891d3d8 100644 --- a/tools/pd-api-bench/go.mod +++ b/tools/pd-api-bench/go.mod @@ -4,7 +4,7 @@ go 1.21 require ( github.com/tikv/pd v0.0.0-00010101000000-000000000000 - github.com/tikv/pd/client v0.0.0-00010101000000-000000000000 + github.com/tikv/pd/client v0.0.0-20231101084237-a1a1eea8dafd go.uber.org/zap v1.24.0 google.golang.org/grpc v1.54.0 ) diff --git a/tools/pd-simulator/main.go b/tools/pd-simulator/main.go index 60d8874d083..5d781757b39 100644 --- a/tools/pd-simulator/main.go +++ b/tools/pd-simulator/main.go @@ -56,7 +56,7 @@ var ( ) func main() { - // wait PD start. Otherwise it will happen error when getting cluster ID. + // wait PD start. Otherwise, it will happen error when getting cluster ID. time.Sleep(3 * time.Second) // ignore some undefined flag flag.CommandLine.ParseErrorsWhitelist.UnknownFlags = true From a16f99ee5cdaaef29acdd8b6d1a7f66be0269d45 Mon Sep 17 00:00:00 2001 From: Hu# Date: Mon, 18 Dec 2023 14:37:52 +0800 Subject: [PATCH 12/21] api: support mcs api for members (#7372) ref tikv/pd#7519 Signed-off-by: husharp --- client/http/api.go | 7 ++ client/http/client.go | 14 +++ pkg/mcs/discovery/discover.go | 42 ++++++++ pkg/mcs/tso/server/apis/v1/api.go | 23 ++++ pkg/mcs/utils/util.go | 6 +- server/apiv2/handlers/micro_service.go | 57 ++++++++++ server/apiv2/router.go | 1 + tests/integrations/mcs/members/member_test.go | 100 ++++++++++++++++++ 8 files changed, 247 insertions(+), 3 deletions(-) create mode 100644 server/apiv2/handlers/micro_service.go create mode 100644 tests/integrations/mcs/members/member_test.go diff --git a/client/http/api.go b/client/http/api.go index 2153cd286e8..4cc169c6a33 100644 --- a/client/http/api.go +++ b/client/http/api.go @@ -75,6 +75,8 @@ const ( MinResolvedTSPrefix = "/pd/api/v1/min-resolved-ts" Status = "/pd/api/v1/status" Version = "/pd/api/v1/version" + // Micro Service + microServicePrefix = "/pd/api/v2/ms" ) // RegionByID returns the path of PD HTTP API to get region by ID. @@ -186,3 +188,8 @@ func PProfProfileAPIWithInterval(interval time.Duration) string { func PProfGoroutineWithDebugLevel(level int) string { return fmt.Sprintf("%s?debug=%d", PProfGoroutine, level) } + +// MicroServiceMembers returns the path of PD HTTP API to get the members of microservice. +func MicroServiceMembers(service string) string { + return fmt.Sprintf("%s/members/%s", microServicePrefix, service) +} diff --git a/client/http/client.go b/client/http/client.go index d74c77571d6..927450b74a2 100644 --- a/client/http/client.go +++ b/client/http/client.go @@ -89,6 +89,8 @@ type Client interface { AccelerateScheduleInBatch(context.Context, []*KeyRange) error /* Other interfaces */ GetMinResolvedTSByStoresIDs(context.Context, []uint64) (uint64, map[uint64]uint64, error) + /* Micro Service interfaces */ + GetMicroServiceMembers(context.Context, string) ([]string, error) /* Client-related methods */ // WithCallerID sets and returns a new client with the given caller ID. @@ -844,3 +846,15 @@ func (c *client) GetMinResolvedTSByStoresIDs(ctx context.Context, storeIDs []uin } return resp.MinResolvedTS, resp.StoresMinResolvedTS, nil } + +// GetMicroServiceMembers gets the members of the microservice. +func (c *client) GetMicroServiceMembers(ctx context.Context, service string) ([]string, error) { + var members []string + err := c.requestWithRetry(ctx, + "GetMicroServiceMembers", MicroServiceMembers(service), + http.MethodGet, nil, &members) + if err != nil { + return nil, err + } + return members, nil +} diff --git a/pkg/mcs/discovery/discover.go b/pkg/mcs/discovery/discover.go index 00e168114b0..89c45497a87 100644 --- a/pkg/mcs/discovery/discover.go +++ b/pkg/mcs/discovery/discover.go @@ -15,8 +15,16 @@ package discovery import ( + "strconv" + + "github.com/pingcap/errors" + "github.com/pingcap/log" + "github.com/tikv/pd/pkg/errs" + "github.com/tikv/pd/pkg/mcs/utils" + "github.com/tikv/pd/pkg/storage/kv" "github.com/tikv/pd/pkg/utils/etcdutil" "go.etcd.io/etcd/clientv3" + "go.uber.org/zap" ) // Discover is used to get all the service instances of the specified service name. @@ -35,3 +43,37 @@ func Discover(cli *clientv3.Client, clusterID, serviceName string) ([]string, er } return values, nil } + +// GetMSMembers returns all the members of the specified service name. +func GetMSMembers(name string, client *clientv3.Client) ([]string, error) { + switch name { + case utils.TSOServiceName, utils.SchedulingServiceName, utils.ResourceManagerServiceName: + clusterID, err := etcdutil.GetClusterID(client, utils.ClusterIDPath) + if err != nil { + return nil, err + } + servicePath := ServicePath(strconv.FormatUint(clusterID, 10), name) + resps, err := kv.NewSlowLogTxn(client).Then(clientv3.OpGet(servicePath, clientv3.WithPrefix())).Commit() + if err != nil { + return nil, errs.ErrEtcdKVGet.Wrap(err).GenWithStackByCause() + } + if !resps.Succeeded { + return nil, errs.ErrEtcdTxnConflict.FastGenByArgs() + } + + var addrs []string + for _, resp := range resps.Responses { + for _, keyValue := range resp.GetResponseRange().GetKvs() { + var entry ServiceRegistryEntry + if err = entry.Deserialize(keyValue.Value); err != nil { + log.Error("try to deserialize service registry entry failed", zap.String("key", string(keyValue.Key)), zap.Error(err)) + continue + } + addrs = append(addrs, entry.ServiceAddr) + } + } + return addrs, nil + } + + return nil, errors.Errorf("unknown service name %s", name) +} diff --git a/pkg/mcs/tso/server/apis/v1/api.go b/pkg/mcs/tso/server/apis/v1/api.go index 33e1e0801aa..e5f0dfb5440 100644 --- a/pkg/mcs/tso/server/apis/v1/api.go +++ b/pkg/mcs/tso/server/apis/v1/api.go @@ -102,6 +102,7 @@ func NewService(srv *tsoserver.Service) *Service { } s.RegisterAdminRouter() s.RegisterKeyspaceGroupRouter() + s.RegisterHealth() return s } @@ -118,6 +119,12 @@ func (s *Service) RegisterKeyspaceGroupRouter() { router.GET("/members", GetKeyspaceGroupMembers) } +// RegisterHealth registers the router of the health handler. +func (s *Service) RegisterHealth() { + router := s.root.Group("health") + router.GET("", GetHealth) +} + func changeLogLevel(c *gin.Context) { svr := c.MustGet(multiservicesapi.ServiceContextKey).(*tsoserver.Service) var level string @@ -201,6 +208,22 @@ func ResetTS(c *gin.Context) { c.String(http.StatusOK, "Reset ts successfully.") } +// GetHealth returns the health status of the TSO service. +func GetHealth(c *gin.Context) { + svr := c.MustGet(multiservicesapi.ServiceContextKey).(*tsoserver.Service) + am, err := svr.GetKeyspaceGroupManager().GetAllocatorManager(utils.DefaultKeyspaceGroupID) + if err != nil { + c.String(http.StatusInternalServerError, err.Error()) + return + } + if am.GetMember().IsLeaderElected() { + c.IndentedJSON(http.StatusOK, "ok") + return + } + + c.String(http.StatusInternalServerError, "no leader elected") +} + // KeyspaceGroupMember contains the keyspace group and its member information. type KeyspaceGroupMember struct { Group *endpoint.KeyspaceGroup diff --git a/pkg/mcs/utils/util.go b/pkg/mcs/utils/util.go index 682e73f20ae..a0708f9bf88 100644 --- a/pkg/mcs/utils/util.go +++ b/pkg/mcs/utils/util.go @@ -45,8 +45,8 @@ import ( const ( // maxRetryTimes is the max retry times for initializing the cluster ID. maxRetryTimes = 5 - // clusterIDPath is the path to store cluster id - clusterIDPath = "/pd/cluster_id" + // ClusterIDPath is the path to store cluster id + ClusterIDPath = "/pd/cluster_id" // retryInterval is the interval to retry. retryInterval = time.Second ) @@ -56,7 +56,7 @@ func InitClusterID(ctx context.Context, client *clientv3.Client) (id uint64, err ticker := time.NewTicker(retryInterval) defer ticker.Stop() for i := 0; i < maxRetryTimes; i++ { - if clusterID, err := etcdutil.GetClusterID(client, clusterIDPath); err == nil && clusterID != 0 { + if clusterID, err := etcdutil.GetClusterID(client, ClusterIDPath); err == nil && clusterID != 0 { return clusterID, nil } select { diff --git a/server/apiv2/handlers/micro_service.go b/server/apiv2/handlers/micro_service.go new file mode 100644 index 00000000000..3c2be3748d4 --- /dev/null +++ b/server/apiv2/handlers/micro_service.go @@ -0,0 +1,57 @@ +// Copyright 2023 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package handlers + +import ( + "net/http" + + "github.com/gin-gonic/gin" + "github.com/tikv/pd/pkg/mcs/discovery" + "github.com/tikv/pd/server" + "github.com/tikv/pd/server/apiv2/middlewares" +) + +// RegisterMicroService registers microservice handler to the router. +func RegisterMicroService(r *gin.RouterGroup) { + router := r.Group("ms") + router.Use(middlewares.BootstrapChecker()) + router.GET("members/:service", GetMembers) +} + +// GetMembers gets all members of the cluster for the specified service. +// @Tags members +// @Summary Get all members of the cluster for the specified service. +// @Produce json +// @Success 200 {object} []string +// @Router /ms/members/{service} [get] +func GetMembers(c *gin.Context) { + svr := c.MustGet(middlewares.ServerContextKey).(*server.Server) + if !svr.IsAPIServiceMode() { + c.AbortWithStatusJSON(http.StatusServiceUnavailable, "not support micro service") + return + } + + if service := c.Param("service"); len(service) > 0 { + addrs, err := discovery.GetMSMembers(service, svr.GetClient()) + if err != nil { + c.AbortWithStatusJSON(http.StatusInternalServerError, err.Error()) + return + } + c.IndentedJSON(http.StatusOK, addrs) + return + } + + c.AbortWithStatusJSON(http.StatusInternalServerError, "please specify service") +} diff --git a/server/apiv2/router.go b/server/apiv2/router.go index 383d336caae..fd3ce38c0e4 100644 --- a/server/apiv2/router.go +++ b/server/apiv2/router.go @@ -64,5 +64,6 @@ func NewV2Handler(_ context.Context, svr *server.Server) (http.Handler, apiutil. root := router.Group(apiV2Prefix) handlers.RegisterKeyspace(root) handlers.RegisterTSOKeyspaceGroup(root) + handlers.RegisterMicroService(root) return router, group, nil } diff --git a/tests/integrations/mcs/members/member_test.go b/tests/integrations/mcs/members/member_test.go new file mode 100644 index 00000000000..d1ccb86a1c7 --- /dev/null +++ b/tests/integrations/mcs/members/member_test.go @@ -0,0 +1,100 @@ +// Copyright 2023 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package members_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/suite" + pdClient "github.com/tikv/pd/client/http" + bs "github.com/tikv/pd/pkg/basicserver" + "github.com/tikv/pd/pkg/mcs/utils" + "github.com/tikv/pd/pkg/utils/tempurl" + "github.com/tikv/pd/pkg/utils/testutil" + "github.com/tikv/pd/tests" +) + +type memberTestSuite struct { + suite.Suite + ctx context.Context + cleanupFunc []testutil.CleanupFunc + cluster *tests.TestCluster + server *tests.TestServer + backendEndpoints string + dialClient pdClient.Client +} + +func TestMemberTestSuite(t *testing.T) { + suite.Run(t, new(memberTestSuite)) +} + +func (suite *memberTestSuite) SetupTest() { + ctx, cancel := context.WithCancel(context.Background()) + suite.ctx = ctx + cluster, err := tests.NewTestAPICluster(suite.ctx, 1) + suite.cluster = cluster + suite.NoError(err) + suite.NoError(cluster.RunInitialServers()) + suite.NotEmpty(cluster.WaitLeader()) + suite.server = cluster.GetLeaderServer() + suite.NoError(suite.server.BootstrapCluster()) + suite.backendEndpoints = suite.server.GetAddr() + suite.dialClient = pdClient.NewClient([]string{suite.server.GetAddr()}) + + // TSO + nodes := make(map[string]bs.Server) + for i := 0; i < utils.DefaultKeyspaceGroupReplicaCount; i++ { + s, cleanup := tests.StartSingleTSOTestServer(suite.ctx, suite.Require(), suite.backendEndpoints, tempurl.Alloc()) + nodes[s.GetAddr()] = s + suite.cleanupFunc = append(suite.cleanupFunc, func() { + cleanup() + }) + } + tests.WaitForPrimaryServing(suite.Require(), nodes) + + // Scheduling + nodes = make(map[string]bs.Server) + for i := 0; i < 3; i++ { + s, cleanup := tests.StartSingleSchedulingTestServer(suite.ctx, suite.Require(), suite.backendEndpoints, tempurl.Alloc()) + nodes[s.GetAddr()] = s + suite.cleanupFunc = append(suite.cleanupFunc, func() { + cleanup() + }) + } + tests.WaitForPrimaryServing(suite.Require(), nodes) + + suite.cleanupFunc = append(suite.cleanupFunc, func() { + cancel() + }) +} + +func (suite *memberTestSuite) TearDownTest() { + for _, cleanup := range suite.cleanupFunc { + cleanup() + } + suite.cluster.Destroy() +} + +func (suite *memberTestSuite) TestMembers() { + re := suite.Require() + members, err := suite.dialClient.GetMicroServiceMembers(suite.ctx, "tso") + re.NoError(err) + re.Len(members, utils.DefaultKeyspaceGroupReplicaCount) + + members, err = suite.dialClient.GetMicroServiceMembers(suite.ctx, "scheduling") + re.NoError(err) + re.Len(members, 3) +} From 5eae459c01a797cbd0c416054c6f0cad16b8740a Mon Sep 17 00:00:00 2001 From: lhy1024 Date: Mon, 18 Dec 2023 16:06:22 +0800 Subject: [PATCH 13/21] *: add pre func in etcdutil and refactor for endpoint (#7555) ref tikv/pd#7418 Signed-off-by: lhy1024 Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- pkg/keyspace/tso_keyspace_group.go | 3 +- pkg/mcs/scheduling/server/cluster.go | 2 +- pkg/mcs/scheduling/server/config/watcher.go | 36 +++++++++---------- pkg/mcs/scheduling/server/meta/watcher.go | 7 ++-- pkg/mcs/scheduling/server/rule/watcher.go | 21 +++++------ pkg/mock/mockcluster/mockcluster.go | 2 +- pkg/schedule/checker/rule_checker.go | 8 ++--- pkg/schedule/placement/rule_manager.go | 5 ++- pkg/schedule/placement/rule_manager_test.go | 7 ++-- pkg/schedule/schedulers/shuffle_region.go | 7 ++-- .../schedulers/shuffle_region_config.go | 4 ++- pkg/statistics/region_collection_test.go | 5 +-- pkg/storage/endpoint/config.go | 8 ++--- pkg/storage/endpoint/gc_safe_point.go | 8 +---- pkg/storage/endpoint/key_path.go | 6 ++++ pkg/storage/endpoint/keyspace.go | 5 --- pkg/storage/endpoint/replication_status.go | 6 +--- pkg/storage/endpoint/rule.go | 25 ------------- pkg/storage/endpoint/safepoint_v2.go | 13 ++----- pkg/storage/endpoint/service_middleware.go | 6 +--- pkg/storage/endpoint/tso_keyspace_group.go | 7 +--- pkg/storage/endpoint/util.go | 32 ++++++++++++++++- pkg/tso/keyspace_group_manager.go | 6 ++-- pkg/utils/etcdutil/etcdutil.go | 35 ++++++++++++------ pkg/utils/etcdutil/etcdutil_test.go | 15 +++++--- server/cluster/cluster.go | 2 +- server/cluster/cluster_test.go | 10 +++--- server/keyspace_service.go | 3 +- server/server.go | 3 +- 29 files changed, 150 insertions(+), 147 deletions(-) diff --git a/pkg/keyspace/tso_keyspace_group.go b/pkg/keyspace/tso_keyspace_group.go index c8694c4a7c6..51a53f75cc2 100644 --- a/pkg/keyspace/tso_keyspace_group.go +++ b/pkg/keyspace/tso_keyspace_group.go @@ -245,9 +245,10 @@ func (m *GroupManager) initTSONodesWatcher(client *clientv3.Client, clusterID ui client, "tso-nodes-watcher", tsoServiceKey, + func([]*clientv3.Event) error { return nil }, putFn, deleteFn, - func() error { return nil }, + func([]*clientv3.Event) error { return nil }, clientv3.WithRange(tsoServiceEndKey), ) } diff --git a/pkg/mcs/scheduling/server/cluster.go b/pkg/mcs/scheduling/server/cluster.go index b24db7ac805..5dd1c9f7fce 100644 --- a/pkg/mcs/scheduling/server/cluster.go +++ b/pkg/mcs/scheduling/server/cluster.go @@ -65,7 +65,7 @@ func NewCluster(parentCtx context.Context, persistConfig *config.PersistConfig, cancel() return nil, err } - ruleManager := placement.NewRuleManager(storage, basicCluster, persistConfig) + ruleManager := placement.NewRuleManager(ctx, storage, basicCluster, persistConfig) c := &Cluster{ ctx: ctx, cancel: cancel, diff --git a/pkg/mcs/scheduling/server/config/watcher.go b/pkg/mcs/scheduling/server/config/watcher.go index 4ded93ceb1b..32028592504 100644 --- a/pkg/mcs/scheduling/server/config/watcher.go +++ b/pkg/mcs/scheduling/server/config/watcher.go @@ -139,14 +139,13 @@ func (cw *Watcher) initializeConfigWatcher() error { deleteFn := func(kv *mvccpb.KeyValue) error { return nil } - postEventFn := func() error { - return nil - } cw.configWatcher = etcdutil.NewLoopWatcher( cw.ctx, &cw.wg, cw.etcdClient, "scheduling-config-watcher", cw.configPath, - putFn, deleteFn, postEventFn, + func([]*clientv3.Event) error { return nil }, + putFn, deleteFn, + func([]*clientv3.Event) error { return nil }, ) cw.configWatcher.StartWatchLoop() return cw.configWatcher.WaitLoad() @@ -154,7 +153,7 @@ func (cw *Watcher) initializeConfigWatcher() error { func (cw *Watcher) initializeTTLConfigWatcher() error { putFn := func(kv *mvccpb.KeyValue) error { - key := string(kv.Key)[len(sc.TTLConfigPrefix)+1:] + key := strings.TrimPrefix(string(kv.Key), sc.TTLConfigPrefix+"/") value := string(kv.Value) leaseID := kv.Lease resp, err := cw.etcdClient.TimeToLive(cw.ctx, clientv3.LeaseID(leaseID)) @@ -166,18 +165,18 @@ func (cw *Watcher) initializeTTLConfigWatcher() error { return nil } deleteFn := func(kv *mvccpb.KeyValue) error { - key := string(kv.Key)[len(sc.TTLConfigPrefix)+1:] + key := strings.TrimPrefix(string(kv.Key), sc.TTLConfigPrefix+"/") cw.ttl.PutWithTTL(key, nil, 0) return nil } - postEventFn := func() error { - return nil - } cw.ttlConfigWatcher = etcdutil.NewLoopWatcher( cw.ctx, &cw.wg, cw.etcdClient, "scheduling-ttl-config-watcher", cw.ttlConfigPrefix, - putFn, deleteFn, postEventFn, clientv3.WithPrefix(), + func([]*clientv3.Event) error { return nil }, + putFn, deleteFn, + func([]*clientv3.Event) error { return nil }, + clientv3.WithPrefix(), ) cw.ttlConfigWatcher.StartWatchLoop() return cw.ttlConfigWatcher.WaitLoad() @@ -186,13 +185,14 @@ func (cw *Watcher) initializeTTLConfigWatcher() error { func (cw *Watcher) initializeSchedulerConfigWatcher() error { prefixToTrim := cw.schedulerConfigPathPrefix + "/" putFn := func(kv *mvccpb.KeyValue) error { - name := strings.TrimPrefix(string(kv.Key), prefixToTrim) + key := string(kv.Key) + name := strings.TrimPrefix(key, prefixToTrim) log.Info("update scheduler config", zap.String("name", name), zap.String("value", string(kv.Value))) err := cw.storage.SaveSchedulerConfig(name, kv.Value) if err != nil { log.Warn("failed to save scheduler config", - zap.String("event-kv-key", string(kv.Key)), + zap.String("event-kv-key", key), zap.String("trimmed-key", name), zap.Error(err)) return err @@ -204,19 +204,19 @@ func (cw *Watcher) initializeSchedulerConfigWatcher() error { return nil } deleteFn := func(kv *mvccpb.KeyValue) error { - log.Info("remove scheduler config", zap.String("key", string(kv.Key))) + key := string(kv.Key) + log.Info("remove scheduler config", zap.String("key", key)) return cw.storage.RemoveSchedulerConfig( - strings.TrimPrefix(string(kv.Key), prefixToTrim), + strings.TrimPrefix(key, prefixToTrim), ) } - postEventFn := func() error { - return nil - } cw.schedulerConfigWatcher = etcdutil.NewLoopWatcher( cw.ctx, &cw.wg, cw.etcdClient, "scheduling-scheduler-config-watcher", cw.schedulerConfigPathPrefix, - putFn, deleteFn, postEventFn, + func([]*clientv3.Event) error { return nil }, + putFn, deleteFn, + func([]*clientv3.Event) error { return nil }, clientv3.WithPrefix(), ) cw.schedulerConfigWatcher.StartWatchLoop() diff --git a/pkg/mcs/scheduling/server/meta/watcher.go b/pkg/mcs/scheduling/server/meta/watcher.go index 6fae537eab9..808e8fc565e 100644 --- a/pkg/mcs/scheduling/server/meta/watcher.go +++ b/pkg/mcs/scheduling/server/meta/watcher.go @@ -104,14 +104,13 @@ func (w *Watcher) initializeStoreWatcher() error { } return nil } - postEventFn := func() error { - return nil - } w.storeWatcher = etcdutil.NewLoopWatcher( w.ctx, &w.wg, w.etcdClient, "scheduling-store-watcher", w.storePathPrefix, - putFn, deleteFn, postEventFn, + func([]*clientv3.Event) error { return nil }, + putFn, deleteFn, + func([]*clientv3.Event) error { return nil }, clientv3.WithPrefix(), ) w.storeWatcher.StartWatchLoop() diff --git a/pkg/mcs/scheduling/server/rule/watcher.go b/pkg/mcs/scheduling/server/rule/watcher.go index 912fb9c01e5..96e19cf5002 100644 --- a/pkg/mcs/scheduling/server/rule/watcher.go +++ b/pkg/mcs/scheduling/server/rule/watcher.go @@ -131,14 +131,13 @@ func (rw *Watcher) initializeRuleWatcher() error { rw.checkerController.AddSuspectKeyRange(rule.StartKey, rule.EndKey) return rw.ruleManager.DeleteRule(rule.GroupID, rule.ID) } - postEventFn := func() error { - return nil - } rw.ruleWatcher = etcdutil.NewLoopWatcher( rw.ctx, &rw.wg, rw.etcdClient, "scheduling-rule-watcher", rw.rulesPathPrefix, - putFn, deleteFn, postEventFn, + func([]*clientv3.Event) error { return nil }, + putFn, deleteFn, + func([]*clientv3.Event) error { return nil }, clientv3.WithPrefix(), ) rw.ruleWatcher.StartWatchLoop() @@ -168,14 +167,13 @@ func (rw *Watcher) initializeGroupWatcher() error { } return rw.ruleManager.DeleteRuleGroup(trimmedKey) } - postEventFn := func() error { - return nil - } rw.groupWatcher = etcdutil.NewLoopWatcher( rw.ctx, &rw.wg, rw.etcdClient, "scheduling-rule-group-watcher", rw.ruleGroupPathPrefix, - putFn, deleteFn, postEventFn, + func([]*clientv3.Event) error { return nil }, + putFn, deleteFn, + func([]*clientv3.Event) error { return nil }, clientv3.WithPrefix(), ) rw.groupWatcher.StartWatchLoop() @@ -197,14 +195,13 @@ func (rw *Watcher) initializeRegionLabelWatcher() error { log.Info("delete region label rule", zap.String("key", key)) return rw.regionLabeler.DeleteLabelRule(strings.TrimPrefix(key, prefixToTrim)) } - postEventFn := func() error { - return nil - } rw.labelWatcher = etcdutil.NewLoopWatcher( rw.ctx, &rw.wg, rw.etcdClient, "scheduling-region-label-watcher", rw.regionLabelPathPrefix, - putFn, deleteFn, postEventFn, + func([]*clientv3.Event) error { return nil }, + putFn, deleteFn, + func([]*clientv3.Event) error { return nil }, clientv3.WithPrefix(), ) rw.labelWatcher.StartWatchLoop() diff --git a/pkg/mock/mockcluster/mockcluster.go b/pkg/mock/mockcluster/mockcluster.go index 01282b40534..6cf7ae143df 100644 --- a/pkg/mock/mockcluster/mockcluster.go +++ b/pkg/mock/mockcluster/mockcluster.go @@ -212,7 +212,7 @@ func (mc *Cluster) AllocPeer(storeID uint64) (*metapb.Peer, error) { func (mc *Cluster) initRuleManager() { if mc.RuleManager == nil { - mc.RuleManager = placement.NewRuleManager(mc.GetStorage(), mc, mc.GetSharedConfig()) + mc.RuleManager = placement.NewRuleManager(mc.ctx, mc.GetStorage(), mc, mc.GetSharedConfig()) mc.RuleManager.Initialize(int(mc.GetReplicationConfig().MaxReplicas), mc.GetReplicationConfig().LocationLabels, mc.GetReplicationConfig().IsolationLevel) } } diff --git a/pkg/schedule/checker/rule_checker.go b/pkg/schedule/checker/rule_checker.go index 553ece09e65..95cc77ade5d 100644 --- a/pkg/schedule/checker/rule_checker.go +++ b/pkg/schedule/checker/rule_checker.go @@ -199,7 +199,7 @@ func (c *RuleChecker) fixRulePeer(region *core.RegionInfo, fit *placement.Region if c.isDownPeer(region, peer) { if c.isStoreDownTimeHitMaxDownTime(peer.GetStoreId()) { ruleCheckerReplaceDownCounter.Inc() - return c.replaceUnexpectRulePeer(region, rf, fit, peer, downStatus) + return c.replaceUnexpectedRulePeer(region, rf, fit, peer, downStatus) } // When witness placement rule is enabled, promotes the witness to voter when region has down voter. if c.isWitnessEnabled() && core.IsVoter(peer) { @@ -211,7 +211,7 @@ func (c *RuleChecker) fixRulePeer(region *core.RegionInfo, fit *placement.Region } if c.isOfflinePeer(peer) { ruleCheckerReplaceOfflineCounter.Inc() - return c.replaceUnexpectRulePeer(region, rf, fit, peer, offlineStatus) + return c.replaceUnexpectedRulePeer(region, rf, fit, peer, offlineStatus) } } // fix loose matched peers. @@ -246,7 +246,7 @@ func (c *RuleChecker) addRulePeer(region *core.RegionInfo, fit *placement.Region continue } ruleCheckerNoStoreThenTryReplace.Inc() - op, err := c.replaceUnexpectRulePeer(region, oldPeerRuleFit, fit, p, "swap-fit") + op, err := c.replaceUnexpectedRulePeer(region, oldPeerRuleFit, fit, p, "swap-fit") if err != nil { return nil, err } @@ -267,7 +267,7 @@ func (c *RuleChecker) addRulePeer(region *core.RegionInfo, fit *placement.Region } // The peer's store may in Offline or Down, need to be replace. -func (c *RuleChecker) replaceUnexpectRulePeer(region *core.RegionInfo, rf *placement.RuleFit, fit *placement.RegionFit, peer *metapb.Peer, status string) (*operator.Operator, error) { +func (c *RuleChecker) replaceUnexpectedRulePeer(region *core.RegionInfo, rf *placement.RuleFit, fit *placement.RegionFit, peer *metapb.Peer, status string) (*operator.Operator, error) { var fastFailover bool // If the store to which the original peer belongs is TiFlash, the new peer cannot be set to witness, nor can it perform fast failover if c.isWitnessEnabled() && !c.cluster.GetStore(peer.StoreId).IsTiFlash() { diff --git a/pkg/schedule/placement/rule_manager.go b/pkg/schedule/placement/rule_manager.go index e25b8802b45..621c52d738e 100644 --- a/pkg/schedule/placement/rule_manager.go +++ b/pkg/schedule/placement/rule_manager.go @@ -16,6 +16,7 @@ package placement import ( "bytes" + "context" "encoding/hex" "encoding/json" "fmt" @@ -49,6 +50,7 @@ const ( // RuleManager is responsible for the lifecycle of all placement Rules. // It is thread safe. type RuleManager struct { + ctx context.Context storage endpoint.RuleStorage syncutil.RWMutex initialized bool @@ -63,8 +65,9 @@ type RuleManager struct { } // NewRuleManager creates a RuleManager instance. -func NewRuleManager(storage endpoint.RuleStorage, storeSetInformer core.StoreSetInformer, conf config.SharedConfigProvider) *RuleManager { +func NewRuleManager(ctx context.Context, storage endpoint.RuleStorage, storeSetInformer core.StoreSetInformer, conf config.SharedConfigProvider) *RuleManager { return &RuleManager{ + ctx: ctx, storage: storage, storeSetInformer: storeSetInformer, conf: conf, diff --git a/pkg/schedule/placement/rule_manager_test.go b/pkg/schedule/placement/rule_manager_test.go index 68a18b538d4..c0987f6dd33 100644 --- a/pkg/schedule/placement/rule_manager_test.go +++ b/pkg/schedule/placement/rule_manager_test.go @@ -15,6 +15,7 @@ package placement import ( + "context" "encoding/hex" "testing" @@ -32,7 +33,7 @@ func newTestManager(t *testing.T, enableWitness bool) (endpoint.RuleStorage, *Ru re := require.New(t) store := endpoint.NewStorageEndpoint(kv.NewMemoryKV(), nil) var err error - manager := NewRuleManager(store, nil, mockconfig.NewTestOptions()) + manager := NewRuleManager(context.Background(), store, nil, mockconfig.NewTestOptions()) manager.conf.SetEnableWitness(enableWitness) err = manager.Initialize(3, []string{"zone", "rack", "host"}, "") re.NoError(err) @@ -156,7 +157,7 @@ func TestSaveLoad(t *testing.T) { re.NoError(manager.SetRule(r.Clone())) } - m2 := NewRuleManager(store, nil, nil) + m2 := NewRuleManager(context.Background(), store, nil, nil) err := m2.Initialize(3, []string{"no", "labels"}, "") re.NoError(err) re.Len(m2.GetAllRules(), 3) @@ -174,7 +175,7 @@ func TestSetAfterGet(t *testing.T) { rule.Count = 1 manager.SetRule(rule) - m2 := NewRuleManager(store, nil, nil) + m2 := NewRuleManager(context.Background(), store, nil, nil) err := m2.Initialize(100, []string{}, "") re.NoError(err) rule = m2.GetRule(DefaultGroupID, DefaultRuleID) diff --git a/pkg/schedule/schedulers/shuffle_region.go b/pkg/schedule/schedulers/shuffle_region.go index f1d35e80925..f9bed18d3fa 100644 --- a/pkg/schedule/schedulers/shuffle_region.go +++ b/pkg/schedule/schedulers/shuffle_region.go @@ -139,18 +139,19 @@ func (s *shuffleRegionScheduler) scheduleRemovePeer(cluster sche.SchedulerCluste pendingFilter := filter.NewRegionPendingFilter() downFilter := filter.NewRegionDownFilter() replicaFilter := filter.NewRegionReplicatedFilter(cluster) + ranges := s.conf.GetRanges() for _, source := range candidates.Stores { var region *core.RegionInfo if s.conf.IsRoleAllow(roleFollower) { - region = filter.SelectOneRegion(cluster.RandFollowerRegions(source.GetID(), s.conf.Ranges), nil, + region = filter.SelectOneRegion(cluster.RandFollowerRegions(source.GetID(), ranges), nil, pendingFilter, downFilter, replicaFilter) } if region == nil && s.conf.IsRoleAllow(roleLeader) { - region = filter.SelectOneRegion(cluster.RandLeaderRegions(source.GetID(), s.conf.Ranges), nil, + region = filter.SelectOneRegion(cluster.RandLeaderRegions(source.GetID(), ranges), nil, pendingFilter, downFilter, replicaFilter) } if region == nil && s.conf.IsRoleAllow(roleLearner) { - region = filter.SelectOneRegion(cluster.RandLearnerRegions(source.GetID(), s.conf.Ranges), nil, + region = filter.SelectOneRegion(cluster.RandLearnerRegions(source.GetID(), ranges), nil, pendingFilter, downFilter, replicaFilter) } if region != nil { diff --git a/pkg/schedule/schedulers/shuffle_region_config.go b/pkg/schedule/schedulers/shuffle_region_config.go index 7d04879c992..552d7ea8bce 100644 --- a/pkg/schedule/schedulers/shuffle_region_config.go +++ b/pkg/schedule/schedulers/shuffle_region_config.go @@ -58,7 +58,9 @@ func (conf *shuffleRegionSchedulerConfig) GetRoles() []string { func (conf *shuffleRegionSchedulerConfig) GetRanges() []core.KeyRange { conf.RLock() defer conf.RUnlock() - return conf.Ranges + ranges := make([]core.KeyRange, len(conf.Ranges)) + copy(ranges, conf.Ranges) + return ranges } func (conf *shuffleRegionSchedulerConfig) IsRoleAllow(role string) bool { diff --git a/pkg/statistics/region_collection_test.go b/pkg/statistics/region_collection_test.go index f0df9ce6e07..cbbf7672bee 100644 --- a/pkg/statistics/region_collection_test.go +++ b/pkg/statistics/region_collection_test.go @@ -15,6 +15,7 @@ package statistics import ( + "context" "testing" "github.com/pingcap/kvproto/pkg/metapb" @@ -29,7 +30,7 @@ import ( func TestRegionStatistics(t *testing.T) { re := require.New(t) store := storage.NewStorageWithMemoryBackend() - manager := placement.NewRuleManager(store, nil, nil) + manager := placement.NewRuleManager(context.Background(), store, nil, nil) err := manager.Initialize(3, []string{"zone", "rack", "host"}, "") re.NoError(err) opt := mockconfig.NewTestOptions() @@ -118,7 +119,7 @@ func TestRegionStatistics(t *testing.T) { func TestRegionStatisticsWithPlacementRule(t *testing.T) { re := require.New(t) store := storage.NewStorageWithMemoryBackend() - manager := placement.NewRuleManager(store, nil, nil) + manager := placement.NewRuleManager(context.Background(), store, nil, nil) err := manager.Initialize(3, []string{"zone", "rack", "host"}, "") re.NoError(err) opt := mockconfig.NewTestOptions() diff --git a/pkg/storage/endpoint/config.go b/pkg/storage/endpoint/config.go index db5565a4b90..edfdcbca9a3 100644 --- a/pkg/storage/endpoint/config.go +++ b/pkg/storage/endpoint/config.go @@ -51,17 +51,13 @@ func (se *StorageEndpoint) LoadConfig(cfg interface{}) (bool, error) { // SaveConfig stores marshallable cfg to the configPath. func (se *StorageEndpoint) SaveConfig(cfg interface{}) error { - value, err := json.Marshal(cfg) - if err != nil { - return errs.ErrJSONMarshal.Wrap(err).GenWithStackByCause() - } - return se.Save(configPath, string(value)) + return se.saveJSON(configPath, cfg) } // LoadAllSchedulerConfigs loads all schedulers' config. func (se *StorageEndpoint) LoadAllSchedulerConfigs() ([]string, []string, error) { prefix := customSchedulerConfigPath + "/" - keys, values, err := se.LoadRange(prefix, clientv3.GetPrefixRangeEnd(prefix), 1000) + keys, values, err := se.LoadRange(prefix, clientv3.GetPrefixRangeEnd(prefix), MinKVRangeLimit) for i, key := range keys { keys[i] = strings.TrimPrefix(key, prefix) } diff --git a/pkg/storage/endpoint/gc_safe_point.go b/pkg/storage/endpoint/gc_safe_point.go index db5c58205c8..c2f09980651 100644 --- a/pkg/storage/endpoint/gc_safe_point.go +++ b/pkg/storage/endpoint/gc_safe_point.go @@ -169,13 +169,7 @@ func (se *StorageEndpoint) SaveServiceGCSafePoint(ssp *ServiceSafePoint) error { return errors.New("TTL of gc_worker's service safe point must be infinity") } - key := gcSafePointServicePath(ssp.ServiceID) - value, err := json.Marshal(ssp) - if err != nil { - return err - } - - return se.Save(key, string(value)) + return se.saveJSON(gcSafePointServicePath(ssp.ServiceID), ssp) } // RemoveServiceGCSafePoint removes a GC safepoint for the service diff --git a/pkg/storage/endpoint/key_path.go b/pkg/storage/endpoint/key_path.go index cac40db29c5..69b8d0f2f8e 100644 --- a/pkg/storage/endpoint/key_path.go +++ b/pkg/storage/endpoint/key_path.go @@ -31,6 +31,7 @@ const ( serviceMiddlewarePath = "service_middleware" schedulePath = "schedule" gcPath = "gc" + ruleCommonPath = "rule" rulesPath = "rules" ruleGroupPath = "rule_group" regionLabelPath = "region_label" @@ -102,6 +103,11 @@ func RulesPathPrefix(clusterID uint64) string { return path.Join(PDRootPath(clusterID), rulesPath) } +// RuleCommonPathPrefix returns the path prefix to save the placement rule common config. +func RuleCommonPathPrefix(clusterID uint64) string { + return path.Join(PDRootPath(clusterID), ruleCommonPath) +} + // RuleGroupPathPrefix returns the path prefix to save the placement rule groups. func RuleGroupPathPrefix(clusterID uint64) string { return path.Join(PDRootPath(clusterID), ruleGroupPath) diff --git a/pkg/storage/endpoint/keyspace.go b/pkg/storage/endpoint/keyspace.go index 09733ad59c1..77c81b2c8d6 100644 --- a/pkg/storage/endpoint/keyspace.go +++ b/pkg/storage/endpoint/keyspace.go @@ -97,11 +97,6 @@ func (se *StorageEndpoint) LoadKeyspaceID(txn kv.Txn, name string) (bool, uint32 return true, uint32(id64), nil } -// RunInTxn runs the given function in a transaction. -func (se *StorageEndpoint) RunInTxn(ctx context.Context, f func(txn kv.Txn) error) error { - return se.Base.RunInTxn(ctx, f) -} - // LoadRangeKeyspace loads keyspaces starting at startID. // limit specifies the limit of loaded keyspaces. func (se *StorageEndpoint) LoadRangeKeyspace(txn kv.Txn, startID uint32, limit int) ([]*keyspacepb.KeyspaceMeta, error) { diff --git a/pkg/storage/endpoint/replication_status.go b/pkg/storage/endpoint/replication_status.go index 4bac51071bc..0a14770ff47 100644 --- a/pkg/storage/endpoint/replication_status.go +++ b/pkg/storage/endpoint/replication_status.go @@ -43,9 +43,5 @@ func (se *StorageEndpoint) LoadReplicationStatus(mode string, status interface{} // SaveReplicationStatus stores replication status by mode. func (se *StorageEndpoint) SaveReplicationStatus(mode string, status interface{}) error { - value, err := json.Marshal(status) - if err != nil { - return errs.ErrJSONMarshal.Wrap(err).GenWithStackByArgs() - } - return se.Save(replicationModePath(mode), string(value)) + return se.saveJSON(replicationModePath(mode), status) } diff --git a/pkg/storage/endpoint/rule.go b/pkg/storage/endpoint/rule.go index 125c5bc31eb..80b6fc7c0ff 100644 --- a/pkg/storage/endpoint/rule.go +++ b/pkg/storage/endpoint/rule.go @@ -14,12 +14,6 @@ package endpoint -import ( - "strings" - - "go.etcd.io/etcd/clientv3" -) - // RuleStorage defines the storage operations on the rule. type RuleStorage interface { LoadRule(ruleKey string) (string, error) @@ -103,22 +97,3 @@ func (se *StorageEndpoint) LoadRule(ruleKey string) (string, error) { func (se *StorageEndpoint) LoadRules(f func(k, v string)) error { return se.loadRangeByPrefix(rulesPath+"/", f) } - -// loadRangeByPrefix iterates all key-value pairs in the storage that has the prefix. -func (se *StorageEndpoint) loadRangeByPrefix(prefix string, f func(k, v string)) error { - nextKey := prefix - endKey := clientv3.GetPrefixRangeEnd(prefix) - for { - keys, values, err := se.LoadRange(nextKey, endKey, MinKVRangeLimit) - if err != nil { - return err - } - for i := range keys { - f(strings.TrimPrefix(keys[i], prefix), values[i]) - } - if len(keys) < MinKVRangeLimit { - return nil - } - nextKey = keys[len(keys)-1] + "\x00" - } -} diff --git a/pkg/storage/endpoint/safepoint_v2.go b/pkg/storage/endpoint/safepoint_v2.go index cac2606a470..8d690d07261 100644 --- a/pkg/storage/endpoint/safepoint_v2.go +++ b/pkg/storage/endpoint/safepoint_v2.go @@ -79,12 +79,7 @@ func (se *StorageEndpoint) LoadGCSafePointV2(keyspaceID uint32) (*GCSafePointV2, // SaveGCSafePointV2 saves gc safe point for the given keyspace. func (se *StorageEndpoint) SaveGCSafePointV2(gcSafePoint *GCSafePointV2) error { - key := GCSafePointV2Path(gcSafePoint.KeyspaceID) - value, err := json.Marshal(gcSafePoint) - if err != nil { - return errs.ErrJSONMarshal.Wrap(err).GenWithStackByCause() - } - return se.Save(key, string(value)) + return se.saveJSON(GCSafePointV2Path(gcSafePoint.KeyspaceID), gcSafePoint) } // LoadAllGCSafePoints returns gc safe point for all keyspaces @@ -203,11 +198,7 @@ func (se *StorageEndpoint) SaveServiceSafePointV2(serviceSafePoint *ServiceSafeP } key := ServiceSafePointV2Path(serviceSafePoint.KeyspaceID, serviceSafePoint.ServiceID) - value, err := json.Marshal(serviceSafePoint) - if err != nil { - return errs.ErrJSONMarshal.Wrap(err).GenWithStackByCause() - } - return se.Save(key, string(value)) + return se.saveJSON(key, serviceSafePoint) } // RemoveServiceSafePointV2 removes a service safe point. diff --git a/pkg/storage/endpoint/service_middleware.go b/pkg/storage/endpoint/service_middleware.go index 62cf91c97bf..2becbf3686e 100644 --- a/pkg/storage/endpoint/service_middleware.go +++ b/pkg/storage/endpoint/service_middleware.go @@ -43,9 +43,5 @@ func (se *StorageEndpoint) LoadServiceMiddlewareConfig(cfg interface{}) (bool, e // SaveServiceMiddlewareConfig stores marshallable cfg to the serviceMiddlewarePath. func (se *StorageEndpoint) SaveServiceMiddlewareConfig(cfg interface{}) error { - value, err := json.Marshal(cfg) - if err != nil { - return errs.ErrJSONMarshal.Wrap(err).GenWithStackByCause() - } - return se.Save(serviceMiddlewarePath, string(value)) + return se.saveJSON(serviceMiddlewarePath, cfg) } diff --git a/pkg/storage/endpoint/tso_keyspace_group.go b/pkg/storage/endpoint/tso_keyspace_group.go index 498cd878887..39a08afe937 100644 --- a/pkg/storage/endpoint/tso_keyspace_group.go +++ b/pkg/storage/endpoint/tso_keyspace_group.go @@ -177,12 +177,7 @@ func (se *StorageEndpoint) LoadKeyspaceGroup(txn kv.Txn, id uint32) (*KeyspaceGr // SaveKeyspaceGroup saves the keyspace group. func (se *StorageEndpoint) SaveKeyspaceGroup(txn kv.Txn, kg *KeyspaceGroup) error { - key := KeyspaceGroupIDPath(kg.ID) - value, err := json.Marshal(kg) - if err != nil { - return err - } - return txn.Save(key, string(value)) + return saveJSONInTxn(txn, KeyspaceGroupIDPath(kg.ID), kg) } // DeleteKeyspaceGroup deletes the keyspace group. diff --git a/pkg/storage/endpoint/util.go b/pkg/storage/endpoint/util.go index 37f98a55709..3058c059628 100644 --- a/pkg/storage/endpoint/util.go +++ b/pkg/storage/endpoint/util.go @@ -16,9 +16,12 @@ package endpoint import ( "encoding/json" + "strings" "github.com/gogo/protobuf/proto" "github.com/tikv/pd/pkg/errs" + "github.com/tikv/pd/pkg/storage/kv" + "go.etcd.io/etcd/clientv3" ) func (se *StorageEndpoint) loadProto(key string, msg proto.Message) (bool, error) { @@ -42,9 +45,36 @@ func (se *StorageEndpoint) saveProto(key string, msg proto.Message) error { } func (se *StorageEndpoint) saveJSON(key string, data interface{}) error { + return saveJSONInTxn(se /* use the same interface */, key, data) +} + +func saveJSONInTxn(txn kv.Txn, key string, data interface{}) error { value, err := json.Marshal(data) if err != nil { return errs.ErrJSONMarshal.Wrap(err).GenWithStackByArgs() } - return se.Save(key, string(value)) + return txn.Save(key, string(value)) +} + +// loadRangeByPrefix iterates all key-value pairs in the storage that has the prefix. +func (se *StorageEndpoint) loadRangeByPrefix(prefix string, f func(k, v string)) error { + return loadRangeByPrefixInTxn(se /* use the same interface */, prefix, f) +} + +func loadRangeByPrefixInTxn(txn kv.Txn, prefix string, f func(k, v string)) error { + nextKey := prefix + endKey := clientv3.GetPrefixRangeEnd(prefix) + for { + keys, values, err := txn.LoadRange(nextKey, endKey, MinKVRangeLimit) + if err != nil { + return err + } + for i := range keys { + f(strings.TrimPrefix(keys[i], prefix), values[i]) + } + if len(keys) < MinKVRangeLimit { + return nil + } + nextKey = keys[len(keys)-1] + "\x00" + } } diff --git a/pkg/tso/keyspace_group_manager.go b/pkg/tso/keyspace_group_manager.go index 58534de1642..0e69986f255 100644 --- a/pkg/tso/keyspace_group_manager.go +++ b/pkg/tso/keyspace_group_manager.go @@ -514,9 +514,10 @@ func (kgm *KeyspaceGroupManager) InitializeTSOServerWatchLoop() error { kgm.etcdClient, "tso-nodes-watcher", kgm.tsoServiceKey, + func([]*clientv3.Event) error { return nil }, putFn, deleteFn, - func() error { return nil }, + func([]*clientv3.Event) error { return nil }, clientv3.WithRange(tsoServiceEndKey), ) kgm.tsoNodesWatcher.StartWatchLoop() @@ -558,7 +559,7 @@ func (kgm *KeyspaceGroupManager) InitializeGroupWatchLoop() error { kgm.deleteKeyspaceGroup(groupID) return nil } - postEventFn := func() error { + postEventFn := func([]*clientv3.Event) error { // Retry the groups that are not initialized successfully before. for id, group := range kgm.groupUpdateRetryList { delete(kgm.groupUpdateRetryList, id) @@ -572,6 +573,7 @@ func (kgm *KeyspaceGroupManager) InitializeGroupWatchLoop() error { kgm.etcdClient, "keyspace-watcher", startKey, + func([]*clientv3.Event) error { return nil }, putFn, deleteFn, postEventFn, diff --git a/pkg/utils/etcdutil/etcdutil.go b/pkg/utils/etcdutil/etcdutil.go index 03c2374efc6..0e1b2731474 100644 --- a/pkg/utils/etcdutil/etcdutil.go +++ b/pkg/utils/etcdutil/etcdutil.go @@ -587,8 +587,10 @@ type LoopWatcher struct { putFn func(*mvccpb.KeyValue) error // deleteFn is used to handle the delete event. deleteFn func(*mvccpb.KeyValue) error - // postEventFn is used to call after handling all events. - postEventFn func() error + // postEventsFn is used to call after handling all events. + postEventsFn func([]*clientv3.Event) error + // preEventsFn is used to call before handling all events. + preEventsFn func([]*clientv3.Event) error // forceLoadMu is used to ensure two force loads have minimal interval. forceLoadMu syncutil.RWMutex @@ -613,7 +615,9 @@ func NewLoopWatcher( ctx context.Context, wg *sync.WaitGroup, client *clientv3.Client, name, key string, - putFn, deleteFn func(*mvccpb.KeyValue) error, postEventFn func() error, + preEventsFn func([]*clientv3.Event) error, + putFn, deleteFn func(*mvccpb.KeyValue) error, + postEventsFn func([]*clientv3.Event) error, opts ...clientv3.OpOption, ) *LoopWatcher { return &LoopWatcher{ @@ -627,7 +631,8 @@ func NewLoopWatcher( updateClientCh: make(chan *clientv3.Client, 1), putFn: putFn, deleteFn: deleteFn, - postEventFn: postEventFn, + postEventsFn: postEventsFn, + preEventsFn: preEventsFn, opts: opts, lastTimeForceLoad: time.Now(), loadTimeout: defaultLoadDataFromEtcdTimeout, @@ -813,28 +818,34 @@ func (lw *LoopWatcher) watch(ctx context.Context, revision int64) (nextRevision zap.Int64("revision", revision), zap.String("name", lw.name), zap.String("key", lw.key)) goto watchChanLoop } + if err := lw.preEventsFn(wresp.Events); err != nil { + log.Error("run pre event failed in watch loop", zap.Error(err), + zap.Int64("revision", revision), zap.String("name", lw.name), zap.String("key", lw.key)) + } for _, event := range wresp.Events { switch event.Type { case clientv3.EventTypePut: if err := lw.putFn(event.Kv); err != nil { log.Error("put failed in watch loop", zap.Error(err), - zap.Int64("revision", revision), zap.String("name", lw.name), zap.String("key", lw.key)) + zap.Int64("revision", revision), zap.String("name", lw.name), + zap.String("watch-key", lw.key), zap.ByteString("event-kv-key", event.Kv.Key)) } else { - log.Debug("put in watch loop", zap.String("name", lw.name), + log.Debug("put successfully in watch loop", zap.String("name", lw.name), zap.ByteString("key", event.Kv.Key), zap.ByteString("value", event.Kv.Value)) } case clientv3.EventTypeDelete: if err := lw.deleteFn(event.Kv); err != nil { log.Error("delete failed in watch loop", zap.Error(err), - zap.Int64("revision", revision), zap.String("name", lw.name), zap.String("key", lw.key)) + zap.Int64("revision", revision), zap.String("name", lw.name), + zap.String("watch-key", lw.key), zap.ByteString("event-kv-key", event.Kv.Key)) } else { - log.Debug("delete in watch loop", zap.String("name", lw.name), + log.Debug("delete successfully in watch loop", zap.String("name", lw.name), zap.ByteString("key", event.Kv.Key)) } } } - if err := lw.postEventFn(); err != nil { + if err := lw.postEventsFn(wresp.Events); err != nil { log.Error("run post event failed in watch loop", zap.Error(err), zap.Int64("revision", revision), zap.String("name", lw.name), zap.String("key", lw.key)) } @@ -864,6 +875,10 @@ func (lw *LoopWatcher) load(ctx context.Context) (nextRevision int64, err error) zap.String("key", lw.key), zap.Error(err)) return 0, err } + if err := lw.preEventsFn([]*clientv3.Event{}); err != nil { + log.Error("run pre event failed in watch loop", zap.String("name", lw.name), + zap.String("key", lw.key), zap.Error(err)) + } for i, item := range resp.Kvs { if resp.More && i == len(resp.Kvs)-1 { // The last key is the start key of the next batch. @@ -878,7 +893,7 @@ func (lw *LoopWatcher) load(ctx context.Context) (nextRevision int64, err error) } // Note: if there are no keys in etcd, the resp.More is false. It also means the load is finished. if !resp.More { - if err := lw.postEventFn(); err != nil { + if err := lw.postEventsFn([]*clientv3.Event{}); err != nil { log.Error("run post event failed in watch loop", zap.String("name", lw.name), zap.String("key", lw.key), zap.Error(err)) } diff --git a/pkg/utils/etcdutil/etcdutil_test.go b/pkg/utils/etcdutil/etcdutil_test.go index f7fadd3bbf6..861a57cef13 100644 --- a/pkg/utils/etcdutil/etcdutil_test.go +++ b/pkg/utils/etcdutil/etcdutil_test.go @@ -410,6 +410,7 @@ func (suite *loopWatcherTestSuite) TestLoadWithoutKey() { suite.client, "test", "TestLoadWithoutKey", + func([]*clientv3.Event) error { return nil }, func(kv *mvccpb.KeyValue) error { cache.Lock() defer cache.Unlock() @@ -417,7 +418,7 @@ func (suite *loopWatcherTestSuite) TestLoadWithoutKey() { return nil }, func(kv *mvccpb.KeyValue) error { return nil }, - func() error { return nil }, + func([]*clientv3.Event) error { return nil }, ) watcher.StartWatchLoop() err := watcher.WaitLoad() @@ -441,6 +442,7 @@ func (suite *loopWatcherTestSuite) TestCallBack() { suite.client, "test", "TestCallBack", + func([]*clientv3.Event) error { return nil }, func(kv *mvccpb.KeyValue) error { result = append(result, string(kv.Key)) return nil @@ -451,7 +453,7 @@ func (suite *loopWatcherTestSuite) TestCallBack() { delete(cache.data, string(kv.Key)) return nil }, - func() error { + func([]*clientv3.Event) error { cache.Lock() defer cache.Unlock() for _, r := range result { @@ -506,6 +508,7 @@ func (suite *loopWatcherTestSuite) TestWatcherLoadLimit() { suite.client, "test", "TestWatcherLoadLimit", + func([]*clientv3.Event) error { return nil }, func(kv *mvccpb.KeyValue) error { cache.Lock() defer cache.Unlock() @@ -515,7 +518,7 @@ func (suite *loopWatcherTestSuite) TestWatcherLoadLimit() { func(kv *mvccpb.KeyValue) error { return nil }, - func() error { + func([]*clientv3.Event) error { return nil }, clientv3.WithPrefix(), @@ -550,6 +553,7 @@ func (suite *loopWatcherTestSuite) TestWatcherBreak() { suite.client, "test", "TestWatcherBreak", + func([]*clientv3.Event) error { return nil }, func(kv *mvccpb.KeyValue) error { if string(kv.Key) == "TestWatcherBreak" { cache.Lock() @@ -559,7 +563,7 @@ func (suite *loopWatcherTestSuite) TestWatcherBreak() { return nil }, func(kv *mvccpb.KeyValue) error { return nil }, - func() error { return nil }, + func([]*clientv3.Event) error { return nil }, ) watcher.watchChangeRetryInterval = 100 * time.Millisecond watcher.StartWatchLoop() @@ -633,9 +637,10 @@ func (suite *loopWatcherTestSuite) TestWatcherRequestProgress() { suite.client, "test", "TestWatcherChanBlock", + func([]*clientv3.Event) error { return nil }, func(kv *mvccpb.KeyValue) error { return nil }, func(kv *mvccpb.KeyValue) error { return nil }, - func() error { return nil }, + func([]*clientv3.Event) error { return nil }, ) suite.wg.Add(1) diff --git a/server/cluster/cluster.go b/server/cluster/cluster.go index 78f6ddd4364..ecbd40e2582 100644 --- a/server/cluster/cluster.go +++ b/server/cluster/cluster.go @@ -269,7 +269,7 @@ func (c *RaftCluster) InitCluster( c.unsafeRecoveryController = unsaferecovery.NewController(c) c.keyspaceGroupManager = keyspaceGroupManager c.hbstreams = hbstreams - c.ruleManager = placement.NewRuleManager(c.storage, c, c.GetOpts()) + c.ruleManager = placement.NewRuleManager(c.ctx, c.storage, c, c.GetOpts()) if c.opt.IsPlacementRulesEnabled() { err := c.ruleManager.Initialize(c.opt.GetMaxReplicas(), c.opt.GetLocationLabels(), c.opt.GetIsolationLevel()) if err != nil { diff --git a/server/cluster/cluster_test.go b/server/cluster/cluster_test.go index 85edf911779..7094fd6b673 100644 --- a/server/cluster/cluster_test.go +++ b/server/cluster/cluster_test.go @@ -241,7 +241,7 @@ func TestSetOfflineStore(t *testing.T) { re.NoError(err) cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend()) cluster.coordinator = schedule.NewCoordinator(ctx, cluster, nil) - cluster.ruleManager = placement.NewRuleManager(storage.NewStorageWithMemoryBackend(), cluster, cluster.GetOpts()) + cluster.ruleManager = placement.NewRuleManager(ctx, storage.NewStorageWithMemoryBackend(), cluster, cluster.GetOpts()) if opt.IsPlacementRulesEnabled() { err := cluster.ruleManager.Initialize(opt.GetMaxReplicas(), opt.GetLocationLabels(), opt.GetIsolationLevel()) if err != nil { @@ -438,7 +438,7 @@ func TestUpStore(t *testing.T) { re.NoError(err) cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend()) cluster.coordinator = schedule.NewCoordinator(ctx, cluster, nil) - cluster.ruleManager = placement.NewRuleManager(storage.NewStorageWithMemoryBackend(), cluster, cluster.GetOpts()) + cluster.ruleManager = placement.NewRuleManager(ctx, storage.NewStorageWithMemoryBackend(), cluster, cluster.GetOpts()) if opt.IsPlacementRulesEnabled() { err := cluster.ruleManager.Initialize(opt.GetMaxReplicas(), opt.GetLocationLabels(), opt.GetIsolationLevel()) if err != nil { @@ -541,7 +541,7 @@ func TestDeleteStoreUpdatesClusterVersion(t *testing.T) { re.NoError(err) cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend()) cluster.coordinator = schedule.NewCoordinator(ctx, cluster, nil) - cluster.ruleManager = placement.NewRuleManager(storage.NewStorageWithMemoryBackend(), cluster, cluster.GetOpts()) + cluster.ruleManager = placement.NewRuleManager(ctx, storage.NewStorageWithMemoryBackend(), cluster, cluster.GetOpts()) if opt.IsPlacementRulesEnabled() { err := cluster.ruleManager.Initialize(opt.GetMaxReplicas(), opt.GetLocationLabels(), opt.GetIsolationLevel()) if err != nil { @@ -1268,7 +1268,7 @@ func TestOfflineAndMerge(t *testing.T) { re.NoError(err) cluster := newTestRaftCluster(ctx, mockid.NewIDAllocator(), opt, storage.NewStorageWithMemoryBackend()) cluster.coordinator = schedule.NewCoordinator(ctx, cluster, nil) - cluster.ruleManager = placement.NewRuleManager(storage.NewStorageWithMemoryBackend(), cluster, cluster.GetOpts()) + cluster.ruleManager = placement.NewRuleManager(ctx, storage.NewStorageWithMemoryBackend(), cluster, cluster.GetOpts()) if opt.IsPlacementRulesEnabled() { err := cluster.ruleManager.Initialize(opt.GetMaxReplicas(), opt.GetLocationLabels(), opt.GetIsolationLevel()) if err != nil { @@ -2130,7 +2130,7 @@ func newTestRaftCluster( ) *RaftCluster { rc := &RaftCluster{serverCtx: ctx, core: core.NewBasicCluster(), storage: s} rc.InitCluster(id, opt, nil, nil) - rc.ruleManager = placement.NewRuleManager(storage.NewStorageWithMemoryBackend(), rc, opt) + rc.ruleManager = placement.NewRuleManager(ctx, storage.NewStorageWithMemoryBackend(), rc, opt) if opt.IsPlacementRulesEnabled() { err := rc.ruleManager.Initialize(opt.GetMaxReplicas(), opt.GetLocationLabels(), opt.GetIsolationLevel()) if err != nil { diff --git a/server/keyspace_service.go b/server/keyspace_service.go index b17239ba0a4..1718108d73b 100644 --- a/server/keyspace_service.go +++ b/server/keyspace_service.go @@ -89,7 +89,7 @@ func (s *KeyspaceServer) WatchKeyspaces(request *keyspacepb.WatchKeyspacesReques deleteFn := func(kv *mvccpb.KeyValue) error { return nil } - postEventFn := func() error { + postEventFn := func([]*clientv3.Event) error { defer func() { keyspaces = keyspaces[:0] }() @@ -109,6 +109,7 @@ func (s *KeyspaceServer) WatchKeyspaces(request *keyspacepb.WatchKeyspacesReques s.client, "keyspace-server-watcher", startKey, + func([]*clientv3.Event) error { return nil }, putFn, deleteFn, postEventFn, diff --git a/server/server.go b/server/server.go index 187c30dbf7a..fcf71922a09 100644 --- a/server/server.go +++ b/server/server.go @@ -2017,9 +2017,10 @@ func (s *Server) initServicePrimaryWatcher(serviceName string, primaryKey string s.client, name, primaryKey, + func([]*clientv3.Event) error { return nil }, putFn, deleteFn, - func() error { return nil }, + func([]*clientv3.Event) error { return nil }, ) } From 40c252346cffa7a3d56a859856adddbec4a0a72c Mon Sep 17 00:00:00 2001 From: JmPotato Date: Mon, 18 Dec 2023 17:22:53 +0800 Subject: [PATCH 14/21] client/http, tests: implement retry mechanism based on the leader and follower (#7554) ref tikv/pd#7300 - Implement retry mechanism based on the leader and follower. - Move method definitions into a separate file. - Use a sturct `requestInfo` to gather the parameters. Signed-off-by: JmPotato Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- client/http/client.go | 910 +++++------------- client/http/client_test.go | 33 +- client/http/interface.go | 695 +++++++++++++ client/http/request_info.go | 123 +++ tests/integrations/client/http_client_test.go | 17 +- 5 files changed, 1092 insertions(+), 686 deletions(-) create mode 100644 client/http/interface.go create mode 100644 client/http/request_info.go diff --git a/client/http/client.go b/client/http/client.go index 927450b74a2..21a3727e00f 100644 --- a/client/http/client.go +++ b/client/http/client.go @@ -23,265 +23,172 @@ import ( "io" "net/http" "strings" + "sync" "time" "github.com/pingcap/errors" - "github.com/pingcap/kvproto/pkg/pdpb" "github.com/pingcap/log" "github.com/prometheus/client_golang/prometheus" "go.uber.org/zap" ) const ( - defaultCallerID = "pd-http-client" - httpScheme = "http" - httpsScheme = "https" - networkErrorStatus = "network error" - - defaultTimeout = 30 * time.Second + // defaultCallerID marks the default caller ID of the PD HTTP client. + defaultCallerID = "pd-http-client" + // defaultInnerCallerID marks the default caller ID of the inner PD HTTP client. + // It's used to distinguish the requests sent by the inner client via some internal logic. + defaultInnerCallerID = "pd-http-client-inner" + httpScheme = "http" + httpsScheme = "https" + networkErrorStatus = "network error" + + defaultMembersInfoUpdateInterval = time.Minute + defaultTimeout = 30 * time.Second ) -// Client is a PD (Placement Driver) HTTP client. -type Client interface { - /* Meta-related interfaces */ - GetRegionByID(context.Context, uint64) (*RegionInfo, error) - GetRegionByKey(context.Context, []byte) (*RegionInfo, error) - GetRegions(context.Context) (*RegionsInfo, error) - GetRegionsByKeyRange(context.Context, *KeyRange, int) (*RegionsInfo, error) - GetRegionsByStoreID(context.Context, uint64) (*RegionsInfo, error) - GetRegionsReplicatedStateByKeyRange(context.Context, *KeyRange) (string, error) - GetHotReadRegions(context.Context) (*StoreHotPeersInfos, error) - GetHotWriteRegions(context.Context) (*StoreHotPeersInfos, error) - GetHistoryHotRegions(context.Context, *HistoryHotRegionsRequest) (*HistoryHotRegions, error) - GetRegionStatusByKeyRange(context.Context, *KeyRange, bool) (*RegionStats, error) - GetStores(context.Context) (*StoresInfo, error) - GetStore(context.Context, uint64) (*StoreInfo, error) - SetStoreLabels(context.Context, int64, map[string]string) error - GetMembers(context.Context) (*MembersInfo, error) - GetLeader(context.Context) (*pdpb.Member, error) - TransferLeader(context.Context, string) error - /* Config-related interfaces */ - GetScheduleConfig(context.Context) (map[string]interface{}, error) - SetScheduleConfig(context.Context, map[string]interface{}) error - GetClusterVersion(context.Context) (string, error) - /* Scheduler-related interfaces */ - GetSchedulers(context.Context) ([]string, error) - CreateScheduler(ctx context.Context, name string, storeID uint64) error - SetSchedulerDelay(context.Context, string, int64) error - /* Rule-related interfaces */ - GetAllPlacementRuleBundles(context.Context) ([]*GroupBundle, error) - GetPlacementRuleBundleByGroup(context.Context, string) (*GroupBundle, error) - GetPlacementRulesByGroup(context.Context, string) ([]*Rule, error) - SetPlacementRule(context.Context, *Rule) error - SetPlacementRuleInBatch(context.Context, []*RuleOp) error - SetPlacementRuleBundles(context.Context, []*GroupBundle, bool) error - DeletePlacementRule(context.Context, string, string) error - GetAllPlacementRuleGroups(context.Context) ([]*RuleGroup, error) - GetPlacementRuleGroupByID(context.Context, string) (*RuleGroup, error) - SetPlacementRuleGroup(context.Context, *RuleGroup) error - DeletePlacementRuleGroupByID(context.Context, string) error - GetAllRegionLabelRules(context.Context) ([]*LabelRule, error) - GetRegionLabelRulesByIDs(context.Context, []string) ([]*LabelRule, error) - SetRegionLabelRule(context.Context, *LabelRule) error - PatchRegionLabelRules(context.Context, *LabelRulePatch) error - /* Scheduling-related interfaces */ - AccelerateSchedule(context.Context, *KeyRange) error - AccelerateScheduleInBatch(context.Context, []*KeyRange) error - /* Other interfaces */ - GetMinResolvedTSByStoresIDs(context.Context, []uint64) (uint64, map[uint64]uint64, error) - /* Micro Service interfaces */ - GetMicroServiceMembers(context.Context, string) ([]string, error) - - /* Client-related methods */ - // WithCallerID sets and returns a new client with the given caller ID. - WithCallerID(string) Client - // WithRespHandler sets and returns a new client with the given HTTP response handler. - // This allows the caller to customize how the response is handled, including error handling logic. - // Additionally, it is important for the caller to handle the content of the response body properly - // in order to ensure that it can be read and marshaled correctly into `res`. - WithRespHandler(func(resp *http.Response, res interface{}) error) Client - Close() -} - -var _ Client = (*client)(nil) +// respHandleFunc is the function to handle the HTTP response. +type respHandleFunc func(resp *http.Response, res interface{}) error -// clientInner is the inner implementation of the PD HTTP client, which will -// implement some internal logics, such as HTTP client, service discovery, etc. +// clientInner is the inner implementation of the PD HTTP client, which contains some fundamental fields. +// It is wrapped by the `client` struct to make sure the inner implementation won't be exposed and could +// be consistent during the copy. type clientInner struct { - pdAddrs []string - tlsConf *tls.Config - cli *http.Client -} + ctx context.Context + cancel context.CancelFunc -type client struct { - // Wrap this struct is to make sure the inner implementation - // won't be exposed and cloud be consistent during the copy. - inner *clientInner + sync.RWMutex + pdAddrs []string + leaderAddrIdx int - callerID string - respHandler func(resp *http.Response, res interface{}) error + tlsConf *tls.Config + cli *http.Client requestCounter *prometheus.CounterVec executionDuration *prometheus.HistogramVec } -// ClientOption configures the HTTP client. -type ClientOption func(c *client) - -// WithHTTPClient configures the client with the given initialized HTTP client. -func WithHTTPClient(cli *http.Client) ClientOption { - return func(c *client) { - c.inner.cli = cli - } -} - -// WithTLSConfig configures the client with the given TLS config. -// This option won't work if the client is configured with WithHTTPClient. -func WithTLSConfig(tlsConf *tls.Config) ClientOption { - return func(c *client) { - c.inner.tlsConf = tlsConf - } -} - -// WithMetrics configures the client with metrics. -func WithMetrics( - requestCounter *prometheus.CounterVec, - executionDuration *prometheus.HistogramVec, -) ClientOption { - return func(c *client) { - c.requestCounter = requestCounter - c.executionDuration = executionDuration - } +func newClientInner() *clientInner { + ctx, cancel := context.WithCancel(context.Background()) + return &clientInner{ctx: ctx, cancel: cancel, leaderAddrIdx: -1} } -// NewClient creates a PD HTTP client with the given PD addresses and TLS config. -func NewClient( - pdAddrs []string, - opts ...ClientOption, -) Client { - c := &client{inner: &clientInner{}, callerID: defaultCallerID} - // Apply the options first. - for _, opt := range opts { - opt(c) - } - // Normalize the addresses with correct scheme prefix. - for i, addr := range pdAddrs { - if !strings.HasPrefix(addr, httpScheme) { - var scheme string - if c.inner.tlsConf != nil { - scheme = httpsScheme - } else { - scheme = httpScheme - } - pdAddrs[i] = fmt.Sprintf("%s://%s", scheme, addr) - } - } - c.inner.pdAddrs = pdAddrs +func (ci *clientInner) init() { // Init the HTTP client if it's not configured. - if c.inner.cli == nil { - c.inner.cli = &http.Client{Timeout: defaultTimeout} - if c.inner.tlsConf != nil { + if ci.cli == nil { + ci.cli = &http.Client{Timeout: defaultTimeout} + if ci.tlsConf != nil { transport := http.DefaultTransport.(*http.Transport).Clone() - transport.TLSClientConfig = c.inner.tlsConf - c.inner.cli.Transport = transport + transport.TLSClientConfig = ci.tlsConf + ci.cli.Transport = transport } } - - return c + // Start the members info updater daemon. + go ci.membersInfoUpdater(ci.ctx) } -// Close closes the HTTP client. -func (c *client) Close() { - if c.inner == nil { - return - } - if c.inner.cli != nil { - c.inner.cli.CloseIdleConnections() +func (ci *clientInner) close() { + ci.cancel() + if ci.cli != nil { + ci.cli.CloseIdleConnections() } - log.Info("[pd] http client closed") } -// WithCallerID sets and returns a new client with the given caller ID. -func (c *client) WithCallerID(callerID string) Client { - newClient := *c - newClient.callerID = callerID - return &newClient -} - -// WithRespHandler sets and returns a new client with the given HTTP response handler. -func (c *client) WithRespHandler( - handler func(resp *http.Response, res interface{}) error, -) Client { - newClient := *c - newClient.respHandler = handler - return &newClient +// getPDAddrs returns the current PD addresses and the index of the leader address. +func (ci *clientInner) getPDAddrs() ([]string, int) { + ci.RLock() + defer ci.RUnlock() + return ci.pdAddrs, ci.leaderAddrIdx } -func (c *client) reqCounter(name, status string) { - if c.requestCounter == nil { - return +func (ci *clientInner) setPDAddrs(pdAddrs []string, leaderAddrIdx int) { + ci.Lock() + defer ci.Unlock() + // Normalize the addresses with correct scheme prefix. + var scheme string + if ci.tlsConf == nil { + scheme = httpScheme + } else { + scheme = httpsScheme + } + for i, addr := range pdAddrs { + if strings.HasPrefix(addr, httpScheme) { + continue + } + pdAddrs[i] = fmt.Sprintf("%s://%s", scheme, addr) } - c.requestCounter.WithLabelValues(name, status).Inc() + ci.pdAddrs = pdAddrs + ci.leaderAddrIdx = leaderAddrIdx } -func (c *client) execDuration(name string, duration time.Duration) { - if c.executionDuration == nil { +func (ci *clientInner) reqCounter(name, status string) { + if ci.requestCounter == nil { return } - c.executionDuration.WithLabelValues(name).Observe(duration.Seconds()) + ci.requestCounter.WithLabelValues(name, status).Inc() } -// Header key definition constants. -const ( - pdAllowFollowerHandleKey = "PD-Allow-Follower-Handle" - xCallerIDKey = "X-Caller-ID" -) - -// HeaderOption configures the HTTP header. -type HeaderOption func(header http.Header) - -// WithAllowFollowerHandle sets the header field to allow a PD follower to handle this request. -func WithAllowFollowerHandle() HeaderOption { - return func(header http.Header) { - header.Set(pdAllowFollowerHandleKey, "true") +func (ci *clientInner) execDuration(name string, duration time.Duration) { + if ci.executionDuration == nil { + return } + ci.executionDuration.WithLabelValues(name).Observe(duration.Seconds()) } -// At present, we will use the retry strategy of polling by default to keep -// it consistent with the current implementation of some clients (e.g. TiDB). -func (c *client) requestWithRetry( +// requestWithRetry will first try to send the request to the PD leader, if it fails, it will try to send +// the request to the other PD followers to gain a better availability. +// TODO: support custom retry logic, e.g. retry with customizable backoffer. +func (ci *clientInner) requestWithRetry( ctx context.Context, - name, uri, method string, - body []byte, res interface{}, + reqInfo *requestInfo, headerOpts ...HeaderOption, ) error { var ( - err error - addr string + err error + addr string + pdAddrs, leaderAddrIdx = ci.getPDAddrs() ) - for idx := 0; idx < len(c.inner.pdAddrs); idx++ { - addr = c.inner.pdAddrs[idx] - err = c.request(ctx, name, fmt.Sprintf("%s%s", addr, uri), method, body, res, headerOpts...) + // Try to send the request to the PD leader first. + if leaderAddrIdx != -1 { + addr = pdAddrs[leaderAddrIdx] + err = ci.doRequest(ctx, addr, reqInfo, headerOpts...) + if err == nil { + return nil + } + log.Debug("[pd] request leader addr failed", + zap.Int("leader-idx", leaderAddrIdx), zap.String("addr", addr), zap.Error(err)) + } + // Try to send the request to the other PD followers. + for idx := 0; idx < len(pdAddrs) && idx != leaderAddrIdx; idx++ { + addr = ci.pdAddrs[idx] + err = ci.doRequest(ctx, addr, reqInfo, headerOpts...) if err == nil { break } - log.Debug("[pd] request one addr failed", + log.Debug("[pd] request follower addr failed", zap.Int("idx", idx), zap.String("addr", addr), zap.Error(err)) } return err } -func (c *client) request( +func (ci *clientInner) doRequest( ctx context.Context, - name, url, method string, - body []byte, res interface{}, + addr string, reqInfo *requestInfo, headerOpts ...HeaderOption, ) error { + var ( + callerID = reqInfo.callerID + name = reqInfo.name + url = reqInfo.getURL(addr) + method = reqInfo.method + body = reqInfo.body + res = reqInfo.res + respHandler = reqInfo.respHandler + ) logFields := []zap.Field{ zap.String("name", name), zap.String("url", url), zap.String("method", method), - zap.String("caller-id", c.callerID), + zap.String("caller-id", callerID), } log.Debug("[pd] request the http url", logFields...) req, err := http.NewRequestWithContext(ctx, method, url, bytes.NewBuffer(body)) @@ -292,21 +199,21 @@ func (c *client) request( for _, opt := range headerOpts { opt(req.Header) } - req.Header.Set(xCallerIDKey, c.callerID) + req.Header.Set(xCallerIDKey, callerID) start := time.Now() - resp, err := c.inner.cli.Do(req) + resp, err := ci.cli.Do(req) if err != nil { - c.reqCounter(name, networkErrorStatus) + ci.reqCounter(name, networkErrorStatus) log.Error("[pd] do http request failed", append(logFields, zap.Error(err))...) return errors.Trace(err) } - c.execDuration(name, time.Since(start)) - c.reqCounter(name, resp.Status) + ci.execDuration(name, time.Since(start)) + ci.reqCounter(name, resp.Status) // Give away the response handling to the caller if the handler is set. - if c.respHandler != nil { - return c.respHandler(resp, res) + if respHandler != nil { + return respHandler(resp, res) } defer func() { @@ -341,520 +248,171 @@ func (c *client) request( return nil } -// GetRegionByID gets the region info by ID. -func (c *client) GetRegionByID(ctx context.Context, regionID uint64) (*RegionInfo, error) { - var region RegionInfo - err := c.requestWithRetry(ctx, - "GetRegionByID", RegionByID(regionID), - http.MethodGet, nil, ®ion) - if err != nil { - return nil, err - } - return ®ion, nil -} - -// GetRegionByKey gets the region info by key. -func (c *client) GetRegionByKey(ctx context.Context, key []byte) (*RegionInfo, error) { - var region RegionInfo - err := c.requestWithRetry(ctx, - "GetRegionByKey", RegionByKey(key), - http.MethodGet, nil, ®ion) - if err != nil { - return nil, err - } - return ®ion, nil -} - -// GetRegions gets the regions info. -func (c *client) GetRegions(ctx context.Context) (*RegionsInfo, error) { - var regions RegionsInfo - err := c.requestWithRetry(ctx, - "GetRegions", Regions, - http.MethodGet, nil, ®ions) - if err != nil { - return nil, err - } - return ®ions, nil -} - -// GetRegionsByKeyRange gets the regions info by key range. If the limit is -1, it will return all regions within the range. -// The keys in the key range should be encoded in the UTF-8 bytes format. -func (c *client) GetRegionsByKeyRange(ctx context.Context, keyRange *KeyRange, limit int) (*RegionsInfo, error) { - var regions RegionsInfo - err := c.requestWithRetry(ctx, - "GetRegionsByKeyRange", RegionsByKeyRange(keyRange, limit), - http.MethodGet, nil, ®ions) - if err != nil { - return nil, err - } - return ®ions, nil -} - -// GetRegionsByStoreID gets the regions info by store ID. -func (c *client) GetRegionsByStoreID(ctx context.Context, storeID uint64) (*RegionsInfo, error) { - var regions RegionsInfo - err := c.requestWithRetry(ctx, - "GetRegionsByStoreID", RegionsByStoreID(storeID), - http.MethodGet, nil, ®ions) - if err != nil { - return nil, err - } - return ®ions, nil -} - -// GetRegionsReplicatedStateByKeyRange gets the regions replicated state info by key range. -// The keys in the key range should be encoded in the hex bytes format (without encoding to the UTF-8 bytes). -func (c *client) GetRegionsReplicatedStateByKeyRange(ctx context.Context, keyRange *KeyRange) (string, error) { - var state string - err := c.requestWithRetry(ctx, - "GetRegionsReplicatedStateByKeyRange", RegionsReplicatedByKeyRange(keyRange), - http.MethodGet, nil, &state) - if err != nil { - return "", err - } - return state, nil -} - -// GetHotReadRegions gets the hot read region statistics info. -func (c *client) GetHotReadRegions(ctx context.Context) (*StoreHotPeersInfos, error) { - var hotReadRegions StoreHotPeersInfos - err := c.requestWithRetry(ctx, - "GetHotReadRegions", HotRead, - http.MethodGet, nil, &hotReadRegions) - if err != nil { - return nil, err - } - return &hotReadRegions, nil -} - -// GetHotWriteRegions gets the hot write region statistics info. -func (c *client) GetHotWriteRegions(ctx context.Context) (*StoreHotPeersInfos, error) { - var hotWriteRegions StoreHotPeersInfos - err := c.requestWithRetry(ctx, - "GetHotWriteRegions", HotWrite, - http.MethodGet, nil, &hotWriteRegions) - if err != nil { - return nil, err +func (ci *clientInner) membersInfoUpdater(ctx context.Context) { + ci.updateMembersInfo(ctx) + log.Info("[pd] http client member info updater started") + ticker := time.NewTicker(defaultMembersInfoUpdateInterval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + log.Info("[pd] http client member info updater stopped") + return + case <-ticker.C: + ci.updateMembersInfo(ctx) + } } - return &hotWriteRegions, nil } -// GetHistoryHotRegions gets the history hot region statistics info. -func (c *client) GetHistoryHotRegions(ctx context.Context, req *HistoryHotRegionsRequest) (*HistoryHotRegions, error) { - reqJSON, err := json.Marshal(req) +func (ci *clientInner) updateMembersInfo(ctx context.Context) { + var membersInfo MembersInfo + err := ci.requestWithRetry(ctx, newRequestInfo(). + WithCallerID(defaultInnerCallerID). + WithName(getMembersName). + WithURI(membersPrefix). + WithMethod(http.MethodGet). + WithResp(&membersInfo)) if err != nil { - return nil, errors.Trace(err) + log.Error("[pd] http client get members info failed", zap.Error(err)) + return } - var historyHotRegions HistoryHotRegions - err = c.requestWithRetry(ctx, - "GetHistoryHotRegions", HotHistory, - http.MethodGet, reqJSON, &historyHotRegions, - WithAllowFollowerHandle()) - if err != nil { - return nil, err + if len(membersInfo.Members) == 0 { + log.Error("[pd] http client get empty members info") + return } - return &historyHotRegions, nil -} - -// GetRegionStatusByKeyRange gets the region status by key range. -// If the `onlyCount` flag is true, the result will only include the count of regions. -// The keys in the key range should be encoded in the UTF-8 bytes format. -func (c *client) GetRegionStatusByKeyRange(ctx context.Context, keyRange *KeyRange, onlyCount bool) (*RegionStats, error) { - var regionStats RegionStats - err := c.requestWithRetry(ctx, - "GetRegionStatusByKeyRange", RegionStatsByKeyRange(keyRange, onlyCount), - http.MethodGet, nil, ®ionStats, + var ( + newPDAddrs []string + newLeaderAddrIdx int = -1 ) - if err != nil { - return nil, err - } - return ®ionStats, nil -} - -// SetStoreLabels sets the labels of a store. -func (c *client) SetStoreLabels(ctx context.Context, storeID int64, storeLabels map[string]string) error { - jsonInput, err := json.Marshal(storeLabels) - if err != nil { - return errors.Trace(err) - } - return c.requestWithRetry(ctx, "SetStoreLabel", LabelByStoreID(storeID), - http.MethodPost, jsonInput, nil) -} - -func (c *client) GetMembers(ctx context.Context) (*MembersInfo, error) { - var members MembersInfo - err := c.requestWithRetry(ctx, - "GetMembers", membersPrefix, - http.MethodGet, nil, &members) - if err != nil { - return nil, err - } - return &members, nil -} - -// GetLeader gets the leader of PD cluster. -func (c *client) GetLeader(ctx context.Context) (*pdpb.Member, error) { - var leader pdpb.Member - err := c.requestWithRetry(ctx, "GetLeader", leaderPrefix, - http.MethodGet, nil, &leader) - if err != nil { - return nil, err - } - return &leader, nil -} - -// TransferLeader transfers the PD leader. -func (c *client) TransferLeader(ctx context.Context, newLeader string) error { - return c.requestWithRetry(ctx, "TransferLeader", TransferLeaderByID(newLeader), - http.MethodPost, nil, nil) -} - -// GetScheduleConfig gets the schedule configurations. -func (c *client) GetScheduleConfig(ctx context.Context) (map[string]interface{}, error) { - var config map[string]interface{} - err := c.requestWithRetry(ctx, - "GetScheduleConfig", ScheduleConfig, - http.MethodGet, nil, &config) - if err != nil { - return nil, err - } - return config, nil -} - -// SetScheduleConfig sets the schedule configurations. -func (c *client) SetScheduleConfig(ctx context.Context, config map[string]interface{}) error { - configJSON, err := json.Marshal(config) - if err != nil { - return errors.Trace(err) - } - return c.requestWithRetry(ctx, - "SetScheduleConfig", ScheduleConfig, - http.MethodPost, configJSON, nil) -} - -// GetStores gets the stores info. -func (c *client) GetStores(ctx context.Context) (*StoresInfo, error) { - var stores StoresInfo - err := c.requestWithRetry(ctx, - "GetStores", Stores, - http.MethodGet, nil, &stores) - if err != nil { - return nil, err - } - return &stores, nil -} - -// GetStore gets the store info by ID. -func (c *client) GetStore(ctx context.Context, storeID uint64) (*StoreInfo, error) { - var store StoreInfo - err := c.requestWithRetry(ctx, - "GetStore", StoreByID(storeID), - http.MethodGet, nil, &store) - if err != nil { - return nil, err - } - return &store, nil -} - -// GetClusterVersion gets the cluster version. -func (c *client) GetClusterVersion(ctx context.Context) (string, error) { - var version string - err := c.requestWithRetry(ctx, - "GetClusterVersion", ClusterVersion, - http.MethodGet, nil, &version) - if err != nil { - return "", err - } - return version, nil -} - -// GetAllPlacementRuleBundles gets all placement rules bundles. -func (c *client) GetAllPlacementRuleBundles(ctx context.Context) ([]*GroupBundle, error) { - var bundles []*GroupBundle - err := c.requestWithRetry(ctx, - "GetPlacementRuleBundle", PlacementRuleBundle, - http.MethodGet, nil, &bundles) - if err != nil { - return nil, err - } - return bundles, nil -} - -// GetPlacementRuleBundleByGroup gets the placement rules bundle by group. -func (c *client) GetPlacementRuleBundleByGroup(ctx context.Context, group string) (*GroupBundle, error) { - var bundle GroupBundle - err := c.requestWithRetry(ctx, - "GetPlacementRuleBundleByGroup", PlacementRuleBundleByGroup(group), - http.MethodGet, nil, &bundle) - if err != nil { - return nil, err + for _, member := range membersInfo.Members { + if membersInfo.Leader != nil && member.GetMemberId() == membersInfo.Leader.GetMemberId() { + newLeaderAddrIdx = len(newPDAddrs) + } + newPDAddrs = append(newPDAddrs, member.GetClientUrls()...) } - return &bundle, nil -} - -// GetPlacementRulesByGroup gets the placement rules by group. -func (c *client) GetPlacementRulesByGroup(ctx context.Context, group string) ([]*Rule, error) { - var rules []*Rule - err := c.requestWithRetry(ctx, - "GetPlacementRulesByGroup", PlacementRulesByGroup(group), - http.MethodGet, nil, &rules) - if err != nil { - return nil, err + // Prevent setting empty addresses. + if len(newPDAddrs) == 0 { + log.Error("[pd] http client get empty member addresses") + return } - return rules, nil -} - -// SetPlacementRule sets the placement rule. -func (c *client) SetPlacementRule(ctx context.Context, rule *Rule) error { - ruleJSON, err := json.Marshal(rule) - if err != nil { - return errors.Trace(err) + oldPDAddrs, oldLeaderAddrIdx := ci.getPDAddrs() + ci.setPDAddrs(newPDAddrs, newLeaderAddrIdx) + // Log the member info change if it happens. + var oldPDLeaderAddr, newPDLeaderAddr string + if oldLeaderAddrIdx != -1 { + oldPDLeaderAddr = oldPDAddrs[oldLeaderAddrIdx] } - return c.requestWithRetry(ctx, - "SetPlacementRule", PlacementRule, - http.MethodPost, ruleJSON, nil) -} - -// SetPlacementRuleInBatch sets the placement rules in batch. -func (c *client) SetPlacementRuleInBatch(ctx context.Context, ruleOps []*RuleOp) error { - ruleOpsJSON, err := json.Marshal(ruleOps) - if err != nil { - return errors.Trace(err) + if newLeaderAddrIdx != -1 { + newPDLeaderAddr = newPDAddrs[newLeaderAddrIdx] } - return c.requestWithRetry(ctx, - "SetPlacementRuleInBatch", PlacementRulesInBatch, - http.MethodPost, ruleOpsJSON, nil) -} - -// SetPlacementRuleBundles sets the placement rule bundles. -// If `partial` is false, all old configurations will be over-written and dropped. -func (c *client) SetPlacementRuleBundles(ctx context.Context, bundles []*GroupBundle, partial bool) error { - bundlesJSON, err := json.Marshal(bundles) - if err != nil { - return errors.Trace(err) + oldMemberNum, newMemberNum := len(oldPDAddrs), len(newPDAddrs) + if oldPDLeaderAddr != newPDLeaderAddr || oldMemberNum != newMemberNum { + log.Info("[pd] http client members info changed", + zap.Int("old-member-num", oldMemberNum), zap.Int("new-member-num", newMemberNum), + zap.Strings("old-addrs", oldPDAddrs), zap.Strings("new-addrs", newPDAddrs), + zap.Int("old-leader-addr-idx", oldLeaderAddrIdx), zap.Int("new-leader-addr-idx", newLeaderAddrIdx), + zap.String("old-leader-addr", oldPDLeaderAddr), zap.String("new-leader-addr", newPDLeaderAddr)) } - return c.requestWithRetry(ctx, - "SetPlacementRuleBundles", PlacementRuleBundleWithPartialParameter(partial), - http.MethodPost, bundlesJSON, nil) } -// DeletePlacementRule deletes the placement rule. -func (c *client) DeletePlacementRule(ctx context.Context, group, id string) error { - return c.requestWithRetry(ctx, - "DeletePlacementRule", PlacementRuleByGroupAndID(group, id), - http.MethodDelete, nil, nil) -} +type client struct { + inner *clientInner -// GetAllPlacementRuleGroups gets all placement rule groups. -func (c *client) GetAllPlacementRuleGroups(ctx context.Context) ([]*RuleGroup, error) { - var ruleGroups []*RuleGroup - err := c.requestWithRetry(ctx, - "GetAllPlacementRuleGroups", placementRuleGroups, - http.MethodGet, nil, &ruleGroups) - if err != nil { - return nil, err - } - return ruleGroups, nil + callerID string + respHandler respHandleFunc } -// GetPlacementRuleGroupByID gets the placement rule group by ID. -func (c *client) GetPlacementRuleGroupByID(ctx context.Context, id string) (*RuleGroup, error) { - var ruleGroup RuleGroup - err := c.requestWithRetry(ctx, - "GetPlacementRuleGroupByID", PlacementRuleGroupByID(id), - http.MethodGet, nil, &ruleGroup) - if err != nil { - return nil, err - } - return &ruleGroup, nil -} +// ClientOption configures the HTTP client. +type ClientOption func(c *client) -// SetPlacementRuleGroup sets the placement rule group. -func (c *client) SetPlacementRuleGroup(ctx context.Context, ruleGroup *RuleGroup) error { - ruleGroupJSON, err := json.Marshal(ruleGroup) - if err != nil { - return errors.Trace(err) +// WithHTTPClient configures the client with the given initialized HTTP client. +func WithHTTPClient(cli *http.Client) ClientOption { + return func(c *client) { + c.inner.cli = cli } - return c.requestWithRetry(ctx, - "SetPlacementRuleGroup", placementRuleGroup, - http.MethodPost, ruleGroupJSON, nil) -} - -// DeletePlacementRuleGroupByID deletes the placement rule group by ID. -func (c *client) DeletePlacementRuleGroupByID(ctx context.Context, id string) error { - return c.requestWithRetry(ctx, - "DeletePlacementRuleGroupByID", PlacementRuleGroupByID(id), - http.MethodDelete, nil, nil) } -// GetAllRegionLabelRules gets all region label rules. -func (c *client) GetAllRegionLabelRules(ctx context.Context) ([]*LabelRule, error) { - var labelRules []*LabelRule - err := c.requestWithRetry(ctx, - "GetAllRegionLabelRules", RegionLabelRules, - http.MethodGet, nil, &labelRules) - if err != nil { - return nil, err +// WithTLSConfig configures the client with the given TLS config. +// This option won't work if the client is configured with WithHTTPClient. +func WithTLSConfig(tlsConf *tls.Config) ClientOption { + return func(c *client) { + c.inner.tlsConf = tlsConf } - return labelRules, nil } -// GetRegionLabelRulesByIDs gets the region label rules by IDs. -func (c *client) GetRegionLabelRulesByIDs(ctx context.Context, ruleIDs []string) ([]*LabelRule, error) { - idsJSON, err := json.Marshal(ruleIDs) - if err != nil { - return nil, errors.Trace(err) - } - var labelRules []*LabelRule - err = c.requestWithRetry(ctx, - "GetRegionLabelRulesByIDs", RegionLabelRulesByIDs, - http.MethodGet, idsJSON, &labelRules) - if err != nil { - return nil, err +// WithMetrics configures the client with metrics. +func WithMetrics( + requestCounter *prometheus.CounterVec, + executionDuration *prometheus.HistogramVec, +) ClientOption { + return func(c *client) { + c.inner.requestCounter = requestCounter + c.inner.executionDuration = executionDuration } - return labelRules, nil } -// SetRegionLabelRule sets the region label rule. -func (c *client) SetRegionLabelRule(ctx context.Context, labelRule *LabelRule) error { - labelRuleJSON, err := json.Marshal(labelRule) - if err != nil { - return errors.Trace(err) +// NewClient creates a PD HTTP client with the given PD addresses and TLS config. +func NewClient( + pdAddrs []string, + opts ...ClientOption, +) Client { + c := &client{inner: newClientInner(), callerID: defaultCallerID} + // Apply the options first. + for _, opt := range opts { + opt(c) } - return c.requestWithRetry(ctx, - "SetRegionLabelRule", RegionLabelRule, - http.MethodPost, labelRuleJSON, nil) + c.inner.setPDAddrs(pdAddrs, -1) + c.inner.init() + return c } -// PatchRegionLabelRules patches the region label rules. -func (c *client) PatchRegionLabelRules(ctx context.Context, labelRulePatch *LabelRulePatch) error { - labelRulePatchJSON, err := json.Marshal(labelRulePatch) - if err != nil { - return errors.Trace(err) - } - return c.requestWithRetry(ctx, - "PatchRegionLabelRules", RegionLabelRules, - http.MethodPatch, labelRulePatchJSON, nil) +// Close gracefully closes the HTTP client. +func (c *client) Close() { + c.inner.close() + log.Info("[pd] http client closed") } -// GetSchedulers gets the schedulers from PD cluster. -func (c *client) GetSchedulers(ctx context.Context) ([]string, error) { - var schedulers []string - err := c.requestWithRetry(ctx, "GetSchedulers", Schedulers, - http.MethodGet, nil, &schedulers) - if err != nil { - return nil, err - } - return schedulers, nil +// WithCallerID sets and returns a new client with the given caller ID. +func (c *client) WithCallerID(callerID string) Client { + newClient := *c + newClient.callerID = callerID + return &newClient } -// CreateScheduler creates a scheduler to PD cluster. -func (c *client) CreateScheduler(ctx context.Context, name string, storeID uint64) error { - inputJSON, err := json.Marshal(map[string]interface{}{ - "name": name, - "store_id": storeID, - }) - if err != nil { - return errors.Trace(err) - } - return c.requestWithRetry(ctx, - "CreateScheduler", Schedulers, - http.MethodPost, inputJSON, nil) +// WithRespHandler sets and returns a new client with the given HTTP response handler. +func (c *client) WithRespHandler( + handler func(resp *http.Response, res interface{}) error, +) Client { + newClient := *c + newClient.respHandler = handler + return &newClient } -// AccelerateSchedule accelerates the scheduling of the regions within the given key range. -// The keys in the key range should be encoded in the hex bytes format (without encoding to the UTF-8 bytes). -func (c *client) AccelerateSchedule(ctx context.Context, keyRange *KeyRange) error { - startKey, endKey := keyRange.EscapeAsHexStr() - inputJSON, err := json.Marshal(map[string]string{ - "start_key": startKey, - "end_key": endKey, - }) - if err != nil { - return errors.Trace(err) - } - return c.requestWithRetry(ctx, - "AccelerateSchedule", AccelerateSchedule, - http.MethodPost, inputJSON, nil) -} +// Header key definition constants. +const ( + pdAllowFollowerHandleKey = "PD-Allow-Follower-Handle" + xCallerIDKey = "X-Caller-ID" +) -// AccelerateScheduleInBatch accelerates the scheduling of the regions within the given key ranges in batch. -// The keys in the key ranges should be encoded in the hex bytes format (without encoding to the UTF-8 bytes). -func (c *client) AccelerateScheduleInBatch(ctx context.Context, keyRanges []*KeyRange) error { - input := make([]map[string]string, 0, len(keyRanges)) - for _, keyRange := range keyRanges { - startKey, endKey := keyRange.EscapeAsHexStr() - input = append(input, map[string]string{ - "start_key": startKey, - "end_key": endKey, - }) - } - inputJSON, err := json.Marshal(input) - if err != nil { - return errors.Trace(err) - } - return c.requestWithRetry(ctx, - "AccelerateScheduleInBatch", AccelerateScheduleInBatch, - http.MethodPost, inputJSON, nil) -} +// HeaderOption configures the HTTP header. +type HeaderOption func(header http.Header) -// SetSchedulerDelay sets the delay of given scheduler. -func (c *client) SetSchedulerDelay(ctx context.Context, scheduler string, delaySec int64) error { - m := map[string]int64{ - "delay": delaySec, - } - inputJSON, err := json.Marshal(m) - if err != nil { - return errors.Trace(err) +// WithAllowFollowerHandle sets the header field to allow a PD follower to handle this request. +func WithAllowFollowerHandle() HeaderOption { + return func(header http.Header) { + header.Set(pdAllowFollowerHandleKey, "true") } - return c.requestWithRetry(ctx, - "SetSchedulerDelay", SchedulerByName(scheduler), - http.MethodPost, inputJSON, nil) } -// GetMinResolvedTSByStoresIDs get min-resolved-ts by stores IDs. -// - When storeIDs has zero length, it will return (cluster-level's min_resolved_ts, nil, nil) when no error. -// - When storeIDs is {"cluster"}, it will return (cluster-level's min_resolved_ts, stores_min_resolved_ts, nil) when no error. -// - When storeID is specified to ID lists, it will return (min_resolved_ts of given stores, stores_min_resolved_ts, nil) when no error. -func (c *client) GetMinResolvedTSByStoresIDs(ctx context.Context, storeIDs []uint64) (uint64, map[uint64]uint64, error) { - uri := MinResolvedTSPrefix - // scope is an optional parameter, it can be `cluster` or specified store IDs. - // - When no scope is given, cluster-level's min_resolved_ts will be returned and storesMinResolvedTS will be nil. - // - When scope is `cluster`, cluster-level's min_resolved_ts will be returned and storesMinResolvedTS will be filled. - // - When scope given a list of stores, min_resolved_ts will be provided for each store - // and the scope-specific min_resolved_ts will be returned. - if len(storeIDs) != 0 { - storeIDStrs := make([]string, len(storeIDs)) - for idx, id := range storeIDs { - storeIDStrs[idx] = fmt.Sprintf("%d", id) - } - uri = fmt.Sprintf("%s?scope=%s", uri, strings.Join(storeIDStrs, ",")) - } - resp := struct { - MinResolvedTS uint64 `json:"min_resolved_ts"` - IsRealTime bool `json:"is_real_time,omitempty"` - StoresMinResolvedTS map[uint64]uint64 `json:"stores_min_resolved_ts"` - }{} - err := c.requestWithRetry(ctx, - "GetMinResolvedTSByStoresIDs", uri, - http.MethodGet, nil, &resp) - if err != nil { - return 0, nil, err - } - if !resp.IsRealTime { - return 0, nil, errors.Trace(errors.New("min resolved ts is not enabled")) - } - return resp.MinResolvedTS, resp.StoresMinResolvedTS, nil +func (c *client) request(ctx context.Context, reqInfo *requestInfo, headerOpts ...HeaderOption) error { + return c.inner.requestWithRetry(ctx, reqInfo. + WithCallerID(c.callerID). + WithRespHandler(c.respHandler), + headerOpts...) } -// GetMicroServiceMembers gets the members of the microservice. -func (c *client) GetMicroServiceMembers(ctx context.Context, service string) ([]string, error) { - var members []string - err := c.requestWithRetry(ctx, - "GetMicroServiceMembers", MicroServiceMembers(service), - http.MethodGet, nil, &members) - if err != nil { - return nil, err - } - return members, nil +// UpdateMembersInfo updates the members info of the PD cluster in the inner client. +// Exported for testing. +func (c *client) UpdateMembersInfo() { + c.inner.updateMembersInfo(c.inner.ctx) } diff --git a/client/http/client_test.go b/client/http/client_test.go index 70c2ddee08b..7c7da80827c 100644 --- a/client/http/client_test.go +++ b/client/http/client_test.go @@ -16,12 +16,28 @@ package http import ( "context" + "crypto/tls" "net/http" "testing" "github.com/stretchr/testify/require" + "go.uber.org/atomic" ) +func TestPDAddrNormalization(t *testing.T) { + re := require.New(t) + c := NewClient([]string{"127.0.0.1"}) + pdAddrs, leaderAddrIdx := c.(*client).inner.getPDAddrs() + re.Equal(1, len(pdAddrs)) + re.Equal(-1, leaderAddrIdx) + re.Contains(pdAddrs[0], httpScheme) + c = NewClient([]string{"127.0.0.1"}, WithTLSConfig(&tls.Config{})) + pdAddrs, leaderAddrIdx = c.(*client).inner.getPDAddrs() + re.Equal(1, len(pdAddrs)) + re.Equal(-1, leaderAddrIdx) + re.Contains(pdAddrs[0], httpsScheme) +} + // requestChecker is used to check the HTTP request sent by the client. type requestChecker struct { checker func(req *http.Request) error @@ -40,8 +56,11 @@ func newHTTPClientWithRequestChecker(checker func(req *http.Request) error) *htt func TestPDAllowFollowerHandleHeader(t *testing.T) { re := require.New(t) - var expectedVal string httpClient := newHTTPClientWithRequestChecker(func(req *http.Request) error { + var expectedVal string + if req.URL.Path == HotHistory { + expectedVal = "true" + } val := req.Header.Get(pdAllowFollowerHandleKey) if val != expectedVal { re.Failf("PD allow follower handler header check failed", @@ -51,16 +70,17 @@ func TestPDAllowFollowerHandleHeader(t *testing.T) { }) c := NewClient([]string{"http://127.0.0.1"}, WithHTTPClient(httpClient)) c.GetRegions(context.Background()) - expectedVal = "true" c.GetHistoryHotRegions(context.Background(), &HistoryHotRegionsRequest{}) + c.Close() } func TestCallerID(t *testing.T) { re := require.New(t) - expectedVal := defaultCallerID + expectedVal := atomic.NewString(defaultCallerID) httpClient := newHTTPClientWithRequestChecker(func(req *http.Request) error { val := req.Header.Get(xCallerIDKey) - if val != expectedVal { + // Exclude the request sent by the inner client. + if val != defaultInnerCallerID && val != expectedVal.Load() { re.Failf("Caller ID header check failed", "should be %s, but got %s", expectedVal, val) } @@ -68,6 +88,7 @@ func TestCallerID(t *testing.T) { }) c := NewClient([]string{"http://127.0.0.1"}, WithHTTPClient(httpClient)) c.GetRegions(context.Background()) - expectedVal = "test" - c.WithCallerID(expectedVal).GetRegions(context.Background()) + expectedVal.Store("test") + c.WithCallerID(expectedVal.Load()).GetRegions(context.Background()) + c.Close() } diff --git a/client/http/interface.go b/client/http/interface.go new file mode 100644 index 00000000000..25fa99fa0bd --- /dev/null +++ b/client/http/interface.go @@ -0,0 +1,695 @@ +// Copyright 2023 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package http + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + + "github.com/pingcap/errors" + "github.com/pingcap/kvproto/pkg/pdpb" +) + +// Client is a PD (Placement Driver) HTTP client. +type Client interface { + /* Member-related interfaces */ + GetMembers(context.Context) (*MembersInfo, error) + GetLeader(context.Context) (*pdpb.Member, error) + TransferLeader(context.Context, string) error + /* Meta-related interfaces */ + GetRegionByID(context.Context, uint64) (*RegionInfo, error) + GetRegionByKey(context.Context, []byte) (*RegionInfo, error) + GetRegions(context.Context) (*RegionsInfo, error) + GetRegionsByKeyRange(context.Context, *KeyRange, int) (*RegionsInfo, error) + GetRegionsByStoreID(context.Context, uint64) (*RegionsInfo, error) + GetRegionsReplicatedStateByKeyRange(context.Context, *KeyRange) (string, error) + GetHotReadRegions(context.Context) (*StoreHotPeersInfos, error) + GetHotWriteRegions(context.Context) (*StoreHotPeersInfos, error) + GetHistoryHotRegions(context.Context, *HistoryHotRegionsRequest) (*HistoryHotRegions, error) + GetRegionStatusByKeyRange(context.Context, *KeyRange, bool) (*RegionStats, error) + GetStores(context.Context) (*StoresInfo, error) + GetStore(context.Context, uint64) (*StoreInfo, error) + SetStoreLabels(context.Context, int64, map[string]string) error + /* Config-related interfaces */ + GetScheduleConfig(context.Context) (map[string]interface{}, error) + SetScheduleConfig(context.Context, map[string]interface{}) error + GetClusterVersion(context.Context) (string, error) + /* Scheduler-related interfaces */ + GetSchedulers(context.Context) ([]string, error) + CreateScheduler(ctx context.Context, name string, storeID uint64) error + SetSchedulerDelay(context.Context, string, int64) error + /* Rule-related interfaces */ + GetAllPlacementRuleBundles(context.Context) ([]*GroupBundle, error) + GetPlacementRuleBundleByGroup(context.Context, string) (*GroupBundle, error) + GetPlacementRulesByGroup(context.Context, string) ([]*Rule, error) + SetPlacementRule(context.Context, *Rule) error + SetPlacementRuleInBatch(context.Context, []*RuleOp) error + SetPlacementRuleBundles(context.Context, []*GroupBundle, bool) error + DeletePlacementRule(context.Context, string, string) error + GetAllPlacementRuleGroups(context.Context) ([]*RuleGroup, error) + GetPlacementRuleGroupByID(context.Context, string) (*RuleGroup, error) + SetPlacementRuleGroup(context.Context, *RuleGroup) error + DeletePlacementRuleGroupByID(context.Context, string) error + GetAllRegionLabelRules(context.Context) ([]*LabelRule, error) + GetRegionLabelRulesByIDs(context.Context, []string) ([]*LabelRule, error) + SetRegionLabelRule(context.Context, *LabelRule) error + PatchRegionLabelRules(context.Context, *LabelRulePatch) error + /* Scheduling-related interfaces */ + AccelerateSchedule(context.Context, *KeyRange) error + AccelerateScheduleInBatch(context.Context, []*KeyRange) error + /* Other interfaces */ + GetMinResolvedTSByStoresIDs(context.Context, []uint64) (uint64, map[uint64]uint64, error) + /* Micro Service interfaces */ + GetMicroServiceMembers(context.Context, string) ([]string, error) + + /* Client-related methods */ + // WithCallerID sets and returns a new client with the given caller ID. + WithCallerID(string) Client + // WithRespHandler sets and returns a new client with the given HTTP response handler. + // This allows the caller to customize how the response is handled, including error handling logic. + // Additionally, it is important for the caller to handle the content of the response body properly + // in order to ensure that it can be read and marshaled correctly into `res`. + WithRespHandler(func(resp *http.Response, res interface{}) error) Client + // Close gracefully closes the HTTP client. + Close() +} + +var _ Client = (*client)(nil) + +// GetMembers gets the members info of PD cluster. +func (c *client) GetMembers(ctx context.Context) (*MembersInfo, error) { + var members MembersInfo + err := c.request(ctx, newRequestInfo(). + WithName(getMembersName). + WithURI(membersPrefix). + WithMethod(http.MethodGet). + WithResp(&members)) + if err != nil { + return nil, err + } + return &members, nil +} + +// GetLeader gets the leader of PD cluster. +func (c *client) GetLeader(ctx context.Context) (*pdpb.Member, error) { + var leader pdpb.Member + err := c.request(ctx, newRequestInfo(). + WithName(getLeaderName). + WithURI(leaderPrefix). + WithMethod(http.MethodGet). + WithResp(&leader)) + if err != nil { + return nil, err + } + return &leader, nil +} + +// TransferLeader transfers the PD leader. +func (c *client) TransferLeader(ctx context.Context, newLeader string) error { + return c.request(ctx, newRequestInfo(). + WithName(transferLeaderName). + WithURI(TransferLeaderByID(newLeader)). + WithMethod(http.MethodPost)) +} + +// GetRegionByID gets the region info by ID. +func (c *client) GetRegionByID(ctx context.Context, regionID uint64) (*RegionInfo, error) { + var region RegionInfo + err := c.request(ctx, newRequestInfo(). + WithName(getRegionByIDName). + WithURI(RegionByID(regionID)). + WithMethod(http.MethodGet). + WithResp(®ion)) + if err != nil { + return nil, err + } + return ®ion, nil +} + +// GetRegionByKey gets the region info by key. +func (c *client) GetRegionByKey(ctx context.Context, key []byte) (*RegionInfo, error) { + var region RegionInfo + err := c.request(ctx, newRequestInfo(). + WithName(getRegionByKeyName). + WithURI(RegionByKey(key)). + WithMethod(http.MethodGet). + WithResp(®ion)) + if err != nil { + return nil, err + } + return ®ion, nil +} + +// GetRegions gets the regions info. +func (c *client) GetRegions(ctx context.Context) (*RegionsInfo, error) { + var regions RegionsInfo + err := c.request(ctx, newRequestInfo(). + WithName(getRegionsName). + WithURI(Regions). + WithMethod(http.MethodGet). + WithResp(®ions)) + if err != nil { + return nil, err + } + return ®ions, nil +} + +// GetRegionsByKeyRange gets the regions info by key range. If the limit is -1, it will return all regions within the range. +// The keys in the key range should be encoded in the UTF-8 bytes format. +func (c *client) GetRegionsByKeyRange(ctx context.Context, keyRange *KeyRange, limit int) (*RegionsInfo, error) { + var regions RegionsInfo + err := c.request(ctx, newRequestInfo(). + WithName(getRegionsByKeyRangeName). + WithURI(RegionsByKeyRange(keyRange, limit)). + WithMethod(http.MethodGet). + WithResp(®ions)) + if err != nil { + return nil, err + } + return ®ions, nil +} + +// GetRegionsByStoreID gets the regions info by store ID. +func (c *client) GetRegionsByStoreID(ctx context.Context, storeID uint64) (*RegionsInfo, error) { + var regions RegionsInfo + err := c.request(ctx, newRequestInfo(). + WithName(getRegionsByStoreIDName). + WithURI(RegionsByStoreID(storeID)). + WithMethod(http.MethodGet). + WithResp(®ions)) + if err != nil { + return nil, err + } + return ®ions, nil +} + +// GetRegionsReplicatedStateByKeyRange gets the regions replicated state info by key range. +// The keys in the key range should be encoded in the hex bytes format (without encoding to the UTF-8 bytes). +func (c *client) GetRegionsReplicatedStateByKeyRange(ctx context.Context, keyRange *KeyRange) (string, error) { + var state string + err := c.request(ctx, newRequestInfo(). + WithName(getRegionsReplicatedStateByKeyRangeName). + WithURI(RegionsReplicatedByKeyRange(keyRange)). + WithMethod(http.MethodGet). + WithResp(&state)) + if err != nil { + return "", err + } + return state, nil +} + +// GetHotReadRegions gets the hot read region statistics info. +func (c *client) GetHotReadRegions(ctx context.Context) (*StoreHotPeersInfos, error) { + var hotReadRegions StoreHotPeersInfos + err := c.request(ctx, newRequestInfo(). + WithName(getHotReadRegionsName). + WithURI(HotRead). + WithMethod(http.MethodGet). + WithResp(&hotReadRegions)) + if err != nil { + return nil, err + } + return &hotReadRegions, nil +} + +// GetHotWriteRegions gets the hot write region statistics info. +func (c *client) GetHotWriteRegions(ctx context.Context) (*StoreHotPeersInfos, error) { + var hotWriteRegions StoreHotPeersInfos + err := c.request(ctx, newRequestInfo(). + WithName(getHotWriteRegionsName). + WithURI(HotWrite). + WithMethod(http.MethodGet). + WithResp(&hotWriteRegions)) + if err != nil { + return nil, err + } + return &hotWriteRegions, nil +} + +// GetHistoryHotRegions gets the history hot region statistics info. +func (c *client) GetHistoryHotRegions(ctx context.Context, req *HistoryHotRegionsRequest) (*HistoryHotRegions, error) { + reqJSON, err := json.Marshal(req) + if err != nil { + return nil, errors.Trace(err) + } + var historyHotRegions HistoryHotRegions + err = c.request(ctx, newRequestInfo(). + WithName(getHistoryHotRegionsName). + WithURI(HotHistory). + WithMethod(http.MethodGet). + WithBody(reqJSON). + WithResp(&historyHotRegions), + WithAllowFollowerHandle()) + if err != nil { + return nil, err + } + return &historyHotRegions, nil +} + +// GetRegionStatusByKeyRange gets the region status by key range. +// If the `onlyCount` flag is true, the result will only include the count of regions. +// The keys in the key range should be encoded in the UTF-8 bytes format. +func (c *client) GetRegionStatusByKeyRange(ctx context.Context, keyRange *KeyRange, onlyCount bool) (*RegionStats, error) { + var regionStats RegionStats + err := c.request(ctx, newRequestInfo(). + WithName(getRegionStatusByKeyRangeName). + WithURI(RegionStatsByKeyRange(keyRange, onlyCount)). + WithMethod(http.MethodGet). + WithResp(®ionStats)) + if err != nil { + return nil, err + } + return ®ionStats, nil +} + +// SetStoreLabels sets the labels of a store. +func (c *client) SetStoreLabels(ctx context.Context, storeID int64, storeLabels map[string]string) error { + jsonInput, err := json.Marshal(storeLabels) + if err != nil { + return errors.Trace(err) + } + return c.request(ctx, newRequestInfo(). + WithName(setStoreLabelsName). + WithURI(LabelByStoreID(storeID)). + WithMethod(http.MethodPost). + WithBody(jsonInput)) +} + +// GetScheduleConfig gets the schedule configurations. +func (c *client) GetScheduleConfig(ctx context.Context) (map[string]interface{}, error) { + var config map[string]interface{} + err := c.request(ctx, newRequestInfo(). + WithName(getScheduleConfigName). + WithURI(ScheduleConfig). + WithMethod(http.MethodGet). + WithResp(&config)) + if err != nil { + return nil, err + } + return config, nil +} + +// SetScheduleConfig sets the schedule configurations. +func (c *client) SetScheduleConfig(ctx context.Context, config map[string]interface{}) error { + configJSON, err := json.Marshal(config) + if err != nil { + return errors.Trace(err) + } + return c.request(ctx, newRequestInfo(). + WithName(setScheduleConfigName). + WithURI(ScheduleConfig). + WithMethod(http.MethodPost). + WithBody(configJSON)) +} + +// GetStores gets the stores info. +func (c *client) GetStores(ctx context.Context) (*StoresInfo, error) { + var stores StoresInfo + err := c.request(ctx, newRequestInfo(). + WithName(getStoresName). + WithURI(Stores). + WithMethod(http.MethodGet). + WithResp(&stores)) + if err != nil { + return nil, err + } + return &stores, nil +} + +// GetStore gets the store info by ID. +func (c *client) GetStore(ctx context.Context, storeID uint64) (*StoreInfo, error) { + var store StoreInfo + err := c.request(ctx, newRequestInfo(). + WithName(getStoreName). + WithURI(StoreByID(storeID)). + WithMethod(http.MethodGet). + WithResp(&store)) + if err != nil { + return nil, err + } + return &store, nil +} + +// GetClusterVersion gets the cluster version. +func (c *client) GetClusterVersion(ctx context.Context) (string, error) { + var version string + err := c.request(ctx, newRequestInfo(). + WithName(getClusterVersionName). + WithURI(ClusterVersion). + WithMethod(http.MethodGet). + WithResp(&version)) + if err != nil { + return "", err + } + return version, nil +} + +// GetAllPlacementRuleBundles gets all placement rules bundles. +func (c *client) GetAllPlacementRuleBundles(ctx context.Context) ([]*GroupBundle, error) { + var bundles []*GroupBundle + err := c.request(ctx, newRequestInfo(). + WithName(getAllPlacementRuleBundlesName). + WithURI(PlacementRuleBundle). + WithMethod(http.MethodGet). + WithResp(&bundles)) + if err != nil { + return nil, err + } + return bundles, nil +} + +// GetPlacementRuleBundleByGroup gets the placement rules bundle by group. +func (c *client) GetPlacementRuleBundleByGroup(ctx context.Context, group string) (*GroupBundle, error) { + var bundle GroupBundle + err := c.request(ctx, newRequestInfo(). + WithName(getPlacementRuleBundleByGroupName). + WithURI(PlacementRuleBundleByGroup(group)). + WithMethod(http.MethodGet). + WithResp(&bundle)) + if err != nil { + return nil, err + } + return &bundle, nil +} + +// GetPlacementRulesByGroup gets the placement rules by group. +func (c *client) GetPlacementRulesByGroup(ctx context.Context, group string) ([]*Rule, error) { + var rules []*Rule + err := c.request(ctx, newRequestInfo(). + WithName(getPlacementRulesByGroupName). + WithURI(PlacementRulesByGroup(group)). + WithMethod(http.MethodGet). + WithResp(&rules)) + if err != nil { + return nil, err + } + return rules, nil +} + +// SetPlacementRule sets the placement rule. +func (c *client) SetPlacementRule(ctx context.Context, rule *Rule) error { + ruleJSON, err := json.Marshal(rule) + if err != nil { + return errors.Trace(err) + } + return c.request(ctx, newRequestInfo(). + WithName(setPlacementRuleName). + WithURI(PlacementRule). + WithMethod(http.MethodPost). + WithBody(ruleJSON)) +} + +// SetPlacementRuleInBatch sets the placement rules in batch. +func (c *client) SetPlacementRuleInBatch(ctx context.Context, ruleOps []*RuleOp) error { + ruleOpsJSON, err := json.Marshal(ruleOps) + if err != nil { + return errors.Trace(err) + } + return c.request(ctx, newRequestInfo(). + WithName(setPlacementRuleInBatchName). + WithURI(PlacementRulesInBatch). + WithMethod(http.MethodPost). + WithBody(ruleOpsJSON)) +} + +// SetPlacementRuleBundles sets the placement rule bundles. +// If `partial` is false, all old configurations will be over-written and dropped. +func (c *client) SetPlacementRuleBundles(ctx context.Context, bundles []*GroupBundle, partial bool) error { + bundlesJSON, err := json.Marshal(bundles) + if err != nil { + return errors.Trace(err) + } + return c.request(ctx, newRequestInfo(). + WithName(setPlacementRuleBundlesName). + WithURI(PlacementRuleBundleWithPartialParameter(partial)). + WithMethod(http.MethodPost). + WithBody(bundlesJSON)) +} + +// DeletePlacementRule deletes the placement rule. +func (c *client) DeletePlacementRule(ctx context.Context, group, id string) error { + return c.request(ctx, newRequestInfo(). + WithName(deletePlacementRuleName). + WithURI(PlacementRuleByGroupAndID(group, id)). + WithMethod(http.MethodDelete)) +} + +// GetAllPlacementRuleGroups gets all placement rule groups. +func (c *client) GetAllPlacementRuleGroups(ctx context.Context) ([]*RuleGroup, error) { + var ruleGroups []*RuleGroup + err := c.request(ctx, newRequestInfo(). + WithName(getAllPlacementRuleGroupsName). + WithURI(placementRuleGroups). + WithMethod(http.MethodGet). + WithResp(&ruleGroups)) + if err != nil { + return nil, err + } + return ruleGroups, nil +} + +// GetPlacementRuleGroupByID gets the placement rule group by ID. +func (c *client) GetPlacementRuleGroupByID(ctx context.Context, id string) (*RuleGroup, error) { + var ruleGroup RuleGroup + err := c.request(ctx, newRequestInfo(). + WithName(getPlacementRuleGroupByIDName). + WithURI(PlacementRuleGroupByID(id)). + WithMethod(http.MethodGet). + WithResp(&ruleGroup)) + if err != nil { + return nil, err + } + return &ruleGroup, nil +} + +// SetPlacementRuleGroup sets the placement rule group. +func (c *client) SetPlacementRuleGroup(ctx context.Context, ruleGroup *RuleGroup) error { + ruleGroupJSON, err := json.Marshal(ruleGroup) + if err != nil { + return errors.Trace(err) + } + return c.request(ctx, newRequestInfo(). + WithName(setPlacementRuleGroupName). + WithURI(placementRuleGroup). + WithMethod(http.MethodPost). + WithBody(ruleGroupJSON)) +} + +// DeletePlacementRuleGroupByID deletes the placement rule group by ID. +func (c *client) DeletePlacementRuleGroupByID(ctx context.Context, id string) error { + return c.request(ctx, newRequestInfo(). + WithName(deletePlacementRuleGroupByIDName). + WithURI(PlacementRuleGroupByID(id)). + WithMethod(http.MethodDelete)) +} + +// GetAllRegionLabelRules gets all region label rules. +func (c *client) GetAllRegionLabelRules(ctx context.Context) ([]*LabelRule, error) { + var labelRules []*LabelRule + err := c.request(ctx, newRequestInfo(). + WithName(getAllRegionLabelRulesName). + WithURI(RegionLabelRules). + WithMethod(http.MethodGet). + WithResp(&labelRules)) + if err != nil { + return nil, err + } + return labelRules, nil +} + +// GetRegionLabelRulesByIDs gets the region label rules by IDs. +func (c *client) GetRegionLabelRulesByIDs(ctx context.Context, ruleIDs []string) ([]*LabelRule, error) { + idsJSON, err := json.Marshal(ruleIDs) + if err != nil { + return nil, errors.Trace(err) + } + var labelRules []*LabelRule + err = c.request(ctx, newRequestInfo(). + WithName(getRegionLabelRulesByIDsName). + WithURI(RegionLabelRules). + WithMethod(http.MethodGet). + WithBody(idsJSON). + WithResp(&labelRules)) + if err != nil { + return nil, err + } + return labelRules, nil +} + +// SetRegionLabelRule sets the region label rule. +func (c *client) SetRegionLabelRule(ctx context.Context, labelRule *LabelRule) error { + labelRuleJSON, err := json.Marshal(labelRule) + if err != nil { + return errors.Trace(err) + } + return c.request(ctx, newRequestInfo(). + WithName(setRegionLabelRuleName). + WithURI(RegionLabelRule). + WithMethod(http.MethodPost). + WithBody(labelRuleJSON)) +} + +// PatchRegionLabelRules patches the region label rules. +func (c *client) PatchRegionLabelRules(ctx context.Context, labelRulePatch *LabelRulePatch) error { + labelRulePatchJSON, err := json.Marshal(labelRulePatch) + if err != nil { + return errors.Trace(err) + } + return c.request(ctx, newRequestInfo(). + WithName(patchRegionLabelRulesName). + WithURI(RegionLabelRules). + WithMethod(http.MethodPatch). + WithBody(labelRulePatchJSON)) +} + +// GetSchedulers gets the schedulers from PD cluster. +func (c *client) GetSchedulers(ctx context.Context) ([]string, error) { + var schedulers []string + err := c.request(ctx, newRequestInfo(). + WithName(getSchedulersName). + WithURI(Schedulers). + WithMethod(http.MethodGet). + WithResp(&schedulers)) + if err != nil { + return nil, err + } + return schedulers, nil +} + +// CreateScheduler creates a scheduler to PD cluster. +func (c *client) CreateScheduler(ctx context.Context, name string, storeID uint64) error { + inputJSON, err := json.Marshal(map[string]interface{}{ + "name": name, + "store_id": storeID, + }) + if err != nil { + return errors.Trace(err) + } + return c.request(ctx, newRequestInfo(). + WithName(createSchedulerName). + WithURI(Schedulers). + WithMethod(http.MethodPost). + WithBody(inputJSON)) +} + +// AccelerateSchedule accelerates the scheduling of the regions within the given key range. +// The keys in the key range should be encoded in the hex bytes format (without encoding to the UTF-8 bytes). +func (c *client) AccelerateSchedule(ctx context.Context, keyRange *KeyRange) error { + startKey, endKey := keyRange.EscapeAsHexStr() + inputJSON, err := json.Marshal(map[string]string{ + "start_key": startKey, + "end_key": endKey, + }) + if err != nil { + return errors.Trace(err) + } + return c.request(ctx, newRequestInfo(). + WithName(accelerateScheduleName). + WithURI(AccelerateSchedule). + WithMethod(http.MethodPost). + WithBody(inputJSON)) +} + +// AccelerateScheduleInBatch accelerates the scheduling of the regions within the given key ranges in batch. +// The keys in the key ranges should be encoded in the hex bytes format (without encoding to the UTF-8 bytes). +func (c *client) AccelerateScheduleInBatch(ctx context.Context, keyRanges []*KeyRange) error { + input := make([]map[string]string, 0, len(keyRanges)) + for _, keyRange := range keyRanges { + startKey, endKey := keyRange.EscapeAsHexStr() + input = append(input, map[string]string{ + "start_key": startKey, + "end_key": endKey, + }) + } + inputJSON, err := json.Marshal(input) + if err != nil { + return errors.Trace(err) + } + return c.request(ctx, newRequestInfo(). + WithName(accelerateScheduleInBatchName). + WithURI(AccelerateScheduleInBatch). + WithMethod(http.MethodPost). + WithBody(inputJSON)) +} + +// SetSchedulerDelay sets the delay of given scheduler. +func (c *client) SetSchedulerDelay(ctx context.Context, scheduler string, delaySec int64) error { + m := map[string]int64{ + "delay": delaySec, + } + inputJSON, err := json.Marshal(m) + if err != nil { + return errors.Trace(err) + } + return c.request(ctx, newRequestInfo(). + WithName(setSchedulerDelayName). + WithURI(SchedulerByName(scheduler)). + WithMethod(http.MethodPost). + WithBody(inputJSON)) +} + +// GetMinResolvedTSByStoresIDs get min-resolved-ts by stores IDs. +// - When storeIDs has zero length, it will return (cluster-level's min_resolved_ts, nil, nil) when no error. +// - When storeIDs is {"cluster"}, it will return (cluster-level's min_resolved_ts, stores_min_resolved_ts, nil) when no error. +// - When storeID is specified to ID lists, it will return (min_resolved_ts of given stores, stores_min_resolved_ts, nil) when no error. +func (c *client) GetMinResolvedTSByStoresIDs(ctx context.Context, storeIDs []uint64) (uint64, map[uint64]uint64, error) { + uri := MinResolvedTSPrefix + // scope is an optional parameter, it can be `cluster` or specified store IDs. + // - When no scope is given, cluster-level's min_resolved_ts will be returned and storesMinResolvedTS will be nil. + // - When scope is `cluster`, cluster-level's min_resolved_ts will be returned and storesMinResolvedTS will be filled. + // - When scope given a list of stores, min_resolved_ts will be provided for each store + // and the scope-specific min_resolved_ts will be returned. + if len(storeIDs) != 0 { + storeIDStrs := make([]string, len(storeIDs)) + for idx, id := range storeIDs { + storeIDStrs[idx] = fmt.Sprintf("%d", id) + } + uri = fmt.Sprintf("%s?scope=%s", uri, strings.Join(storeIDStrs, ",")) + } + resp := struct { + MinResolvedTS uint64 `json:"min_resolved_ts"` + IsRealTime bool `json:"is_real_time,omitempty"` + StoresMinResolvedTS map[uint64]uint64 `json:"stores_min_resolved_ts"` + }{} + err := c.request(ctx, newRequestInfo(). + WithName(getMinResolvedTSByStoresIDsName). + WithURI(uri). + WithMethod(http.MethodGet). + WithResp(&resp)) + if err != nil { + return 0, nil, err + } + if !resp.IsRealTime { + return 0, nil, errors.Trace(errors.New("min resolved ts is not enabled")) + } + return resp.MinResolvedTS, resp.StoresMinResolvedTS, nil +} + +// GetMicroServiceMembers gets the members of the microservice. +func (c *client) GetMicroServiceMembers(ctx context.Context, service string) ([]string, error) { + var members []string + err := c.request(ctx, newRequestInfo(). + WithName(getMicroServiceMembersName). + WithURI(MicroServiceMembers(service)). + WithMethod(http.MethodGet). + WithResp(&members)) + if err != nil { + return nil, err + } + return members, nil +} diff --git a/client/http/request_info.go b/client/http/request_info.go new file mode 100644 index 00000000000..dc811618d9c --- /dev/null +++ b/client/http/request_info.go @@ -0,0 +1,123 @@ +// Copyright 2023 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package http + +import "fmt" + +// The following constants are the names of the requests. +const ( + getMembersName = "GetMembers" + getLeaderName = "GetLeader" + transferLeaderName = "TransferLeader" + getRegionByIDName = "GetRegionByID" + getRegionByKeyName = "GetRegionByKey" + getRegionsName = "GetRegions" + getRegionsByKeyRangeName = "GetRegionsByKeyRange" + getRegionsByStoreIDName = "GetRegionsByStoreID" + getRegionsReplicatedStateByKeyRangeName = "GetRegionsReplicatedStateByKeyRange" + getHotReadRegionsName = "GetHotReadRegions" + getHotWriteRegionsName = "GetHotWriteRegions" + getHistoryHotRegionsName = "GetHistoryHotRegions" + getRegionStatusByKeyRangeName = "GetRegionStatusByKeyRange" + getStoresName = "GetStores" + getStoreName = "GetStore" + setStoreLabelsName = "SetStoreLabels" + getScheduleConfigName = "GetScheduleConfig" + setScheduleConfigName = "SetScheduleConfig" + getClusterVersionName = "GetClusterVersion" + getSchedulersName = "GetSchedulers" + createSchedulerName = "CreateScheduler" + setSchedulerDelayName = "SetSchedulerDelay" + getAllPlacementRuleBundlesName = "GetAllPlacementRuleBundles" + getPlacementRuleBundleByGroupName = "GetPlacementRuleBundleByGroup" + getPlacementRulesByGroupName = "GetPlacementRulesByGroup" + setPlacementRuleName = "SetPlacementRule" + setPlacementRuleInBatchName = "SetPlacementRuleInBatch" + setPlacementRuleBundlesName = "SetPlacementRuleBundles" + deletePlacementRuleName = "DeletePlacementRule" + getAllPlacementRuleGroupsName = "GetAllPlacementRuleGroups" + getPlacementRuleGroupByIDName = "GetPlacementRuleGroupByID" + setPlacementRuleGroupName = "SetPlacementRuleGroup" + deletePlacementRuleGroupByIDName = "DeletePlacementRuleGroupByID" + getAllRegionLabelRulesName = "GetAllRegionLabelRules" + getRegionLabelRulesByIDsName = "GetRegionLabelRulesByIDs" + setRegionLabelRuleName = "SetRegionLabelRule" + patchRegionLabelRulesName = "PatchRegionLabelRules" + accelerateScheduleName = "AccelerateSchedule" + accelerateScheduleInBatchName = "AccelerateScheduleInBatch" + getMinResolvedTSByStoresIDsName = "GetMinResolvedTSByStoresIDs" + getMicroServiceMembersName = "GetMicroServiceMembers" +) + +type requestInfo struct { + callerID string + name string + uri string + method string + body []byte + res interface{} + respHandler respHandleFunc +} + +// newRequestInfo creates a new request info. +func newRequestInfo() *requestInfo { + return &requestInfo{} +} + +// WithCallerID sets the caller ID of the request. +func (ri *requestInfo) WithCallerID(callerID string) *requestInfo { + ri.callerID = callerID + return ri +} + +// WithName sets the name of the request. +func (ri *requestInfo) WithName(name string) *requestInfo { + ri.name = name + return ri +} + +// WithURI sets the URI of the request. +func (ri *requestInfo) WithURI(uri string) *requestInfo { + ri.uri = uri + return ri +} + +// WithMethod sets the method of the request. +func (ri *requestInfo) WithMethod(method string) *requestInfo { + ri.method = method + return ri +} + +// WithBody sets the body of the request. +func (ri *requestInfo) WithBody(body []byte) *requestInfo { + ri.body = body + return ri +} + +// WithResp sets the response struct of the request. +func (ri *requestInfo) WithResp(res interface{}) *requestInfo { + ri.res = res + return ri +} + +// WithRespHandler sets the response handle function of the request. +func (ri *requestInfo) WithRespHandler(respHandler respHandleFunc) *requestInfo { + ri.respHandler = respHandler + return ri +} + +func (ri *requestInfo) getURL(addr string) string { + return fmt.Sprintf("%s%s", addr, ri.uri) +} diff --git a/tests/integrations/client/http_client_test.go b/tests/integrations/client/http_client_test.go index 7c8f66f4826..1a20ae2e784 100644 --- a/tests/integrations/client/http_client_test.go +++ b/tests/integrations/client/http_client_test.go @@ -438,13 +438,13 @@ func (suite *httpClientTestSuite) TestTransferLeader() { re.NoError(err) re.Len(members.Members, 2) - oldLeader, err := suite.client.GetLeader(suite.ctx) + leader, err := suite.client.GetLeader(suite.ctx) re.NoError(err) // Transfer leader to another pd for _, member := range members.Members { - if member.Name != oldLeader.Name { - err = suite.client.TransferLeader(suite.ctx, member.Name) + if member.GetName() != leader.GetName() { + err = suite.client.TransferLeader(suite.ctx, member.GetName()) re.NoError(err) break } @@ -453,5 +453,14 @@ func (suite *httpClientTestSuite) TestTransferLeader() { newLeader := suite.cluster.WaitLeader() re.NotEmpty(newLeader) re.NoError(err) - re.NotEqual(oldLeader.Name, newLeader) + re.NotEqual(leader.GetName(), newLeader) + // Force to update the members info. + suite.client.(interface{ UpdateMembersInfo() }).UpdateMembersInfo() + leader, err = suite.client.GetLeader(suite.ctx) + re.NoError(err) + re.Equal(newLeader, leader.GetName()) + members, err = suite.client.GetMembers(suite.ctx) + re.NoError(err) + re.Len(members.Members, 2) + re.Equal(leader.GetName(), members.Leader.GetName()) } From cfb1d8f94134b598f924aca231fee66326106ce8 Mon Sep 17 00:00:00 2001 From: lucasliang Date: Mon, 18 Dec 2023 18:27:24 +0800 Subject: [PATCH 15/21] scheduler: enable `evict-slow-store` by default. (#7505) close tikv/pd#7564, ref tikv/tikv#15909 Enable `evict-slow-store` scheduler by default. Signed-off-by: lucasliang --- pkg/schedule/config/config.go | 1 + server/cluster/cluster_test.go | 18 ++++++++++------- tests/integrations/mcs/scheduling/api_test.go | 5 +++-- .../mcs/scheduling/server_test.go | 4 ++-- tests/pdctl/scheduler/scheduler_test.go | 20 ++++++++++++++++++- tests/server/api/scheduler_test.go | 4 ++++ tests/server/cluster/cluster_test.go | 8 +++++--- 7 files changed, 45 insertions(+), 15 deletions(-) diff --git a/pkg/schedule/config/config.go b/pkg/schedule/config/config.go index 27b8917d1bf..90a37c93d91 100644 --- a/pkg/schedule/config/config.go +++ b/pkg/schedule/config/config.go @@ -556,6 +556,7 @@ var DefaultSchedulers = SchedulerConfigs{ {Type: "balance-witness"}, {Type: "hot-region"}, {Type: "transfer-witness-leader"}, + {Type: "evict-slow-store"}, } // IsDefaultScheduler checks whether the scheduler is enable by default. diff --git a/server/cluster/cluster_test.go b/server/cluster/cluster_test.go index 7094fd6b673..8c889923ea7 100644 --- a/server/cluster/cluster_test.go +++ b/server/cluster/cluster_test.go @@ -3017,6 +3017,7 @@ func TestAddScheduler(t *testing.T) { re.NoError(controller.RemoveScheduler(schedulers.HotRegionName)) re.NoError(controller.RemoveScheduler(schedulers.BalanceWitnessName)) re.NoError(controller.RemoveScheduler(schedulers.TransferWitnessLeaderName)) + re.NoError(controller.RemoveScheduler(schedulers.EvictSlowStoreName)) re.Empty(controller.GetSchedulerNames()) stream := mockhbstream.NewHeartbeatStream() @@ -3107,13 +3108,15 @@ func TestPersistScheduler(t *testing.T) { re.NoError(err) re.Len(sches, defaultCount+2) - // remove 5 schedulers + // remove all default schedulers re.NoError(controller.RemoveScheduler(schedulers.BalanceLeaderName)) re.NoError(controller.RemoveScheduler(schedulers.BalanceRegionName)) re.NoError(controller.RemoveScheduler(schedulers.HotRegionName)) re.NoError(controller.RemoveScheduler(schedulers.BalanceWitnessName)) re.NoError(controller.RemoveScheduler(schedulers.TransferWitnessLeaderName)) - re.Len(controller.GetSchedulerNames(), defaultCount-3) + re.NoError(controller.RemoveScheduler(schedulers.EvictSlowStoreName)) + // only remains 2 items with independent config. + re.Len(controller.GetSchedulerNames(), 2) re.NoError(co.GetCluster().GetSchedulerConfig().Persist(storage)) co.Stop() co.GetSchedulersController().Wait() @@ -3137,7 +3140,7 @@ func TestPersistScheduler(t *testing.T) { re.NoError(err) re.Len(sches, 3) - // option have 6 items because the default scheduler do not remove. + // option have 9 items because the default scheduler do not remove. re.Len(newOpt.GetSchedulers(), defaultCount+3) re.NoError(newOpt.Persist(storage)) tc.RaftCluster.SetScheduleConfig(newOpt.GetScheduleConfig()) @@ -3164,9 +3167,9 @@ func TestPersistScheduler(t *testing.T) { brs, err := schedulers.CreateScheduler(schedulers.BalanceRegionType, oc, storage, schedulers.ConfigSliceDecoder(schedulers.BalanceRegionType, []string{"", ""})) re.NoError(err) re.NoError(controller.AddScheduler(brs)) - re.Len(controller.GetSchedulerNames(), defaultCount) + re.Len(controller.GetSchedulerNames(), 5) - // the scheduler option should contain 6 items + // the scheduler option should contain 9 items // the `hot scheduler` are disabled re.Len(co.GetCluster().GetSchedulerConfig().(*config.PersistOptions).GetSchedulers(), defaultCount+3) re.NoError(controller.RemoveScheduler(schedulers.GrantLeaderName)) @@ -3185,9 +3188,9 @@ func TestPersistScheduler(t *testing.T) { co.Run() controller = co.GetSchedulersController() - re.Len(controller.GetSchedulerNames(), defaultCount-1) + re.Len(controller.GetSchedulerNames(), 4) re.NoError(controller.RemoveScheduler(schedulers.EvictLeaderName)) - re.Len(controller.GetSchedulerNames(), defaultCount-2) + re.Len(controller.GetSchedulerNames(), 3) } func TestRemoveScheduler(t *testing.T) { @@ -3225,6 +3228,7 @@ func TestRemoveScheduler(t *testing.T) { re.NoError(controller.RemoveScheduler(schedulers.GrantLeaderName)) re.NoError(controller.RemoveScheduler(schedulers.BalanceWitnessName)) re.NoError(controller.RemoveScheduler(schedulers.TransferWitnessLeaderName)) + re.NoError(controller.RemoveScheduler(schedulers.EvictSlowStoreName)) // all removed sches, _, err = storage.LoadAllSchedulerConfigs() re.NoError(err) diff --git a/tests/integrations/mcs/scheduling/api_test.go b/tests/integrations/mcs/scheduling/api_test.go index 4c71f8f14a3..4ad6680a7cd 100644 --- a/tests/integrations/mcs/scheduling/api_test.go +++ b/tests/integrations/mcs/scheduling/api_test.go @@ -382,13 +382,14 @@ func (suite *apiTestSuite) checkConfig(cluster *tests.TestCluster) { suite.Equal(cfg.DataDir, s.GetConfig().DataDir) testutil.Eventually(re, func() bool { // wait for all schedulers to be loaded in scheduling server. - return len(cfg.Schedule.SchedulersPayload) == 5 + return len(cfg.Schedule.SchedulersPayload) == 6 }) suite.Contains(cfg.Schedule.SchedulersPayload, "balance-leader-scheduler") suite.Contains(cfg.Schedule.SchedulersPayload, "balance-region-scheduler") suite.Contains(cfg.Schedule.SchedulersPayload, "balance-hot-region-scheduler") suite.Contains(cfg.Schedule.SchedulersPayload, "balance-witness-scheduler") suite.Contains(cfg.Schedule.SchedulersPayload, "transfer-witness-leader-scheduler") + suite.Contains(cfg.Schedule.SchedulersPayload, "evict-slow-store-scheduler") } func (suite *apiTestSuite) TestConfigForward() { @@ -412,7 +413,7 @@ func (suite *apiTestSuite) checkConfigForward(cluster *tests.TestCluster) { re.Equal(cfg["replication"].(map[string]interface{})["max-replicas"], float64(opts.GetReplicationConfig().MaxReplicas)) schedulers := cfg["schedule"].(map[string]interface{})["schedulers-payload"].(map[string]interface{}) - return len(schedulers) == 5 + return len(schedulers) == 6 }) // Test to change config in api server diff --git a/tests/integrations/mcs/scheduling/server_test.go b/tests/integrations/mcs/scheduling/server_test.go index c65352114df..0c0442300d9 100644 --- a/tests/integrations/mcs/scheduling/server_test.go +++ b/tests/integrations/mcs/scheduling/server_test.go @@ -136,7 +136,7 @@ func (suite *serverTestSuite) TestPrimaryChange() { testutil.Eventually(re, func() bool { watchedAddr, ok := suite.pdLeader.GetServicePrimaryAddr(suite.ctx, mcs.SchedulingServiceName) return ok && oldPrimaryAddr == watchedAddr && - len(primary.GetCluster().GetCoordinator().GetSchedulersController().GetSchedulerNames()) == 5 + len(primary.GetCluster().GetCoordinator().GetSchedulersController().GetSchedulerNames()) == 6 }) // change primary primary.Close() @@ -147,7 +147,7 @@ func (suite *serverTestSuite) TestPrimaryChange() { testutil.Eventually(re, func() bool { watchedAddr, ok := suite.pdLeader.GetServicePrimaryAddr(suite.ctx, mcs.SchedulingServiceName) return ok && newPrimaryAddr == watchedAddr && - len(primary.GetCluster().GetCoordinator().GetSchedulersController().GetSchedulerNames()) == 5 + len(primary.GetCluster().GetCoordinator().GetSchedulersController().GetSchedulerNames()) == 6 }) } diff --git a/tests/pdctl/scheduler/scheduler_test.go b/tests/pdctl/scheduler/scheduler_test.go index fb7c239b431..140ee7a7c44 100644 --- a/tests/pdctl/scheduler/scheduler_test.go +++ b/tests/pdctl/scheduler/scheduler_test.go @@ -56,6 +56,7 @@ func (suite *schedulerTestSuite) SetupSuite() { "balance-hot-region-scheduler", "balance-witness-scheduler", "transfer-witness-leader-scheduler", + "evict-slow-store-scheduler", } } @@ -173,6 +174,7 @@ func (suite *schedulerTestSuite) checkScheduler(cluster *tests.TestCluster) { "balance-hot-region-scheduler": true, "transfer-witness-leader-scheduler": true, "balance-witness-scheduler": true, + "evict-slow-store-scheduler": true, } checkSchedulerCommand(nil, expected) @@ -183,6 +185,7 @@ func (suite *schedulerTestSuite) checkScheduler(cluster *tests.TestCluster) { "balance-hot-region-scheduler": true, "transfer-witness-leader-scheduler": true, "balance-witness-scheduler": true, + "evict-slow-store-scheduler": true, } checkSchedulerCommand(args, expected) @@ -228,6 +231,7 @@ func (suite *schedulerTestSuite) checkScheduler(cluster *tests.TestCluster) { schedulers[idx]: true, "transfer-witness-leader-scheduler": true, "balance-witness-scheduler": true, + "evict-slow-store-scheduler": true, } checkSchedulerCommand(args, expected) @@ -245,6 +249,7 @@ func (suite *schedulerTestSuite) checkScheduler(cluster *tests.TestCluster) { schedulers[idx]: true, "transfer-witness-leader-scheduler": true, "balance-witness-scheduler": true, + "evict-slow-store-scheduler": true, } // check update success @@ -260,6 +265,7 @@ func (suite *schedulerTestSuite) checkScheduler(cluster *tests.TestCluster) { "balance-hot-region-scheduler": true, "transfer-witness-leader-scheduler": true, "balance-witness-scheduler": true, + "evict-slow-store-scheduler": true, } checkSchedulerCommand(args, expected) checkStorePause([]uint64{}, schedulers[idx]) @@ -272,6 +278,7 @@ func (suite *schedulerTestSuite) checkScheduler(cluster *tests.TestCluster) { schedulers[idx]: true, "transfer-witness-leader-scheduler": true, "balance-witness-scheduler": true, + "evict-slow-store-scheduler": true, } checkSchedulerCommand(args, expected) checkStorePause([]uint64{2}, schedulers[idx]) @@ -284,6 +291,7 @@ func (suite *schedulerTestSuite) checkScheduler(cluster *tests.TestCluster) { schedulers[idx]: true, "transfer-witness-leader-scheduler": true, "balance-witness-scheduler": true, + "evict-slow-store-scheduler": true, } checkSchedulerCommand(args, expected) @@ -300,6 +308,7 @@ func (suite *schedulerTestSuite) checkScheduler(cluster *tests.TestCluster) { schedulers[idx]: true, "transfer-witness-leader-scheduler": true, "balance-witness-scheduler": true, + "evict-slow-store-scheduler": true, } checkSchedulerCommand(args, expected) @@ -315,6 +324,7 @@ func (suite *schedulerTestSuite) checkScheduler(cluster *tests.TestCluster) { "balance-hot-region-scheduler": true, "transfer-witness-leader-scheduler": true, "balance-witness-scheduler": true, + "evict-slow-store-scheduler": true, } checkSchedulerCommand(args, expected) checkStorePause([]uint64{}, schedulers[idx]) @@ -327,6 +337,7 @@ func (suite *schedulerTestSuite) checkScheduler(cluster *tests.TestCluster) { "shuffle-region-scheduler": true, "transfer-witness-leader-scheduler": true, "balance-witness-scheduler": true, + "evict-slow-store-scheduler": true, }) var roles []string mustExec(re, cmd, []string{"-u", pdAddr, "scheduler", "config", "shuffle-region-scheduler", "show-roles"}, &roles) @@ -348,6 +359,7 @@ func (suite *schedulerTestSuite) checkScheduler(cluster *tests.TestCluster) { "grant-hot-region-scheduler": true, "transfer-witness-leader-scheduler": true, "balance-witness-scheduler": true, + "evict-slow-store-scheduler": true, }) var conf3 map[string]interface{} expected3 := map[string]interface{}{ @@ -527,7 +539,11 @@ func (suite *schedulerTestSuite) checkScheduler(cluster *tests.TestCluster) { evictSlownessSchedulers := []string{"evict-slow-store-scheduler", "evict-slow-trend-scheduler"} for _, schedulerName := range evictSlownessSchedulers { echo = mustExec(re, cmd, []string{"-u", pdAddr, "scheduler", "add", schedulerName}, nil) - re.Contains(echo, "Success!") + if strings.Contains(echo, "Success!") { + re.Contains(echo, "Success!") + } else { + re.Contains(echo, "scheduler existed") + } testutil.Eventually(re, func() bool { echo = mustExec(re, cmd, []string{"-u", pdAddr, "scheduler", "show"}, nil) return strings.Contains(echo, schedulerName) @@ -546,6 +562,8 @@ func (suite *schedulerTestSuite) checkScheduler(cluster *tests.TestCluster) { return !strings.Contains(echo, schedulerName) }) } + echo = mustExec(re, cmd, []string{"-u", pdAddr, "scheduler", "add", "evict-slow-store-scheduler"}, nil) + re.Contains(echo, "Success!") // test shuffle hot region scheduler echo = mustExec(re, cmd, []string{"-u", pdAddr, "scheduler", "add", "shuffle-hot-region-scheduler"}, nil) diff --git a/tests/server/api/scheduler_test.go b/tests/server/api/scheduler_test.go index b3810da154a..d0472795f93 100644 --- a/tests/server/api/scheduler_test.go +++ b/tests/server/api/scheduler_test.go @@ -477,6 +477,10 @@ func (suite *scheduleTestSuite) checkAPI(cluster *tests.TestCluster) { suite.NoError(err) }, }, + { + name: "evict-slow-store-scheduler", + createdName: "evict-slow-store-scheduler", + }, } for _, testCase := range testCases { input := make(map[string]interface{}) diff --git a/tests/server/cluster/cluster_test.go b/tests/server/cluster/cluster_test.go index 18a82bcf0fe..0b0779d9434 100644 --- a/tests/server/cluster/cluster_test.go +++ b/tests/server/cluster/cluster_test.go @@ -1310,6 +1310,7 @@ func TestTransferLeaderForScheduler(t *testing.T) { time.Sleep(time.Second) re.True(leaderServer.GetRaftCluster().IsPrepared()) + schedsNum := len(rc.GetCoordinator().GetSchedulersController().GetSchedulerNames()) // Add evict leader scheduler api.MustAddScheduler(re, leaderServer.GetAddr(), schedulers.EvictLeaderName, map[string]interface{}{ "store_id": 1, @@ -1318,8 +1319,9 @@ func TestTransferLeaderForScheduler(t *testing.T) { "store_id": 2, }) // Check scheduler updated. + schedsNum += 1 schedulersController := rc.GetCoordinator().GetSchedulersController() - re.Len(schedulersController.GetSchedulerNames(), 6) + re.Len(schedulersController.GetSchedulerNames(), schedsNum) checkEvictLeaderSchedulerExist(re, schedulersController, true) checkEvictLeaderStoreIDs(re, schedulersController, []uint64{1, 2}) @@ -1339,7 +1341,7 @@ func TestTransferLeaderForScheduler(t *testing.T) { re.True(leaderServer.GetRaftCluster().IsPrepared()) // Check scheduler updated. schedulersController = rc1.GetCoordinator().GetSchedulersController() - re.Len(schedulersController.GetSchedulerNames(), 6) + re.Len(schedulersController.GetSchedulerNames(), schedsNum) checkEvictLeaderSchedulerExist(re, schedulersController, true) checkEvictLeaderStoreIDs(re, schedulersController, []uint64{1, 2}) @@ -1358,7 +1360,7 @@ func TestTransferLeaderForScheduler(t *testing.T) { re.True(leaderServer.GetRaftCluster().IsPrepared()) // Check scheduler updated schedulersController = rc.GetCoordinator().GetSchedulersController() - re.Len(schedulersController.GetSchedulerNames(), 6) + re.Len(schedulersController.GetSchedulerNames(), schedsNum) checkEvictLeaderSchedulerExist(re, schedulersController, true) checkEvictLeaderStoreIDs(re, schedulersController, []uint64{1, 2}) From 0a332a95f603309f97176008779b39627acf99d3 Mon Sep 17 00:00:00 2001 From: JmPotato Date: Tue, 19 Dec 2023 08:27:22 +0800 Subject: [PATCH 16/21] Makefile: add test-real-cluster command to the root Makefile (#7567) ref tikv/pd#7298 Add `test-real-cluster` command to the root Makefile. Signed-off-by: JmPotato --- Makefile | 9 ++++++++- pd.code-workspace | 4 ++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 2a506eb576f..133f99cfac8 100644 --- a/Makefile +++ b/Makefile @@ -265,7 +265,13 @@ test-tso-consistency: install-tools CGO_ENABLED=1 go test -race -tags without_dashboard,tso_consistency_test,deadlock $(TSO_INTEGRATION_TEST_PKGS) || { $(FAILPOINT_DISABLE); exit 1; } @$(FAILPOINT_DISABLE) -.PHONY: test basic-test test-with-cover test-tso-function test-tso-consistency +REAL_CLUSTER_TEST_PATH := $(ROOT_PATH)/tests/integrations/realtiup + +test-real-cluster: + # testing with the real cluster... + cd $(REAL_CLUSTER_TEST_PATH) && $(MAKE) check + +.PHONY: test basic-test test-with-cover test-tso-function test-tso-consistency test-real-cluster #### Daily CI coverage analyze #### @@ -297,6 +303,7 @@ clean-test: rm -rf /tmp/test_pd* rm -rf /tmp/pd-tests* rm -rf /tmp/test_etcd* + rm -f $(REAL_CLUSTER_TEST_PATH)/playground.log go clean -testcache clean-build: diff --git a/pd.code-workspace b/pd.code-workspace index d6110b56a09..54a8ea324aa 100644 --- a/pd.code-workspace +++ b/pd.code-workspace @@ -20,6 +20,10 @@ "name": "tso-tests", "path": "tests/integrations/tso" }, + { + "name": "real-cluster-tests", + "path": "tests/integrations/realtiup" + }, { "name": "pd-tso-bench", "path": "tools/pd-tso-bench" From 25f48f0bdd27b0d454a41ff0132ed643c6dd51f9 Mon Sep 17 00:00:00 2001 From: JmPotato Date: Tue, 19 Dec 2023 11:19:51 +0800 Subject: [PATCH 17/21] client/http: require source mark when initializing the http client (#7565) ref tikv/pd#7300 When creating a client, it is required to pass in the `source` parameter to further distinguish the source of logs and requests. Signed-off-by: JmPotato Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- client/http/client.go | 32 +++++++++++-------- client/http/client_test.go | 11 ++++--- tests/integrations/client/http_client_test.go | 2 +- tests/integrations/mcs/members/member_test.go | 2 +- tests/integrations/realtiup/util.go | 2 +- 5 files changed, 28 insertions(+), 21 deletions(-) diff --git a/client/http/client.go b/client/http/client.go index 21a3727e00f..bf8e9af9bbe 100644 --- a/client/http/client.go +++ b/client/http/client.go @@ -60,6 +60,9 @@ type clientInner struct { pdAddrs []string leaderAddrIdx int + // source is used to mark the source of the client creation, + // it will also be used in the caller ID of the inner client. + source string tlsConf *tls.Config cli *http.Client @@ -67,9 +70,9 @@ type clientInner struct { executionDuration *prometheus.HistogramVec } -func newClientInner() *clientInner { +func newClientInner(source string) *clientInner { ctx, cancel := context.WithCancel(context.Background()) - return &clientInner{ctx: ctx, cancel: cancel, leaderAddrIdx: -1} + return &clientInner{ctx: ctx, cancel: cancel, leaderAddrIdx: -1, source: source} } func (ci *clientInner) init() { @@ -155,7 +158,7 @@ func (ci *clientInner) requestWithRetry( return nil } log.Debug("[pd] request leader addr failed", - zap.Int("leader-idx", leaderAddrIdx), zap.String("addr", addr), zap.Error(err)) + zap.String("source", ci.source), zap.Int("leader-idx", leaderAddrIdx), zap.String("addr", addr), zap.Error(err)) } // Try to send the request to the other PD followers. for idx := 0; idx < len(pdAddrs) && idx != leaderAddrIdx; idx++ { @@ -165,7 +168,7 @@ func (ci *clientInner) requestWithRetry( break } log.Debug("[pd] request follower addr failed", - zap.Int("idx", idx), zap.String("addr", addr), zap.Error(err)) + zap.String("source", ci.source), zap.Int("idx", idx), zap.String("addr", addr), zap.Error(err)) } return err } @@ -176,6 +179,7 @@ func (ci *clientInner) doRequest( headerOpts ...HeaderOption, ) error { var ( + source = ci.source callerID = reqInfo.callerID name = reqInfo.name url = reqInfo.getURL(addr) @@ -185,6 +189,7 @@ func (ci *clientInner) doRequest( respHandler = reqInfo.respHandler ) logFields := []zap.Field{ + zap.String("source", source), zap.String("name", name), zap.String("url", url), zap.String("method", method), @@ -250,13 +255,13 @@ func (ci *clientInner) doRequest( func (ci *clientInner) membersInfoUpdater(ctx context.Context) { ci.updateMembersInfo(ctx) - log.Info("[pd] http client member info updater started") + log.Info("[pd] http client member info updater started", zap.String("source", ci.source)) ticker := time.NewTicker(defaultMembersInfoUpdateInterval) defer ticker.Stop() for { select { case <-ctx.Done(): - log.Info("[pd] http client member info updater stopped") + log.Info("[pd] http client member info updater stopped", zap.String("source", ci.source)) return case <-ticker.C: ci.updateMembersInfo(ctx) @@ -267,17 +272,17 @@ func (ci *clientInner) membersInfoUpdater(ctx context.Context) { func (ci *clientInner) updateMembersInfo(ctx context.Context) { var membersInfo MembersInfo err := ci.requestWithRetry(ctx, newRequestInfo(). - WithCallerID(defaultInnerCallerID). + WithCallerID(fmt.Sprintf("%s-%s", ci.source, defaultInnerCallerID)). WithName(getMembersName). WithURI(membersPrefix). WithMethod(http.MethodGet). WithResp(&membersInfo)) if err != nil { - log.Error("[pd] http client get members info failed", zap.Error(err)) + log.Error("[pd] http client get members info failed", zap.String("source", ci.source), zap.Error(err)) return } if len(membersInfo.Members) == 0 { - log.Error("[pd] http client get empty members info") + log.Error("[pd] http client get empty members info", zap.String("source", ci.source)) return } var ( @@ -292,7 +297,7 @@ func (ci *clientInner) updateMembersInfo(ctx context.Context) { } // Prevent setting empty addresses. if len(newPDAddrs) == 0 { - log.Error("[pd] http client get empty member addresses") + log.Error("[pd] http client get empty member addresses", zap.String("source", ci.source)) return } oldPDAddrs, oldLeaderAddrIdx := ci.getPDAddrs() @@ -307,7 +312,7 @@ func (ci *clientInner) updateMembersInfo(ctx context.Context) { } oldMemberNum, newMemberNum := len(oldPDAddrs), len(newPDAddrs) if oldPDLeaderAddr != newPDLeaderAddr || oldMemberNum != newMemberNum { - log.Info("[pd] http client members info changed", + log.Info("[pd] http client members info changed", zap.String("source", ci.source), zap.Int("old-member-num", oldMemberNum), zap.Int("new-member-num", newMemberNum), zap.Strings("old-addrs", oldPDAddrs), zap.Strings("new-addrs", newPDAddrs), zap.Int("old-leader-addr-idx", oldLeaderAddrIdx), zap.Int("new-leader-addr-idx", newLeaderAddrIdx), @@ -353,10 +358,11 @@ func WithMetrics( // NewClient creates a PD HTTP client with the given PD addresses and TLS config. func NewClient( + source string, pdAddrs []string, opts ...ClientOption, ) Client { - c := &client{inner: newClientInner(), callerID: defaultCallerID} + c := &client{inner: newClientInner(source), callerID: defaultCallerID} // Apply the options first. for _, opt := range opts { opt(c) @@ -369,7 +375,7 @@ func NewClient( // Close gracefully closes the HTTP client. func (c *client) Close() { c.inner.close() - log.Info("[pd] http client closed") + log.Info("[pd] http client closed", zap.String("source", c.inner.source)) } // WithCallerID sets and returns a new client with the given caller ID. diff --git a/client/http/client_test.go b/client/http/client_test.go index 7c7da80827c..af16ac649b5 100644 --- a/client/http/client_test.go +++ b/client/http/client_test.go @@ -18,6 +18,7 @@ import ( "context" "crypto/tls" "net/http" + "strings" "testing" "github.com/stretchr/testify/require" @@ -26,12 +27,12 @@ import ( func TestPDAddrNormalization(t *testing.T) { re := require.New(t) - c := NewClient([]string{"127.0.0.1"}) + c := NewClient("test-http-pd-addr", []string{"127.0.0.1"}) pdAddrs, leaderAddrIdx := c.(*client).inner.getPDAddrs() re.Equal(1, len(pdAddrs)) re.Equal(-1, leaderAddrIdx) re.Contains(pdAddrs[0], httpScheme) - c = NewClient([]string{"127.0.0.1"}, WithTLSConfig(&tls.Config{})) + c = NewClient("test-https-pd-addr", []string{"127.0.0.1"}, WithTLSConfig(&tls.Config{})) pdAddrs, leaderAddrIdx = c.(*client).inner.getPDAddrs() re.Equal(1, len(pdAddrs)) re.Equal(-1, leaderAddrIdx) @@ -68,7 +69,7 @@ func TestPDAllowFollowerHandleHeader(t *testing.T) { } return nil }) - c := NewClient([]string{"http://127.0.0.1"}, WithHTTPClient(httpClient)) + c := NewClient("test-header", []string{"http://127.0.0.1"}, WithHTTPClient(httpClient)) c.GetRegions(context.Background()) c.GetHistoryHotRegions(context.Background(), &HistoryHotRegionsRequest{}) c.Close() @@ -80,13 +81,13 @@ func TestCallerID(t *testing.T) { httpClient := newHTTPClientWithRequestChecker(func(req *http.Request) error { val := req.Header.Get(xCallerIDKey) // Exclude the request sent by the inner client. - if val != defaultInnerCallerID && val != expectedVal.Load() { + if !strings.Contains(val, defaultInnerCallerID) && val != expectedVal.Load() { re.Failf("Caller ID header check failed", "should be %s, but got %s", expectedVal, val) } return nil }) - c := NewClient([]string{"http://127.0.0.1"}, WithHTTPClient(httpClient)) + c := NewClient("test-caller-id", []string{"http://127.0.0.1"}, WithHTTPClient(httpClient)) c.GetRegions(context.Background()) expectedVal.Store("test") c.WithCallerID(expectedVal.Load()).GetRegions(context.Background()) diff --git a/tests/integrations/client/http_client_test.go b/tests/integrations/client/http_client_test.go index 1a20ae2e784..12064c989a1 100644 --- a/tests/integrations/client/http_client_test.go +++ b/tests/integrations/client/http_client_test.go @@ -72,7 +72,7 @@ func (suite *httpClientTestSuite) SetupSuite() { for _, s := range testServers { endpoints = append(endpoints, s.GetConfig().AdvertiseClientUrls) } - suite.client = pd.NewClient(endpoints) + suite.client = pd.NewClient("pd-http-client-it", endpoints) } func (suite *httpClientTestSuite) TearDownSuite() { diff --git a/tests/integrations/mcs/members/member_test.go b/tests/integrations/mcs/members/member_test.go index d1ccb86a1c7..048bcf72154 100644 --- a/tests/integrations/mcs/members/member_test.go +++ b/tests/integrations/mcs/members/member_test.go @@ -52,7 +52,7 @@ func (suite *memberTestSuite) SetupTest() { suite.server = cluster.GetLeaderServer() suite.NoError(suite.server.BootstrapCluster()) suite.backendEndpoints = suite.server.GetAddr() - suite.dialClient = pdClient.NewClient([]string{suite.server.GetAddr()}) + suite.dialClient = pdClient.NewClient("mcs-member-test", []string{suite.server.GetAddr()}) // TSO nodes := make(map[string]bs.Server) diff --git a/tests/integrations/realtiup/util.go b/tests/integrations/realtiup/util.go index 66d6127b5c4..8f0c71038d6 100644 --- a/tests/integrations/realtiup/util.go +++ b/tests/integrations/realtiup/util.go @@ -24,7 +24,7 @@ const physicalShiftBits = 18 var ( pdAddrs = []string{"127.0.0.1:2379"} - pdHTTPCli = http.NewClient(pdAddrs) + pdHTTPCli = http.NewClient("pd-realtiup-test", pdAddrs) ) // GetTimeFromTS extracts time.Time from a timestamp. From 83da4c26b10e7f4e371ebc1a934fddd5981c10b0 Mon Sep 17 00:00:00 2001 From: Connor Date: Tue, 19 Dec 2023 16:10:52 +0800 Subject: [PATCH 18/21] dashboard: Pass TLS info to dashboard to fix TiKV heap profiling (#7563) close tikv/pd#7561 Pass TLS info to dashboard to fix TiKV heap profiling Signed-off-by: Connor1996 --- go.mod | 2 +- go.sum | 4 ++-- pkg/dashboard/adapter/config.go | 3 +++ pkg/utils/grpcutil/grpcutil.go | 31 +++++++++++++++++++++---------- tests/integrations/client/go.mod | 2 +- tests/integrations/client/go.sum | 4 ++-- tests/integrations/mcs/go.mod | 2 +- tests/integrations/mcs/go.sum | 4 ++-- tests/integrations/tso/go.mod | 2 +- tests/integrations/tso/go.sum | 4 ++-- 10 files changed, 36 insertions(+), 22 deletions(-) diff --git a/go.mod b/go.mod index d5cbc41f654..e60364d9c57 100644 --- a/go.mod +++ b/go.mod @@ -36,7 +36,7 @@ require ( github.com/pingcap/kvproto v0.0.0-20231018065736-c0689aded40c github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 github.com/pingcap/sysutil v1.0.1-0.20230407040306-fb007c5aff21 - github.com/pingcap/tidb-dashboard v0.0.0-20231127105651-ce4097837c5e + github.com/pingcap/tidb-dashboard v0.0.0-20231218095437-aa621ed4de2c github.com/prometheus/client_golang v1.11.1 github.com/prometheus/common v0.26.0 github.com/sasha-s/go-deadlock v0.2.0 diff --git a/go.sum b/go.sum index bf35be7eb8c..b1136a6f2a2 100644 --- a/go.sum +++ b/go.sum @@ -466,8 +466,8 @@ github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 h1:HR/ylkkLmGdSSDaD8 github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3/go.mod h1:DWQW5jICDR7UJh4HtxXSM20Churx4CQL0fwL/SoOSA4= github.com/pingcap/sysutil v1.0.1-0.20230407040306-fb007c5aff21 h1:QV6jqlfOkh8hqvEAgwBZa+4bSgO0EeKC7s5c6Luam2I= github.com/pingcap/sysutil v1.0.1-0.20230407040306-fb007c5aff21/go.mod h1:QYnjfA95ZaMefyl1NO8oPtKeb8pYUdnDVhQgf+qdpjM= -github.com/pingcap/tidb-dashboard v0.0.0-20231127105651-ce4097837c5e h1:SJUSDejvKtj9vSh5ptRHh4iMrvPV3oKO8yp6/SYE8vc= -github.com/pingcap/tidb-dashboard v0.0.0-20231127105651-ce4097837c5e/go.mod h1:ucZBRz52icb23T/5Z4CsuUHmarYiin7p2MeiVBe+o8c= +github.com/pingcap/tidb-dashboard v0.0.0-20231218095437-aa621ed4de2c h1:iEZwsxxOxXaH0zEfzVAn6fjveOlPh3v3DsYlhWJAVi0= +github.com/pingcap/tidb-dashboard v0.0.0-20231218095437-aa621ed4de2c/go.mod h1:ucZBRz52icb23T/5Z4CsuUHmarYiin7p2MeiVBe+o8c= github.com/pingcap/tipb v0.0.0-20220718022156-3e2483c20a9e h1:FBaTXU8C3xgt/drM58VHxojHo/QoG1oPsgWTGvaSpO4= github.com/pingcap/tipb v0.0.0-20220718022156-3e2483c20a9e/go.mod h1:A7mrd7WHBl1o63LE2bIBGEJMTNWXqhgmYiOvMLxozfs= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= diff --git a/pkg/dashboard/adapter/config.go b/pkg/dashboard/adapter/config.go index 63c900acf77..a1661b84f2b 100644 --- a/pkg/dashboard/adapter/config.go +++ b/pkg/dashboard/adapter/config.go @@ -38,6 +38,9 @@ func GenDashboardConfig(srv *server.Server) (*config.Config, error) { if dashboardCfg.ClusterTLSConfig, err = cfg.Security.ToTLSConfig(); err != nil { return nil, err } + if dashboardCfg.ClusterTLSInfo, err = cfg.Security.ToTLSInfo(); err != nil { + return nil, err + } if dashboardCfg.TiDBTLSConfig, err = cfg.Dashboard.ToTiDBTLSConfig(); err != nil { return nil, err } diff --git a/pkg/utils/grpcutil/grpcutil.go b/pkg/utils/grpcutil/grpcutil.go index a001ec4bd03..0030551d0fc 100644 --- a/pkg/utils/grpcutil/grpcutil.go +++ b/pkg/utils/grpcutil/grpcutil.go @@ -56,6 +56,24 @@ type TLSConfig struct { SSLKEYBytes []byte } +// ToTLSInfo converts TLSConfig to transport.TLSInfo. +func (s TLSConfig) ToTLSInfo() (*transport.TLSInfo, error) { + if len(s.CertPath) == 0 && len(s.KeyPath) == 0 { + return nil, nil + } + allowedCN, err := s.GetOneAllowedCN() + if err != nil { + return nil, err + } + + return &transport.TLSInfo{ + CertFile: s.CertPath, + KeyFile: s.KeyPath, + TrustedCAFile: s.CAPath, + AllowedCN: allowedCN, + }, nil +} + // ToTLSConfig generates tls config. func (s TLSConfig) ToTLSConfig() (*tls.Config, error) { if len(s.SSLCABytes) != 0 || len(s.SSLCertBytes) != 0 || len(s.SSLKEYBytes) != 0 { @@ -77,19 +95,12 @@ func (s TLSConfig) ToTLSConfig() (*tls.Config, error) { }, nil } - if len(s.CertPath) == 0 && len(s.KeyPath) == 0 { + tlsInfo, err := s.ToTLSInfo() + if tlsInfo == nil { return nil, nil } - allowedCN, err := s.GetOneAllowedCN() if err != nil { - return nil, err - } - - tlsInfo := transport.TLSInfo{ - CertFile: s.CertPath, - KeyFile: s.KeyPath, - TrustedCAFile: s.CAPath, - AllowedCN: allowedCN, + return nil, errs.ErrEtcdTLSConfig.Wrap(err).GenWithStackByCause() } tlsConfig, err := tlsInfo.ClientConfig() diff --git a/tests/integrations/client/go.mod b/tests/integrations/client/go.mod index da130278ae0..5c73abfb8ed 100644 --- a/tests/integrations/client/go.mod +++ b/tests/integrations/client/go.mod @@ -123,7 +123,7 @@ require ( github.com/pingcap/errcode v0.3.0 // indirect github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c // indirect github.com/pingcap/sysutil v1.0.1-0.20230407040306-fb007c5aff21 // indirect - github.com/pingcap/tidb-dashboard v0.0.0-20231127105651-ce4097837c5e // indirect + github.com/pingcap/tidb-dashboard v0.0.0-20231218095437-aa621ed4de2c // indirect github.com/pingcap/tipb v0.0.0-20220718022156-3e2483c20a9e // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/tests/integrations/client/go.sum b/tests/integrations/client/go.sum index e13da5d8375..f70f08f366e 100644 --- a/tests/integrations/client/go.sum +++ b/tests/integrations/client/go.sum @@ -430,8 +430,8 @@ github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 h1:HR/ylkkLmGdSSDaD8 github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3/go.mod h1:DWQW5jICDR7UJh4HtxXSM20Churx4CQL0fwL/SoOSA4= github.com/pingcap/sysutil v1.0.1-0.20230407040306-fb007c5aff21 h1:QV6jqlfOkh8hqvEAgwBZa+4bSgO0EeKC7s5c6Luam2I= github.com/pingcap/sysutil v1.0.1-0.20230407040306-fb007c5aff21/go.mod h1:QYnjfA95ZaMefyl1NO8oPtKeb8pYUdnDVhQgf+qdpjM= -github.com/pingcap/tidb-dashboard v0.0.0-20231127105651-ce4097837c5e h1:SJUSDejvKtj9vSh5ptRHh4iMrvPV3oKO8yp6/SYE8vc= -github.com/pingcap/tidb-dashboard v0.0.0-20231127105651-ce4097837c5e/go.mod h1:ucZBRz52icb23T/5Z4CsuUHmarYiin7p2MeiVBe+o8c= +github.com/pingcap/tidb-dashboard v0.0.0-20231218095437-aa621ed4de2c h1:iEZwsxxOxXaH0zEfzVAn6fjveOlPh3v3DsYlhWJAVi0= +github.com/pingcap/tidb-dashboard v0.0.0-20231218095437-aa621ed4de2c/go.mod h1:ucZBRz52icb23T/5Z4CsuUHmarYiin7p2MeiVBe+o8c= github.com/pingcap/tipb v0.0.0-20220718022156-3e2483c20a9e h1:FBaTXU8C3xgt/drM58VHxojHo/QoG1oPsgWTGvaSpO4= github.com/pingcap/tipb v0.0.0-20220718022156-3e2483c20a9e/go.mod h1:A7mrd7WHBl1o63LE2bIBGEJMTNWXqhgmYiOvMLxozfs= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= diff --git a/tests/integrations/mcs/go.mod b/tests/integrations/mcs/go.mod index 1823a224fa1..2e41e87b746 100644 --- a/tests/integrations/mcs/go.mod +++ b/tests/integrations/mcs/go.mod @@ -123,7 +123,7 @@ require ( github.com/pingcap/errcode v0.3.0 // indirect github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c // indirect github.com/pingcap/sysutil v1.0.1-0.20230407040306-fb007c5aff21 // indirect - github.com/pingcap/tidb-dashboard v0.0.0-20231127105651-ce4097837c5e // indirect + github.com/pingcap/tidb-dashboard v0.0.0-20231218095437-aa621ed4de2c // indirect github.com/pingcap/tipb v0.0.0-20220718022156-3e2483c20a9e // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/tests/integrations/mcs/go.sum b/tests/integrations/mcs/go.sum index dfead54afe1..65e8cf72aab 100644 --- a/tests/integrations/mcs/go.sum +++ b/tests/integrations/mcs/go.sum @@ -434,8 +434,8 @@ github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 h1:HR/ylkkLmGdSSDaD8 github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3/go.mod h1:DWQW5jICDR7UJh4HtxXSM20Churx4CQL0fwL/SoOSA4= github.com/pingcap/sysutil v1.0.1-0.20230407040306-fb007c5aff21 h1:QV6jqlfOkh8hqvEAgwBZa+4bSgO0EeKC7s5c6Luam2I= github.com/pingcap/sysutil v1.0.1-0.20230407040306-fb007c5aff21/go.mod h1:QYnjfA95ZaMefyl1NO8oPtKeb8pYUdnDVhQgf+qdpjM= -github.com/pingcap/tidb-dashboard v0.0.0-20231127105651-ce4097837c5e h1:SJUSDejvKtj9vSh5ptRHh4iMrvPV3oKO8yp6/SYE8vc= -github.com/pingcap/tidb-dashboard v0.0.0-20231127105651-ce4097837c5e/go.mod h1:ucZBRz52icb23T/5Z4CsuUHmarYiin7p2MeiVBe+o8c= +github.com/pingcap/tidb-dashboard v0.0.0-20231218095437-aa621ed4de2c h1:iEZwsxxOxXaH0zEfzVAn6fjveOlPh3v3DsYlhWJAVi0= +github.com/pingcap/tidb-dashboard v0.0.0-20231218095437-aa621ed4de2c/go.mod h1:ucZBRz52icb23T/5Z4CsuUHmarYiin7p2MeiVBe+o8c= github.com/pingcap/tipb v0.0.0-20220718022156-3e2483c20a9e h1:FBaTXU8C3xgt/drM58VHxojHo/QoG1oPsgWTGvaSpO4= github.com/pingcap/tipb v0.0.0-20220718022156-3e2483c20a9e/go.mod h1:A7mrd7WHBl1o63LE2bIBGEJMTNWXqhgmYiOvMLxozfs= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= diff --git a/tests/integrations/tso/go.mod b/tests/integrations/tso/go.mod index 0d734716ed5..af38fcf4241 100644 --- a/tests/integrations/tso/go.mod +++ b/tests/integrations/tso/go.mod @@ -121,7 +121,7 @@ require ( github.com/pingcap/errors v0.11.5-0.20211224045212-9687c2b0f87c // indirect github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 // indirect github.com/pingcap/sysutil v1.0.1-0.20230407040306-fb007c5aff21 // indirect - github.com/pingcap/tidb-dashboard v0.0.0-20231127105651-ce4097837c5e // indirect + github.com/pingcap/tidb-dashboard v0.0.0-20231218095437-aa621ed4de2c // indirect github.com/pingcap/tipb v0.0.0-20220718022156-3e2483c20a9e // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect diff --git a/tests/integrations/tso/go.sum b/tests/integrations/tso/go.sum index 94fbde2ad57..c078eab2f6d 100644 --- a/tests/integrations/tso/go.sum +++ b/tests/integrations/tso/go.sum @@ -428,8 +428,8 @@ github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3 h1:HR/ylkkLmGdSSDaD8 github.com/pingcap/log v1.1.1-0.20221110025148-ca232912c9f3/go.mod h1:DWQW5jICDR7UJh4HtxXSM20Churx4CQL0fwL/SoOSA4= github.com/pingcap/sysutil v1.0.1-0.20230407040306-fb007c5aff21 h1:QV6jqlfOkh8hqvEAgwBZa+4bSgO0EeKC7s5c6Luam2I= github.com/pingcap/sysutil v1.0.1-0.20230407040306-fb007c5aff21/go.mod h1:QYnjfA95ZaMefyl1NO8oPtKeb8pYUdnDVhQgf+qdpjM= -github.com/pingcap/tidb-dashboard v0.0.0-20231127105651-ce4097837c5e h1:SJUSDejvKtj9vSh5ptRHh4iMrvPV3oKO8yp6/SYE8vc= -github.com/pingcap/tidb-dashboard v0.0.0-20231127105651-ce4097837c5e/go.mod h1:ucZBRz52icb23T/5Z4CsuUHmarYiin7p2MeiVBe+o8c= +github.com/pingcap/tidb-dashboard v0.0.0-20231218095437-aa621ed4de2c h1:iEZwsxxOxXaH0zEfzVAn6fjveOlPh3v3DsYlhWJAVi0= +github.com/pingcap/tidb-dashboard v0.0.0-20231218095437-aa621ed4de2c/go.mod h1:ucZBRz52icb23T/5Z4CsuUHmarYiin7p2MeiVBe+o8c= github.com/pingcap/tipb v0.0.0-20220718022156-3e2483c20a9e h1:FBaTXU8C3xgt/drM58VHxojHo/QoG1oPsgWTGvaSpO4= github.com/pingcap/tipb v0.0.0-20220718022156-3e2483c20a9e/go.mod h1:A7mrd7WHBl1o63LE2bIBGEJMTNWXqhgmYiOvMLxozfs= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= From 59c9d0475fca695f441f525a1969d404715f752c Mon Sep 17 00:00:00 2001 From: JmPotato Date: Wed, 20 Dec 2023 10:42:22 +0800 Subject: [PATCH 19/21] tests: fix some errors detected by testifylint (#7580) ref tikv/pd#4813 Fix some errors detected by `testifylint`. Signed-off-by: JmPotato --- .golangci.yml | 19 ++- Makefile | 8 +- pkg/autoscaling/calculation_test.go | 2 +- pkg/autoscaling/prometheus_test.go | 2 +- pkg/balancer/balancer_test.go | 6 +- pkg/cache/cache_test.go | 32 ++-- pkg/core/storelimit/limit_test.go | 2 +- pkg/dashboard/adapter/redirector_test.go | 12 +- pkg/election/leadership_test.go | 2 +- pkg/encryption/config_test.go | 6 +- pkg/encryption/crypter_test.go | 10 +- pkg/encryption/key_manager_test.go | 2 +- pkg/keyspace/keyspace_test.go | 31 ++-- pkg/keyspace/tso_keyspace_group_test.go | 9 +- pkg/mcs/resourcemanager/server/config_test.go | 4 +- .../server/token_buckets_test.go | 14 +- pkg/ratelimit/controller_test.go | 38 ++--- pkg/ratelimit/limiter_test.go | 30 ++-- pkg/schedule/checker/rule_checker_test.go | 14 +- pkg/schedule/filter/counter_test.go | 10 +- pkg/schedule/labeler/rule_test.go | 4 +- pkg/schedule/operator/builder_test.go | 157 +++++++++--------- pkg/schedule/placement/rule_manager_test.go | 10 +- pkg/schedule/plan/balance_plan_test.go | 42 +++-- pkg/schedule/scatter/region_scatterer_test.go | 4 +- .../schedulers/balance_benchmark_test.go | 2 +- pkg/schedule/schedulers/balance_test.go | 38 +++-- .../schedulers/balance_witness_test.go | 2 +- pkg/schedule/schedulers/evict_leader_test.go | 2 +- .../schedulers/evict_slow_trend_test.go | 4 +- pkg/schedule/schedulers/hot_region_test.go | 24 +-- pkg/schedule/schedulers/hot_region_v2_test.go | 8 +- pkg/schedule/schedulers/scheduler_test.go | 2 +- pkg/storage/storage_gc_test.go | 2 +- pkg/tso/keyspace_group_manager_test.go | 2 +- .../unsafe_recovery_controller_test.go | 6 +- pkg/utils/etcdutil/etcdutil_test.go | 80 ++++----- pkg/utils/syncutil/lock_group_test.go | 4 +- pkg/utils/typeutil/duration_test.go | 2 +- pkg/window/policy_test.go | 6 +- pkg/window/window_test.go | 23 +-- scripts/check-test.sh | 37 ----- server/api/admin_test.go | 124 +++++++------- server/api/region_test.go | 2 +- server/cluster/cluster_test.go | 18 +- server/config/config_test.go | 4 +- tests/pdctl/hot/hot_test.go | 18 +- tests/pdctl/keyspace/keyspace_group_test.go | 16 +- tests/pdctl/keyspace/keyspace_test.go | 16 +- tests/pdctl/log/log_test.go | 7 +- tests/pdctl/operator/operator_test.go | 2 +- .../resource_manager_command_test.go | 22 +-- tests/pdctl/scheduler/scheduler_test.go | 6 +- tests/server/api/api_test.go | 42 ++--- tests/server/api/rule_test.go | 22 +-- tests/server/apiv2/handlers/keyspace_test.go | 16 +- .../apiv2/handlers/tso_keyspace_group_test.go | 9 +- tests/server/cluster/cluster_test.go | 2 +- tests/server/cluster/cluster_work_test.go | 5 +- tests/server/keyspace/keyspace_test.go | 2 +- tools/pd-backup/pdbackup/backup_test.go | 24 +-- .../simulator/simutil/key_test.go | 4 +- 62 files changed, 541 insertions(+), 534 deletions(-) delete mode 100755 scripts/check-test.sh diff --git a/.golangci.yml b/.golangci.yml index 079e25ec2b3..59954cecee3 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -11,15 +11,12 @@ linters: - makezero - gosec - bodyclose + # TODO: enable when all existing errors are fixed + # - testifylint disable: - errcheck linters-settings: gocritic: - # Which checks should be enabled; can't be combined with 'disabled-checks'; - # See https://go-critic.github.io/overview#checks-overview - # To check which checks are enabled run `GL_DEBUG=gocritic golangci-lint run` - # By default list of stable checks is used. - enabled-checks: # Which checks should be disabled; can't be combined with 'enabled-checks'; default is empty disabled-checks: - regexpMust @@ -33,3 +30,15 @@ linters-settings: - G402 - G404 - G601 + testifylint: + enable: + - bool-compare + - compares + - empty + - error-is-as + - error-nil + - expected-actual + - len + - require-error + - suite-dont-use-pkg + - suite-extra-assert-call diff --git a/Makefile b/Makefile index 133f99cfac8..bf25730c0d9 100644 --- a/Makefile +++ b/Makefile @@ -172,7 +172,7 @@ install-tools: #### Static checks #### -check: install-tools tidy static generate-errdoc check-test +check: install-tools tidy static generate-errdoc static: install-tools @ echo "gofmt ..." @@ -199,11 +199,7 @@ check-plugin: @echo "checking plugin..." cd ./plugin/scheduler_example && $(MAKE) evictLeaderPlugin.so && rm evictLeaderPlugin.so -check-test: - @echo "checking test..." - ./scripts/check-test.sh - -.PHONY: check static tidy generate-errdoc check-plugin check-test +.PHONY: check static tidy generate-errdoc check-plugin #### Test utils #### diff --git a/pkg/autoscaling/calculation_test.go b/pkg/autoscaling/calculation_test.go index de3be68d68c..85f723b562c 100644 --- a/pkg/autoscaling/calculation_test.go +++ b/pkg/autoscaling/calculation_test.go @@ -233,7 +233,7 @@ func TestGetTotalCPUUseTime(t *testing.T) { } totalCPUUseTime, _ := getTotalCPUUseTime(querier, TiDB, instances, time.Now(), 0) expected := mockResultValue * float64(len(instances)) - re.True(math.Abs(expected-totalCPUUseTime) < 1e-6) + re.Less(math.Abs(expected-totalCPUUseTime), 1e-6) } func TestGetTotalCPUQuota(t *testing.T) { diff --git a/pkg/autoscaling/prometheus_test.go b/pkg/autoscaling/prometheus_test.go index 6c30e3ead4c..2efdc348ead 100644 --- a/pkg/autoscaling/prometheus_test.go +++ b/pkg/autoscaling/prometheus_test.go @@ -196,7 +196,7 @@ func TestRetrieveCPUMetrics(t *testing.T) { for i := 0; i < len(addresses)-1; i++ { value, ok := result[addresses[i]] re.True(ok) - re.True(math.Abs(value-mockResultValue) < 1e-6) + re.Less(math.Abs(value-mockResultValue), 1e-6) } _, ok := result[addresses[len(addresses)-1]] diff --git a/pkg/balancer/balancer_test.go b/pkg/balancer/balancer_test.go index f95487a4cc7..996b4f1da35 100644 --- a/pkg/balancer/balancer_test.go +++ b/pkg/balancer/balancer_test.go @@ -62,7 +62,7 @@ func TestBalancerDuplicate(t *testing.T) { NewRoundRobin[uint32](), } for _, balancer := range balancers { - re.Len(balancer.GetAll(), 0) + re.Empty(balancer.GetAll()) // test duplicate put balancer.Put(1) re.Len(balancer.GetAll(), 1) @@ -70,9 +70,9 @@ func TestBalancerDuplicate(t *testing.T) { re.Len(balancer.GetAll(), 1) // test duplicate delete balancer.Delete(1) - re.Len(balancer.GetAll(), 0) + re.Empty(balancer.GetAll()) balancer.Delete(1) - re.Len(balancer.GetAll(), 0) + re.Empty(balancer.GetAll()) } } diff --git a/pkg/cache/cache_test.go b/pkg/cache/cache_test.go index b02e8823398..fe9f84223c1 100644 --- a/pkg/cache/cache_test.go +++ b/pkg/cache/cache_test.go @@ -77,7 +77,7 @@ func TestExpireRegionCache(t *testing.T) { re.Equal(3, cache.Len()) - re.Equal(sortIDs(cache.GetAllID()), []uint64{1, 2, 3}) + re.Equal([]uint64{1, 2, 3}, sortIDs(cache.GetAllID())) // after 20ms, the key 1 will be expired time.Sleep(20 * time.Millisecond) @@ -98,7 +98,7 @@ func TestExpireRegionCache(t *testing.T) { // we can't ensure whether gc is executed, so we check the length of cache in a loop. return cache.Len() == 2 }, testutil.WithWaitFor(50*time.Millisecond), testutil.WithTickInterval(time.Millisecond)) - re.Equal(sortIDs(cache.GetAllID()), []uint64{2, 3}) + re.Equal([]uint64{2, 3}, sortIDs(cache.GetAllID())) cache.Remove(2) @@ -111,7 +111,7 @@ func TestExpireRegionCache(t *testing.T) { re.Equal(3.0, value) re.Equal(1, cache.Len()) - re.Equal(sortIDs(cache.GetAllID()), []uint64{3}) + re.Equal([]uint64{3}, sortIDs(cache.GetAllID())) } func sortIDs(ids []uint64) []uint64 { @@ -131,15 +131,15 @@ func TestLRUCache(t *testing.T) { val, ok := cache.Get(3) re.True(ok) - re.Equal(val, "3") + re.Equal("3", val) val, ok = cache.Get(2) re.True(ok) - re.Equal(val, "2") + re.Equal("2", val) val, ok = cache.Get(1) re.True(ok) - re.Equal(val, "1") + re.Equal("1", val) re.Equal(3, cache.Len()) @@ -153,27 +153,27 @@ func TestLRUCache(t *testing.T) { val, ok = cache.Get(1) re.True(ok) - re.Equal(val, "1") + re.Equal("1", val) val, ok = cache.Get(2) re.True(ok) - re.Equal(val, "2") + re.Equal("2", val) val, ok = cache.Get(4) re.True(ok) - re.Equal(val, "4") + re.Equal("4", val) re.Equal(3, cache.Len()) val, ok = cache.Peek(1) re.True(ok) - re.Equal(val, "1") + re.Equal("1", val) elems := cache.Elems() re.Len(elems, 3) - re.Equal(elems[0].Value, "4") - re.Equal(elems[1].Value, "2") - re.Equal(elems[2].Value, "1") + re.Equal("4", elems[0].Value) + re.Equal("2", elems[1].Value) + re.Equal("1", elems[2].Value) cache.Remove(1) cache.Remove(2) @@ -247,15 +247,15 @@ func TestFifoFromLastSameElems(t *testing.T) { }) } items := fun() - re.Equal(1, len(items)) + re.Len(items, 1) cache.Put(1, &testStruct{value: "3"}) cache.Put(2, &testStruct{value: "3"}) items = fun() - re.Equal(3, len(items)) + re.Len(items, 3) re.Equal("3", items[0].Value.(*testStruct).value) cache.Put(1, &testStruct{value: "2"}) items = fun() - re.Equal(1, len(items)) + re.Len(items, 1) re.Equal("2", items[0].Value.(*testStruct).value) } diff --git a/pkg/core/storelimit/limit_test.go b/pkg/core/storelimit/limit_test.go index 6f57c01eccb..946729f8ce2 100644 --- a/pkg/core/storelimit/limit_test.go +++ b/pkg/core/storelimit/limit_test.go @@ -30,7 +30,7 @@ func TestStoreLimit(t *testing.T) { re := require.New(t) rate := int64(15) limit := NewStoreRateLimit(float64(rate)).(*StoreRateLimit) - re.Equal(limit.Rate(AddPeer), float64(15)) + re.Equal(float64(15), limit.Rate(AddPeer)) re.True(limit.Available(influence*rate, AddPeer, constant.Low)) re.True(limit.Take(influence*rate, AddPeer, constant.Low)) re.False(limit.Take(influence, AddPeer, constant.Low)) diff --git a/pkg/dashboard/adapter/redirector_test.go b/pkg/dashboard/adapter/redirector_test.go index c5d837507fc..5fc9ea5ea99 100644 --- a/pkg/dashboard/adapter/redirector_test.go +++ b/pkg/dashboard/adapter/redirector_test.go @@ -65,37 +65,39 @@ func (suite *redirectorTestSuite) TearDownSuite() { } func (suite *redirectorTestSuite) TestReverseProxy() { + re := suite.Require() redirectorServer := httptest.NewServer(http.HandlerFunc(suite.redirector.ReverseProxy)) defer redirectorServer.Close() suite.redirector.SetAddress(suite.tempServer.URL) // Test normal forwarding req, err := http.NewRequest(http.MethodGet, redirectorServer.URL, http.NoBody) - suite.NoError(err) + re.NoError(err) checkHTTPRequest(suite.Require(), suite.noRedirectHTTPClient, req, http.StatusOK, suite.tempText) // Test the requests that are forwarded by others req, err = http.NewRequest(http.MethodGet, redirectorServer.URL, http.NoBody) - suite.NoError(err) + re.NoError(err) req.Header.Set(proxyHeader, "other") checkHTTPRequest(suite.Require(), suite.noRedirectHTTPClient, req, http.StatusOK, suite.tempText) // Test LoopDetected suite.redirector.SetAddress(redirectorServer.URL) req, err = http.NewRequest(http.MethodGet, redirectorServer.URL, http.NoBody) - suite.NoError(err) + re.NoError(err) checkHTTPRequest(suite.Require(), suite.noRedirectHTTPClient, req, http.StatusLoopDetected, "") } func (suite *redirectorTestSuite) TestTemporaryRedirect() { + re := suite.Require() redirectorServer := httptest.NewServer(http.HandlerFunc(suite.redirector.TemporaryRedirect)) defer redirectorServer.Close() suite.redirector.SetAddress(suite.tempServer.URL) // Test TemporaryRedirect req, err := http.NewRequest(http.MethodGet, redirectorServer.URL, http.NoBody) - suite.NoError(err) + re.NoError(err) checkHTTPRequest(suite.Require(), suite.noRedirectHTTPClient, req, http.StatusTemporaryRedirect, "") // Test Response req, err = http.NewRequest(http.MethodGet, redirectorServer.URL, http.NoBody) - suite.NoError(err) + re.NoError(err) checkHTTPRequest(suite.Require(), http.DefaultClient, req, http.StatusOK, suite.tempText) } diff --git a/pkg/election/leadership_test.go b/pkg/election/leadership_test.go index c259476e44e..be1922fe381 100644 --- a/pkg/election/leadership_test.go +++ b/pkg/election/leadership_test.go @@ -175,7 +175,7 @@ func TestExitWatch(t *testing.T) { resp2, err := client.MemberList(context.Background()) re.NoError(err) - re.Equal(3, len(resp2.Members)) + re.Len(resp2.Members, 3) etcd2.Server.HardStop() etcd3.Server.HardStop() diff --git a/pkg/encryption/config_test.go b/pkg/encryption/config_test.go index 30c9c9dded8..6f7e4a41b03 100644 --- a/pkg/encryption/config_test.go +++ b/pkg/encryption/config_test.go @@ -38,19 +38,19 @@ func TestAdjustInvalidDataEncryptionMethod(t *testing.T) { t.Parallel() re := require.New(t) config := &Config{DataEncryptionMethod: "unknown"} - re.NotNil(config.Adjust()) + re.Error(config.Adjust()) } func TestAdjustNegativeRotationDuration(t *testing.T) { t.Parallel() re := require.New(t) config := &Config{DataKeyRotationPeriod: typeutil.NewDuration(time.Duration(int64(-1)))} - re.NotNil(config.Adjust()) + re.Error(config.Adjust()) } func TestAdjustInvalidMasterKeyType(t *testing.T) { t.Parallel() re := require.New(t) config := &Config{MasterKey: MasterKeyConfig{Type: "unknown"}} - re.NotNil(config.Adjust()) + re.Error(config.Adjust()) } diff --git a/pkg/encryption/crypter_test.go b/pkg/encryption/crypter_test.go index 2f952d5b729..12a851d1563 100644 --- a/pkg/encryption/crypter_test.go +++ b/pkg/encryption/crypter_test.go @@ -26,11 +26,11 @@ import ( func TestEncryptionMethodSupported(t *testing.T) { t.Parallel() re := require.New(t) - re.NotNil(CheckEncryptionMethodSupported(encryptionpb.EncryptionMethod_PLAINTEXT)) - re.NotNil(CheckEncryptionMethodSupported(encryptionpb.EncryptionMethod_UNKNOWN)) - re.Nil(CheckEncryptionMethodSupported(encryptionpb.EncryptionMethod_AES128_CTR)) - re.Nil(CheckEncryptionMethodSupported(encryptionpb.EncryptionMethod_AES192_CTR)) - re.Nil(CheckEncryptionMethodSupported(encryptionpb.EncryptionMethod_AES256_CTR)) + re.Error(CheckEncryptionMethodSupported(encryptionpb.EncryptionMethod_PLAINTEXT)) + re.Error(CheckEncryptionMethodSupported(encryptionpb.EncryptionMethod_UNKNOWN)) + re.NoError(CheckEncryptionMethodSupported(encryptionpb.EncryptionMethod_AES128_CTR)) + re.NoError(CheckEncryptionMethodSupported(encryptionpb.EncryptionMethod_AES192_CTR)) + re.NoError(CheckEncryptionMethodSupported(encryptionpb.EncryptionMethod_AES256_CTR)) } func TestKeyLength(t *testing.T) { diff --git a/pkg/encryption/key_manager_test.go b/pkg/encryption/key_manager_test.go index 3134e714543..96bdb3c0eb5 100644 --- a/pkg/encryption/key_manager_test.go +++ b/pkg/encryption/key_manager_test.go @@ -313,7 +313,7 @@ func TestLoadKeyEmpty(t *testing.T) { // Simulate keys get deleted. _, err = client.Delete(context.Background(), EncryptionKeysPath) re.NoError(err) - re.NotNil(m.loadKeys()) + re.Error(m.loadKeys()) } func TestWatcher(t *testing.T) { diff --git a/pkg/keyspace/keyspace_test.go b/pkg/keyspace/keyspace_test.go index 27e7de359ee..552adc8d83e 100644 --- a/pkg/keyspace/keyspace_test.go +++ b/pkg/keyspace/keyspace_test.go @@ -75,13 +75,14 @@ func (m *mockConfig) GetCheckRegionSplitInterval() time.Duration { } func (suite *keyspaceTestSuite) SetupTest() { + re := suite.Require() suite.ctx, suite.cancel = context.WithCancel(context.Background()) store := endpoint.NewStorageEndpoint(kv.NewMemoryKV(), nil) allocator := mockid.NewIDAllocator() kgm := NewKeyspaceGroupManager(suite.ctx, store, nil, 0) suite.manager = NewKeyspaceManager(suite.ctx, store, nil, allocator, &mockConfig{}, kgm) - suite.NoError(kgm.Bootstrap(suite.ctx)) - suite.NoError(suite.manager.Bootstrap()) + re.NoError(kgm.Bootstrap(suite.ctx)) + re.NoError(suite.manager.Bootstrap()) } func (suite *keyspaceTestSuite) TearDownTest() { @@ -89,11 +90,13 @@ func (suite *keyspaceTestSuite) TearDownTest() { } func (suite *keyspaceTestSuite) SetupSuite() { - suite.NoError(failpoint.Enable("github.com/tikv/pd/pkg/keyspace/skipSplitRegion", "return(true)")) + re := suite.Require() + re.NoError(failpoint.Enable("github.com/tikv/pd/pkg/keyspace/skipSplitRegion", "return(true)")) } func (suite *keyspaceTestSuite) TearDownSuite() { - suite.NoError(failpoint.Disable("github.com/tikv/pd/pkg/keyspace/skipSplitRegion")) + re := suite.Require() + re.NoError(failpoint.Disable("github.com/tikv/pd/pkg/keyspace/skipSplitRegion")) } func makeCreateKeyspaceRequests(count int) []*CreateKeyspaceRequest { @@ -205,20 +208,20 @@ func (suite *keyspaceTestSuite) TestUpdateKeyspaceState() { // Disabling an ENABLED keyspace is allowed. Should update StateChangedAt. updated, err := manager.UpdateKeyspaceState(createRequest.Name, keyspacepb.KeyspaceState_DISABLED, oldTime) re.NoError(err) - re.Equal(updated.State, keyspacepb.KeyspaceState_DISABLED) - re.Equal(updated.StateChangedAt, oldTime) + re.Equal(keyspacepb.KeyspaceState_DISABLED, updated.State) + re.Equal(oldTime, updated.StateChangedAt) newTime := time.Now().Unix() // Disabling an DISABLED keyspace is allowed. Should NOT update StateChangedAt. updated, err = manager.UpdateKeyspaceState(createRequest.Name, keyspacepb.KeyspaceState_DISABLED, newTime) re.NoError(err) - re.Equal(updated.State, keyspacepb.KeyspaceState_DISABLED) - re.Equal(updated.StateChangedAt, oldTime) + re.Equal(keyspacepb.KeyspaceState_DISABLED, updated.State) + re.Equal(oldTime, updated.StateChangedAt) // Archiving a DISABLED keyspace is allowed. Should update StateChangeAt. updated, err = manager.UpdateKeyspaceState(createRequest.Name, keyspacepb.KeyspaceState_ARCHIVED, newTime) re.NoError(err) - re.Equal(updated.State, keyspacepb.KeyspaceState_ARCHIVED) - re.Equal(updated.StateChangedAt, newTime) + re.Equal(keyspacepb.KeyspaceState_ARCHIVED, updated.State) + re.Equal(newTime, updated.StateChangedAt) // Changing state of an ARCHIVED keyspace is not allowed. _, err = manager.UpdateKeyspaceState(createRequest.Name, keyspacepb.KeyspaceState_ENABLED, newTime) re.Error(err) @@ -244,7 +247,7 @@ func (suite *keyspaceTestSuite) TestLoadRangeKeyspace() { // Load all keyspaces including the default keyspace. keyspaces, err := manager.LoadRangeKeyspace(0, 0) re.NoError(err) - re.Equal(total+1, len(keyspaces)) + re.Len(keyspaces, total+1) for i := range keyspaces { re.Equal(uint32(i), keyspaces[i].Id) if i != 0 { @@ -256,7 +259,7 @@ func (suite *keyspaceTestSuite) TestLoadRangeKeyspace() { // Result should be keyspaces with id 0 - 49. keyspaces, err = manager.LoadRangeKeyspace(0, 50) re.NoError(err) - re.Equal(50, len(keyspaces)) + re.Len(keyspaces, 50) for i := range keyspaces { re.Equal(uint32(i), keyspaces[i].Id) if i != 0 { @@ -269,7 +272,7 @@ func (suite *keyspaceTestSuite) TestLoadRangeKeyspace() { loadStart := 33 keyspaces, err = manager.LoadRangeKeyspace(uint32(loadStart), 20) re.NoError(err) - re.Equal(20, len(keyspaces)) + re.Len(keyspaces, 20) for i := range keyspaces { re.Equal(uint32(loadStart+i), keyspaces[i].Id) checkCreateRequest(re, requests[i+loadStart-1], keyspaces[i]) @@ -280,7 +283,7 @@ func (suite *keyspaceTestSuite) TestLoadRangeKeyspace() { loadStart = 90 keyspaces, err = manager.LoadRangeKeyspace(uint32(loadStart), 30) re.NoError(err) - re.Equal(11, len(keyspaces)) + re.Len(keyspaces, 11) for i := range keyspaces { re.Equal(uint32(loadStart+i), keyspaces[i].Id) checkCreateRequest(re, requests[i+loadStart-1], keyspaces[i]) diff --git a/pkg/keyspace/tso_keyspace_group_test.go b/pkg/keyspace/tso_keyspace_group_test.go index 993923d2fd7..2dec780c3c8 100644 --- a/pkg/keyspace/tso_keyspace_group_test.go +++ b/pkg/keyspace/tso_keyspace_group_test.go @@ -43,13 +43,14 @@ func TestKeyspaceGroupTestSuite(t *testing.T) { } func (suite *keyspaceGroupTestSuite) SetupTest() { + re := suite.Require() suite.ctx, suite.cancel = context.WithCancel(context.Background()) store := endpoint.NewStorageEndpoint(kv.NewMemoryKV(), nil) suite.kgm = NewKeyspaceGroupManager(suite.ctx, store, nil, 0) idAllocator := mockid.NewIDAllocator() cluster := mockcluster.NewCluster(suite.ctx, mockconfig.NewTestOptions()) suite.kg = NewKeyspaceManager(suite.ctx, store, cluster, idAllocator, &mockConfig{}, suite.kgm) - suite.NoError(suite.kgm.Bootstrap(suite.ctx)) + re.NoError(suite.kgm.Bootstrap(suite.ctx)) } func (suite *keyspaceGroupTestSuite) TearDownTest() { @@ -191,7 +192,7 @@ func (suite *keyspaceGroupTestSuite) TestUpdateKeyspace() { re.Len(kg2.Keyspaces, 1) kg3, err := suite.kgm.GetKeyspaceGroupByID(3) re.NoError(err) - re.Len(kg3.Keyspaces, 0) + re.Empty(kg3.Keyspaces) _, err = suite.kg.UpdateKeyspaceConfig("test", []*Mutation{ { @@ -211,7 +212,7 @@ func (suite *keyspaceGroupTestSuite) TestUpdateKeyspace() { re.Len(kg2.Keyspaces, 1) kg3, err = suite.kgm.GetKeyspaceGroupByID(3) re.NoError(err) - re.Len(kg3.Keyspaces, 0) + re.Empty(kg3.Keyspaces) _, err = suite.kg.UpdateKeyspaceConfig("test", []*Mutation{ { Op: OpPut, @@ -227,7 +228,7 @@ func (suite *keyspaceGroupTestSuite) TestUpdateKeyspace() { re.NoError(err) kg2, err = suite.kgm.GetKeyspaceGroupByID(2) re.NoError(err) - re.Len(kg2.Keyspaces, 0) + re.Empty(kg2.Keyspaces) kg3, err = suite.kgm.GetKeyspaceGroupByID(3) re.NoError(err) re.Len(kg3.Keyspaces, 1) diff --git a/pkg/mcs/resourcemanager/server/config_test.go b/pkg/mcs/resourcemanager/server/config_test.go index dd8dd2d2814..64fd133ea73 100644 --- a/pkg/mcs/resourcemanager/server/config_test.go +++ b/pkg/mcs/resourcemanager/server/config_test.go @@ -42,8 +42,8 @@ read-cpu-ms-cost = 5.0 err = cfg.Adjust(&meta, false) re.NoError(err) - re.Equal(cfg.Controller.DegradedModeWaitDuration.Duration, time.Second*2) - re.Equal(cfg.Controller.LTBMaxWaitDuration.Duration, time.Second*60) + re.Equal(time.Second*2, cfg.Controller.DegradedModeWaitDuration.Duration) + re.Equal(time.Second*60, cfg.Controller.LTBMaxWaitDuration.Duration) re.LessOrEqual(math.Abs(cfg.Controller.RequestUnit.CPUMsCost-5), 1e-7) re.LessOrEqual(math.Abs(cfg.Controller.RequestUnit.WriteCostPerByte-4), 1e-7) re.LessOrEqual(math.Abs(cfg.Controller.RequestUnit.WriteBaseCost-3), 1e-7) diff --git a/pkg/mcs/resourcemanager/server/token_buckets_test.go b/pkg/mcs/resourcemanager/server/token_buckets_test.go index a7d3b9e3bad..4138be5d66e 100644 --- a/pkg/mcs/resourcemanager/server/token_buckets_test.go +++ b/pkg/mcs/resourcemanager/server/token_buckets_test.go @@ -70,27 +70,27 @@ func TestGroupTokenBucketRequest(t *testing.T) { clientUniqueID := uint64(0) tb, trickle := gtb.request(time1, 190000, uint64(time.Second)*10/uint64(time.Millisecond), clientUniqueID) re.LessOrEqual(math.Abs(tb.Tokens-190000), 1e-7) - re.Equal(trickle, int64(0)) + re.Zero(trickle) // need to lend token tb, trickle = gtb.request(time1, 11000, uint64(time.Second)*10/uint64(time.Millisecond), clientUniqueID) re.LessOrEqual(math.Abs(tb.Tokens-11000), 1e-7) - re.Equal(trickle, int64(time.Second)*11000./4000./int64(time.Millisecond)) + re.Equal(int64(time.Second)*11000./4000./int64(time.Millisecond), trickle) tb, trickle = gtb.request(time1, 35000, uint64(time.Second)*10/uint64(time.Millisecond), clientUniqueID) re.LessOrEqual(math.Abs(tb.Tokens-35000), 1e-7) - re.Equal(trickle, int64(time.Second)*10/int64(time.Millisecond)) + re.Equal(int64(time.Second)*10/int64(time.Millisecond), trickle) tb, trickle = gtb.request(time1, 60000, uint64(time.Second)*10/uint64(time.Millisecond), clientUniqueID) re.LessOrEqual(math.Abs(tb.Tokens-22000), 1e-7) - re.Equal(trickle, int64(time.Second)*10/int64(time.Millisecond)) + re.Equal(int64(time.Second)*10/int64(time.Millisecond), trickle) // Get reserved 10000 tokens = fillrate(2000) * 10 * defaultReserveRatio(0.5) // Max loan tokens is 60000. tb, trickle = gtb.request(time1, 3000, uint64(time.Second)*10/uint64(time.Millisecond), clientUniqueID) re.LessOrEqual(math.Abs(tb.Tokens-3000), 1e-7) - re.Equal(trickle, int64(time.Second)*10/int64(time.Millisecond)) + re.Equal(int64(time.Second)*10/int64(time.Millisecond), trickle) tb, trickle = gtb.request(time1, 12000, uint64(time.Second)*10/uint64(time.Millisecond), clientUniqueID) re.LessOrEqual(math.Abs(tb.Tokens-10000), 1e-7) - re.Equal(trickle, int64(time.Second)*10/int64(time.Millisecond)) + re.Equal(int64(time.Second)*10/int64(time.Millisecond), trickle) time2 := time1.Add(20 * time.Second) tb, trickle = gtb.request(time2, 20000, uint64(time.Second)*10/uint64(time.Millisecond), clientUniqueID) re.LessOrEqual(math.Abs(tb.Tokens-20000), 1e-7) - re.Equal(trickle, int64(time.Second)*10/int64(time.Millisecond)) + re.Equal(int64(time.Second)*10/int64(time.Millisecond), trickle) } diff --git a/pkg/ratelimit/controller_test.go b/pkg/ratelimit/controller_test.go index a830217cb9f..59cc0c16445 100644 --- a/pkg/ratelimit/controller_test.go +++ b/pkg/ratelimit/controller_test.go @@ -88,7 +88,7 @@ func TestControllerWithConcurrencyLimiter(t *testing.T) { opt: UpdateConcurrencyLimiter(10), checkOptionStatus: func(label string, o Option) { status := limiter.Update(label, o) - re.True(status&ConcurrencyChanged != 0) + re.NotZero(status & ConcurrencyChanged) }, totalRequest: 15, fail: 5, @@ -105,7 +105,7 @@ func TestControllerWithConcurrencyLimiter(t *testing.T) { opt: UpdateConcurrencyLimiter(10), checkOptionStatus: func(label string, o Option) { status := limiter.Update(label, o) - re.True(status&ConcurrencyNoChange != 0) + re.NotZero(status & ConcurrencyNoChange) }, checkStatusFunc: func(label string) {}, }, @@ -113,7 +113,7 @@ func TestControllerWithConcurrencyLimiter(t *testing.T) { opt: UpdateConcurrencyLimiter(5), checkOptionStatus: func(label string, o Option) { status := limiter.Update(label, o) - re.True(status&ConcurrencyChanged != 0) + re.NotZero(status & ConcurrencyChanged) }, totalRequest: 15, fail: 10, @@ -130,7 +130,7 @@ func TestControllerWithConcurrencyLimiter(t *testing.T) { opt: UpdateConcurrencyLimiter(0), checkOptionStatus: func(label string, o Option) { status := limiter.Update(label, o) - re.True(status&ConcurrencyDeleted != 0) + re.NotZero(status & ConcurrencyDeleted) }, totalRequest: 15, fail: 0, @@ -152,7 +152,7 @@ func TestControllerWithConcurrencyLimiter(t *testing.T) { opt: UpdateConcurrencyLimiter(15), checkOptionStatus: func(label string, o Option) { status := limiter.Update(label, o) - re.True(status&ConcurrencyChanged != 0) + re.NotZero(status & ConcurrencyChanged) }, totalRequest: 10, fail: 0, @@ -169,7 +169,7 @@ func TestControllerWithConcurrencyLimiter(t *testing.T) { opt: UpdateConcurrencyLimiter(10), checkOptionStatus: func(label string, o Option) { status := limiter.Update(label, o) - re.True(status&ConcurrencyChanged != 0) + re.NotZero(status & ConcurrencyChanged) }, totalRequest: 10, fail: 10, @@ -202,7 +202,7 @@ func TestBlockList(t *testing.T) { re.True(limiter.IsInAllowList(label)) status := UpdateQPSLimiter(float64(rate.Every(time.Second)), 1)(label, limiter) - re.True(status&InAllowList != 0) + re.NotZero(status & InAllowList) for i := 0; i < 10; i++ { _, err := limiter.Allow(label) re.NoError(err) @@ -221,7 +221,7 @@ func TestControllerWithQPSLimiter(t *testing.T) { opt: UpdateQPSLimiter(float64(rate.Every(time.Second)), 1), checkOptionStatus: func(label string, o Option) { status := limiter.Update(label, o) - re.True(status&QPSChanged != 0) + re.NotZero(status & QPSChanged) }, totalRequest: 3, fail: 2, @@ -237,7 +237,7 @@ func TestControllerWithQPSLimiter(t *testing.T) { opt: UpdateQPSLimiter(float64(rate.Every(time.Second)), 1), checkOptionStatus: func(label string, o Option) { status := limiter.Update(label, o) - re.True(status&QPSNoChange != 0) + re.NotZero(status & QPSNoChange) }, checkStatusFunc: func(label string) {}, }, @@ -245,7 +245,7 @@ func TestControllerWithQPSLimiter(t *testing.T) { opt: UpdateQPSLimiter(5, 5), checkOptionStatus: func(label string, o Option) { status := limiter.Update(label, o) - re.True(status&QPSChanged != 0) + re.NotZero(status & QPSChanged) }, totalRequest: 10, fail: 5, @@ -261,7 +261,7 @@ func TestControllerWithQPSLimiter(t *testing.T) { opt: UpdateQPSLimiter(0, 0), checkOptionStatus: func(label string, o Option) { status := limiter.Update(label, o) - re.True(status&QPSDeleted != 0) + re.NotZero(status & QPSDeleted) }, totalRequest: 10, fail: 0, @@ -271,7 +271,7 @@ func TestControllerWithQPSLimiter(t *testing.T) { checkStatusFunc: func(label string) { limit, burst := limiter.GetQPSLimiterStatus(label) re.Equal(rate.Limit(0), limit) - re.Equal(0, burst) + re.Zero(burst) }, }, }, @@ -283,7 +283,7 @@ func TestControllerWithQPSLimiter(t *testing.T) { opt: UpdateQPSLimiter(50, 5), checkOptionStatus: func(label string, o Option) { status := limiter.Update(label, o) - re.True(status&QPSChanged != 0) + re.NotZero(status & QPSChanged) }, totalRequest: 10, fail: 5, @@ -299,7 +299,7 @@ func TestControllerWithQPSLimiter(t *testing.T) { opt: UpdateQPSLimiter(0, 0), checkOptionStatus: func(label string, o Option) { status := limiter.Update(label, o) - re.True(status&QPSDeleted != 0) + re.NotZero(status & QPSDeleted) }, totalRequest: 10, fail: 0, @@ -309,7 +309,7 @@ func TestControllerWithQPSLimiter(t *testing.T) { checkStatusFunc: func(label string) { limit, burst := limiter.GetQPSLimiterStatus(label) re.Equal(rate.Limit(0), limit) - re.Equal(0, burst) + re.Zero(burst) }, }, }, @@ -334,7 +334,7 @@ func TestControllerWithTwoLimiters(t *testing.T) { }), checkOptionStatus: func(label string, o Option) { status := limiter.Update(label, o) - re.True(status&QPSChanged != 0) + re.NotZero(status & QPSChanged) }, totalRequest: 200, fail: 100, @@ -354,7 +354,7 @@ func TestControllerWithTwoLimiters(t *testing.T) { opt: UpdateQPSLimiter(float64(rate.Every(time.Second)), 1), checkOptionStatus: func(label string, o Option) { status := limiter.Update(label, o) - re.True(status&QPSChanged != 0) + re.NotZero(status & QPSChanged) }, totalRequest: 200, fail: 199, @@ -376,7 +376,7 @@ func TestControllerWithTwoLimiters(t *testing.T) { opt: UpdateQPSLimiter(50, 5), checkOptionStatus: func(label string, o Option) { status := limiter.Update(label, o) - re.True(status&QPSChanged != 0) + re.NotZero(status & QPSChanged) }, totalRequest: 10, fail: 5, @@ -392,7 +392,7 @@ func TestControllerWithTwoLimiters(t *testing.T) { opt: UpdateQPSLimiter(0, 0), checkOptionStatus: func(label string, o Option) { status := limiter.Update(label, o) - re.True(status&QPSDeleted != 0) + re.NotZero(status & QPSDeleted) }, totalRequest: 10, fail: 0, diff --git a/pkg/ratelimit/limiter_test.go b/pkg/ratelimit/limiter_test.go index 8834495f3e9..88da865879b 100644 --- a/pkg/ratelimit/limiter_test.go +++ b/pkg/ratelimit/limiter_test.go @@ -45,7 +45,7 @@ func TestWithConcurrencyLimiter(t *testing.T) { limiter := newLimiter() status := limiter.updateConcurrencyConfig(10) - re.True(status&ConcurrencyChanged != 0) + re.NotZero(status & ConcurrencyChanged) var lock syncutil.Mutex successCount, failedCount := 0, 0 var wg sync.WaitGroup @@ -68,10 +68,10 @@ func TestWithConcurrencyLimiter(t *testing.T) { re.Equal(uint64(0), current) status = limiter.updateConcurrencyConfig(10) - re.True(status&ConcurrencyNoChange != 0) + re.NotZero(status & ConcurrencyNoChange) status = limiter.updateConcurrencyConfig(5) - re.True(status&ConcurrencyChanged != 0) + re.NotZero(status & ConcurrencyChanged) failedCount = 0 successCount = 0 for i := 0; i < 15; i++ { @@ -86,7 +86,7 @@ func TestWithConcurrencyLimiter(t *testing.T) { } status = limiter.updateConcurrencyConfig(0) - re.True(status&ConcurrencyDeleted != 0) + re.NotZero(status & ConcurrencyDeleted) failedCount = 0 successCount = 0 for i := 0; i < 15; i++ { @@ -107,7 +107,7 @@ func TestWithQPSLimiter(t *testing.T) { re := require.New(t) limiter := newLimiter() status := limiter.updateQPSConfig(float64(rate.Every(time.Second)), 1) - re.True(status&QPSChanged != 0) + re.NotZero(status & QPSChanged) var lock syncutil.Mutex successCount, failedCount := 0, 0 @@ -126,10 +126,10 @@ func TestWithQPSLimiter(t *testing.T) { re.Equal(1, burst) status = limiter.updateQPSConfig(float64(rate.Every(time.Second)), 1) - re.True(status&QPSNoChange != 0) + re.NotZero(status & QPSNoChange) status = limiter.updateQPSConfig(5, 5) - re.True(status&QPSChanged != 0) + re.NotZero(status & QPSChanged) limit, burst = limiter.getQPSLimiterStatus() re.Equal(rate.Limit(5), limit) re.Equal(5, burst) @@ -147,19 +147,19 @@ func TestWithQPSLimiter(t *testing.T) { time.Sleep(time.Second) status = limiter.updateQPSConfig(0, 0) - re.True(status&QPSDeleted != 0) + re.NotZero(status & QPSDeleted) for i := 0; i < 10; i++ { _, err := limiter.allow() re.NoError(err) } qLimit, qCurrent := limiter.getQPSLimiterStatus() re.Equal(rate.Limit(0), qLimit) - re.Equal(0, qCurrent) + re.Zero(qCurrent) successCount = 0 failedCount = 0 status = limiter.updateQPSConfig(float64(rate.Every(3*time.Second)), 100) - re.True(status&QPSChanged != 0) + re.NotZero(status & QPSChanged) wg.Add(200) for i := 0; i < 200; i++ { go countSingleLimiterHandleResult(limiter, &successCount, &failedCount, &lock, &wg, r) @@ -186,8 +186,8 @@ func TestWithTwoLimiters(t *testing.T) { } limiter := newLimiter() status := limiter.updateDimensionConfig(cfg) - re.True(status&QPSChanged != 0) - re.True(status&ConcurrencyChanged != 0) + re.NotZero(status & QPSChanged) + re.NotZero(status & ConcurrencyChanged) var lock syncutil.Mutex successCount, failedCount := 0, 0 @@ -214,7 +214,7 @@ func TestWithTwoLimiters(t *testing.T) { r.release() } status = limiter.updateQPSConfig(float64(rate.Every(10*time.Second)), 1) - re.True(status&QPSChanged != 0) + re.NotZero(status & QPSChanged) wg.Add(100) for i := 0; i < 100; i++ { go countSingleLimiterHandleResult(limiter, &successCount, &failedCount, &lock, &wg, r) @@ -228,8 +228,8 @@ func TestWithTwoLimiters(t *testing.T) { cfg = &DimensionConfig{} status = limiter.updateDimensionConfig(cfg) - re.True(status&ConcurrencyDeleted != 0) - re.True(status&QPSDeleted != 0) + re.NotZero(status & ConcurrencyDeleted) + re.NotZero(status & QPSDeleted) } func countSingleLimiterHandleResult(limiter *limiter, successCount *int, diff --git a/pkg/schedule/checker/rule_checker_test.go b/pkg/schedule/checker/rule_checker_test.go index e77830fac49..72d3e7e5ec4 100644 --- a/pkg/schedule/checker/rule_checker_test.go +++ b/pkg/schedule/checker/rule_checker_test.go @@ -426,6 +426,7 @@ func (suite *ruleCheckerTestSuite) TestFixRoleLeaderIssue3130() { } func (suite *ruleCheckerTestSuite) TestFixLeaderRoleWithUnhealthyRegion() { + re := suite.Require() suite.cluster.AddLabelsStore(1, 1, map[string]string{"rule": "follower"}) suite.cluster.AddLabelsStore(2, 1, map[string]string{"rule": "follower"}) suite.cluster.AddLabelsStore(3, 1, map[string]string{"rule": "leader"}) @@ -456,12 +457,12 @@ func (suite *ruleCheckerTestSuite) TestFixLeaderRoleWithUnhealthyRegion() { }, }, }) - suite.NoError(err) + re.NoError(err) // no Leader suite.cluster.AddNoLeaderRegion(1, 1, 2, 3) r := suite.cluster.GetRegion(1) op := suite.rc.Check(r) - suite.Nil(op) + re.Nil(op) } func (suite *ruleCheckerTestSuite) TestFixRuleWitness() { @@ -532,6 +533,7 @@ func (suite *ruleCheckerTestSuite) TestFixRuleWitness3() { } func (suite *ruleCheckerTestSuite) TestFixRuleWitness4() { + re := suite.Require() suite.cluster.AddLabelsStore(1, 1, map[string]string{"A": "leader"}) suite.cluster.AddLabelsStore(2, 1, map[string]string{"B": "voter"}) suite.cluster.AddLabelsStore(3, 1, map[string]string{"C": "learner"}) @@ -565,12 +567,12 @@ func (suite *ruleCheckerTestSuite) TestFixRuleWitness4() { }, }, }) - suite.NoError(err) + re.NoError(err) op := suite.rc.Check(r) - suite.NotNil(op) - suite.Equal("fix-non-witness-peer", op.Desc()) - suite.Equal(uint64(3), op.Step(0).(operator.BecomeNonWitness).StoreID) + re.NotNil(op) + re.Equal("fix-non-witness-peer", op.Desc()) + re.Equal(uint64(3), op.Step(0).(operator.BecomeNonWitness).StoreID) } func (suite *ruleCheckerTestSuite) TestFixRuleWitness5() { diff --git a/pkg/schedule/filter/counter_test.go b/pkg/schedule/filter/counter_test.go index 067a07f138b..78a1ef5395b 100644 --- a/pkg/schedule/filter/counter_test.go +++ b/pkg/schedule/filter/counter_test.go @@ -34,7 +34,7 @@ func TestString(t *testing.T) { for _, data := range testcases { re.Equal(data.expected, filterType(data.filterType).String()) } - re.Equal(int(filtersLen), len(filters)) + re.Len(filters, int(filtersLen)) } func TestCounter(t *testing.T) { @@ -42,9 +42,9 @@ func TestCounter(t *testing.T) { counter := NewCounter(BalanceLeader.String()) counter.inc(source, storeStateTombstone, 1, 2) counter.inc(target, storeStateTombstone, 1, 2) - re.Equal(counter.counter[source][storeStateTombstone][1][2], 1) - re.Equal(counter.counter[target][storeStateTombstone][1][2], 1) + re.Equal(1, counter.counter[source][storeStateTombstone][1][2]) + re.Equal(1, counter.counter[target][storeStateTombstone][1][2]) counter.Flush() - re.Equal(counter.counter[source][storeStateTombstone][1][2], 0) - re.Equal(counter.counter[target][storeStateTombstone][1][2], 0) + re.Zero(counter.counter[source][storeStateTombstone][1][2]) + re.Zero(counter.counter[target][storeStateTombstone][1][2]) } diff --git a/pkg/schedule/labeler/rule_test.go b/pkg/schedule/labeler/rule_test.go index 0b341754007..00c179b36b8 100644 --- a/pkg/schedule/labeler/rule_test.go +++ b/pkg/schedule/labeler/rule_test.go @@ -42,7 +42,7 @@ func TestRegionLabelTTL(t *testing.T) { label.TTL = "10h10m10s10ms" err = label.checkAndAdjustExpire() re.NoError(err) - re.Greater(len(label.StartAt), 0) + re.NotEmpty(label.StartAt) re.False(label.expireBefore(time.Now().Add(time.Hour))) re.True(label.expireBefore(time.Now().Add(24 * time.Hour))) @@ -56,5 +56,5 @@ func TestRegionLabelTTL(t *testing.T) { re.Equal(label.TTL, label2.TTL) label2.checkAndAdjustExpire() // The `expire` should be the same with minor inaccuracies. - re.True(math.Abs(label2.expire.Sub(*label.expire).Seconds()) < 1) + re.Less(math.Abs(label2.expire.Sub(*label.expire).Seconds()), 1.0) } diff --git a/pkg/schedule/operator/builder_test.go b/pkg/schedule/operator/builder_test.go index 864734eb5ff..b010dcf935b 100644 --- a/pkg/schedule/operator/builder_test.go +++ b/pkg/schedule/operator/builder_test.go @@ -62,21 +62,22 @@ func (suite *operatorBuilderTestSuite) TearDownTest() { } func (suite *operatorBuilderTestSuite) TestNewBuilder() { + re := suite.Require() peers := []*metapb.Peer{{Id: 11, StoreId: 1}, {Id: 12, StoreId: 2, Role: metapb.PeerRole_Learner}} region := core.NewRegionInfo(&metapb.Region{Id: 42, Peers: peers}, peers[0]) builder := NewBuilder("test", suite.cluster, region) - suite.NoError(builder.err) - suite.Len(builder.originPeers, 2) - suite.Equal(peers[0], builder.originPeers[1]) - suite.Equal(peers[1], builder.originPeers[2]) - suite.Equal(uint64(1), builder.originLeaderStoreID) - suite.Len(builder.targetPeers, 2) - suite.Equal(peers[0], builder.targetPeers[1]) - suite.Equal(peers[1], builder.targetPeers[2]) + re.NoError(builder.err) + re.Len(builder.originPeers, 2) + re.Equal(peers[0], builder.originPeers[1]) + re.Equal(peers[1], builder.originPeers[2]) + re.Equal(uint64(1), builder.originLeaderStoreID) + re.Len(builder.targetPeers, 2) + re.Equal(peers[0], builder.targetPeers[1]) + re.Equal(peers[1], builder.targetPeers[2]) region = region.Clone(core.WithLeader(nil)) builder = NewBuilder("test", suite.cluster, region) - suite.Error(builder.err) + re.Error(builder.err) } func (suite *operatorBuilderTestSuite) newBuilder() *Builder { @@ -90,18 +91,19 @@ func (suite *operatorBuilderTestSuite) newBuilder() *Builder { } func (suite *operatorBuilderTestSuite) TestRecord() { - suite.Error(suite.newBuilder().AddPeer(&metapb.Peer{StoreId: 1}).err) - suite.NoError(suite.newBuilder().AddPeer(&metapb.Peer{StoreId: 4}).err) - suite.Error(suite.newBuilder().PromoteLearner(1).err) - suite.NoError(suite.newBuilder().PromoteLearner(3).err) - suite.NoError(suite.newBuilder().SetLeader(1).SetLeader(2).err) - suite.Error(suite.newBuilder().SetLeader(3).err) - suite.Error(suite.newBuilder().RemovePeer(4).err) - suite.NoError(suite.newBuilder().AddPeer(&metapb.Peer{StoreId: 4, Role: metapb.PeerRole_Learner}).RemovePeer(4).err) - suite.Error(suite.newBuilder().SetLeader(2).RemovePeer(2).err) - suite.Error(suite.newBuilder().PromoteLearner(4).err) - suite.Error(suite.newBuilder().SetLeader(4).err) - suite.Error(suite.newBuilder().SetPeers(map[uint64]*metapb.Peer{2: {Id: 2}}).err) + re := suite.Require() + re.Error(suite.newBuilder().AddPeer(&metapb.Peer{StoreId: 1}).err) + re.NoError(suite.newBuilder().AddPeer(&metapb.Peer{StoreId: 4}).err) + re.Error(suite.newBuilder().PromoteLearner(1).err) + re.NoError(suite.newBuilder().PromoteLearner(3).err) + re.NoError(suite.newBuilder().SetLeader(1).SetLeader(2).err) + re.Error(suite.newBuilder().SetLeader(3).err) + re.Error(suite.newBuilder().RemovePeer(4).err) + re.NoError(suite.newBuilder().AddPeer(&metapb.Peer{StoreId: 4, Role: metapb.PeerRole_Learner}).RemovePeer(4).err) + re.Error(suite.newBuilder().SetLeader(2).RemovePeer(2).err) + re.Error(suite.newBuilder().PromoteLearner(4).err) + re.Error(suite.newBuilder().SetLeader(4).err) + re.Error(suite.newBuilder().SetPeers(map[uint64]*metapb.Peer{2: {Id: 2}}).err) m := map[uint64]*metapb.Peer{ 2: {StoreId: 2}, @@ -109,18 +111,19 @@ func (suite *operatorBuilderTestSuite) TestRecord() { 4: {StoreId: 4}, } builder := suite.newBuilder().SetPeers(m).SetAddLightPeer() - suite.Len(builder.targetPeers, 3) - suite.Equal(m[2], builder.targetPeers[2]) - suite.Equal(m[3], builder.targetPeers[3]) - suite.Equal(m[4], builder.targetPeers[4]) - suite.Equal(uint64(0), builder.targetLeaderStoreID) - suite.True(builder.addLightPeer) + re.Len(builder.targetPeers, 3) + re.Equal(m[2], builder.targetPeers[2]) + re.Equal(m[3], builder.targetPeers[3]) + re.Equal(m[4], builder.targetPeers[4]) + re.Equal(uint64(0), builder.targetLeaderStoreID) + re.True(builder.addLightPeer) } func (suite *operatorBuilderTestSuite) TestPrepareBuild() { + re := suite.Require() // no voter. _, err := suite.newBuilder().SetPeers(map[uint64]*metapb.Peer{4: {StoreId: 4, Role: metapb.PeerRole_Learner}}).prepareBuild() - suite.Error(err) + re.Error(err) // use joint consensus builder := suite.newBuilder().SetPeers(map[uint64]*metapb.Peer{ @@ -130,19 +133,19 @@ func (suite *operatorBuilderTestSuite) TestPrepareBuild() { 5: {StoreId: 5, Role: metapb.PeerRole_Learner}, }) _, err = builder.prepareBuild() - suite.NoError(err) - suite.Len(builder.toAdd, 2) - suite.NotEqual(metapb.PeerRole_Learner, builder.toAdd[4].GetRole()) - suite.Equal(uint64(14), builder.toAdd[4].GetId()) - suite.Equal(metapb.PeerRole_Learner, builder.toAdd[5].GetRole()) - suite.NotEqual(uint64(0), builder.toAdd[5].GetId()) - suite.Len(builder.toRemove, 1) - suite.NotNil(builder.toRemove[2]) - suite.Len(builder.toPromote, 1) - suite.NotNil(builder.toPromote[3]) - suite.Len(builder.toDemote, 1) - suite.NotNil(builder.toDemote[1]) - suite.Equal(uint64(1), builder.currentLeaderStoreID) + re.NoError(err) + re.Len(builder.toAdd, 2) + re.NotEqual(metapb.PeerRole_Learner, builder.toAdd[4].GetRole()) + re.Equal(uint64(14), builder.toAdd[4].GetId()) + re.Equal(metapb.PeerRole_Learner, builder.toAdd[5].GetRole()) + re.NotEqual(uint64(0), builder.toAdd[5].GetId()) + re.Len(builder.toRemove, 1) + re.NotNil(builder.toRemove[2]) + re.Len(builder.toPromote, 1) + re.NotNil(builder.toPromote[3]) + re.Len(builder.toDemote, 1) + re.NotNil(builder.toDemote[1]) + re.Equal(uint64(1), builder.currentLeaderStoreID) // do not use joint consensus builder = suite.newBuilder().SetPeers(map[uint64]*metapb.Peer{ @@ -154,22 +157,23 @@ func (suite *operatorBuilderTestSuite) TestPrepareBuild() { }) builder.useJointConsensus = false _, err = builder.prepareBuild() - suite.NoError(err) - suite.Len(builder.toAdd, 3) - suite.Equal(metapb.PeerRole_Learner, builder.toAdd[1].GetRole()) - suite.NotEqual(uint64(0), builder.toAdd[1].GetId()) - suite.NotEqual(metapb.PeerRole_Learner, builder.toAdd[4].GetRole()) - suite.Equal(uint64(14), builder.toAdd[4].GetId()) - suite.Equal(metapb.PeerRole_Learner, builder.toAdd[5].GetRole()) - suite.NotEqual(uint64(0), builder.toAdd[5].GetId()) - suite.Len(builder.toRemove, 1) - suite.NotNil(builder.toRemove[1]) - suite.Len(builder.toPromote, 1) - suite.NotNil(builder.toPromote[3]) - suite.Equal(uint64(1), builder.currentLeaderStoreID) + re.NoError(err) + re.Len(builder.toAdd, 3) + re.Equal(metapb.PeerRole_Learner, builder.toAdd[1].GetRole()) + re.NotEqual(uint64(0), builder.toAdd[1].GetId()) + re.NotEqual(metapb.PeerRole_Learner, builder.toAdd[4].GetRole()) + re.Equal(uint64(14), builder.toAdd[4].GetId()) + re.Equal(metapb.PeerRole_Learner, builder.toAdd[5].GetRole()) + re.NotEqual(uint64(0), builder.toAdd[5].GetId()) + re.Len(builder.toRemove, 1) + re.NotNil(builder.toRemove[1]) + re.Len(builder.toPromote, 1) + re.NotNil(builder.toPromote[3]) + re.Equal(uint64(1), builder.currentLeaderStoreID) } func (suite *operatorBuilderTestSuite) TestBuild() { + re := suite.Require() type testCase struct { name string useJointConsensus bool @@ -545,42 +549,42 @@ func (suite *operatorBuilderTestSuite) TestBuild() { builder.SetPeers(m).SetLeader(testCase.targetPeers[0].GetStoreId()) op, err := builder.Build(0) if len(testCase.steps) == 0 { - suite.Error(err) + re.Error(err) continue } - suite.NoError(err) - suite.Equal(testCase.kind, op.Kind()) - suite.Len(testCase.steps, op.Len()) + re.NoError(err) + re.Equal(testCase.kind, op.Kind()) + re.Len(testCase.steps, op.Len()) for i := 0; i < op.Len(); i++ { switch step := op.Step(i).(type) { case TransferLeader: - suite.Equal(testCase.steps[i].(TransferLeader).FromStore, step.FromStore) - suite.Equal(testCase.steps[i].(TransferLeader).ToStore, step.ToStore) + re.Equal(testCase.steps[i].(TransferLeader).FromStore, step.FromStore) + re.Equal(testCase.steps[i].(TransferLeader).ToStore, step.ToStore) case AddPeer: - suite.Equal(testCase.steps[i].(AddPeer).ToStore, step.ToStore) + re.Equal(testCase.steps[i].(AddPeer).ToStore, step.ToStore) case RemovePeer: - suite.Equal(testCase.steps[i].(RemovePeer).FromStore, step.FromStore) + re.Equal(testCase.steps[i].(RemovePeer).FromStore, step.FromStore) case AddLearner: - suite.Equal(testCase.steps[i].(AddLearner).ToStore, step.ToStore) + re.Equal(testCase.steps[i].(AddLearner).ToStore, step.ToStore) case PromoteLearner: - suite.Equal(testCase.steps[i].(PromoteLearner).ToStore, step.ToStore) + re.Equal(testCase.steps[i].(PromoteLearner).ToStore, step.ToStore) case ChangePeerV2Enter: - suite.Len(step.PromoteLearners, len(testCase.steps[i].(ChangePeerV2Enter).PromoteLearners)) - suite.Len(step.DemoteVoters, len(testCase.steps[i].(ChangePeerV2Enter).DemoteVoters)) + re.Len(step.PromoteLearners, len(testCase.steps[i].(ChangePeerV2Enter).PromoteLearners)) + re.Len(step.DemoteVoters, len(testCase.steps[i].(ChangePeerV2Enter).DemoteVoters)) for j, p := range testCase.steps[i].(ChangePeerV2Enter).PromoteLearners { - suite.Equal(p.ToStore, step.PromoteLearners[j].ToStore) + re.Equal(p.ToStore, step.PromoteLearners[j].ToStore) } for j, d := range testCase.steps[i].(ChangePeerV2Enter).DemoteVoters { - suite.Equal(d.ToStore, step.DemoteVoters[j].ToStore) + re.Equal(d.ToStore, step.DemoteVoters[j].ToStore) } case ChangePeerV2Leave: - suite.Len(step.PromoteLearners, len(testCase.steps[i].(ChangePeerV2Leave).PromoteLearners)) - suite.Len(step.DemoteVoters, len(testCase.steps[i].(ChangePeerV2Leave).DemoteVoters)) + re.Len(step.PromoteLearners, len(testCase.steps[i].(ChangePeerV2Leave).PromoteLearners)) + re.Len(step.DemoteVoters, len(testCase.steps[i].(ChangePeerV2Leave).DemoteVoters)) for j, p := range testCase.steps[i].(ChangePeerV2Leave).PromoteLearners { - suite.Equal(p.ToStore, step.PromoteLearners[j].ToStore) + re.Equal(p.ToStore, step.PromoteLearners[j].ToStore) } for j, d := range testCase.steps[i].(ChangePeerV2Leave).DemoteVoters { - suite.Equal(d.ToStore, step.DemoteVoters[j].ToStore) + re.Equal(d.ToStore, step.DemoteVoters[j].ToStore) } } } @@ -588,26 +592,27 @@ func (suite *operatorBuilderTestSuite) TestBuild() { } func (suite *operatorBuilderTestSuite) TestTargetUnhealthyPeer() { + re := suite.Require() p := &metapb.Peer{Id: 2, StoreId: 2, Role: metapb.PeerRole_Learner} region := core.NewRegionInfo(&metapb.Region{Id: 1, Peers: []*metapb.Peer{{Id: 1, StoreId: 1}, p}}, &metapb.Peer{Id: 1, StoreId: 1}, core.WithPendingPeers([]*metapb.Peer{p})) builder := NewBuilder("test", suite.cluster, region) builder.PromoteLearner(2) - suite.Error(builder.err) + re.Error(builder.err) region = core.NewRegionInfo(&metapb.Region{Id: 1, Peers: []*metapb.Peer{{Id: 1, StoreId: 1}, p}}, &metapb.Peer{Id: 1, StoreId: 1}, core.WithDownPeers([]*pdpb.PeerStats{{Peer: p}})) builder = NewBuilder("test", suite.cluster, region) builder.PromoteLearner(2) - suite.Error(builder.err) + re.Error(builder.err) p = &metapb.Peer{Id: 2, StoreId: 2, Role: metapb.PeerRole_Voter} region = core.NewRegionInfo(&metapb.Region{Id: 1, Peers: []*metapb.Peer{{Id: 1, StoreId: 1}, p}}, &metapb.Peer{Id: 1, StoreId: 1}, core.WithPendingPeers([]*metapb.Peer{p})) builder = NewBuilder("test", suite.cluster, region) builder.SetLeader(2) - suite.Error(builder.err) + re.Error(builder.err) region = core.NewRegionInfo(&metapb.Region{Id: 1, Peers: []*metapb.Peer{{Id: 1, StoreId: 1}, p}}, &metapb.Peer{Id: 1, StoreId: 1}, core.WithDownPeers([]*pdpb.PeerStats{{Peer: p}})) builder = NewBuilder("test", suite.cluster, region) builder.SetLeader(2) - suite.Error(builder.err) + re.Error(builder.err) } diff --git a/pkg/schedule/placement/rule_manager_test.go b/pkg/schedule/placement/rule_manager_test.go index c0987f6dd33..0539e935113 100644 --- a/pkg/schedule/placement/rule_manager_test.go +++ b/pkg/schedule/placement/rule_manager_test.go @@ -161,11 +161,11 @@ func TestSaveLoad(t *testing.T) { err := m2.Initialize(3, []string{"no", "labels"}, "") re.NoError(err) re.Len(m2.GetAllRules(), 3) - re.Equal(rules[0].String(), m2.GetRule(DefaultGroupID, DefaultRuleID).String()) - re.Equal(rules[1].String(), m2.GetRule("foo", "baz").String()) - re.Equal(rules[2].String(), m2.GetRule("foo", "bar").String()) - re.Equal(manager.GetRulesCount(), 3) - re.Equal(manager.GetGroupsCount(), 2) + re.Equal(m2.GetRule(DefaultGroupID, DefaultRuleID).String(), rules[0].String()) + re.Equal(m2.GetRule("foo", "baz").String(), rules[1].String()) + re.Equal(m2.GetRule("foo", "bar").String(), rules[2].String()) + re.Equal(3, manager.GetRulesCount()) + re.Equal(2, manager.GetGroupsCount()) } func TestSetAfterGet(t *testing.T) { diff --git a/pkg/schedule/plan/balance_plan_test.go b/pkg/schedule/plan/balance_plan_test.go index 59ad637d5c8..59f2acc689a 100644 --- a/pkg/schedule/plan/balance_plan_test.go +++ b/pkg/schedule/plan/balance_plan_test.go @@ -114,6 +114,7 @@ func (suite *balanceSchedulerPlanAnalyzeTestSuite) TearDownSuite() { } func (suite *balanceSchedulerPlanAnalyzeTestSuite) TestAnalyzerResult1() { + re := suite.Require() plans := make([]Plan, 0) plans = append(plans, &BalanceSchedulerPlan{Source: suite.stores[4], Step: 2, Target: suite.stores[0], Status: NewStatus(StatusStoreScoreDisallowed)}) plans = append(plans, &BalanceSchedulerPlan{Source: suite.stores[4], Step: 2, Target: suite.stores[1], Status: NewStatus(StatusStoreScoreDisallowed)}) @@ -141,9 +142,9 @@ func (suite *balanceSchedulerPlanAnalyzeTestSuite) TestAnalyzerResult1() { plans = append(plans, &BalanceSchedulerPlan{Source: suite.stores[0], Step: 2, Target: suite.stores[3], Status: NewStatus(StatusStoreNotMatchRule)}) plans = append(plans, &BalanceSchedulerPlan{Source: suite.stores[0], Step: 2, Target: suite.stores[4], Status: NewStatus(StatusStoreScoreDisallowed)}) statuses, isNormal, err := BalancePlanSummary(plans) - suite.NoError(err) - suite.True(isNormal) - suite.True(suite.check(statuses, + re.NoError(err) + re.True(isNormal) + re.True(suite.check(statuses, map[uint64]*Status{ 1: NewStatus(StatusStoreNotMatchRule), 2: NewStatus(StatusStoreNotMatchRule), @@ -154,6 +155,7 @@ func (suite *balanceSchedulerPlanAnalyzeTestSuite) TestAnalyzerResult1() { } func (suite *balanceSchedulerPlanAnalyzeTestSuite) TestAnalyzerResult2() { + re := suite.Require() plans := make([]Plan, 0) plans = append(plans, &BalanceSchedulerPlan{Source: suite.stores[4], Step: 0, Status: NewStatus(StatusStoreDown)}) plans = append(plans, &BalanceSchedulerPlan{Source: suite.stores[3], Step: 0, Status: NewStatus(StatusStoreDown)}) @@ -161,9 +163,9 @@ func (suite *balanceSchedulerPlanAnalyzeTestSuite) TestAnalyzerResult2() { plans = append(plans, &BalanceSchedulerPlan{Source: suite.stores[1], Step: 0, Status: NewStatus(StatusStoreDown)}) plans = append(plans, &BalanceSchedulerPlan{Source: suite.stores[0], Step: 0, Status: NewStatus(StatusStoreDown)}) statuses, isNormal, err := BalancePlanSummary(plans) - suite.NoError(err) - suite.False(isNormal) - suite.True(suite.check(statuses, + re.NoError(err) + re.False(isNormal) + re.True(suite.check(statuses, map[uint64]*Status{ 1: NewStatus(StatusStoreDown), 2: NewStatus(StatusStoreDown), @@ -174,6 +176,7 @@ func (suite *balanceSchedulerPlanAnalyzeTestSuite) TestAnalyzerResult2() { } func (suite *balanceSchedulerPlanAnalyzeTestSuite) TestAnalyzerResult3() { + re := suite.Require() plans := make([]Plan, 0) plans = append(plans, &BalanceSchedulerPlan{Source: suite.stores[4], Step: 0, Status: NewStatus(StatusStoreDown)}) plans = append(plans, &BalanceSchedulerPlan{Source: suite.stores[3], Region: suite.regions[0], Step: 1, Status: NewStatus(StatusRegionNotMatchRule)}) @@ -181,9 +184,9 @@ func (suite *balanceSchedulerPlanAnalyzeTestSuite) TestAnalyzerResult3() { plans = append(plans, &BalanceSchedulerPlan{Source: suite.stores[1], Region: suite.regions[1], Step: 1, Status: NewStatus(StatusRegionNotMatchRule)}) plans = append(plans, &BalanceSchedulerPlan{Source: suite.stores[0], Region: suite.regions[1], Step: 1, Status: NewStatus(StatusRegionNotMatchRule)}) statuses, isNormal, err := BalancePlanSummary(plans) - suite.NoError(err) - suite.False(isNormal) - suite.True(suite.check(statuses, + re.NoError(err) + re.False(isNormal) + re.True(suite.check(statuses, map[uint64]*Status{ 1: NewStatus(StatusRegionNotMatchRule), 2: NewStatus(StatusRegionNotMatchRule), @@ -193,6 +196,7 @@ func (suite *balanceSchedulerPlanAnalyzeTestSuite) TestAnalyzerResult3() { } func (suite *balanceSchedulerPlanAnalyzeTestSuite) TestAnalyzerResult4() { + re := suite.Require() plans := make([]Plan, 0) plans = append(plans, &BalanceSchedulerPlan{Source: suite.stores[4], Step: 0, Status: NewStatus(StatusStoreDown)}) plans = append(plans, &BalanceSchedulerPlan{Source: suite.stores[3], Region: suite.regions[0], Step: 1, Status: NewStatus(StatusRegionNotMatchRule)}) @@ -208,9 +212,9 @@ func (suite *balanceSchedulerPlanAnalyzeTestSuite) TestAnalyzerResult4() { plans = append(plans, &BalanceSchedulerPlan{Source: suite.stores[0], Target: suite.stores[3], Step: 2, Status: NewStatus(StatusStoreNotMatchRule)}) plans = append(plans, &BalanceSchedulerPlan{Source: suite.stores[0], Target: suite.stores[4], Step: 2, Status: NewStatus(StatusStoreDown)}) statuses, isNormal, err := BalancePlanSummary(plans) - suite.NoError(err) - suite.False(isNormal) - suite.True(suite.check(statuses, + re.NoError(err) + re.False(isNormal) + re.True(suite.check(statuses, map[uint64]*Status{ 1: NewStatus(StatusStoreAlreadyHasPeer), 2: NewStatus(StatusStoreAlreadyHasPeer), @@ -221,6 +225,7 @@ func (suite *balanceSchedulerPlanAnalyzeTestSuite) TestAnalyzerResult4() { } func (suite *balanceSchedulerPlanAnalyzeTestSuite) TestAnalyzerResult5() { + re := suite.Require() plans := make([]Plan, 0) plans = append(plans, &BalanceSchedulerPlan{Source: suite.stores[4], Step: 0, Status: NewStatus(StatusStoreRemoveLimitThrottled)}) plans = append(plans, &BalanceSchedulerPlan{Source: suite.stores[3], Region: suite.regions[0], Step: 1, Status: NewStatus(StatusRegionNotMatchRule)}) @@ -234,9 +239,9 @@ func (suite *balanceSchedulerPlanAnalyzeTestSuite) TestAnalyzerResult5() { plans = append(plans, &BalanceSchedulerPlan{Source: suite.stores[0], Target: suite.stores[2], Step: 2, Status: NewStatus(StatusStoreNotMatchRule)}) plans = append(plans, &BalanceSchedulerPlan{Source: suite.stores[0], Target: suite.stores[3], Step: 2, Status: NewStatus(StatusStoreNotMatchRule)}) statuses, isNormal, err := BalancePlanSummary(plans) - suite.NoError(err) - suite.False(isNormal) - suite.True(suite.check(statuses, + re.NoError(err) + re.False(isNormal) + re.True(suite.check(statuses, map[uint64]*Status{ 1: NewStatus(StatusStoreAlreadyHasPeer), 2: NewStatus(StatusStoreAlreadyHasPeer), @@ -247,6 +252,7 @@ func (suite *balanceSchedulerPlanAnalyzeTestSuite) TestAnalyzerResult5() { } func (suite *balanceSchedulerPlanAnalyzeTestSuite) TestAnalyzerResult6() { + re := suite.Require() basePlan := NewBalanceSchedulerPlan() collector := NewCollector(basePlan) collector.Collect(SetResourceWithStep(suite.stores[0], 2), SetStatus(NewStatus(StatusStoreDown))) @@ -258,9 +264,9 @@ func (suite *balanceSchedulerPlanAnalyzeTestSuite) TestAnalyzerResult6() { basePlan.Step++ collector.Collect(SetResource(suite.regions[0]), SetStatus(NewStatus(StatusRegionNoLeader))) statuses, isNormal, err := BalancePlanSummary(collector.GetPlans()) - suite.NoError(err) - suite.False(isNormal) - suite.True(suite.check(statuses, + re.NoError(err) + re.False(isNormal) + re.True(suite.check(statuses, map[uint64]*Status{ 1: NewStatus(StatusStoreDown), 2: NewStatus(StatusStoreDown), diff --git a/pkg/schedule/scatter/region_scatterer_test.go b/pkg/schedule/scatter/region_scatterer_test.go index 70517d23fee..af41ed04b76 100644 --- a/pkg/schedule/scatter/region_scatterer_test.go +++ b/pkg/schedule/scatter/region_scatterer_test.go @@ -350,7 +350,7 @@ func TestSomeStoresFilteredScatterGroupInConcurrency(t *testing.T) { // prevent store from being disconnected tc.SetStoreLastHeartbeatInterval(i, 40*time.Minute) } - re.Equal(tc.GetStore(uint64(6)).IsDisconnected(), true) + re.True(tc.GetStore(uint64(6)).IsDisconnected()) scatterer := NewRegionScatterer(ctx, tc, oc, tc.AddSuspectRegions) var wg sync.WaitGroup for j := 0; j < 10; j++ { @@ -466,7 +466,7 @@ func TestScatterForManyRegion(t *testing.T) { re.NoError(failpoint.Enable("github.com/tikv/pd/pkg/schedule/scatter/scatterHbStreamsDrain", `return(true)`)) scatterer.scatterRegions(regions, failures, group, 3, false) re.NoError(failpoint.Disable("github.com/tikv/pd/pkg/schedule/scatter/scatterHbStreamsDrain")) - re.Len(failures, 0) + re.Empty(failures) } func TestScattersGroup(t *testing.T) { diff --git a/pkg/schedule/schedulers/balance_benchmark_test.go b/pkg/schedule/schedulers/balance_benchmark_test.go index 694d5edb658..2d7befd27af 100644 --- a/pkg/schedule/schedulers/balance_benchmark_test.go +++ b/pkg/schedule/schedulers/balance_benchmark_test.go @@ -163,7 +163,7 @@ func BenchmarkPlacementRule(b *testing.B) { ops, plans = sc.Schedule(tc, false) } b.StopTimer() - re.Len(plans, 0) + re.Empty(plans) re.Len(ops, 1) re.Contains(ops[0].String(), "to [191]") } diff --git a/pkg/schedule/schedulers/balance_test.go b/pkg/schedule/schedulers/balance_test.go index 54fe8ff489b..dafe810b2b7 100644 --- a/pkg/schedule/schedulers/balance_test.go +++ b/pkg/schedule/schedulers/balance_test.go @@ -237,9 +237,10 @@ func TestBalanceLeaderSchedulerTestSuite(t *testing.T) { } func (suite *balanceLeaderSchedulerTestSuite) SetupTest() { + re := suite.Require() suite.cancel, suite.conf, suite.tc, suite.oc = prepareSchedulersTest() lb, err := CreateScheduler(BalanceLeaderType, suite.oc, storage.NewStorageWithMemoryBackend(), ConfigSliceDecoder(BalanceLeaderType, []string{"", ""})) - suite.NoError(err) + re.NoError(err) suite.lb = lb } @@ -560,6 +561,7 @@ func (suite *balanceLeaderRangeSchedulerTestSuite) TearDownTest() { } func (suite *balanceLeaderRangeSchedulerTestSuite) TestSingleRangeBalance() { + re := suite.Require() // Stores: 1 2 3 4 // Leaders: 10 10 10 10 // Weight: 0.5 0.9 1 2 @@ -573,36 +575,36 @@ func (suite *balanceLeaderRangeSchedulerTestSuite) TestSingleRangeBalance() { suite.tc.UpdateStoreLeaderWeight(4, 2) suite.tc.AddLeaderRegionWithRange(1, "a", "g", 1, 2, 3, 4) lb, err := CreateScheduler(BalanceLeaderType, suite.oc, storage.NewStorageWithMemoryBackend(), ConfigSliceDecoder(BalanceLeaderType, []string{"", ""})) - suite.NoError(err) + re.NoError(err) ops, _ := lb.Schedule(suite.tc, false) - suite.NotEmpty(ops) - suite.Len(ops, 1) - suite.Len(ops[0].Counters, 1) - suite.Len(ops[0].FinishedCounters, 1) + re.NotEmpty(ops) + re.Len(ops, 1) + re.Len(ops[0].Counters, 1) + re.Len(ops[0].FinishedCounters, 1) lb, err = CreateScheduler(BalanceLeaderType, suite.oc, storage.NewStorageWithMemoryBackend(), ConfigSliceDecoder(BalanceLeaderType, []string{"h", "n"})) - suite.NoError(err) + re.NoError(err) ops, _ = lb.Schedule(suite.tc, false) - suite.Empty(ops) + re.Empty(ops) lb, err = CreateScheduler(BalanceLeaderType, suite.oc, storage.NewStorageWithMemoryBackend(), ConfigSliceDecoder(BalanceLeaderType, []string{"b", "f"})) - suite.NoError(err) + re.NoError(err) ops, _ = lb.Schedule(suite.tc, false) - suite.Empty(ops) + re.Empty(ops) lb, err = CreateScheduler(BalanceLeaderType, suite.oc, storage.NewStorageWithMemoryBackend(), ConfigSliceDecoder(BalanceLeaderType, []string{"", "a"})) - suite.NoError(err) + re.NoError(err) ops, _ = lb.Schedule(suite.tc, false) - suite.Empty(ops) + re.Empty(ops) lb, err = CreateScheduler(BalanceLeaderType, suite.oc, storage.NewStorageWithMemoryBackend(), ConfigSliceDecoder(BalanceLeaderType, []string{"g", ""})) - suite.NoError(err) + re.NoError(err) ops, _ = lb.Schedule(suite.tc, false) - suite.Empty(ops) + re.Empty(ops) lb, err = CreateScheduler(BalanceLeaderType, suite.oc, storage.NewStorageWithMemoryBackend(), ConfigSliceDecoder(BalanceLeaderType, []string{"", "f"})) - suite.NoError(err) + re.NoError(err) ops, _ = lb.Schedule(suite.tc, false) - suite.Empty(ops) + re.Empty(ops) lb, err = CreateScheduler(BalanceLeaderType, suite.oc, storage.NewStorageWithMemoryBackend(), ConfigSliceDecoder(BalanceLeaderType, []string{"b", ""})) - suite.NoError(err) + re.NoError(err) ops, _ = lb.Schedule(suite.tc, false) - suite.Empty(ops) + re.Empty(ops) } func (suite *balanceLeaderRangeSchedulerTestSuite) TestMultiRangeBalance() { diff --git a/pkg/schedule/schedulers/balance_witness_test.go b/pkg/schedule/schedulers/balance_witness_test.go index 59bf04c2303..9bde7e33438 100644 --- a/pkg/schedule/schedulers/balance_witness_test.go +++ b/pkg/schedule/schedulers/balance_witness_test.go @@ -125,7 +125,7 @@ func (suite *balanceWitnessSchedulerTestSuite) TestTransferWitnessOut() { } } } - suite.Equal(3, len(regions)) + suite.Len(regions, 3) for _, count := range targets { suite.Zero(count) } diff --git a/pkg/schedule/schedulers/evict_leader_test.go b/pkg/schedule/schedulers/evict_leader_test.go index d804561f11c..a91b1c3c937 100644 --- a/pkg/schedule/schedulers/evict_leader_test.go +++ b/pkg/schedule/schedulers/evict_leader_test.go @@ -97,7 +97,7 @@ func TestConfigClone(t *testing.T) { con3 := con2.Clone() con3.StoreIDWithRanges[1], _ = getKeyRanges([]string{"a", "b", "c", "d"}) re.Empty(emptyConf.getKeyRangesByID(1)) - re.False(len(con3.getRanges(1)) == len(con2.getRanges(1))) + re.NotEqual(len(con3.getRanges(1)), len(con2.getRanges(1))) con4 := con3.Clone() re.True(bytes.Equal(con4.StoreIDWithRanges[1][0].StartKey, con3.StoreIDWithRanges[1][0].StartKey)) diff --git a/pkg/schedule/schedulers/evict_slow_trend_test.go b/pkg/schedule/schedulers/evict_slow_trend_test.go index 65a70962a20..aed41e83ecd 100644 --- a/pkg/schedule/schedulers/evict_slow_trend_test.go +++ b/pkg/schedule/schedulers/evict_slow_trend_test.go @@ -100,10 +100,10 @@ func (suite *evictSlowTrendTestSuite) TestEvictSlowTrendBasicFuncs() { // Pop captured store 1 and mark it has recovered. time.Sleep(50 * time.Millisecond) suite.Equal(es2.conf.popCandidate(true), store.GetID()) - suite.True(es2.conf.evictCandidate == (slowCandidate{})) + suite.Equal(slowCandidate{}, es2.conf.evictCandidate) es2.conf.markCandidateRecovered() lastCapturedCandidate = es2.conf.lastCapturedCandidate() - suite.True(lastCapturedCandidate.recoverTS.Compare(recoverTS) > 0) + suite.Greater(lastCapturedCandidate.recoverTS.Compare(recoverTS), 0) suite.Equal(lastCapturedCandidate.storeID, store.GetID()) // Test capture another store 2 diff --git a/pkg/schedule/schedulers/hot_region_test.go b/pkg/schedule/schedulers/hot_region_test.go index 6e7208e4251..5b1bc3db4b4 100644 --- a/pkg/schedule/schedulers/hot_region_test.go +++ b/pkg/schedule/schedulers/hot_region_test.go @@ -180,11 +180,11 @@ func checkGCPendingOpInfos(re *require.Assertions, enablePlacementRules bool) { kind := hb.regionPendings[regionID].op.Kind() switch typ { case transferLeader: - re.True(kind&operator.OpLeader != 0) - re.True(kind&operator.OpRegion == 0) + re.NotZero(kind & operator.OpLeader) + re.Zero(kind & operator.OpRegion) case movePeer: - re.True(kind&operator.OpLeader == 0) - re.True(kind&operator.OpRegion != 0) + re.Zero(kind & operator.OpLeader) + re.NotZero(kind & operator.OpRegion) } } } @@ -257,7 +257,7 @@ func TestSplitIfRegionTooHot(t *testing.T) { re.Equal(expectOp.Kind(), ops[0].Kind()) ops, _ = hb.Schedule(tc, false) - re.Len(ops, 0) + re.Empty(ops) tc.UpdateStorageWrittenBytes(1, 6*units.MiB*utils.StoreHeartBeatReportInterval) tc.UpdateStorageWrittenBytes(2, 1*units.MiB*utils.StoreHeartBeatReportInterval) @@ -276,7 +276,7 @@ func TestSplitIfRegionTooHot(t *testing.T) { re.Equal(operator.OpSplit, ops[0].Kind()) ops, _ = hb.Schedule(tc, false) - re.Len(ops, 0) + re.Empty(ops) } func TestSplitBucketsBySize(t *testing.T) { @@ -319,10 +319,10 @@ func TestSplitBucketsBySize(t *testing.T) { region.UpdateBuckets(b, region.GetBuckets()) ops := solve.createSplitOperator([]*core.RegionInfo{region}, bySize) if data.splitKeys == nil { - re.Equal(0, len(ops)) + re.Empty(ops) continue } - re.Equal(1, len(ops)) + re.Len(ops, 1) op := ops[0] re.Equal(splitHotReadBuckets, op.Desc()) @@ -380,10 +380,10 @@ func TestSplitBucketsByLoad(t *testing.T) { time.Sleep(time.Millisecond * 10) ops := solve.createSplitOperator([]*core.RegionInfo{region}, byLoad) if data.splitKeys == nil { - re.Equal(0, len(ops)) + re.Empty(ops) continue } - re.Equal(1, len(ops)) + re.Len(ops, 1) op := ops[0] re.Equal(splitHotReadBuckets, op.Desc()) @@ -731,7 +731,7 @@ func TestHotWriteRegionScheduleByteRateOnlyWithTiFlash(t *testing.T) { loadsEqual( hb.stLoadInfos[writeLeader][1].LoadPred.Expect.Loads, []float64{hotRegionBytesSum / allowLeaderTiKVCount, hotRegionKeysSum / allowLeaderTiKVCount, tikvQuerySum / allowLeaderTiKVCount})) - re.True(tikvQuerySum != hotRegionQuerySum) + re.NotEqual(tikvQuerySum, hotRegionQuerySum) re.True( loadsEqual( hb.stLoadInfos[writePeer][1].LoadPred.Expect.Loads, @@ -1574,7 +1574,7 @@ func TestHotReadWithEvictLeaderScheduler(t *testing.T) { // two dim are both enough uniform among three stores tc.SetStoreEvictLeader(4, true) ops, _ = hb.Schedule(tc, false) - re.Len(ops, 0) + re.Empty(ops) clearPendingInfluence(hb.(*hotScheduler)) } diff --git a/pkg/schedule/schedulers/hot_region_v2_test.go b/pkg/schedule/schedulers/hot_region_v2_test.go index d11ac44dde9..f5e21e02981 100644 --- a/pkg/schedule/schedulers/hot_region_v2_test.go +++ b/pkg/schedule/schedulers/hot_region_v2_test.go @@ -309,7 +309,7 @@ func TestSkipUniformStore(t *testing.T) { // when there is uniform store filter, not schedule stddevThreshold = 0.1 ops, _ = hb.Schedule(tc, false) - re.Len(ops, 0) + re.Empty(ops) clearPendingInfluence(hb.(*hotScheduler)) // Case2: the first dim is enough uniform, we should schedule the second dim @@ -380,7 +380,7 @@ func TestHotReadRegionScheduleWithSmallHotRegion(t *testing.T) { ops = checkHotReadRegionScheduleWithSmallHotRegion(re, highLoad, lowLoad, emptyFunc) re.Len(ops, 1) ops = checkHotReadRegionScheduleWithSmallHotRegion(re, lowLoad, highLoad, emptyFunc) - re.Len(ops, 0) + re.Empty(ops) // Case3: If there is larger hot region, we will schedule it. hotRegionID := uint64(100) @@ -418,7 +418,7 @@ func TestHotReadRegionScheduleWithSmallHotRegion(t *testing.T) { tc.AddRegionWithReadInfo(hotRegionID+1, 2, bigHotRegionByte, 0, bigHotRegionQuery, utils.StoreHeartBeatReportInterval, []uint64{1, 3}) tc.AddRegionWithReadInfo(hotRegionID+1, 1, bigHotRegionByte, 0, bigHotRegionQuery, utils.StoreHeartBeatReportInterval, []uint64{2, 3}) }) - re.Len(ops, 0) + re.Empty(ops) topnPosition = origin // Case7: If there are more than topnPosition hot regions, but them are pending, @@ -430,7 +430,7 @@ func TestHotReadRegionScheduleWithSmallHotRegion(t *testing.T) { tc.AddRegionWithReadInfo(hotRegionID+1, 1, bigHotRegionByte, 0, bigHotRegionQuery, utils.StoreHeartBeatReportInterval, []uint64{2, 3}) hb.regionPendings[hotRegionID+1] = &pendingInfluence{} }) - re.Len(ops, 0) + re.Empty(ops) topnPosition = origin } diff --git a/pkg/schedule/schedulers/scheduler_test.go b/pkg/schedule/schedulers/scheduler_test.go index 57f1fcf1e3f..77c190ad943 100644 --- a/pkg/schedule/schedulers/scheduler_test.go +++ b/pkg/schedule/schedulers/scheduler_test.go @@ -484,7 +484,7 @@ func TestBalanceLeaderWithConflictRule(t *testing.T) { } for _, testCase := range testCases { - re.Nil(tc.SetRule(testCase.rule)) + re.NoError(tc.SetRule(testCase.rule)) ops, _ := lb.Schedule(tc, false) if testCase.schedule { re.Len(ops, 1) diff --git a/pkg/storage/storage_gc_test.go b/pkg/storage/storage_gc_test.go index 141777d441e..77f7c7dbf65 100644 --- a/pkg/storage/storage_gc_test.go +++ b/pkg/storage/storage_gc_test.go @@ -93,7 +93,7 @@ func TestLoadMinServiceSafePoint(t *testing.T) { // gc_worker service safepoint will not be removed. ssp, err := storage.LoadMinServiceSafePointV2(testKeyspaceID, currentTime.Add(5000*time.Second)) re.NoError(err) - re.Equal(ssp.ServiceID, endpoint.GCWorkerServiceSafePointID) + re.Equal(endpoint.GCWorkerServiceSafePointID, ssp.ServiceID) re.NoError(failpoint.Disable("github.com/tikv/pd/pkg/storage/endpoint/removeExpiredKeys")) } diff --git a/pkg/tso/keyspace_group_manager_test.go b/pkg/tso/keyspace_group_manager_test.go index 0c1b017d7aa..54a1adc6b34 100644 --- a/pkg/tso/keyspace_group_manager_test.go +++ b/pkg/tso/keyspace_group_manager_test.go @@ -988,7 +988,7 @@ func (suite *keyspaceGroupManagerTestSuite) TestUpdateKeyspaceGroupMembership() re.Equal(len(keyspaces), len(newGroup.Keyspaces)) for i := 0; i < len(newGroup.Keyspaces); i++ { if i > 0 { - re.True(newGroup.Keyspaces[i-1] < newGroup.Keyspaces[i]) + re.Less(newGroup.Keyspaces[i-1], newGroup.Keyspaces[i]) } } } diff --git a/pkg/unsaferecovery/unsafe_recovery_controller_test.go b/pkg/unsaferecovery/unsafe_recovery_controller_test.go index 44c4e4a7b4d..956b9b8729c 100644 --- a/pkg/unsaferecovery/unsafe_recovery_controller_test.go +++ b/pkg/unsaferecovery/unsafe_recovery_controller_test.go @@ -1158,7 +1158,7 @@ func TestExecutionTimeout(t *testing.T) { re.Equal(Failed, recoveryController.GetStage()) output := recoveryController.Show() - re.Equal(len(output), 3) + re.Len(output, 3) re.Contains(output[1].Details[0], "triggered by error: Exceeds timeout") } @@ -1768,7 +1768,7 @@ func TestEpochComparsion(t *testing.T) { cluster.PutStore(store) } recoveryController := NewController(cluster) - re.Nil(recoveryController.RemoveFailedStores(map[uint64]struct{}{ + re.NoError(recoveryController.RemoveFailedStores(map[uint64]struct{}{ 2: {}, 3: {}, }, 60, false)) @@ -1829,7 +1829,7 @@ func TestEpochComparsion(t *testing.T) { if expect, ok := expects[storeID]; ok { re.Equal(expect.PeerReports, report.PeerReports) } else { - re.Empty(len(report.PeerReports)) + re.Empty(report.PeerReports) } } } diff --git a/pkg/utils/etcdutil/etcdutil_test.go b/pkg/utils/etcdutil/etcdutil_test.go index 861a57cef13..d8b38e7b045 100644 --- a/pkg/utils/etcdutil/etcdutil_test.go +++ b/pkg/utils/etcdutil/etcdutil_test.go @@ -154,11 +154,11 @@ func TestInitClusterID(t *testing.T) { // Get any cluster key to parse the cluster ID. resp, err := EtcdKVGet(client, pdClusterIDPath) re.NoError(err) - re.Equal(0, len(resp.Kvs)) + re.Empty(resp.Kvs) clusterID, err := InitClusterID(client, pdClusterIDPath) re.NoError(err) - re.NotEqual(0, clusterID) + re.NotZero(clusterID) clusterID1, err := InitClusterID(client, pdClusterIDPath) re.NoError(err) @@ -375,15 +375,15 @@ func TestLoopWatcherTestSuite(t *testing.T) { } func (suite *loopWatcherTestSuite) SetupSuite() { + re := suite.Require() var err error - t := suite.T() suite.ctx, suite.cancel = context.WithCancel(context.Background()) suite.cleans = make([]func(), 0) // Start a etcd server and create a client with etcd1 as endpoint. - suite.config = newTestSingleConfig(t) - suite.startEtcd() + suite.config = newTestSingleConfig(suite.T()) + suite.startEtcd(re) suite.client, err = CreateEtcdClient(nil, suite.config.LCUrls) - suite.NoError(err) + re.NoError(err) suite.cleans = append(suite.cleans, func() { suite.client.Close() }) @@ -398,6 +398,7 @@ func (suite *loopWatcherTestSuite) TearDownSuite() { } func (suite *loopWatcherTestSuite) TestLoadWithoutKey() { + re := suite.Require() cache := struct { syncutil.RWMutex data map[string]struct{} @@ -422,13 +423,14 @@ func (suite *loopWatcherTestSuite) TestLoadWithoutKey() { ) watcher.StartWatchLoop() err := watcher.WaitLoad() - suite.NoError(err) // although no key, watcher returns no error + re.NoError(err) // although no key, watcher returns no error cache.RLock() defer cache.RUnlock() - suite.Len(cache.data, 0) + suite.Empty(cache.data) } func (suite *loopWatcherTestSuite) TestCallBack() { + re := suite.Require() cache := struct { syncutil.RWMutex data map[string]struct{} @@ -466,35 +468,36 @@ func (suite *loopWatcherTestSuite) TestCallBack() { ) watcher.StartWatchLoop() err := watcher.WaitLoad() - suite.NoError(err) + re.NoError(err) // put 10 keys for i := 0; i < 10; i++ { - suite.put(fmt.Sprintf("TestCallBack%d", i), "") + suite.put(re, fmt.Sprintf("TestCallBack%d", i), "") } time.Sleep(time.Second) cache.RLock() - suite.Len(cache.data, 10) + re.Len(cache.data, 10) cache.RUnlock() // delete 10 keys for i := 0; i < 10; i++ { key := fmt.Sprintf("TestCallBack%d", i) _, err = suite.client.Delete(suite.ctx, key) - suite.NoError(err) + re.NoError(err) } time.Sleep(time.Second) cache.RLock() - suite.Empty(cache.data) + re.Empty(cache.data) cache.RUnlock() } func (suite *loopWatcherTestSuite) TestWatcherLoadLimit() { + re := suite.Require() for count := 1; count < 10; count++ { for limit := 0; limit < 10; limit++ { ctx, cancel := context.WithCancel(suite.ctx) for i := 0; i < count; i++ { - suite.put(fmt.Sprintf("TestWatcherLoadLimit%d", i), "") + suite.put(re, fmt.Sprintf("TestWatcherLoadLimit%d", i), "") } cache := struct { syncutil.RWMutex @@ -525,9 +528,9 @@ func (suite *loopWatcherTestSuite) TestWatcherLoadLimit() { ) watcher.StartWatchLoop() err := watcher.WaitLoad() - suite.NoError(err) + re.NoError(err) cache.RLock() - suite.Len(cache.data, count) + re.Len(cache.data, count) cache.RUnlock() cancel() } @@ -535,6 +538,7 @@ func (suite *loopWatcherTestSuite) TestWatcherLoadLimit() { } func (suite *loopWatcherTestSuite) TestWatcherBreak() { + re := suite.Require() cache := struct { syncutil.RWMutex data string @@ -568,51 +572,51 @@ func (suite *loopWatcherTestSuite) TestWatcherBreak() { watcher.watchChangeRetryInterval = 100 * time.Millisecond watcher.StartWatchLoop() err := watcher.WaitLoad() - suite.NoError(err) + re.NoError(err) checkCache("") // we use close client and update client in failpoint to simulate the network error and recover - failpoint.Enable("github.com/tikv/pd/pkg/utils/etcdutil/updateClient", "return(true)") + re.NoError(failpoint.Enable("github.com/tikv/pd/pkg/utils/etcdutil/updateClient", "return(true)")) // Case1: restart the etcd server suite.etcd.Close() - suite.startEtcd() - suite.put("TestWatcherBreak", "0") + suite.startEtcd(re) + suite.put(re, "TestWatcherBreak", "0") checkCache("0") suite.etcd.Server.Stop() time.Sleep(DefaultRequestTimeout) suite.etcd.Close() - suite.startEtcd() - suite.put("TestWatcherBreak", "1") + suite.startEtcd(re) + suite.put(re, "TestWatcherBreak", "1") checkCache("1") // Case2: close the etcd client and put a new value after watcher restarts suite.client.Close() suite.client, err = CreateEtcdClient(nil, suite.config.LCUrls) - suite.NoError(err) + re.NoError(err) watcher.updateClientCh <- suite.client - suite.put("TestWatcherBreak", "2") + suite.put(re, "TestWatcherBreak", "2") checkCache("2") // Case3: close the etcd client and put a new value before watcher restarts suite.client.Close() suite.client, err = CreateEtcdClient(nil, suite.config.LCUrls) - suite.NoError(err) - suite.put("TestWatcherBreak", "3") + re.NoError(err) + suite.put(re, "TestWatcherBreak", "3") watcher.updateClientCh <- suite.client checkCache("3") // Case4: close the etcd client and put a new value with compact suite.client.Close() suite.client, err = CreateEtcdClient(nil, suite.config.LCUrls) - suite.NoError(err) - suite.put("TestWatcherBreak", "4") + re.NoError(err) + suite.put(re, "TestWatcherBreak", "4") resp, err := EtcdKVGet(suite.client, "TestWatcherBreak") - suite.NoError(err) + re.NoError(err) revision := resp.Header.Revision resp2, err := suite.etcd.Server.Compact(suite.ctx, &etcdserverpb.CompactionRequest{Revision: revision}) - suite.NoError(err) - suite.Equal(revision, resp2.Header.Revision) + re.NoError(err) + re.Equal(revision, resp2.Header.Revision) watcher.updateClientCh <- suite.client checkCache("4") @@ -623,7 +627,7 @@ func (suite *loopWatcherTestSuite) TestWatcherBreak() { watcher.ForceLoad() checkCache("4") - failpoint.Disable("github.com/tikv/pd/pkg/utils/etcdutil/updateClient") + re.NoError(failpoint.Disable("github.com/tikv/pd/pkg/utils/etcdutil/updateClient")) } func (suite *loopWatcherTestSuite) TestWatcherRequestProgress() { @@ -669,9 +673,9 @@ func (suite *loopWatcherTestSuite) TestWatcherRequestProgress() { checkWatcherRequestProgress(true) } -func (suite *loopWatcherTestSuite) startEtcd() { +func (suite *loopWatcherTestSuite) startEtcd(re *require.Assertions) { etcd1, err := embed.StartEtcd(suite.config) - suite.NoError(err) + re.NoError(err) suite.etcd = etcd1 <-etcd1.Server.ReadyNotify() suite.cleans = append(suite.cleans, func() { @@ -679,11 +683,11 @@ func (suite *loopWatcherTestSuite) startEtcd() { }) } -func (suite *loopWatcherTestSuite) put(key, value string) { +func (suite *loopWatcherTestSuite) put(re *require.Assertions, key, value string) { kv := clientv3.NewKV(suite.client) _, err := kv.Put(suite.ctx, key, value) - suite.NoError(err) + re.NoError(err) resp, err := kv.Get(suite.ctx, key) - suite.NoError(err) - suite.Equal(value, string(resp.Kvs[0].Value)) + re.NoError(err) + re.Equal(value, string(resp.Kvs[0].Value)) } diff --git a/pkg/utils/syncutil/lock_group_test.go b/pkg/utils/syncutil/lock_group_test.go index ff306983e05..897e6b777a6 100644 --- a/pkg/utils/syncutil/lock_group_test.go +++ b/pkg/utils/syncutil/lock_group_test.go @@ -60,14 +60,14 @@ func TestLockGroupWithRemoveEntryOnUnlock(t *testing.T) { for i := 0; i < maxID; i++ { group.Lock(uint32(i)) } - re.Equal(len(group.entries), maxID) + re.Len(group.entries, maxID) for i := 0; i < maxID; i++ { group.Unlock(uint32(i)) } wg.Wait() // Check that size of the lock group is limited. - re.Equal(len(group.entries), 0) + re.Empty(group.entries) } // mustSequentialUpdateSingle checks that for any given update, update is sequential. diff --git a/pkg/utils/typeutil/duration_test.go b/pkg/utils/typeutil/duration_test.go index 9a0beda7979..cff7c3cd66c 100644 --- a/pkg/utils/typeutil/duration_test.go +++ b/pkg/utils/typeutil/duration_test.go @@ -46,6 +46,6 @@ func TestDurationTOML(t *testing.T) { example := &example{} text := []byte(`interval = "1h1m1s"`) - re.Nil(toml.Unmarshal(text, example)) + re.NoError(toml.Unmarshal(text, example)) re.Equal(float64(60*60+60+1), example.Interval.Seconds()) } diff --git a/pkg/window/policy_test.go b/pkg/window/policy_test.go index 489c8428c9a..a81ef0ef82d 100644 --- a/pkg/window/policy_test.go +++ b/pkg/window/policy_test.go @@ -111,7 +111,7 @@ func TestRollingPolicy_AddWithTimespan(t *testing.T) { t.Logf("%+v", bkt) } - re.Equal(0, len(policy.window.buckets[0].Points)) + re.Empty(policy.window.buckets[0].Points) re.Equal(4, int(policy.window.buckets[1].Points[0])) re.Equal(2, int(policy.window.buckets[2].Points[0])) }) @@ -137,8 +137,8 @@ func TestRollingPolicy_AddWithTimespan(t *testing.T) { t.Logf("%+v", bkt) } - re.Equal(0, len(policy.window.buckets[0].Points)) + re.Zero(len(policy.window.buckets[0].Points)) re.Equal(4, int(policy.window.buckets[1].Points[0])) - re.Equal(0, len(policy.window.buckets[2].Points)) + re.Zero(len(policy.window.buckets[2].Points)) }) } diff --git a/pkg/window/window_test.go b/pkg/window/window_test.go index 0205aae47a3..f4df861fc2f 100644 --- a/pkg/window/window_test.go +++ b/pkg/window/window_test.go @@ -20,7 +20,6 @@ package window import ( "testing" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -33,7 +32,7 @@ func TestWindowResetWindow(t *testing.T) { } window.ResetWindow() for i := 0; i < opts.Size; i++ { - re.Equal(len(window.Bucket(i).Points), 0) + re.Empty(window.Bucket(i).Points) } } @@ -45,9 +44,9 @@ func TestWindowResetBucket(t *testing.T) { window.Append(i, 1.0) } window.ResetBucket(1) - re.Equal(len(window.Bucket(1).Points), 0) - re.Equal(window.Bucket(0).Points[0], float64(1.0)) - re.Equal(window.Bucket(2).Points[0], float64(1.0)) + re.Empty(window.Bucket(1).Points) + re.Equal(float64(1.0), window.Bucket(0).Points[0]) + re.Equal(float64(1.0), window.Bucket(2).Points[0]) } func TestWindowResetBuckets(t *testing.T) { @@ -59,7 +58,7 @@ func TestWindowResetBuckets(t *testing.T) { } window.ResetBuckets(0, 3) for i := 0; i < opts.Size; i++ { - re.Equal(len(window.Bucket(i).Points), 0) + re.Empty(window.Bucket(i).Points) } } @@ -74,28 +73,30 @@ func TestWindowAppend(t *testing.T) { window.Append(i, 2.0) } for i := 0; i < opts.Size; i++ { - re.Equal(window.Bucket(i).Points[0], float64(1.0)) + re.Equal(float64(1.0), window.Bucket(i).Points[0]) } for i := 1; i < opts.Size; i++ { - re.Equal(window.Bucket(i).Points[1], float64(2.0)) + re.Equal(float64(2.0), window.Bucket(i).Points[1]) } } func TestWindowAdd(t *testing.T) { + re := require.New(t) opts := Options{Size: 3} window := NewWindow(opts) window.Append(0, 1.0) window.Add(0, 1.0) - assert.Equal(t, window.Bucket(0).Points[0], float64(2.0)) + re.Equal(float64(2.0), window.Bucket(0).Points[0]) window = NewWindow(opts) window.Add(0, 1.0) window.Add(0, 1.0) - assert.Equal(t, window.Bucket(0).Points[0], float64(2.0)) + re.Equal(float64(2.0), window.Bucket(0).Points[0]) } func TestWindowSize(t *testing.T) { + re := require.New(t) opts := Options{Size: 3} window := NewWindow(opts) - assert.Equal(t, window.Size(), 3) + re.Equal(3, window.Size()) } diff --git a/scripts/check-test.sh b/scripts/check-test.sh deleted file mode 100755 index c3168066e3d..00000000000 --- a/scripts/check-test.sh +++ /dev/null @@ -1,37 +0,0 @@ -#!/bin/bash - -# Check if there is any inefficient assert function usage in package. - -res=$(grep -rn --include=\*_test.go -E "(re|suite|require)\.(True|False)\((t, )?reflect\.DeepEqual\(" . | sort -u) \ - -if [ "$res" ]; then - echo "following packages use the inefficient assert function: please replace reflect.DeepEqual with require.Equal" - echo "$res" - exit 1 -fi - -res=$(grep -rn --include=\*_test.go -E "(re|suite|require)\.(True|False)\((t, )?strings\.Contains\(" . | sort -u) - -if [ "$res" ]; then - echo "following packages use the inefficient assert function: please replace strings.Contains with require.Contains" - echo "$res" - exit 1 -fi - -res=$(grep -rn --include=\*_test.go -E "(re|suite|require)\.(Nil|NotNil)\((t, )?(err|error)" . | sort -u) - -if [ "$res" ]; then - echo "following packages use the inefficient assert function: please replace require.Nil/NotNil with require.NoError/Error" - echo "$res" - exit 1 -fi - -res=$(grep -rn --include=\*_test.go -E "(re|suite|require)\.(Equal|NotEqual)\((t, )?(true|false)" . | sort -u) - -if [ "$res" ]; then - echo "following packages use the inefficient assert function: please replace require.Equal/NotEqual(true, xxx) with require.True/False" - echo "$res" - exit 1 -fi - -exit 0 diff --git a/server/api/admin_test.go b/server/api/admin_test.go index 76c5e729eb0..050aa9cfb32 100644 --- a/server/api/admin_test.go +++ b/server/api/admin_test.go @@ -60,6 +60,7 @@ func (suite *adminTestSuite) TearDownSuite() { } func (suite *adminTestSuite) TestDropRegion() { + re := suite.Require() cluster := suite.svr.GetRaftCluster() // Update region's epoch to (100, 100). @@ -73,7 +74,7 @@ func (suite *adminTestSuite) TestDropRegion() { }, })) err := cluster.HandleRegionHeartbeat(region) - suite.NoError(err) + re.NoError(err) // Region epoch cannot decrease. region = region.Clone( @@ -81,25 +82,26 @@ func (suite *adminTestSuite) TestDropRegion() { core.SetRegionVersion(50), ) err = cluster.HandleRegionHeartbeat(region) - suite.Error(err) + re.Error(err) // After drop region from cache, lower version is accepted. url := fmt.Sprintf("%s/admin/cache/region/%d", suite.urlPrefix, region.GetID()) req, err := http.NewRequest(http.MethodDelete, url, http.NoBody) - suite.NoError(err) + re.NoError(err) res, err := testDialClient.Do(req) - suite.NoError(err) - suite.Equal(http.StatusOK, res.StatusCode) + re.NoError(err) + re.Equal(http.StatusOK, res.StatusCode) res.Body.Close() err = cluster.HandleRegionHeartbeat(region) - suite.NoError(err) + re.NoError(err) region = cluster.GetRegionByKey([]byte("foo")) - suite.Equal(uint64(50), region.GetRegionEpoch().ConfVer) - suite.Equal(uint64(50), region.GetRegionEpoch().Version) + re.Equal(uint64(50), region.GetRegionEpoch().ConfVer) + re.Equal(uint64(50), region.GetRegionEpoch().Version) } func (suite *adminTestSuite) TestDropRegions() { + re := suite.Require() cluster := suite.svr.GetRaftCluster() n := uint64(10000) @@ -124,7 +126,7 @@ func (suite *adminTestSuite) TestDropRegions() { regions = append(regions, region) err := cluster.HandleRegionHeartbeat(region) - suite.NoError(err) + re.NoError(err) } // Region epoch cannot decrease. @@ -135,46 +137,46 @@ func (suite *adminTestSuite) TestDropRegions() { ) regions[i] = region err := cluster.HandleRegionHeartbeat(region) - suite.Error(err) + re.Error(err) } for i := uint64(0); i < n; i++ { region := cluster.GetRegionByKey([]byte(fmt.Sprintf("%d", i))) - suite.Equal(uint64(100), region.GetRegionEpoch().ConfVer) - suite.Equal(uint64(100), region.GetRegionEpoch().Version) + re.Equal(uint64(100), region.GetRegionEpoch().ConfVer) + re.Equal(uint64(100), region.GetRegionEpoch().Version) } // After drop all regions from cache, lower version is accepted. url := fmt.Sprintf("%s/admin/cache/regions", suite.urlPrefix) req, err := http.NewRequest(http.MethodDelete, url, http.NoBody) - suite.NoError(err) + re.NoError(err) res, err := testDialClient.Do(req) - suite.NoError(err) - suite.Equal(http.StatusOK, res.StatusCode) + re.NoError(err) + re.Equal(http.StatusOK, res.StatusCode) res.Body.Close() for _, region := range regions { err := cluster.HandleRegionHeartbeat(region) - suite.NoError(err) + re.NoError(err) } for i := uint64(0); i < n; i++ { region := cluster.GetRegionByKey([]byte(fmt.Sprintf("%d", i))) - suite.Equal(uint64(50), region.GetRegionEpoch().ConfVer) - suite.Equal(uint64(50), region.GetRegionEpoch().Version) + re.Equal(uint64(50), region.GetRegionEpoch().ConfVer) + re.Equal(uint64(50), region.GetRegionEpoch().Version) } } func (suite *adminTestSuite) TestPersistFile() { - data := []byte("#!/bin/sh\nrm -rf /") re := suite.Require() + data := []byte("#!/bin/sh\nrm -rf /") err := tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/admin/persist-file/"+replication.DrStatusFile, data, tu.StatusNotOK(re)) - suite.NoError(err) + re.NoError(err) data = []byte(`{"foo":"bar"}`) err = tu.CheckPostJSON(testDialClient, suite.urlPrefix+"/admin/persist-file/"+replication.DrStatusFile, data, tu.StatusOK(re)) - suite.NoError(err) + re.NoError(err) } func makeTS(offset time.Duration) uint64 { @@ -183,13 +185,13 @@ func makeTS(offset time.Duration) uint64 { } func (suite *adminTestSuite) TestResetTS() { + re := suite.Require() args := make(map[string]interface{}) t1 := makeTS(time.Hour) url := fmt.Sprintf("%s/admin/reset-ts", suite.urlPrefix) args["tso"] = fmt.Sprintf("%d", t1) values, err := json.Marshal(args) - suite.NoError(err) - re := suite.Require() + re.NoError(err) tu.Eventually(re, func() bool { resp, err := apiutil.PostJSON(testDialClient, url, values) re.NoError(err) @@ -208,128 +210,128 @@ func (suite *adminTestSuite) TestResetTS() { return false } }) - suite.NoError(err) + re.NoError(err) t2 := makeTS(32 * time.Hour) args["tso"] = fmt.Sprintf("%d", t2) values, err = json.Marshal(args) - suite.NoError(err) + re.NoError(err) err = tu.CheckPostJSON(testDialClient, url, values, tu.Status(re, http.StatusForbidden), tu.StringContain(re, "too large")) - suite.NoError(err) + re.NoError(err) t3 := makeTS(-2 * time.Hour) args["tso"] = fmt.Sprintf("%d", t3) values, err = json.Marshal(args) - suite.NoError(err) + re.NoError(err) err = tu.CheckPostJSON(testDialClient, url, values, tu.Status(re, http.StatusForbidden), tu.StringContain(re, "small")) - suite.NoError(err) + re.NoError(err) args["tso"] = "" values, err = json.Marshal(args) - suite.NoError(err) + re.NoError(err) err = tu.CheckPostJSON(testDialClient, url, values, tu.Status(re, http.StatusBadRequest), tu.StringEqual(re, "\"invalid tso value\"\n")) - suite.NoError(err) + re.NoError(err) args["tso"] = "test" values, err = json.Marshal(args) - suite.NoError(err) + re.NoError(err) err = tu.CheckPostJSON(testDialClient, url, values, tu.Status(re, http.StatusBadRequest), tu.StringEqual(re, "\"invalid tso value\"\n")) - suite.NoError(err) + re.NoError(err) t4 := makeTS(32 * time.Hour) args["tso"] = fmt.Sprintf("%d", t4) args["force-use-larger"] = "xxx" values, err = json.Marshal(args) - suite.NoError(err) + re.NoError(err) err = tu.CheckPostJSON(testDialClient, url, values, tu.Status(re, http.StatusBadRequest), tu.StringContain(re, "invalid force-use-larger value")) - suite.NoError(err) + re.NoError(err) args["force-use-larger"] = false values, err = json.Marshal(args) - suite.NoError(err) + re.NoError(err) err = tu.CheckPostJSON(testDialClient, url, values, tu.Status(re, http.StatusForbidden), tu.StringContain(re, "too large")) - suite.NoError(err) + re.NoError(err) args["force-use-larger"] = true values, err = json.Marshal(args) - suite.NoError(err) + re.NoError(err) err = tu.CheckPostJSON(testDialClient, url, values, tu.StatusOK(re), tu.StringEqual(re, "\"Reset ts successfully.\"\n")) - suite.NoError(err) + re.NoError(err) } func (suite *adminTestSuite) TestMarkSnapshotRecovering() { re := suite.Require() url := fmt.Sprintf("%s/admin/cluster/markers/snapshot-recovering", suite.urlPrefix) // default to false - suite.NoError(tu.CheckGetJSON(testDialClient, url, nil, + re.NoError(tu.CheckGetJSON(testDialClient, url, nil, tu.StatusOK(re), tu.StringContain(re, "false"))) // mark - suite.NoError(tu.CheckPostJSON(testDialClient, url, nil, + re.NoError(tu.CheckPostJSON(testDialClient, url, nil, tu.StatusOK(re))) - suite.NoError(tu.CheckGetJSON(testDialClient, url, nil, + re.NoError(tu.CheckGetJSON(testDialClient, url, nil, tu.StatusOK(re), tu.StringContain(re, "true"))) // test using grpc call grpcServer := server.GrpcServer{Server: suite.svr} resp, err2 := grpcServer.IsSnapshotRecovering(context.Background(), &pdpb.IsSnapshotRecoveringRequest{}) - suite.NoError(err2) - suite.True(resp.Marked) + re.NoError(err2) + re.True(resp.Marked) // unmark err := tu.CheckDelete(testDialClient, url, tu.StatusOK(re)) - suite.NoError(err) - suite.NoError(tu.CheckGetJSON(testDialClient, url, nil, + re.NoError(err) + re.NoError(tu.CheckGetJSON(testDialClient, url, nil, tu.StatusOK(re), tu.StringContain(re, "false"))) } func (suite *adminTestSuite) TestRecoverAllocID() { re := suite.Require() url := fmt.Sprintf("%s/admin/base-alloc-id", suite.urlPrefix) - suite.NoError(tu.CheckPostJSON(testDialClient, url, []byte("invalid json"), tu.Status(re, http.StatusBadRequest))) + re.NoError(tu.CheckPostJSON(testDialClient, url, []byte("invalid json"), tu.Status(re, http.StatusBadRequest))) // no id or invalid id - suite.NoError(tu.CheckPostJSON(testDialClient, url, []byte(`{}`), + re.NoError(tu.CheckPostJSON(testDialClient, url, []byte(`{}`), tu.Status(re, http.StatusBadRequest), tu.StringContain(re, "invalid id value"))) - suite.NoError(tu.CheckPostJSON(testDialClient, url, []byte(`{"id": ""}`), + re.NoError(tu.CheckPostJSON(testDialClient, url, []byte(`{"id": ""}`), tu.Status(re, http.StatusBadRequest), tu.StringContain(re, "invalid id value"))) - suite.NoError(tu.CheckPostJSON(testDialClient, url, []byte(`{"id": 11}`), + re.NoError(tu.CheckPostJSON(testDialClient, url, []byte(`{"id": 11}`), tu.Status(re, http.StatusBadRequest), tu.StringContain(re, "invalid id value"))) - suite.NoError(tu.CheckPostJSON(testDialClient, url, []byte(`{"id": "aa"}`), + re.NoError(tu.CheckPostJSON(testDialClient, url, []byte(`{"id": "aa"}`), tu.Status(re, http.StatusBadRequest), tu.StringContain(re, "invalid syntax"))) // snapshot recovering=false - suite.NoError(tu.CheckPostJSON(testDialClient, url, []byte(`{"id": "100000"}`), + re.NoError(tu.CheckPostJSON(testDialClient, url, []byte(`{"id": "100000"}`), tu.Status(re, http.StatusForbidden), tu.StringContain(re, "can only recover alloc id when recovering"))) // mark and recover alloc id markRecoveringURL := fmt.Sprintf("%s/admin/cluster/markers/snapshot-recovering", suite.urlPrefix) - suite.NoError(tu.CheckPostJSON(testDialClient, markRecoveringURL, nil, + re.NoError(tu.CheckPostJSON(testDialClient, markRecoveringURL, nil, tu.StatusOK(re))) - suite.NoError(tu.CheckPostJSON(testDialClient, url, []byte(`{"id": "1000000"}`), + re.NoError(tu.CheckPostJSON(testDialClient, url, []byte(`{"id": "1000000"}`), tu.StatusOK(re))) id, err2 := suite.svr.GetAllocator().Alloc() - suite.NoError(err2) - suite.Equal(id, uint64(1000001)) + re.NoError(err2) + re.Equal(uint64(1000001), id) // recover alloc id again - suite.NoError(tu.CheckPostJSON(testDialClient, url, []byte(`{"id": "99000000"}`), + re.NoError(tu.CheckPostJSON(testDialClient, url, []byte(`{"id": "99000000"}`), tu.StatusOK(re))) id, err2 = suite.svr.GetAllocator().Alloc() - suite.NoError(err2) - suite.Equal(id, uint64(99000001)) + re.NoError(err2) + re.Equal(uint64(99000001), id) // unmark err := tu.CheckDelete(testDialClient, markRecoveringURL, tu.StatusOK(re)) - suite.NoError(err) - suite.NoError(tu.CheckGetJSON(testDialClient, markRecoveringURL, nil, + re.NoError(err) + re.NoError(tu.CheckGetJSON(testDialClient, markRecoveringURL, nil, tu.StatusOK(re), tu.StringContain(re, "false"))) - suite.NoError(tu.CheckPostJSON(testDialClient, url, []byte(`{"id": "100000"}`), + re.NoError(tu.CheckPostJSON(testDialClient, url, []byte(`{"id": "100000"}`), tu.Status(re, http.StatusForbidden), tu.StringContain(re, "can only recover alloc id when recovering"))) } diff --git a/server/api/region_test.go b/server/api/region_test.go index ea2f2871a95..7e48c80d7bc 100644 --- a/server/api/region_test.go +++ b/server/api/region_test.go @@ -224,7 +224,7 @@ func (suite *regionTestSuite) TestRegionCheck() { func (suite *regionTestSuite) TestRegions() { r := NewAPIRegionInfo(core.NewRegionInfo(&metapb.Region{Id: 1}, nil)) suite.Nil(r.Leader.Peer) - suite.Len(r.Leader.RoleName, 0) + suite.Empty(r.Leader.RoleName) rs := []*core.RegionInfo{ core.NewTestRegionInfo(2, 1, []byte("a"), []byte("b"), core.SetApproximateKeys(10), core.SetApproximateSize(10)), diff --git a/server/cluster/cluster_test.go b/server/cluster/cluster_test.go index 8c889923ea7..d5931394c1b 100644 --- a/server/cluster/cluster_test.go +++ b/server/cluster/cluster_test.go @@ -179,10 +179,10 @@ func TestStoreHeartbeat(t *testing.T) { time.Sleep(20 * time.Millisecond) storeStats = cluster.hotStat.RegionStats(utils.Read, 0) re.Empty(storeStats[1]) - re.Nil(cluster.HandleStoreHeartbeat(hotReq, hotResp)) + re.NoError(cluster.HandleStoreHeartbeat(hotReq, hotResp)) time.Sleep(20 * time.Millisecond) storeStats = cluster.hotStat.RegionStats(utils.Read, 1) - re.Len(storeStats[1], 0) + re.Empty(storeStats[1]) storeStats = cluster.hotStat.RegionStats(utils.Read, 3) re.Empty(storeStats[1]) // after 2 hot heartbeats, wo can find region 1 peer again @@ -2239,7 +2239,7 @@ func checkRegions(re *require.Assertions, cache *core.BasicCluster, regions []*c } } - re.Equal(len(regions), cache.GetTotalRegionCount()) + re.Len(regions, cache.GetTotalRegionCount()) for id, count := range regionCount { re.Equal(count, cache.GetStoreRegionCount(id)) } @@ -2744,7 +2744,7 @@ func TestMergeRegionCancelOneOperator(t *testing.T) { re.Len(ops, co.GetOperatorController().AddWaitingOperator(ops...)) // Cancel source operator. co.GetOperatorController().RemoveOperator(co.GetOperatorController().GetOperator(source.GetID())) - re.Len(co.GetOperatorController().GetOperators(), 0) + re.Empty(co.GetOperatorController().GetOperators()) // Cancel target region. ops, err = operator.CreateMergeRegionOperator("merge-region", tc, source, target, operator.OpMerge) @@ -2752,7 +2752,7 @@ func TestMergeRegionCancelOneOperator(t *testing.T) { re.Len(ops, co.GetOperatorController().AddWaitingOperator(ops...)) // Cancel target operator. co.GetOperatorController().RemoveOperator(co.GetOperatorController().GetOperator(target.GetID())) - re.Len(co.GetOperatorController().GetOperators(), 0) + re.Empty(co.GetOperatorController().GetOperators()) } func TestReplica(t *testing.T) { @@ -3047,8 +3047,8 @@ func TestAddScheduler(t *testing.T) { re.Equal(4, int(batch)) gls, err := schedulers.CreateScheduler(schedulers.GrantLeaderType, oc, storage.NewStorageWithMemoryBackend(), schedulers.ConfigSliceDecoder(schedulers.GrantLeaderType, []string{"0"}), controller.RemoveScheduler) re.NoError(err) - re.NotNil(controller.AddScheduler(gls)) - re.NotNil(controller.RemoveScheduler(gls.GetName())) + re.Error(controller.AddScheduler(gls)) + re.Error(controller.RemoveScheduler(gls.GetName())) gls, err = schedulers.CreateScheduler(schedulers.GrantLeaderType, oc, storage.NewStorageWithMemoryBackend(), schedulers.ConfigSliceDecoder(schedulers.GrantLeaderType, []string{"1"}), controller.RemoveScheduler) re.NoError(err) @@ -3445,7 +3445,7 @@ func TestStoreOverloaded(t *testing.T) { time.Sleep(time.Second) for i := 0; i < 100; i++ { ops, _ := lb.Schedule(tc, false /* dryRun */) - re.Greater(len(ops), 0) + re.NotEmpty(ops) } } @@ -3480,7 +3480,7 @@ func TestStoreOverloadedWithReplace(t *testing.T) { // sleep 2 seconds to make sure that token is filled up time.Sleep(2 * time.Second) ops, _ = lb.Schedule(tc, false /* dryRun */) - re.Greater(len(ops), 0) + re.NotEmpty(ops) } func TestDownStoreLimit(t *testing.T) { diff --git a/server/config/config_test.go b/server/config/config_test.go index 07cdc966409..69cfafd8d36 100644 --- a/server/config/config_test.go +++ b/server/config/config_test.go @@ -492,7 +492,7 @@ func TestRateLimitClone(t *testing.T) { ConcurrencyLimit: 200, } dc := cfg.LimiterConfig["test"] - re.Equal(dc.ConcurrencyLimit, uint64(0)) + re.Zero(dc.ConcurrencyLimit) gCfg := &GRPCRateLimitConfig{ EnableRateLimit: defaultEnableGRPCRateLimitMiddleware, @@ -503,5 +503,5 @@ func TestRateLimitClone(t *testing.T) { ConcurrencyLimit: 300, } gdc := gCfg.LimiterConfig["test"] - re.Equal(gdc.ConcurrencyLimit, uint64(0)) + re.Zero(gdc.ConcurrencyLimit) } diff --git a/tests/pdctl/hot/hot_test.go b/tests/pdctl/hot/hot_test.go index 03c26f40441..366887e19aa 100644 --- a/tests/pdctl/hot/hot_test.go +++ b/tests/pdctl/hot/hot_test.go @@ -368,11 +368,11 @@ func (suite *hotTestSuite) checkHotWithoutHotPeer(cluster *tests.TestCluster) { re.NoError(err) re.NoError(json.Unmarshal(output, &hotRegion)) re.NotNil(hotRegion.AsPeer[1]) - re.Equal(hotRegion.AsPeer[1].Count, 0) - re.Equal(0.0, hotRegion.AsPeer[1].TotalBytesRate) + re.Zero(hotRegion.AsPeer[1].Count) + re.Zero(hotRegion.AsPeer[1].TotalBytesRate) re.Equal(load, hotRegion.AsPeer[1].StoreByteRate) - re.Equal(hotRegion.AsLeader[1].Count, 0) - re.Equal(0.0, hotRegion.AsLeader[1].TotalBytesRate) + re.Zero(hotRegion.AsLeader[1].Count) + re.Zero(hotRegion.AsLeader[1].TotalBytesRate) re.Equal(load, hotRegion.AsLeader[1].StoreByteRate) } { @@ -381,12 +381,12 @@ func (suite *hotTestSuite) checkHotWithoutHotPeer(cluster *tests.TestCluster) { hotRegion := statistics.StoreHotPeersInfos{} re.NoError(err) re.NoError(json.Unmarshal(output, &hotRegion)) - re.Equal(0, hotRegion.AsPeer[1].Count) - re.Equal(0.0, hotRegion.AsPeer[1].TotalBytesRate) + re.Zero(hotRegion.AsPeer[1].Count) + re.Zero(hotRegion.AsPeer[1].TotalBytesRate) re.Equal(load, hotRegion.AsPeer[1].StoreByteRate) - re.Equal(0, hotRegion.AsLeader[1].Count) - re.Equal(0.0, hotRegion.AsLeader[1].TotalBytesRate) - re.Equal(0.0, hotRegion.AsLeader[1].StoreByteRate) // write leader sum + re.Zero(hotRegion.AsLeader[1].Count) + re.Zero(hotRegion.AsLeader[1].TotalBytesRate) + re.Zero(hotRegion.AsLeader[1].StoreByteRate) // write leader sum } } diff --git a/tests/pdctl/keyspace/keyspace_group_test.go b/tests/pdctl/keyspace/keyspace_group_test.go index cbfdf1d099a..0de48a85c64 100644 --- a/tests/pdctl/keyspace/keyspace_group_test.go +++ b/tests/pdctl/keyspace/keyspace_group_test.go @@ -78,14 +78,14 @@ func TestKeyspaceGroup(t *testing.T) { err = json.Unmarshal(output, &keyspaceGroup) re.NoError(err) re.Equal(uint32(1), keyspaceGroup.ID) - re.Equal(keyspaceGroup.Keyspaces, []uint32{111}) + re.Equal([]uint32{111}, keyspaceGroup.Keyspaces) output, err = pdctl.ExecuteCommand(cmd, append(args, "2")...) re.NoError(err) keyspaceGroup = endpoint.KeyspaceGroup{} err = json.Unmarshal(output, &keyspaceGroup) re.NoError(err) re.Equal(uint32(2), keyspaceGroup.ID) - re.Equal(keyspaceGroup.Keyspaces, []uint32{222, 333}) + re.Equal([]uint32{222, 333}, keyspaceGroup.Keyspaces) } func TestSplitKeyspaceGroup(t *testing.T) { @@ -133,8 +133,8 @@ func TestSplitKeyspaceGroup(t *testing.T) { err = json.Unmarshal(output, &keyspaceGroups) re.NoError(err) re.Len(keyspaceGroups, 2) - re.Equal(keyspaceGroups[0].ID, uint32(0)) - re.Equal(keyspaceGroups[1].ID, uint32(1)) + re.Zero(keyspaceGroups[0].ID) + re.Equal(uint32(1), keyspaceGroups[1].ID) re.NoError(failpoint.Disable("github.com/tikv/pd/pkg/keyspace/acceleratedAllocNodes")) re.NoError(failpoint.Disable("github.com/tikv/pd/server/delayStartServerLoop")) @@ -448,7 +448,7 @@ func TestKeyspaceGroupState(t *testing.T) { var keyspaceGroups []*endpoint.KeyspaceGroup err = json.Unmarshal(output, &keyspaceGroups) re.NoError(err) - re.Len(keyspaceGroups, 0) + re.Empty(keyspaceGroups) testutil.Eventually(re, func() bool { args := []string{"-u", pdAddr, "keyspace-group", "split", "0", "2", "3"} output, err := pdctl.ExecuteCommand(cmd, args...) @@ -462,8 +462,8 @@ func TestKeyspaceGroupState(t *testing.T) { err = json.Unmarshal(output, &keyspaceGroups) re.NoError(err) re.Len(keyspaceGroups, 2) - re.Equal(keyspaceGroups[0].ID, uint32(0)) - re.Equal(keyspaceGroups[1].ID, uint32(2)) + re.Equal(uint32(0), keyspaceGroups[0].ID) + re.Equal(uint32(2), keyspaceGroups[1].ID) args = []string{"-u", pdAddr, "keyspace-group", "finish-split", "2"} output, err = pdctl.ExecuteCommand(cmd, args...) @@ -486,7 +486,7 @@ func TestKeyspaceGroupState(t *testing.T) { err = json.Unmarshal(output, &keyspaceGroups) re.NoError(err) re.Len(keyspaceGroups, 1) - re.Equal(keyspaceGroups[0].ID, uint32(0)) + re.Equal(uint32(0), keyspaceGroups[0].ID) re.NoError(failpoint.Disable("github.com/tikv/pd/pkg/keyspace/acceleratedAllocNodes")) re.NoError(failpoint.Disable("github.com/tikv/pd/server/delayStartServerLoop")) diff --git a/tests/pdctl/keyspace/keyspace_test.go b/tests/pdctl/keyspace/keyspace_test.go index 3ff755fe601..f83d09760a9 100644 --- a/tests/pdctl/keyspace/keyspace_test.go +++ b/tests/pdctl/keyspace/keyspace_test.go @@ -147,22 +147,24 @@ func TestKeyspaceTestSuite(t *testing.T) { } func (suite *keyspaceTestSuite) SetupTest() { + re := suite.Require() suite.ctx, suite.cancel = context.WithCancel(context.Background()) - suite.NoError(failpoint.Enable("github.com/tikv/pd/server/delayStartServerLoop", `return(true)`)) - suite.NoError(failpoint.Enable("github.com/tikv/pd/pkg/keyspace/skipSplitRegion", "return(true)")) + re.NoError(failpoint.Enable("github.com/tikv/pd/server/delayStartServerLoop", `return(true)`)) + re.NoError(failpoint.Enable("github.com/tikv/pd/pkg/keyspace/skipSplitRegion", "return(true)")) tc, err := tests.NewTestAPICluster(suite.ctx, 1) - suite.NoError(err) - suite.NoError(tc.RunInitialServers()) + re.NoError(err) + re.NoError(tc.RunInitialServers()) tc.WaitLeader() leaderServer := tc.GetLeaderServer() - suite.NoError(leaderServer.BootstrapCluster()) + re.NoError(leaderServer.BootstrapCluster()) suite.cluster = tc suite.pdAddr = tc.GetConfig().GetClientURL() } func (suite *keyspaceTestSuite) TearDownTest() { - suite.NoError(failpoint.Disable("github.com/tikv/pd/server/delayStartServerLoop")) - suite.NoError(failpoint.Disable("github.com/tikv/pd/pkg/keyspace/skipSplitRegion")) + re := suite.Require() + re.NoError(failpoint.Disable("github.com/tikv/pd/server/delayStartServerLoop")) + re.NoError(failpoint.Disable("github.com/tikv/pd/pkg/keyspace/skipSplitRegion")) suite.cancel() } diff --git a/tests/pdctl/log/log_test.go b/tests/pdctl/log/log_test.go index e6995231329..08df4a78bea 100644 --- a/tests/pdctl/log/log_test.go +++ b/tests/pdctl/log/log_test.go @@ -39,11 +39,12 @@ func TestLogTestSuite(t *testing.T) { } func (suite *logTestSuite) SetupSuite() { + re := suite.Require() suite.ctx, suite.cancel = context.WithCancel(context.Background()) var err error suite.cluster, err = tests.NewTestCluster(suite.ctx, 3) - suite.NoError(err) - suite.NoError(suite.cluster.RunInitialServers()) + re.NoError(err) + re.NoError(suite.cluster.RunInitialServers()) suite.cluster.WaitLeader() suite.pdAddrs = suite.cluster.GetConfig().GetClientURLs() @@ -53,7 +54,7 @@ func (suite *logTestSuite) SetupSuite() { LastHeartbeat: time.Now().UnixNano(), } leaderServer := suite.cluster.GetLeaderServer() - suite.NoError(leaderServer.BootstrapCluster()) + re.NoError(leaderServer.BootstrapCluster()) tests.MustPutStore(suite.Require(), suite.cluster, store) } diff --git a/tests/pdctl/operator/operator_test.go b/tests/pdctl/operator/operator_test.go index aa2fe5d1304..1de61dca880 100644 --- a/tests/pdctl/operator/operator_test.go +++ b/tests/pdctl/operator/operator_test.go @@ -107,7 +107,7 @@ func (suite *operatorTestSuite) checkOperator(cluster *tests.TestCluster) { output, err := pdctl.ExecuteCommand(cmd, args...) re.NoError(err) re.NoError(json.Unmarshal(output, &slice)) - re.Len(slice, 0) + re.Empty(slice) args = []string{"-u", pdAddr, "operator", "check", "2"} output, err = pdctl.ExecuteCommand(cmd, args...) re.NoError(err) diff --git a/tests/pdctl/resourcemanager/resource_manager_command_test.go b/tests/pdctl/resourcemanager/resource_manager_command_test.go index ad43e0abca9..cbd9b481869 100644 --- a/tests/pdctl/resourcemanager/resource_manager_command_test.go +++ b/tests/pdctl/resourcemanager/resource_manager_command_test.go @@ -41,9 +41,10 @@ type testResourceManagerSuite struct { } func (s *testResourceManagerSuite) SetupSuite() { + re := s.Require() s.ctx, s.cancel = context.WithCancel(context.Background()) cluster, err := tests.NewTestCluster(s.ctx, 1) - s.Nil(err) + re.NoError(err) s.cluster = cluster s.cluster.RunInitialServers() cluster.WaitLeader() @@ -56,18 +57,19 @@ func (s *testResourceManagerSuite) TearDownSuite() { } func (s *testResourceManagerSuite) TestConfigController() { + re := s.Require() expectCfg := server.ControllerConfig{} expectCfg.Adjust(nil) // Show controller config checkShow := func() { args := []string{"-u", s.pdAddr, "resource-manager", "config", "controller", "show"} output, err := pdctl.ExecuteCommand(pdctlCmd.GetRootCmd(), args...) - s.Nil(err) + re.NoError(err) actualCfg := server.ControllerConfig{} err = json.Unmarshal(output, &actualCfg) - s.Nil(err) - s.Equal(expectCfg, actualCfg) + re.NoError(err) + re.Equal(expectCfg, actualCfg) } // Check default config @@ -76,22 +78,22 @@ func (s *testResourceManagerSuite) TestConfigController() { // Set controller config args := []string{"-u", s.pdAddr, "resource-manager", "config", "controller", "set", "ltb-max-wait-duration", "1h"} output, err := pdctl.ExecuteCommand(pdctlCmd.GetRootCmd(), args...) - s.Nil(err) - s.Contains(string(output), "Success!") + re.NoError(err) + re.Contains(string(output), "Success!") expectCfg.LTBMaxWaitDuration = typeutil.Duration{Duration: 1 * time.Hour} checkShow() args = []string{"-u", s.pdAddr, "resource-manager", "config", "controller", "set", "enable-controller-trace-log", "true"} output, err = pdctl.ExecuteCommand(pdctlCmd.GetRootCmd(), args...) - s.Nil(err) - s.Contains(string(output), "Success!") + re.NoError(err) + re.Contains(string(output), "Success!") expectCfg.EnableControllerTraceLog = true checkShow() args = []string{"-u", s.pdAddr, "resource-manager", "config", "controller", "set", "write-base-cost", "2"} output, err = pdctl.ExecuteCommand(pdctlCmd.GetRootCmd(), args...) - s.Nil(err) - s.Contains(string(output), "Success!") + re.NoError(err) + re.Contains(string(output), "Success!") expectCfg.RequestUnit.WriteBaseCost = 2 checkShow() } diff --git a/tests/pdctl/scheduler/scheduler_test.go b/tests/pdctl/scheduler/scheduler_test.go index 140ee7a7c44..585a5fd1199 100644 --- a/tests/pdctl/scheduler/scheduler_test.go +++ b/tests/pdctl/scheduler/scheduler_test.go @@ -48,7 +48,8 @@ func TestSchedulerTestSuite(t *testing.T) { } func (suite *schedulerTestSuite) SetupSuite() { - suite.NoError(failpoint.Enable("github.com/tikv/pd/server/cluster/skipStoreConfigSync", `return(true)`)) + re := suite.Require() + re.NoError(failpoint.Enable("github.com/tikv/pd/server/cluster/skipStoreConfigSync", `return(true)`)) suite.env = tests.NewSchedulingTestEnvironment(suite.T()) suite.defaultSchedulers = []string{ "balance-leader-scheduler", @@ -61,8 +62,9 @@ func (suite *schedulerTestSuite) SetupSuite() { } func (suite *schedulerTestSuite) TearDownSuite() { + re := suite.Require() suite.env.Cleanup() - suite.NoError(failpoint.Disable("github.com/tikv/pd/server/cluster/skipStoreConfigSync")) + re.NoError(failpoint.Disable("github.com/tikv/pd/server/cluster/skipStoreConfigSync")) } func (suite *schedulerTestSuite) TearDownTest() { diff --git a/tests/server/api/api_test.go b/tests/server/api/api_test.go index f5db6bb2513..946e65bc6e4 100644 --- a/tests/server/api/api_test.go +++ b/tests/server/api/api_test.go @@ -218,7 +218,7 @@ func (suite *middlewareTestSuite) TestRateLimitMiddleware() { resp, err := dialClient.Do(req) suite.NoError(err) resp.Body.Close() - suite.Equal(leader.GetServer().GetServiceMiddlewarePersistOptions().IsRateLimitEnabled(), true) + suite.True(leader.GetServer().GetServiceMiddlewarePersistOptions().IsRateLimitEnabled()) // returns StatusOK when no rate-limit config req, _ = http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/admin/log", strings.NewReader("\"info\"")) @@ -227,7 +227,7 @@ func (suite *middlewareTestSuite) TestRateLimitMiddleware() { _, err = io.ReadAll(resp.Body) resp.Body.Close() suite.NoError(err) - suite.Equal(resp.StatusCode, http.StatusOK) + suite.Equal(http.StatusOK, resp.StatusCode) input = make(map[string]interface{}) input["type"] = "label" input["label"] = "SetLogLevel" @@ -241,7 +241,7 @@ func (suite *middlewareTestSuite) TestRateLimitMiddleware() { _, err = io.ReadAll(resp.Body) resp.Body.Close() suite.NoError(err) - suite.Equal(resp.StatusCode, http.StatusOK) + suite.Equal(http.StatusOK, resp.StatusCode) for i := 0; i < 3; i++ { req, _ = http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/admin/log", strings.NewReader("\"info\"")) @@ -251,10 +251,10 @@ func (suite *middlewareTestSuite) TestRateLimitMiddleware() { resp.Body.Close() suite.NoError(err) if i > 0 { - suite.Equal(resp.StatusCode, http.StatusTooManyRequests) + suite.Equal(http.StatusTooManyRequests, resp.StatusCode) suite.Equal(string(data), fmt.Sprintf("%s\n", http.StatusText(http.StatusTooManyRequests))) } else { - suite.Equal(resp.StatusCode, http.StatusOK) + suite.Equal(http.StatusOK, resp.StatusCode) } } @@ -268,10 +268,10 @@ func (suite *middlewareTestSuite) TestRateLimitMiddleware() { resp.Body.Close() suite.NoError(err) if i > 0 { - suite.Equal(resp.StatusCode, http.StatusTooManyRequests) + suite.Equal(http.StatusTooManyRequests, resp.StatusCode) suite.Equal(string(data), fmt.Sprintf("%s\n", http.StatusText(http.StatusTooManyRequests))) } else { - suite.Equal(resp.StatusCode, http.StatusOK) + suite.Equal(http.StatusOK, resp.StatusCode) } } @@ -284,7 +284,7 @@ func (suite *middlewareTestSuite) TestRateLimitMiddleware() { data, err := io.ReadAll(resp.Body) resp.Body.Close() suite.NoError(err) - suite.Equal(resp.StatusCode, http.StatusTooManyRequests) + suite.Equal(http.StatusTooManyRequests, resp.StatusCode) suite.Equal(string(data), fmt.Sprintf("%s\n", http.StatusText(http.StatusTooManyRequests))) } @@ -297,12 +297,12 @@ func (suite *middlewareTestSuite) TestRateLimitMiddleware() { } server.MustWaitLeader(suite.Require(), servers) leader = suite.cluster.GetLeaderServer() - suite.Equal(leader.GetServer().GetServiceMiddlewarePersistOptions().IsRateLimitEnabled(), true) + suite.True(leader.GetServer().GetServiceMiddlewarePersistOptions().IsRateLimitEnabled()) cfg, ok := leader.GetServer().GetRateLimitConfig().LimiterConfig["SetLogLevel"] - suite.Equal(ok, true) - suite.Equal(cfg.ConcurrencyLimit, uint64(1)) - suite.Equal(cfg.QPS, 0.5) - suite.Equal(cfg.QPSBurst, 1) + suite.True(ok) + suite.Equal(uint64(1), cfg.ConcurrencyLimit) + suite.Equal(0.5, cfg.QPS) + suite.Equal(1, cfg.QPSBurst) for i := 0; i < 3; i++ { req, _ = http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/admin/log", strings.NewReader("\"info\"")) @@ -312,10 +312,10 @@ func (suite *middlewareTestSuite) TestRateLimitMiddleware() { resp.Body.Close() suite.NoError(err) if i > 0 { - suite.Equal(resp.StatusCode, http.StatusTooManyRequests) + suite.Equal(http.StatusTooManyRequests, resp.StatusCode) suite.Equal(string(data), fmt.Sprintf("%s\n", http.StatusText(http.StatusTooManyRequests))) } else { - suite.Equal(resp.StatusCode, http.StatusOK) + suite.Equal(http.StatusOK, resp.StatusCode) } } @@ -329,10 +329,10 @@ func (suite *middlewareTestSuite) TestRateLimitMiddleware() { resp.Body.Close() suite.NoError(err) if i > 0 { - suite.Equal(resp.StatusCode, http.StatusTooManyRequests) + suite.Equal(http.StatusTooManyRequests, resp.StatusCode) suite.Equal(string(data), fmt.Sprintf("%s\n", http.StatusText(http.StatusTooManyRequests))) } else { - suite.Equal(resp.StatusCode, http.StatusOK) + suite.Equal(http.StatusOK, resp.StatusCode) } } @@ -345,7 +345,7 @@ func (suite *middlewareTestSuite) TestRateLimitMiddleware() { data, err := io.ReadAll(resp.Body) resp.Body.Close() suite.NoError(err) - suite.Equal(resp.StatusCode, http.StatusTooManyRequests) + suite.Equal(http.StatusTooManyRequests, resp.StatusCode) suite.Equal(string(data), fmt.Sprintf("%s\n", http.StatusText(http.StatusTooManyRequests))) } @@ -358,7 +358,7 @@ func (suite *middlewareTestSuite) TestRateLimitMiddleware() { resp, err = dialClient.Do(req) suite.NoError(err) resp.Body.Close() - suite.Equal(leader.GetServer().GetServiceMiddlewarePersistOptions().IsRateLimitEnabled(), false) + suite.False(leader.GetServer().GetServiceMiddlewarePersistOptions().IsRateLimitEnabled()) for i := 0; i < 3; i++ { req, _ = http.NewRequest(http.MethodPost, leader.GetAddr()+"/pd/api/v1/admin/log", strings.NewReader("\"info\"")) @@ -367,7 +367,7 @@ func (suite *middlewareTestSuite) TestRateLimitMiddleware() { _, err = io.ReadAll(resp.Body) resp.Body.Close() suite.NoError(err) - suite.Equal(resp.StatusCode, http.StatusOK) + suite.Equal(http.StatusOK, resp.StatusCode) } } @@ -377,7 +377,7 @@ func (suite *middlewareTestSuite) TestSwaggerUrl() { req, _ := http.NewRequest(http.MethodGet, leader.GetAddr()+"/swagger/ui/index", http.NoBody) resp, err := dialClient.Do(req) suite.NoError(err) - suite.True(resp.StatusCode == http.StatusNotFound) + suite.Equal(http.StatusNotFound, resp.StatusCode) resp.Body.Close() } diff --git a/tests/server/api/rule_test.go b/tests/server/api/rule_test.go index 0a0c3f2fb2e..eaa41cc11bc 100644 --- a/tests/server/api/rule_test.go +++ b/tests/server/api/rule_test.go @@ -256,7 +256,7 @@ func (suite *ruleTestSuite) checkGetAll(cluster *tests.TestCluster) { var resp2 []*placement.Rule err = tu.ReadGetJSON(re, testDialClient, urlPrefix+"/rules", &resp2) suite.NoError(err) - suite.GreaterOrEqual(len(resp2), 1) + suite.NotEmpty(resp2) } func (suite *ruleTestSuite) TestSetAll() { @@ -1039,40 +1039,40 @@ func (suite *regionRuleTestSuite) checkRegionPlacementRule(cluster *tests.TestCl u := fmt.Sprintf("%s/config/rules/region/%d/detail", urlPrefix, 1) err := tu.ReadGetJSON(re, testDialClient, u, fit) suite.NoError(err) - suite.Equal(len(fit.RuleFits), 1) - suite.Equal(len(fit.OrphanPeers), 1) + suite.Len(fit.RuleFits, 1) + suite.Len(fit.OrphanPeers, 1) u = fmt.Sprintf("%s/config/rules/region/%d/detail", urlPrefix, 2) fit = &placement.RegionFit{} err = tu.ReadGetJSON(re, testDialClient, u, fit) suite.NoError(err) - suite.Equal(len(fit.RuleFits), 2) - suite.Equal(len(fit.OrphanPeers), 0) + suite.Len(fit.RuleFits, 2) + suite.Empty(fit.OrphanPeers) u = fmt.Sprintf("%s/config/rules/region/%d/detail", urlPrefix, 3) fit = &placement.RegionFit{} err = tu.ReadGetJSON(re, testDialClient, u, fit) suite.NoError(err) - suite.Equal(len(fit.RuleFits), 0) - suite.Equal(len(fit.OrphanPeers), 2) + suite.Empty(fit.RuleFits) + suite.Len(fit.OrphanPeers, 2) var label labeler.LabelRule escapedID := url.PathEscape("keyspaces/0") u = fmt.Sprintf("%s/config/region-label/rule/%s", urlPrefix, escapedID) err = tu.ReadGetJSON(re, testDialClient, u, &label) suite.NoError(err) - suite.Equal(label.ID, "keyspaces/0") + suite.Equal("keyspaces/0", label.ID) var labels []labeler.LabelRule u = fmt.Sprintf("%s/config/region-label/rules", urlPrefix) err = tu.ReadGetJSON(re, testDialClient, u, &labels) suite.NoError(err) suite.Len(labels, 1) - suite.Equal(labels[0].ID, "keyspaces/0") + suite.Equal("keyspaces/0", labels[0].ID) u = fmt.Sprintf("%s/config/region-label/rules/ids", urlPrefix) err = tu.CheckGetJSON(testDialClient, u, []byte(`["rule1", "rule3"]`), func(resp []byte, statusCode int, _ http.Header) { err := json.Unmarshal(resp, &labels) suite.NoError(err) - suite.Len(labels, 0) + suite.Empty(labels) }) suite.NoError(err) @@ -1080,7 +1080,7 @@ func (suite *regionRuleTestSuite) checkRegionPlacementRule(cluster *tests.TestCl err := json.Unmarshal(resp, &labels) suite.NoError(err) suite.Len(labels, 1) - suite.Equal(labels[0].ID, "keyspaces/0") + suite.Equal("keyspaces/0", labels[0].ID) }) suite.NoError(err) diff --git a/tests/server/apiv2/handlers/keyspace_test.go b/tests/server/apiv2/handlers/keyspace_test.go index f7b43ab194d..535f01cc33e 100644 --- a/tests/server/apiv2/handlers/keyspace_test.go +++ b/tests/server/apiv2/handlers/keyspace_test.go @@ -46,22 +46,24 @@ func TestKeyspaceTestSuite(t *testing.T) { } func (suite *keyspaceTestSuite) SetupTest() { + re := suite.Require() ctx, cancel := context.WithCancel(context.Background()) suite.cleanup = cancel cluster, err := tests.NewTestCluster(ctx, 1) suite.cluster = cluster - suite.NoError(err) - suite.NoError(cluster.RunInitialServers()) - suite.NotEmpty(cluster.WaitLeader()) + re.NoError(err) + re.NoError(cluster.RunInitialServers()) + re.NotEmpty(cluster.WaitLeader()) suite.server = cluster.GetLeaderServer() - suite.NoError(suite.server.BootstrapCluster()) - suite.NoError(failpoint.Enable("github.com/tikv/pd/pkg/keyspace/skipSplitRegion", "return(true)")) + re.NoError(suite.server.BootstrapCluster()) + re.NoError(failpoint.Enable("github.com/tikv/pd/pkg/keyspace/skipSplitRegion", "return(true)")) } func (suite *keyspaceTestSuite) TearDownTest() { + re := suite.Require() suite.cleanup() suite.cluster.Destroy() - suite.NoError(failpoint.Disable("github.com/tikv/pd/pkg/keyspace/skipSplitRegion")) + re.NoError(failpoint.Disable("github.com/tikv/pd/pkg/keyspace/skipSplitRegion")) } func (suite *keyspaceTestSuite) TestCreateLoadKeyspace() { @@ -133,7 +135,7 @@ func (suite *keyspaceTestSuite) TestLoadRangeKeyspace() { loadResponse := sendLoadRangeRequest(re, suite.server, "", "") re.Empty(loadResponse.NextPageToken) // Load response should contain no more pages. // Load response should contain all created keyspace and a default. - re.Equal(len(keyspaces)+1, len(loadResponse.Keyspaces)) + re.Len(loadResponse.Keyspaces, len(keyspaces)+1) for i, created := range keyspaces { re.Equal(created, loadResponse.Keyspaces[i+1].KeyspaceMeta) } diff --git a/tests/server/apiv2/handlers/tso_keyspace_group_test.go b/tests/server/apiv2/handlers/tso_keyspace_group_test.go index 214de6e95ef..2bf2db715fa 100644 --- a/tests/server/apiv2/handlers/tso_keyspace_group_test.go +++ b/tests/server/apiv2/handlers/tso_keyspace_group_test.go @@ -39,14 +39,15 @@ func TestKeyspaceGroupTestSuite(t *testing.T) { } func (suite *keyspaceGroupTestSuite) SetupTest() { + re := suite.Require() suite.ctx, suite.cancel = context.WithCancel(context.Background()) cluster, err := tests.NewTestAPICluster(suite.ctx, 1) suite.cluster = cluster - suite.NoError(err) - suite.NoError(cluster.RunInitialServers()) - suite.NotEmpty(cluster.WaitLeader()) + re.NoError(err) + re.NoError(cluster.RunInitialServers()) + re.NotEmpty(cluster.WaitLeader()) suite.server = cluster.GetLeaderServer() - suite.NoError(suite.server.BootstrapCluster()) + re.NoError(suite.server.BootstrapCluster()) } func (suite *keyspaceGroupTestSuite) TearDownTest() { diff --git a/tests/server/cluster/cluster_test.go b/tests/server/cluster/cluster_test.go index 0b0779d9434..67c798d7f69 100644 --- a/tests/server/cluster/cluster_test.go +++ b/tests/server/cluster/cluster_test.go @@ -1485,7 +1485,7 @@ func TestMinResolvedTS(t *testing.T) { } // default run job - re.NotEqual(rc.GetPDServerConfig().MinResolvedTSPersistenceInterval.Duration, 0) + re.NotZero(rc.GetPDServerConfig().MinResolvedTSPersistenceInterval.Duration) setMinResolvedTSPersistenceInterval(re, rc, svr, 0) re.Equal(time.Duration(0), rc.GetPDServerConfig().MinResolvedTSPersistenceInterval.Duration) diff --git a/tests/server/cluster/cluster_work_test.go b/tests/server/cluster/cluster_work_test.go index eabecf8e29b..f503563dbb1 100644 --- a/tests/server/cluster/cluster_work_test.go +++ b/tests/server/cluster/cluster_work_test.go @@ -16,7 +16,6 @@ package cluster_test import ( "context" - "errors" "sort" "testing" "time" @@ -112,9 +111,9 @@ func TestAskSplit(t *testing.T) { re.NoError(leaderServer.GetServer().SaveTTLConfig(map[string]interface{}{"schedule.enable-tikv-split-region": 0}, time.Minute)) _, err = rc.HandleAskSplit(req) - re.True(errors.Is(err, errs.ErrSchedulerTiKVSplitDisabled)) + re.ErrorIs(err, errs.ErrSchedulerTiKVSplitDisabled) _, err = rc.HandleAskBatchSplit(req1) - re.True(errors.Is(err, errs.ErrSchedulerTiKVSplitDisabled)) + re.ErrorIs(err, errs.ErrSchedulerTiKVSplitDisabled) re.NoError(leaderServer.GetServer().SaveTTLConfig(map[string]interface{}{"schedule.enable-tikv-split-region": 0}, 0)) // wait ttl config takes effect time.Sleep(time.Second) diff --git a/tests/server/keyspace/keyspace_test.go b/tests/server/keyspace/keyspace_test.go index 86b8f6fd37c..3ee15e1edc1 100644 --- a/tests/server/keyspace/keyspace_test.go +++ b/tests/server/keyspace/keyspace_test.go @@ -100,7 +100,7 @@ func checkLabelRule(re *require.Assertions, id uint32, regionLabeler *labeler.Re rangeRule, ok := loadedLabel.Data.([]*labeler.KeyRangeRule) re.True(ok) - re.Equal(2, len(rangeRule)) + re.Len(rangeRule, 2) keyspaceIDBytes := make([]byte, 4) nextKeyspaceIDBytes := make([]byte, 4) diff --git a/tools/pd-backup/pdbackup/backup_test.go b/tools/pd-backup/pdbackup/backup_test.go index 40e4190f5d4..b35bf1e8a70 100644 --- a/tools/pd-backup/pdbackup/backup_test.go +++ b/tools/pd-backup/pdbackup/backup_test.go @@ -99,6 +99,7 @@ func setupServer() (*httptest.Server, *config.Config) { } func (s *backupTestSuite) BeforeTest(suiteName, testName string) { + re := s.Require() ctx, cancel := context.WithTimeout(context.Background(), time.Second*3) defer cancel() @@ -106,21 +107,21 @@ func (s *backupTestSuite) BeforeTest(suiteName, testName string) { ctx, pdClusterIDPath, string(typeutil.Uint64ToBytes(clusterID))) - s.NoError(err) + re.NoError(err) var ( rootPath = path.Join(pdRootPath, strconv.FormatUint(clusterID, 10)) allocTimestampMaxBytes = typeutil.Uint64ToBytes(allocTimestampMax) ) _, err = s.etcdClient.Put(ctx, endpoint.TimestampPath(rootPath), string(allocTimestampMaxBytes)) - s.NoError(err) + re.NoError(err) var ( allocIDPath = path.Join(rootPath, "alloc_id") allocIDMaxBytes = typeutil.Uint64ToBytes(allocIDMax) ) _, err = s.etcdClient.Put(ctx, allocIDPath, string(allocIDMaxBytes)) - s.NoError(err) + re.NoError(err) } func (s *backupTestSuite) AfterTest(suiteName, testName string) { @@ -128,8 +129,9 @@ func (s *backupTestSuite) AfterTest(suiteName, testName string) { } func (s *backupTestSuite) TestGetBackupInfo() { + re := s.Require() actual, err := GetBackupInfo(s.etcdClient, s.server.URL) - s.NoError(err) + re.NoError(err) expected := &BackupInfo{ ClusterID: clusterID, @@ -137,22 +139,22 @@ func (s *backupTestSuite) TestGetBackupInfo() { AllocTimestampMax: allocTimestampMax, Config: s.serverConfig, } - s.Equal(expected, actual) + re.Equal(expected, actual) tmpFile, err := os.CreateTemp(os.TempDir(), "pd_backup_info_test.json") - s.NoError(err) + re.NoError(err) defer os.RemoveAll(tmpFile.Name()) - s.NoError(OutputToFile(actual, tmpFile)) + re.NoError(OutputToFile(actual, tmpFile)) _, err = tmpFile.Seek(0, 0) - s.NoError(err) + re.NoError(err) b, err := io.ReadAll(tmpFile) - s.NoError(err) + re.NoError(err) var restored BackupInfo err = json.Unmarshal(b, &restored) - s.NoError(err) + re.NoError(err) - s.Equal(expected, &restored) + re.Equal(expected, &restored) } diff --git a/tools/pd-simulator/simulator/simutil/key_test.go b/tools/pd-simulator/simulator/simutil/key_test.go index b34f1bb3809..6f71bd12d14 100644 --- a/tools/pd-simulator/simulator/simutil/key_test.go +++ b/tools/pd-simulator/simulator/simutil/key_test.go @@ -102,13 +102,13 @@ func TestGenerateKeys(t *testing.T) { numKeys := 10 actual := GenerateKeys(numKeys) - re.Equal(len(actual), numKeys) + re.Len(actual, numKeys) // make sure every key: // i. has length `keyLen` // ii. has only characters from `keyChars` for _, key := range actual { - re.Equal(len(key), keyLen) + re.Len(key, keyLen) for _, char := range key { re.True(strings.ContainsRune(keyChars, char)) } From cfbc9b96cdc23d4f7c723b0b4130f6612f9ed00a Mon Sep 17 00:00:00 2001 From: lhy1024 Date: Wed, 20 Dec 2023 12:01:52 +0800 Subject: [PATCH 20/21] mcs: watch rule change with txn (#7550) close tikv/pd#7418 Signed-off-by: lhy1024 Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- pkg/core/basic_cluster.go | 21 ++ pkg/mcs/scheduling/server/apis/v1/api.go | 2 +- pkg/mcs/scheduling/server/rule/watcher.go | 154 +++++---- pkg/schedule/placement/config.go | 32 +- pkg/schedule/placement/config_test.go | 40 +-- pkg/schedule/placement/rule_manager.go | 254 +++++++------- pkg/schedule/placement/rule_manager_test.go | 14 +- pkg/storage/endpoint/rule.go | 63 ++-- pkg/tso/keyspace_group_manager.go | 4 +- pkg/utils/etcdutil/etcdutil.go | 24 +- server/keyspace_service.go | 7 +- tests/server/api/region_test.go | 56 +++- tests/server/api/rule_test.go | 345 ++++++++++++++++++-- 13 files changed, 711 insertions(+), 305 deletions(-) diff --git a/pkg/core/basic_cluster.go b/pkg/core/basic_cluster.go index 2258a816324..d70b620db3b 100644 --- a/pkg/core/basic_cluster.go +++ b/pkg/core/basic_cluster.go @@ -309,3 +309,24 @@ func NewKeyRange(startKey, endKey string) KeyRange { EndKey: []byte(endKey), } } + +// KeyRanges is a slice of KeyRange. +type KeyRanges struct { + krs []*KeyRange +} + +// Append appends a KeyRange. +func (rs *KeyRanges) Append(startKey, endKey []byte) { + rs.krs = append(rs.krs, &KeyRange{ + StartKey: startKey, + EndKey: endKey, + }) +} + +// Ranges returns the slice of KeyRange. +func (rs *KeyRanges) Ranges() []*KeyRange { + if rs == nil { + return nil + } + return rs.krs +} diff --git a/pkg/mcs/scheduling/server/apis/v1/api.go b/pkg/mcs/scheduling/server/apis/v1/api.go index b59780b7a61..e6881f2f85c 100644 --- a/pkg/mcs/scheduling/server/apis/v1/api.go +++ b/pkg/mcs/scheduling/server/apis/v1/api.go @@ -1330,5 +1330,5 @@ func checkRegionsReplicated(c *gin.Context) { c.String(http.StatusBadRequest, err.Error()) return } - c.String(http.StatusOK, state) + c.IndentedJSON(http.StatusOK, state) } diff --git a/pkg/mcs/scheduling/server/rule/watcher.go b/pkg/mcs/scheduling/server/rule/watcher.go index 96e19cf5002..3e11cf9ff9d 100644 --- a/pkg/mcs/scheduling/server/rule/watcher.go +++ b/pkg/mcs/scheduling/server/rule/watcher.go @@ -20,6 +20,7 @@ import ( "sync" "github.com/pingcap/log" + "github.com/tikv/pd/pkg/core" "github.com/tikv/pd/pkg/schedule/checker" "github.com/tikv/pd/pkg/schedule/labeler" "github.com/tikv/pd/pkg/schedule/placement" @@ -36,6 +37,10 @@ type Watcher struct { cancel context.CancelFunc wg sync.WaitGroup + // ruleCommonPathPrefix: + // - Key: /pd/{cluster_id}/rule + // - Value: placement.Rule or placement.RuleGroup + ruleCommonPathPrefix string // rulesPathPrefix: // - Key: /pd/{cluster_id}/rules/{group_id}-{rule_id} // - Value: placement.Rule @@ -60,8 +65,10 @@ type Watcher struct { regionLabeler *labeler.RegionLabeler ruleWatcher *etcdutil.LoopWatcher - groupWatcher *etcdutil.LoopWatcher labelWatcher *etcdutil.LoopWatcher + + // patch is used to cache the placement rule changes. + patch *placement.RuleConfigPatch } // NewWatcher creates a new watcher to watch the Placement Rule change from PD API server. @@ -79,6 +86,7 @@ func NewWatcher( ctx: ctx, cancel: cancel, rulesPathPrefix: endpoint.RulesPathPrefix(clusterID), + ruleCommonPathPrefix: endpoint.RuleCommonPathPrefix(clusterID), ruleGroupPathPrefix: endpoint.RuleGroupPathPrefix(clusterID), regionLabelPathPrefix: endpoint.RegionLabelPathPrefix(clusterID), etcdClient: etcdClient, @@ -91,10 +99,6 @@ func NewWatcher( if err != nil { return nil, err } - err = rw.initializeGroupWatcher() - if err != nil { - return nil, err - } err = rw.initializeRegionLabelWatcher() if err != nil { return nil, err @@ -103,83 +107,109 @@ func NewWatcher( } func (rw *Watcher) initializeRuleWatcher() error { - prefixToTrim := rw.rulesPathPrefix + "/" + var suspectKeyRanges *core.KeyRanges + + preEventsFn := func(events []*clientv3.Event) error { + // It will be locked until the postFn is finished. + rw.ruleManager.Lock() + rw.patch = rw.ruleManager.BeginPatch() + suspectKeyRanges = &core.KeyRanges{} + return nil + } + putFn := func(kv *mvccpb.KeyValue) error { - log.Info("update placement rule", zap.String("key", string(kv.Key)), zap.String("value", string(kv.Value))) - rule, err := placement.NewRuleFromJSON(kv.Value) - if err != nil { - return err - } - // Update the suspect key ranges in the checker. - rw.checkerController.AddSuspectKeyRange(rule.StartKey, rule.EndKey) - if oldRule := rw.ruleManager.GetRule(rule.GroupID, rule.ID); oldRule != nil { - rw.checkerController.AddSuspectKeyRange(oldRule.StartKey, oldRule.EndKey) + key := string(kv.Key) + if strings.HasPrefix(key, rw.rulesPathPrefix) { + log.Info("update placement rule", zap.String("key", key), zap.String("value", string(kv.Value))) + rule, err := placement.NewRuleFromJSON(kv.Value) + if err != nil { + return err + } + // Try to add the rule change to the patch. + if err := rw.ruleManager.AdjustRule(rule, ""); err != nil { + return err + } + rw.patch.SetRule(rule) + // Update the suspect key ranges in lock. + suspectKeyRanges.Append(rule.StartKey, rule.EndKey) + if oldRule := rw.ruleManager.GetRuleLocked(rule.GroupID, rule.ID); oldRule != nil { + suspectKeyRanges.Append(oldRule.StartKey, oldRule.EndKey) + } + return nil + } else if strings.HasPrefix(key, rw.ruleGroupPathPrefix) { + log.Info("update placement rule group", zap.String("key", key), zap.String("value", string(kv.Value))) + ruleGroup, err := placement.NewRuleGroupFromJSON(kv.Value) + if err != nil { + return err + } + // Try to add the rule group change to the patch. + rw.patch.SetGroup(ruleGroup) + // Update the suspect key ranges + for _, rule := range rw.ruleManager.GetRulesByGroupLocked(ruleGroup.ID) { + suspectKeyRanges.Append(rule.StartKey, rule.EndKey) + } + return nil + } else { + log.Warn("unknown key when updating placement rule", zap.String("key", key)) + return nil } - return rw.ruleManager.SetRule(rule) } deleteFn := func(kv *mvccpb.KeyValue) error { key := string(kv.Key) - log.Info("delete placement rule", zap.String("key", key)) - ruleJSON, err := rw.ruleStorage.LoadRule(strings.TrimPrefix(key, prefixToTrim)) - if err != nil { + if strings.HasPrefix(key, rw.rulesPathPrefix) { + log.Info("delete placement rule", zap.String("key", key)) + ruleJSON, err := rw.ruleStorage.LoadRule(strings.TrimPrefix(key, rw.rulesPathPrefix+"/")) + if err != nil { + return err + } + rule, err := placement.NewRuleFromJSON([]byte(ruleJSON)) + if err != nil { + return err + } + // Try to add the rule change to the patch. + rw.patch.DeleteRule(rule.GroupID, rule.ID) + // Update the suspect key ranges + suspectKeyRanges.Append(rule.StartKey, rule.EndKey) return err + } else if strings.HasPrefix(key, rw.ruleGroupPathPrefix) { + log.Info("delete placement rule group", zap.String("key", key)) + trimmedKey := strings.TrimPrefix(key, rw.ruleGroupPathPrefix+"/") + // Try to add the rule group change to the patch. + rw.patch.DeleteGroup(trimmedKey) + // Update the suspect key ranges + for _, rule := range rw.ruleManager.GetRulesByGroupLocked(trimmedKey) { + suspectKeyRanges.Append(rule.StartKey, rule.EndKey) + } + return nil + } else { + log.Warn("unknown key when deleting placement rule", zap.String("key", key)) + return nil } - rule, err := placement.NewRuleFromJSON([]byte(ruleJSON)) - if err != nil { + } + postEventsFn := func(events []*clientv3.Event) error { + defer rw.ruleManager.Unlock() + if err := rw.ruleManager.TryCommitPatch(rw.patch); err != nil { + log.Error("failed to commit patch", zap.Error(err)) return err } - rw.checkerController.AddSuspectKeyRange(rule.StartKey, rule.EndKey) - return rw.ruleManager.DeleteRule(rule.GroupID, rule.ID) + for _, kr := range suspectKeyRanges.Ranges() { + rw.checkerController.AddSuspectKeyRange(kr.StartKey, kr.EndKey) + } + return nil } rw.ruleWatcher = etcdutil.NewLoopWatcher( rw.ctx, &rw.wg, rw.etcdClient, - "scheduling-rule-watcher", rw.rulesPathPrefix, - func([]*clientv3.Event) error { return nil }, + "scheduling-rule-watcher", rw.ruleCommonPathPrefix, + preEventsFn, putFn, deleteFn, - func([]*clientv3.Event) error { return nil }, + postEventsFn, clientv3.WithPrefix(), ) rw.ruleWatcher.StartWatchLoop() return rw.ruleWatcher.WaitLoad() } -func (rw *Watcher) initializeGroupWatcher() error { - prefixToTrim := rw.ruleGroupPathPrefix + "/" - putFn := func(kv *mvccpb.KeyValue) error { - log.Info("update placement rule group", zap.String("key", string(kv.Key)), zap.String("value", string(kv.Value))) - ruleGroup, err := placement.NewRuleGroupFromJSON(kv.Value) - if err != nil { - return err - } - // Add all rule key ranges within the group to the suspect key ranges. - for _, rule := range rw.ruleManager.GetRulesByGroup(ruleGroup.ID) { - rw.checkerController.AddSuspectKeyRange(rule.StartKey, rule.EndKey) - } - return rw.ruleManager.SetRuleGroup(ruleGroup) - } - deleteFn := func(kv *mvccpb.KeyValue) error { - key := string(kv.Key) - log.Info("delete placement rule group", zap.String("key", key)) - trimmedKey := strings.TrimPrefix(key, prefixToTrim) - for _, rule := range rw.ruleManager.GetRulesByGroup(trimmedKey) { - rw.checkerController.AddSuspectKeyRange(rule.StartKey, rule.EndKey) - } - return rw.ruleManager.DeleteRuleGroup(trimmedKey) - } - rw.groupWatcher = etcdutil.NewLoopWatcher( - rw.ctx, &rw.wg, - rw.etcdClient, - "scheduling-rule-group-watcher", rw.ruleGroupPathPrefix, - func([]*clientv3.Event) error { return nil }, - putFn, deleteFn, - func([]*clientv3.Event) error { return nil }, - clientv3.WithPrefix(), - ) - rw.groupWatcher.StartWatchLoop() - return rw.groupWatcher.WaitLoad() -} - func (rw *Watcher) initializeRegionLabelWatcher() error { prefixToTrim := rw.regionLabelPathPrefix + "/" putFn := func(kv *mvccpb.KeyValue) error { diff --git a/pkg/schedule/placement/config.go b/pkg/schedule/placement/config.go index 878db4b2e0a..00c0f94b94e 100644 --- a/pkg/schedule/placement/config.go +++ b/pkg/schedule/placement/config.go @@ -79,28 +79,30 @@ func (c *ruleConfig) getGroup(id string) *RuleGroup { return &RuleGroup{ID: id} } -func (c *ruleConfig) beginPatch() *ruleConfigPatch { - return &ruleConfigPatch{ +func (c *ruleConfig) beginPatch() *RuleConfigPatch { + return &RuleConfigPatch{ c: c, mut: newRuleConfig(), } } -// A helper data structure to update ruleConfig. -type ruleConfigPatch struct { +// RuleConfigPatch is a helper data structure to update ruleConfig. +type RuleConfigPatch struct { c *ruleConfig // original configuration to be updated mut *ruleConfig // record all to-commit rules and groups } -func (p *ruleConfigPatch) setRule(r *Rule) { +// SetRule sets a rule to the patch. +func (p *RuleConfigPatch) SetRule(r *Rule) { p.mut.rules[r.Key()] = r } -func (p *ruleConfigPatch) deleteRule(group, id string) { +// DeleteRule deletes a rule from the patch. +func (p *RuleConfigPatch) DeleteRule(group, id string) { p.mut.rules[[2]string{group, id}] = nil } -func (p *ruleConfigPatch) getGroup(id string) *RuleGroup { +func (p *RuleConfigPatch) getGroup(id string) *RuleGroup { if g, ok := p.mut.groups[id]; ok { return g } @@ -110,15 +112,17 @@ func (p *ruleConfigPatch) getGroup(id string) *RuleGroup { return &RuleGroup{ID: id} } -func (p *ruleConfigPatch) setGroup(g *RuleGroup) { +// SetGroup sets a group to the patch. +func (p *RuleConfigPatch) SetGroup(g *RuleGroup) { p.mut.groups[g.ID] = g } -func (p *ruleConfigPatch) deleteGroup(id string) { - p.setGroup(&RuleGroup{ID: id}) +// DeleteGroup deletes a group from the patch. +func (p *RuleConfigPatch) DeleteGroup(id string) { + p.SetGroup(&RuleGroup{ID: id}) } -func (p *ruleConfigPatch) iterateRules(f func(*Rule)) { +func (p *RuleConfigPatch) iterateRules(f func(*Rule)) { for _, r := range p.mut.rules { if r != nil { // nil means delete. f(r) @@ -131,13 +135,13 @@ func (p *ruleConfigPatch) iterateRules(f func(*Rule)) { } } -func (p *ruleConfigPatch) adjust() { +func (p *RuleConfigPatch) adjust() { // setup rule.group for `buildRuleList` use. p.iterateRules(func(r *Rule) { r.group = p.getGroup(r.GroupID) }) } // trim unnecessary updates. For example, remove a rule then insert the same rule. -func (p *ruleConfigPatch) trim() { +func (p *RuleConfigPatch) trim() { for key, rule := range p.mut.rules { if jsonEquals(rule, p.c.getRule(key)) { delete(p.mut.rules, key) @@ -151,7 +155,7 @@ func (p *ruleConfigPatch) trim() { } // merge all mutations to ruleConfig. -func (p *ruleConfigPatch) commit() { +func (p *RuleConfigPatch) commit() { for key, rule := range p.mut.rules { if rule == nil { delete(p.c.rules, key) diff --git a/pkg/schedule/placement/config_test.go b/pkg/schedule/placement/config_test.go index 8f7161a56d7..ccee8837331 100644 --- a/pkg/schedule/placement/config_test.go +++ b/pkg/schedule/placement/config_test.go @@ -30,40 +30,40 @@ func TestTrim(t *testing.T) { rc.setGroup(&RuleGroup{ID: "g2", Index: 2}) testCases := []struct { - ops func(p *ruleConfigPatch) + ops func(p *RuleConfigPatch) mutRules map[[2]string]*Rule mutGroups map[string]*RuleGroup }{ { - func(p *ruleConfigPatch) { - p.setRule(&Rule{GroupID: "g1", ID: "id1", Index: 100}) - p.setRule(&Rule{GroupID: "g1", ID: "id2"}) - p.setGroup(&RuleGroup{ID: "g1", Index: 100}) - p.setGroup(&RuleGroup{ID: "g2", Index: 2}) + func(p *RuleConfigPatch) { + p.SetRule(&Rule{GroupID: "g1", ID: "id1", Index: 100}) + p.SetRule(&Rule{GroupID: "g1", ID: "id2"}) + p.SetGroup(&RuleGroup{ID: "g1", Index: 100}) + p.SetGroup(&RuleGroup{ID: "g2", Index: 2}) }, map[[2]string]*Rule{{"g1", "id1"}: {GroupID: "g1", ID: "id1", Index: 100}}, map[string]*RuleGroup{"g1": {ID: "g1", Index: 100}}, }, { - func(p *ruleConfigPatch) { - p.deleteRule("g1", "id1") - p.deleteGroup("g2") - p.deleteRule("g3", "id3") - p.deleteGroup("g3") + func(p *RuleConfigPatch) { + p.DeleteRule("g1", "id1") + p.DeleteGroup("g2") + p.DeleteRule("g3", "id3") + p.DeleteGroup("g3") }, map[[2]string]*Rule{{"g1", "id1"}: nil}, map[string]*RuleGroup{"g2": {ID: "g2"}}, }, { - func(p *ruleConfigPatch) { - p.setRule(&Rule{GroupID: "g1", ID: "id2", Index: 200}) - p.setRule(&Rule{GroupID: "g1", ID: "id2"}) - p.setRule(&Rule{GroupID: "g3", ID: "id3"}) - p.deleteRule("g3", "id3") - p.setGroup(&RuleGroup{ID: "g1", Index: 100}) - p.setGroup(&RuleGroup{ID: "g1", Index: 1}) - p.setGroup(&RuleGroup{ID: "g3", Index: 3}) - p.deleteGroup("g3") + func(p *RuleConfigPatch) { + p.SetRule(&Rule{GroupID: "g1", ID: "id2", Index: 200}) + p.SetRule(&Rule{GroupID: "g1", ID: "id2"}) + p.SetRule(&Rule{GroupID: "g3", ID: "id3"}) + p.DeleteRule("g3", "id3") + p.SetGroup(&RuleGroup{ID: "g1", Index: 100}) + p.SetGroup(&RuleGroup{ID: "g1", Index: 1}) + p.SetGroup(&RuleGroup{ID: "g3", Index: 3}) + p.DeleteGroup("g3") }, map[[2]string]*Rule{}, map[string]*RuleGroup{}, diff --git a/pkg/schedule/placement/rule_manager.go b/pkg/schedule/placement/rule_manager.go index 621c52d738e..ea85911462b 100644 --- a/pkg/schedule/placement/rule_manager.go +++ b/pkg/schedule/placement/rule_manager.go @@ -33,6 +33,7 @@ import ( "github.com/tikv/pd/pkg/schedule/config" "github.com/tikv/pd/pkg/slice" "github.com/tikv/pd/pkg/storage/endpoint" + "github.com/tikv/pd/pkg/storage/kv" "github.com/tikv/pd/pkg/utils/syncutil" "go.uber.org/zap" "golang.org/x/exp/slices" @@ -128,12 +129,17 @@ func (m *RuleManager) Initialize(maxReplica int, locationLabels []string, isolat IsolationLevel: isolationLevel, }) } - for _, defaultRule := range defaultRules { - if err := m.storage.SaveRule(defaultRule.StoreKey(), defaultRule); err != nil { - // TODO: Need to delete the previously successfully saved Rules? - return err + if err := m.storage.RunInTxn(m.ctx, func(txn kv.Txn) (err error) { + for _, defaultRule := range defaultRules { + if err := m.storage.SaveRule(txn, defaultRule.StoreKey(), defaultRule); err != nil { + // TODO: Need to delete the previously successfully saved Rules? + return err + } + m.ruleConfig.setRule(defaultRule) } - m.ruleConfig.setRule(defaultRule) + return nil + }); err != nil { + return err } } m.ruleConfig.adjust() @@ -151,61 +157,66 @@ func (m *RuleManager) loadRules() error { toSave []*Rule toDelete []string ) - err := m.storage.LoadRules(func(k, v string) { - r, err := NewRuleFromJSON([]byte(v)) - if err != nil { - log.Error("failed to unmarshal rule value", zap.String("rule-key", k), zap.String("rule-value", v), errs.ZapError(errs.ErrLoadRule)) - toDelete = append(toDelete, k) - return - } - err = m.adjustRule(r, "") + return m.storage.RunInTxn(m.ctx, func(txn kv.Txn) (err error) { + err = m.storage.LoadRules(txn, func(k, v string) { + r, err := NewRuleFromJSON([]byte(v)) + if err != nil { + log.Error("failed to unmarshal rule value", zap.String("rule-key", k), zap.String("rule-value", v), errs.ZapError(errs.ErrLoadRule)) + toDelete = append(toDelete, k) + return + } + err = m.AdjustRule(r, "") + if err != nil { + log.Error("rule is in bad format", zap.String("rule-key", k), zap.String("rule-value", v), errs.ZapError(errs.ErrLoadRule, err)) + toDelete = append(toDelete, k) + return + } + _, ok := m.ruleConfig.rules[r.Key()] + if ok { + log.Error("duplicated rule key", zap.String("rule-key", k), zap.String("rule-value", v), errs.ZapError(errs.ErrLoadRule)) + toDelete = append(toDelete, k) + return + } + if k != r.StoreKey() { + log.Error("mismatch data key, need to restore", zap.String("rule-key", k), zap.String("rule-value", v), errs.ZapError(errs.ErrLoadRule)) + toDelete = append(toDelete, k) + toSave = append(toSave, r) + } + m.ruleConfig.rules[r.Key()] = r + }) if err != nil { - log.Error("rule is in bad format", zap.String("rule-key", k), zap.String("rule-value", v), errs.ZapError(errs.ErrLoadRule, err)) - toDelete = append(toDelete, k) - return + return err } - _, ok := m.ruleConfig.rules[r.Key()] - if ok { - log.Error("duplicated rule key", zap.String("rule-key", k), zap.String("rule-value", v), errs.ZapError(errs.ErrLoadRule)) - toDelete = append(toDelete, k) - return + + for _, s := range toSave { + if err = m.storage.SaveRule(txn, s.StoreKey(), s); err != nil { + return err + } } - if k != r.StoreKey() { - log.Error("mismatch data key, need to restore", zap.String("rule-key", k), zap.String("rule-value", v), errs.ZapError(errs.ErrLoadRule)) - toDelete = append(toDelete, k) - toSave = append(toSave, r) + for _, d := range toDelete { + if err = m.storage.DeleteRule(txn, d); err != nil { + return err + } } - m.ruleConfig.rules[r.Key()] = r + return nil }) - if err != nil { - return err - } - for _, s := range toSave { - if err = m.storage.SaveRule(s.StoreKey(), s); err != nil { - return err - } - } - for _, d := range toDelete { - if err = m.storage.DeleteRule(d); err != nil { - return err - } - } - return nil } func (m *RuleManager) loadGroups() error { - return m.storage.LoadRuleGroups(func(k, v string) { - g, err := NewRuleGroupFromJSON([]byte(v)) - if err != nil { - log.Error("failed to unmarshal rule group", zap.String("group-id", k), errs.ZapError(errs.ErrLoadRuleGroup, err)) - return - } - m.ruleConfig.groups[g.ID] = g + return m.storage.RunInTxn(m.ctx, func(txn kv.Txn) (err error) { + return m.storage.LoadRuleGroups(txn, func(k, v string) { + g, err := NewRuleGroupFromJSON([]byte(v)) + if err != nil { + log.Error("failed to unmarshal rule group", zap.String("group-id", k), errs.ZapError(errs.ErrLoadRuleGroup, err)) + return + } + m.ruleConfig.groups[g.ID] = g + }) }) } -// check and adjust rule from client or storage. -func (m *RuleManager) adjustRule(r *Rule, groupID string) (err error) { +// AdjustRule check and adjust rule from client or storage. +func (m *RuleManager) AdjustRule(r *Rule, groupID string) (err error) { r.StartKey, err = hex.DecodeString(r.StartKeyHex) if err != nil { return errs.ErrHexDecodingString.FastGenByArgs(r.StartKeyHex) @@ -279,6 +290,11 @@ func (m *RuleManager) adjustRule(r *Rule, groupID string) (err error) { func (m *RuleManager) GetRule(group, id string) *Rule { m.RLock() defer m.RUnlock() + return m.GetRuleLocked(group, id) +} + +// GetRuleLocked returns the Rule with the same (group, id). +func (m *RuleManager) GetRuleLocked(group, id string) *Rule { if r := m.ruleConfig.getRule([2]string{group, id}); r != nil { return r.Clone() } @@ -287,14 +303,14 @@ func (m *RuleManager) GetRule(group, id string) *Rule { // SetRule inserts or updates a Rule. func (m *RuleManager) SetRule(rule *Rule) error { - if err := m.adjustRule(rule, ""); err != nil { + if err := m.AdjustRule(rule, ""); err != nil { return err } m.Lock() defer m.Unlock() - p := m.beginPatch() - p.setRule(rule) - if err := m.tryCommitPatch(p); err != nil { + p := m.BeginPatch() + p.SetRule(rule) + if err := m.TryCommitPatch(p); err != nil { return err } log.Info("placement rule updated", zap.String("rule", fmt.Sprint(rule))) @@ -305,9 +321,9 @@ func (m *RuleManager) SetRule(rule *Rule) error { func (m *RuleManager) DeleteRule(group, id string) error { m.Lock() defer m.Unlock() - p := m.beginPatch() - p.deleteRule(group, id) - if err := m.tryCommitPatch(p); err != nil { + p := m.BeginPatch() + p.DeleteRule(group, id) + if err := m.TryCommitPatch(p); err != nil { return err } log.Info("placement rule is removed", zap.String("group", group), zap.String("id", id)) @@ -351,6 +367,11 @@ func (m *RuleManager) GetGroupsCount() int { func (m *RuleManager) GetRulesByGroup(group string) []*Rule { m.RLock() defer m.RUnlock() + return m.GetRulesByGroupLocked(group) +} + +// GetRulesByGroupLocked returns sorted rules of a group. +func (m *RuleManager) GetRulesByGroupLocked(group string) []*Rule { var rules []*Rule for _, r := range m.ruleConfig.rules { if r.GroupID == group { @@ -442,11 +463,13 @@ func (m *RuleManager) CheckIsCachedDirectly(regionID uint64) bool { return ok } -func (m *RuleManager) beginPatch() *ruleConfigPatch { +// BeginPatch returns a patch for multiple changes. +func (m *RuleManager) BeginPatch() *RuleConfigPatch { return m.ruleConfig.beginPatch() } -func (m *RuleManager) tryCommitPatch(patch *ruleConfigPatch) error { +// TryCommitPatch tries to commit a patch. +func (m *RuleManager) TryCommitPatch(patch *RuleConfigPatch) error { patch.adjust() ruleList, err := buildRuleList(patch) @@ -469,49 +492,44 @@ func (m *RuleManager) tryCommitPatch(patch *ruleConfigPatch) error { } func (m *RuleManager) savePatch(p *ruleConfig) error { - // TODO: it is not completely safe - // 1. in case that half of rules applied, error.. we have to cancel persisted rules - // but that may fail too, causing memory/disk inconsistency - // either rely a transaction API, or clients to request again until success - // 2. in case that PD is suddenly down in the loop, inconsistency again - // now we can only rely clients to request again - var err error - for key, r := range p.rules { - if r == nil { - r = &Rule{GroupID: key[0], ID: key[1]} - err = m.storage.DeleteRule(r.StoreKey()) - } else { - err = m.storage.SaveRule(r.StoreKey(), r) - } - if err != nil { - return err - } - } - for id, g := range p.groups { - if g.isDefault() { - err = m.storage.DeleteRuleGroup(id) - } else { - err = m.storage.SaveRuleGroup(id, g) + return m.storage.RunInTxn(m.ctx, func(txn kv.Txn) (err error) { + for key, r := range p.rules { + if r == nil { + r = &Rule{GroupID: key[0], ID: key[1]} + err = m.storage.DeleteRule(txn, r.StoreKey()) + } else { + err = m.storage.SaveRule(txn, r.StoreKey(), r) + } + if err != nil { + return err + } } - if err != nil { - return err + for id, g := range p.groups { + if g.isDefault() { + err = m.storage.DeleteRuleGroup(txn, id) + } else { + err = m.storage.SaveRuleGroup(txn, id, g) + } + if err != nil { + return err + } } - } - return nil + return nil + }) } // SetRules inserts or updates lots of Rules at once. func (m *RuleManager) SetRules(rules []*Rule) error { m.Lock() defer m.Unlock() - p := m.beginPatch() + p := m.BeginPatch() for _, r := range rules { - if err := m.adjustRule(r, ""); err != nil { + if err := m.AdjustRule(r, ""); err != nil { return err } - p.setRule(r) + p.SetRule(r) } - if err := m.tryCommitPatch(p); err != nil { + if err := m.TryCommitPatch(p); err != nil { return err } @@ -546,7 +564,7 @@ func (r RuleOp) String() string { func (m *RuleManager) Batch(todo []RuleOp) error { for _, t := range todo { if t.Action == RuleOpAdd { - err := m.adjustRule(t.Rule, "") + err := m.AdjustRule(t.Rule, "") if err != nil { return err } @@ -556,25 +574,25 @@ func (m *RuleManager) Batch(todo []RuleOp) error { m.Lock() defer m.Unlock() - patch := m.beginPatch() + patch := m.BeginPatch() for _, t := range todo { switch t.Action { case RuleOpAdd: - patch.setRule(t.Rule) + patch.SetRule(t.Rule) case RuleOpDel: if !t.DeleteByIDPrefix { - patch.deleteRule(t.GroupID, t.ID) + patch.DeleteRule(t.GroupID, t.ID) } else { m.ruleConfig.iterateRules(func(r *Rule) { if r.GroupID == t.GroupID && strings.HasPrefix(r.ID, t.ID) { - patch.deleteRule(r.GroupID, r.ID) + patch.DeleteRule(r.GroupID, r.ID) } }) } } } - if err := m.tryCommitPatch(patch); err != nil { + if err := m.TryCommitPatch(patch); err != nil { return err } @@ -608,9 +626,9 @@ func (m *RuleManager) GetRuleGroups() []*RuleGroup { func (m *RuleManager) SetRuleGroup(group *RuleGroup) error { m.Lock() defer m.Unlock() - p := m.beginPatch() - p.setGroup(group) - if err := m.tryCommitPatch(p); err != nil { + p := m.BeginPatch() + p.SetGroup(group) + if err := m.TryCommitPatch(p); err != nil { return err } log.Info("group config updated", zap.String("group", fmt.Sprint(group))) @@ -621,9 +639,9 @@ func (m *RuleManager) SetRuleGroup(group *RuleGroup) error { func (m *RuleManager) DeleteRuleGroup(id string) error { m.Lock() defer m.Unlock() - p := m.beginPatch() - p.deleteGroup(id) - if err := m.tryCommitPatch(p); err != nil { + p := m.BeginPatch() + p.DeleteGroup(id) + if err := m.TryCommitPatch(p); err != nil { return err } log.Info("group config reset", zap.String("group", id)) @@ -681,7 +699,7 @@ func (m *RuleManager) GetGroupBundle(id string) (b GroupBundle) { func (m *RuleManager) SetAllGroupBundles(groups []GroupBundle, override bool) error { m.Lock() defer m.Unlock() - p := m.beginPatch() + p := m.BeginPatch() matchID := func(a string) bool { for _, g := range groups { if g.ID == a { @@ -692,28 +710,28 @@ func (m *RuleManager) SetAllGroupBundles(groups []GroupBundle, override bool) er } for k := range m.ruleConfig.rules { if override || matchID(k[0]) { - p.deleteRule(k[0], k[1]) + p.DeleteRule(k[0], k[1]) } } for id := range m.ruleConfig.groups { if override || matchID(id) { - p.deleteGroup(id) + p.DeleteGroup(id) } } for _, g := range groups { - p.setGroup(&RuleGroup{ + p.SetGroup(&RuleGroup{ ID: g.ID, Index: g.Index, Override: g.Override, }) for _, r := range g.Rules { - if err := m.adjustRule(r, g.ID); err != nil { + if err := m.AdjustRule(r, g.ID); err != nil { return err } - p.setRule(r) + p.SetRule(r) } } - if err := m.tryCommitPatch(p); err != nil { + if err := m.TryCommitPatch(p); err != nil { return err } log.Info("full config reset", zap.String("config", fmt.Sprint(groups))) @@ -725,26 +743,26 @@ func (m *RuleManager) SetAllGroupBundles(groups []GroupBundle, override bool) er func (m *RuleManager) SetGroupBundle(group GroupBundle) error { m.Lock() defer m.Unlock() - p := m.beginPatch() + p := m.BeginPatch() if _, ok := m.ruleConfig.groups[group.ID]; ok { for k := range m.ruleConfig.rules { if k[0] == group.ID { - p.deleteRule(k[0], k[1]) + p.DeleteRule(k[0], k[1]) } } } - p.setGroup(&RuleGroup{ + p.SetGroup(&RuleGroup{ ID: group.ID, Index: group.Index, Override: group.Override, }) for _, r := range group.Rules { - if err := m.adjustRule(r, group.ID); err != nil { + if err := m.AdjustRule(r, group.ID); err != nil { return err } - p.setRule(r) + p.SetRule(r) } - if err := m.tryCommitPatch(p); err != nil { + if err := m.TryCommitPatch(p); err != nil { return err } log.Info("group is reset", zap.String("group", fmt.Sprint(group))) @@ -765,18 +783,18 @@ func (m *RuleManager) DeleteGroupBundle(id string, regex bool) error { matchID = r.MatchString } - p := m.beginPatch() + p := m.BeginPatch() for k := range m.ruleConfig.rules { if matchID(k[0]) { - p.deleteRule(k[0], k[1]) + p.DeleteRule(k[0], k[1]) } } for _, g := range m.ruleConfig.groups { if matchID(g.ID) { - p.deleteGroup(g.ID) + p.DeleteGroup(g.ID) } } - if err := m.tryCommitPatch(p); err != nil { + if err := m.TryCommitPatch(p); err != nil { return err } log.Info("groups are removed", zap.String("id", id), zap.Bool("regexp", regex)) diff --git a/pkg/schedule/placement/rule_manager_test.go b/pkg/schedule/placement/rule_manager_test.go index 0539e935113..5494b3c5a9d 100644 --- a/pkg/schedule/placement/rule_manager_test.go +++ b/pkg/schedule/placement/rule_manager_test.go @@ -91,23 +91,23 @@ func TestAdjustRule(t *testing.T) { {GroupID: "group", ID: "id", StartKeyHex: "123abc", EndKeyHex: "123abf", Role: Voter, Count: -1}, {GroupID: "group", ID: "id", StartKeyHex: "123abc", EndKeyHex: "123abf", Role: Voter, Count: 3, LabelConstraints: []LabelConstraint{{Op: "foo"}}}, } - re.NoError(manager.adjustRule(&rules[0], "group")) + re.NoError(manager.AdjustRule(&rules[0], "group")) re.Equal([]byte{0x12, 0x3a, 0xbc}, rules[0].StartKey) re.Equal([]byte{0x12, 0x3a, 0xbf}, rules[0].EndKey) - re.Error(manager.adjustRule(&rules[1], "")) + re.Error(manager.AdjustRule(&rules[1], "")) for i := 2; i < len(rules); i++ { - re.Error(manager.adjustRule(&rules[i], "group")) + re.Error(manager.AdjustRule(&rules[i], "group")) } manager.SetKeyType(constant.Table.String()) - re.Error(manager.adjustRule(&Rule{GroupID: "group", ID: "id", StartKeyHex: "123abc", EndKeyHex: "123abf", Role: Voter, Count: 3}, "group")) + re.Error(manager.AdjustRule(&Rule{GroupID: "group", ID: "id", StartKeyHex: "123abc", EndKeyHex: "123abf", Role: Voter, Count: 3}, "group")) manager.SetKeyType(constant.Txn.String()) - re.Error(manager.adjustRule(&Rule{GroupID: "group", ID: "id", StartKeyHex: "123abc", EndKeyHex: "123abf", Role: Voter, Count: 3}, "group")) + re.Error(manager.AdjustRule(&Rule{GroupID: "group", ID: "id", StartKeyHex: "123abc", EndKeyHex: "123abf", Role: Voter, Count: 3}, "group")) - re.Error(manager.adjustRule(&Rule{ + re.Error(manager.AdjustRule(&Rule{ GroupID: "group", ID: "id", StartKeyHex: hex.EncodeToString(codec.EncodeBytes([]byte{0})), @@ -116,7 +116,7 @@ func TestAdjustRule(t *testing.T) { Count: 3, }, "group")) - re.Error(manager.adjustRule(&Rule{ + re.Error(manager.AdjustRule(&Rule{ GroupID: "tiflash", ID: "id", StartKeyHex: hex.EncodeToString(codec.EncodeBytes([]byte{0})), diff --git a/pkg/storage/endpoint/rule.go b/pkg/storage/endpoint/rule.go index 80b6fc7c0ff..b18360040ea 100644 --- a/pkg/storage/endpoint/rule.go +++ b/pkg/storage/endpoint/rule.go @@ -14,58 +14,54 @@ package endpoint +import ( + "context" + + "github.com/tikv/pd/pkg/storage/kv" +) + // RuleStorage defines the storage operations on the rule. type RuleStorage interface { + LoadRules(txn kv.Txn, f func(k, v string)) error + SaveRule(txn kv.Txn, ruleKey string, rule interface{}) error + DeleteRule(txn kv.Txn, ruleKey string) error + LoadRuleGroups(txn kv.Txn, f func(k, v string)) error + SaveRuleGroup(txn kv.Txn, groupID string, group interface{}) error + DeleteRuleGroup(txn kv.Txn, groupID string) error + // LoadRule is used only in rule watcher. LoadRule(ruleKey string) (string, error) - LoadRules(f func(k, v string)) error - SaveRule(ruleKey string, rule interface{}) error - SaveRuleJSON(ruleKey, rule string) error - DeleteRule(ruleKey string) error - LoadRuleGroups(f func(k, v string)) error - SaveRuleGroup(groupID string, group interface{}) error - SaveRuleGroupJSON(groupID, group string) error - DeleteRuleGroup(groupID string) error + LoadRegionRules(f func(k, v string)) error SaveRegionRule(ruleKey string, rule interface{}) error - SaveRegionRuleJSON(ruleKey, rule string) error DeleteRegionRule(ruleKey string) error + RunInTxn(ctx context.Context, f func(txn kv.Txn) error) error } var _ RuleStorage = (*StorageEndpoint)(nil) // SaveRule stores a rule cfg to the rulesPath. -func (se *StorageEndpoint) SaveRule(ruleKey string, rule interface{}) error { - return se.saveJSON(ruleKeyPath(ruleKey), rule) -} - -// SaveRuleJSON stores a rule cfg JSON to the rulesPath. -func (se *StorageEndpoint) SaveRuleJSON(ruleKey, rule string) error { - return se.Save(ruleKeyPath(ruleKey), rule) +func (se *StorageEndpoint) SaveRule(txn kv.Txn, ruleKey string, rule interface{}) error { + return saveJSONInTxn(txn, ruleKeyPath(ruleKey), rule) } // DeleteRule removes a rule from storage. -func (se *StorageEndpoint) DeleteRule(ruleKey string) error { - return se.Remove(ruleKeyPath(ruleKey)) +func (se *StorageEndpoint) DeleteRule(txn kv.Txn, ruleKey string) error { + return txn.Remove(ruleKeyPath(ruleKey)) } // LoadRuleGroups loads all rule groups from storage. -func (se *StorageEndpoint) LoadRuleGroups(f func(k, v string)) error { - return se.loadRangeByPrefix(ruleGroupPath+"/", f) +func (se *StorageEndpoint) LoadRuleGroups(txn kv.Txn, f func(k, v string)) error { + return loadRangeByPrefixInTxn(txn, ruleGroupPath+"/", f) } // SaveRuleGroup stores a rule group config to storage. -func (se *StorageEndpoint) SaveRuleGroup(groupID string, group interface{}) error { - return se.saveJSON(ruleGroupIDPath(groupID), group) -} - -// SaveRuleGroupJSON stores a rule group config JSON to storage. -func (se *StorageEndpoint) SaveRuleGroupJSON(groupID, group string) error { - return se.Save(ruleGroupIDPath(groupID), group) +func (se *StorageEndpoint) SaveRuleGroup(txn kv.Txn, groupID string, group interface{}) error { + return saveJSONInTxn(txn, ruleGroupIDPath(groupID), group) } // DeleteRuleGroup removes a rule group from storage. -func (se *StorageEndpoint) DeleteRuleGroup(groupID string) error { - return se.Remove(ruleGroupIDPath(groupID)) +func (se *StorageEndpoint) DeleteRuleGroup(txn kv.Txn, groupID string) error { + return txn.Remove(ruleGroupIDPath(groupID)) } // LoadRegionRules loads region rules from storage. @@ -78,11 +74,6 @@ func (se *StorageEndpoint) SaveRegionRule(ruleKey string, rule interface{}) erro return se.saveJSON(regionLabelKeyPath(ruleKey), rule) } -// SaveRegionRuleJSON saves a region rule JSON to the storage. -func (se *StorageEndpoint) SaveRegionRuleJSON(ruleKey, rule string) error { - return se.Save(regionLabelKeyPath(ruleKey), rule) -} - // DeleteRegionRule removes a region rule from storage. func (se *StorageEndpoint) DeleteRegionRule(ruleKey string) error { return se.Remove(regionLabelKeyPath(ruleKey)) @@ -94,6 +85,6 @@ func (se *StorageEndpoint) LoadRule(ruleKey string) (string, error) { } // LoadRules loads placement rules from storage. -func (se *StorageEndpoint) LoadRules(f func(k, v string)) error { - return se.loadRangeByPrefix(rulesPath+"/", f) +func (se *StorageEndpoint) LoadRules(txn kv.Txn, f func(k, v string)) error { + return loadRangeByPrefixInTxn(txn, rulesPath+"/", f) } diff --git a/pkg/tso/keyspace_group_manager.go b/pkg/tso/keyspace_group_manager.go index 0e69986f255..c48c066a2aa 100644 --- a/pkg/tso/keyspace_group_manager.go +++ b/pkg/tso/keyspace_group_manager.go @@ -559,7 +559,7 @@ func (kgm *KeyspaceGroupManager) InitializeGroupWatchLoop() error { kgm.deleteKeyspaceGroup(groupID) return nil } - postEventFn := func([]*clientv3.Event) error { + postEventsFn := func([]*clientv3.Event) error { // Retry the groups that are not initialized successfully before. for id, group := range kgm.groupUpdateRetryList { delete(kgm.groupUpdateRetryList, id) @@ -576,7 +576,7 @@ func (kgm *KeyspaceGroupManager) InitializeGroupWatchLoop() error { func([]*clientv3.Event) error { return nil }, putFn, deleteFn, - postEventFn, + postEventsFn, clientv3.WithRange(endKey), ) if kgm.loadKeyspaceGroupsTimeout > 0 { diff --git a/pkg/utils/etcdutil/etcdutil.go b/pkg/utils/etcdutil/etcdutil.go index 0e1b2731474..f6beafee511 100644 --- a/pkg/utils/etcdutil/etcdutil.go +++ b/pkg/utils/etcdutil/etcdutil.go @@ -865,6 +865,16 @@ func (lw *LoopWatcher) load(ctx context.Context) (nextRevision int64, err error) if limit != 0 { limit++ } + if err := lw.preEventsFn([]*clientv3.Event{}); err != nil { + log.Error("run pre event failed in watch loop", zap.String("name", lw.name), + zap.String("key", lw.key), zap.Error(err)) + } + defer func() { + if err := lw.postEventsFn([]*clientv3.Event{}); err != nil { + log.Error("run post event failed in watch loop", zap.String("name", lw.name), + zap.String("key", lw.key), zap.Error(err)) + } + }() for { // Sort by key to get the next key and we don't need to worry about the performance, // Because the default sort is just SortByKey and SortAscend @@ -875,10 +885,6 @@ func (lw *LoopWatcher) load(ctx context.Context) (nextRevision int64, err error) zap.String("key", lw.key), zap.Error(err)) return 0, err } - if err := lw.preEventsFn([]*clientv3.Event{}); err != nil { - log.Error("run pre event failed in watch loop", zap.String("name", lw.name), - zap.String("key", lw.key), zap.Error(err)) - } for i, item := range resp.Kvs { if resp.More && i == len(resp.Kvs)-1 { // The last key is the start key of the next batch. @@ -888,15 +894,15 @@ func (lw *LoopWatcher) load(ctx context.Context) (nextRevision int64, err error) } err = lw.putFn(item) if err != nil { - log.Error("put failed in watch loop when loading", zap.String("name", lw.name), zap.String("key", lw.key), zap.Error(err)) + log.Error("put failed in watch loop when loading", zap.String("name", lw.name), zap.String("watch-key", lw.key), + zap.ByteString("key", item.Key), zap.ByteString("value", item.Value), zap.Error(err)) + } else { + log.Debug("put successfully in watch loop when loading", zap.String("name", lw.name), zap.String("watch-key", lw.key), + zap.ByteString("key", item.Key), zap.ByteString("value", item.Value)) } } // Note: if there are no keys in etcd, the resp.More is false. It also means the load is finished. if !resp.More { - if err := lw.postEventsFn([]*clientv3.Event{}); err != nil { - log.Error("run post event failed in watch loop", zap.String("name", lw.name), - zap.String("key", lw.key), zap.Error(err)) - } return resp.Header.Revision + 1, err } } diff --git a/server/keyspace_service.go b/server/keyspace_service.go index 1718108d73b..11d912a5f54 100644 --- a/server/keyspace_service.go +++ b/server/keyspace_service.go @@ -89,7 +89,10 @@ func (s *KeyspaceServer) WatchKeyspaces(request *keyspacepb.WatchKeyspacesReques deleteFn := func(kv *mvccpb.KeyValue) error { return nil } - postEventFn := func([]*clientv3.Event) error { + postEventsFn := func([]*clientv3.Event) error { + if len(keyspaces) == 0 { + return nil + } defer func() { keyspaces = keyspaces[:0] }() @@ -112,7 +115,7 @@ func (s *KeyspaceServer) WatchKeyspaces(request *keyspacepb.WatchKeyspacesReques func([]*clientv3.Event) error { return nil }, putFn, deleteFn, - postEventFn, + postEventsFn, clientv3.WithRange(clientv3.GetPrefixRangeEnd(startKey)), ) watcher.StartWatchLoop() diff --git a/tests/server/api/region_test.go b/tests/server/api/region_test.go index 450995a6e5e..328c0fcd885 100644 --- a/tests/server/api/region_test.go +++ b/tests/server/api/region_test.go @@ -248,8 +248,7 @@ func (suite *regionTestSuite) checkScatterRegions(cluster *tests.TestCluster) { } func (suite *regionTestSuite) TestCheckRegionsReplicated() { - // Fixme: after delete+set rule, the key range will be empty, so the test will fail in api mode. - suite.env.RunTestInPDMode(suite.checkRegionsReplicated) + suite.env.RunTestInTwoModes(suite.checkRegionsReplicated) } func (suite *regionTestSuite) checkRegionsReplicated(cluster *tests.TestCluster) { @@ -304,6 +303,14 @@ func (suite *regionTestSuite) checkRegionsReplicated(cluster *tests.TestCluster) err = tu.CheckPostJSON(testDialClient, urlPrefix+"/config/placement-rule", data, tu.StatusOK(re)) suite.NoError(err) + tu.Eventually(re, func() bool { + respBundle := make([]placement.GroupBundle, 0) + err = tu.CheckGetJSON(testDialClient, urlPrefix+"/config/placement-rule", nil, + tu.StatusOK(re), tu.ExtractJSON(re, &respBundle)) + suite.NoError(err) + return len(respBundle) == 1 && respBundle[0].ID == "5" + }) + tu.Eventually(re, func() bool { err = tu.ReadGetJSON(re, testDialClient, url, &status) suite.NoError(err) @@ -328,9 +335,19 @@ func (suite *regionTestSuite) checkRegionsReplicated(cluster *tests.TestCluster) err = tu.CheckPostJSON(testDialClient, urlPrefix+"/config/placement-rule", data, tu.StatusOK(re)) suite.NoError(err) - err = tu.ReadGetJSON(re, testDialClient, url, &status) - suite.NoError(err) - suite.Equal("REPLICATED", status) + tu.Eventually(re, func() bool { + respBundle := make([]placement.GroupBundle, 0) + err = tu.CheckGetJSON(testDialClient, urlPrefix+"/config/placement-rule", nil, + tu.StatusOK(re), tu.ExtractJSON(re, &respBundle)) + suite.NoError(err) + return len(respBundle) == 1 && len(respBundle[0].Rules) == 2 + }) + + tu.Eventually(re, func() bool { + err = tu.ReadGetJSON(re, testDialClient, url, &status) + suite.NoError(err) + return status == "REPLICATED" + }) // test multiple bundles bundle = append(bundle, placement.GroupBundle{ @@ -347,17 +364,34 @@ func (suite *regionTestSuite) checkRegionsReplicated(cluster *tests.TestCluster) err = tu.CheckPostJSON(testDialClient, urlPrefix+"/config/placement-rule", data, tu.StatusOK(re)) suite.NoError(err) - err = tu.ReadGetJSON(re, testDialClient, url, &status) - suite.NoError(err) - suite.Equal("INPROGRESS", status) + tu.Eventually(re, func() bool { + respBundle := make([]placement.GroupBundle, 0) + err = tu.CheckGetJSON(testDialClient, urlPrefix+"/config/placement-rule", nil, + tu.StatusOK(re), tu.ExtractJSON(re, &respBundle)) + suite.NoError(err) + if len(respBundle) != 2 { + return false + } + s1 := respBundle[0].ID == "5" && respBundle[1].ID == "6" + s2 := respBundle[0].ID == "6" && respBundle[1].ID == "5" + return s1 || s2 + }) + + tu.Eventually(re, func() bool { + err = tu.ReadGetJSON(re, testDialClient, url, &status) + suite.NoError(err) + return status == "INPROGRESS" + }) r1 = core.NewTestRegionInfo(2, 1, []byte("a"), []byte("b")) r1.GetMeta().Peers = append(r1.GetMeta().Peers, &metapb.Peer{Id: 5, StoreId: 1}, &metapb.Peer{Id: 6, StoreId: 1}, &metapb.Peer{Id: 7, StoreId: 1}) tests.MustPutRegionInfo(re, cluster, r1) - err = tu.ReadGetJSON(re, testDialClient, url, &status) - suite.NoError(err) - suite.Equal("REPLICATED", status) + tu.Eventually(re, func() bool { + err = tu.ReadGetJSON(re, testDialClient, url, &status) + suite.NoError(err) + return status == "REPLICATED" + }) } func (suite *regionTestSuite) checkRegionCount(cluster *tests.TestCluster, count uint64) { diff --git a/tests/server/api/rule_test.go b/tests/server/api/rule_test.go index eaa41cc11bc..af70d5afed9 100644 --- a/tests/server/api/rule_test.go +++ b/tests/server/api/rule_test.go @@ -20,6 +20,9 @@ import ( "fmt" "net/http" "net/url" + "sort" + "strconv" + "sync" "testing" "github.com/pingcap/kvproto/pkg/metapb" @@ -28,6 +31,7 @@ import ( "github.com/tikv/pd/pkg/errs" "github.com/tikv/pd/pkg/schedule/labeler" "github.com/tikv/pd/pkg/schedule/placement" + "github.com/tikv/pd/pkg/utils/syncutil" tu "github.com/tikv/pd/pkg/utils/testutil" "github.com/tikv/pd/server/config" "github.com/tikv/pd/tests" @@ -777,7 +781,7 @@ func (suite *ruleTestSuite) checkBundle(cluster *tests.TestCluster) { err := tu.ReadGetJSON(re, testDialClient, urlPrefix+"/placement-rule", &bundles) suite.NoError(err) suite.Len(bundles, 1) - suite.compareBundle(bundles[0], b1) + suite.assertBundleEqual(bundles[0], b1) // Set b2 := placement.GroupBundle{ @@ -797,14 +801,14 @@ func (suite *ruleTestSuite) checkBundle(cluster *tests.TestCluster) { var bundle placement.GroupBundle err = tu.ReadGetJSON(re, testDialClient, urlPrefix+"/placement-rule/foo", &bundle) suite.NoError(err) - suite.compareBundle(bundle, b2) + suite.assertBundleEqual(bundle, b2) // GetAll again err = tu.ReadGetJSON(re, testDialClient, urlPrefix+"/placement-rule", &bundles) suite.NoError(err) suite.Len(bundles, 2) - suite.compareBundle(bundles[0], b1) - suite.compareBundle(bundles[1], b2) + suite.assertBundleEqual(bundles[0], b1) + suite.assertBundleEqual(bundles[1], b2) // Delete err = tu.CheckDelete(testDialClient, urlPrefix+"/placement-rule/pd", tu.StatusOK(suite.Require())) @@ -814,7 +818,7 @@ func (suite *ruleTestSuite) checkBundle(cluster *tests.TestCluster) { err = tu.ReadGetJSON(re, testDialClient, urlPrefix+"/placement-rule", &bundles) suite.NoError(err) suite.Len(bundles, 1) - suite.compareBundle(bundles[0], b2) + suite.assertBundleEqual(bundles[0], b2) // SetAll b2.Rules = append(b2.Rules, &placement.Rule{GroupID: "foo", ID: "baz", Index: 2, Role: placement.Follower, Count: 1}) @@ -829,9 +833,9 @@ func (suite *ruleTestSuite) checkBundle(cluster *tests.TestCluster) { err = tu.ReadGetJSON(re, testDialClient, urlPrefix+"/placement-rule", &bundles) suite.NoError(err) suite.Len(bundles, 3) - suite.compareBundle(bundles[0], b2) - suite.compareBundle(bundles[1], b1) - suite.compareBundle(bundles[2], b3) + suite.assertBundleEqual(bundles[0], b2) + suite.assertBundleEqual(bundles[1], b1) + suite.assertBundleEqual(bundles[2], b3) // Delete using regexp err = tu.CheckDelete(testDialClient, urlPrefix+"/placement-rule/"+url.PathEscape("foo.*")+"?regexp", tu.StatusOK(suite.Require())) @@ -841,7 +845,7 @@ func (suite *ruleTestSuite) checkBundle(cluster *tests.TestCluster) { err = tu.ReadGetJSON(re, testDialClient, urlPrefix+"/placement-rule", &bundles) suite.NoError(err) suite.Len(bundles, 1) - suite.compareBundle(bundles[0], b1) + suite.assertBundleEqual(bundles[0], b1) // Set id := "rule-without-group-id" @@ -862,14 +866,14 @@ func (suite *ruleTestSuite) checkBundle(cluster *tests.TestCluster) { // Get err = tu.ReadGetJSON(re, testDialClient, urlPrefix+"/placement-rule/"+id, &bundle) suite.NoError(err) - suite.compareBundle(bundle, b4) + suite.assertBundleEqual(bundle, b4) // GetAll again err = tu.ReadGetJSON(re, testDialClient, urlPrefix+"/placement-rule", &bundles) suite.NoError(err) suite.Len(bundles, 2) - suite.compareBundle(bundles[0], b1) - suite.compareBundle(bundles[1], b4) + suite.assertBundleEqual(bundles[0], b1) + suite.assertBundleEqual(bundles[1], b4) // SetAll b5 := placement.GroupBundle{ @@ -890,9 +894,9 @@ func (suite *ruleTestSuite) checkBundle(cluster *tests.TestCluster) { err = tu.ReadGetJSON(re, testDialClient, urlPrefix+"/placement-rule", &bundles) suite.NoError(err) suite.Len(bundles, 3) - suite.compareBundle(bundles[0], b1) - suite.compareBundle(bundles[1], b4) - suite.compareBundle(bundles[2], b5) + suite.assertBundleEqual(bundles[0], b1) + suite.assertBundleEqual(bundles[1], b4) + suite.assertBundleEqual(bundles[2], b5) } func (suite *ruleTestSuite) TestBundleBadRequest() { @@ -925,20 +929,315 @@ func (suite *ruleTestSuite) checkBundleBadRequest(cluster *tests.TestCluster) { } } -func (suite *ruleTestSuite) compareBundle(b1, b2 placement.GroupBundle) { - tu.Eventually(suite.Require(), func() bool { - if b2.ID != b1.ID || b2.Index != b1.Index || b2.Override != b1.Override || len(b2.Rules) != len(b1.Rules) { +func (suite *ruleTestSuite) TestLeaderAndVoter() { + suite.env.RunTestInTwoModes(suite.checkLeaderAndVoter) +} + +func (suite *ruleTestSuite) checkLeaderAndVoter(cluster *tests.TestCluster) { + re := suite.Require() + leaderServer := cluster.GetLeaderServer() + pdAddr := leaderServer.GetAddr() + urlPrefix := fmt.Sprintf("%s%s/api/v1", pdAddr, apiPrefix) + + stores := []*metapb.Store{ + { + Id: 1, + Address: "tikv1", + State: metapb.StoreState_Up, + NodeState: metapb.NodeState_Serving, + Version: "7.5.0", + Labels: []*metapb.StoreLabel{{Key: "zone", Value: "z1"}}, + }, + { + Id: 2, + Address: "tikv2", + State: metapb.StoreState_Up, + NodeState: metapb.NodeState_Serving, + Version: "7.5.0", + Labels: []*metapb.StoreLabel{{Key: "zone", Value: "z2"}}, + }, + } + + for _, store := range stores { + tests.MustPutStore(re, cluster, store) + } + + bundles := [][]placement.GroupBundle{ + { + { + ID: "1", + Index: 1, + Rules: []*placement.Rule{ + { + ID: "rule_1", Index: 1, Role: placement.Voter, Count: 1, GroupID: "1", + LabelConstraints: []placement.LabelConstraint{ + {Key: "zone", Op: "in", Values: []string{"z1"}}, + }, + }, + { + ID: "rule_2", Index: 2, Role: placement.Leader, Count: 1, GroupID: "1", + LabelConstraints: []placement.LabelConstraint{ + {Key: "zone", Op: "in", Values: []string{"z2"}}, + }, + }, + }, + }, + }, + { + { + ID: "1", + Index: 1, + Rules: []*placement.Rule{ + { + ID: "rule_1", Index: 1, Role: placement.Leader, Count: 1, GroupID: "1", + LabelConstraints: []placement.LabelConstraint{ + {Key: "zone", Op: "in", Values: []string{"z2"}}, + }, + }, + { + ID: "rule_2", Index: 2, Role: placement.Voter, Count: 1, GroupID: "1", + LabelConstraints: []placement.LabelConstraint{ + {Key: "zone", Op: "in", Values: []string{"z1"}}, + }, + }, + }, + }, + }} + for _, bundle := range bundles { + data, err := json.Marshal(bundle) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, urlPrefix+"/config/placement-rule", data, tu.StatusOK(re)) + suite.NoError(err) + + tu.Eventually(re, func() bool { + respBundle := make([]placement.GroupBundle, 0) + err := tu.CheckGetJSON(testDialClient, urlPrefix+"/config/placement-rule", nil, + tu.StatusOK(re), tu.ExtractJSON(re, &respBundle)) + suite.NoError(err) + suite.Len(respBundle, 1) + if bundle[0].Rules[0].Role == placement.Leader { + return respBundle[0].Rules[0].Role == placement.Leader + } + if bundle[0].Rules[0].Role == placement.Voter { + return respBundle[0].Rules[0].Role == placement.Voter + } return false - } - for i := range b1.Rules { - if !suite.compareRule(b1.Rules[i], b2.Rules[i]) { + }) + } +} + +func (suite *ruleTestSuite) TestDeleteAndUpdate() { + suite.env.RunTestInTwoModes(suite.checkDeleteAndUpdate) +} + +func (suite *ruleTestSuite) checkDeleteAndUpdate(cluster *tests.TestCluster) { + re := suite.Require() + leaderServer := cluster.GetLeaderServer() + pdAddr := leaderServer.GetAddr() + urlPrefix := fmt.Sprintf("%s%s/api/v1", pdAddr, apiPrefix) + + bundles := [][]placement.GroupBundle{ + // 1 rule group with 1 rule + {{ + ID: "1", + Index: 1, + Rules: []*placement.Rule{ + { + ID: "foo", Index: 1, Role: placement.Voter, Count: 1, GroupID: "1", + }, + }, + }}, + // 2 rule groups with different range rules + {{ + ID: "1", + Index: 1, + Rules: []*placement.Rule{ + { + ID: "foo", Index: 1, Role: placement.Voter, Count: 1, GroupID: "1", + StartKey: []byte("a"), EndKey: []byte("b"), + }, + }, + }, { + ID: "2", + Index: 2, + Rules: []*placement.Rule{ + { + ID: "foo", Index: 2, Role: placement.Voter, Count: 1, GroupID: "2", + StartKey: []byte("b"), EndKey: []byte("c"), + }, + }, + }}, + // 2 rule groups with 1 rule and 2 rules + {{ + ID: "3", + Index: 3, + Rules: []*placement.Rule{ + { + ID: "foo", Index: 3, Role: placement.Voter, Count: 1, GroupID: "3", + }, + }, + }, { + ID: "4", + Index: 4, + Rules: []*placement.Rule{ + { + ID: "foo", Index: 4, Role: placement.Voter, Count: 1, GroupID: "4", + }, + { + ID: "bar", Index: 6, Role: placement.Voter, Count: 1, GroupID: "4", + }, + }, + }}, + // 1 rule group with 2 rules + {{ + ID: "5", + Index: 5, + Rules: []*placement.Rule{ + { + ID: "foo", Index: 5, Role: placement.Voter, Count: 1, GroupID: "5", + }, + { + ID: "bar", Index: 6, Role: placement.Voter, Count: 1, GroupID: "5", + }, + }, + }}, + } + + for _, bundle := range bundles { + data, err := json.Marshal(bundle) + suite.NoError(err) + err = tu.CheckPostJSON(testDialClient, urlPrefix+"/config/placement-rule", data, tu.StatusOK(re)) + suite.NoError(err) + + tu.Eventually(re, func() bool { + respBundle := make([]placement.GroupBundle, 0) + err = tu.CheckGetJSON(testDialClient, urlPrefix+"/config/placement-rule", nil, + tu.StatusOK(re), tu.ExtractJSON(re, &respBundle)) + suite.NoError(err) + if len(respBundle) != len(bundle) { return false } - } - return true + sort.Slice(respBundle, func(i, j int) bool { return respBundle[i].ID < respBundle[j].ID }) + sort.Slice(bundle, func(i, j int) bool { return bundle[i].ID < bundle[j].ID }) + for i := range respBundle { + if !suite.compareBundle(respBundle[i], bundle[i]) { + return false + } + } + return true + }) + } +} + +func (suite *ruleTestSuite) TestConcurrency() { + suite.env.RunTestInTwoModes(suite.checkConcurrency) +} + +func (suite *ruleTestSuite) checkConcurrency(cluster *tests.TestCluster) { + // test concurrency of set rule group with different group id + suite.checkConcurrencyWith(cluster, + func(i int) []placement.GroupBundle { + return []placement.GroupBundle{ + { + ID: strconv.Itoa(i), + Index: i, + Rules: []*placement.Rule{ + { + ID: "foo", Index: i, Role: placement.Voter, Count: 1, GroupID: strconv.Itoa(i), + }, + }, + }, + } + }, + func(resp []placement.GroupBundle, i int) bool { + return len(resp) == 1 && resp[0].ID == strconv.Itoa(i) + }, + ) + // test concurrency of set rule with different id + suite.checkConcurrencyWith(cluster, + func(i int) []placement.GroupBundle { + return []placement.GroupBundle{ + { + ID: "pd", + Index: 1, + Rules: []*placement.Rule{ + { + ID: strconv.Itoa(i), Index: i, Role: placement.Voter, Count: 1, GroupID: "pd", + }, + }, + }, + } + }, + func(resp []placement.GroupBundle, i int) bool { + return len(resp) == 1 && resp[0].ID == "pd" && resp[0].Rules[0].ID == strconv.Itoa(i) + }, + ) +} + +func (suite *ruleTestSuite) checkConcurrencyWith(cluster *tests.TestCluster, + genBundle func(int) []placement.GroupBundle, + checkBundle func([]placement.GroupBundle, int) bool) { + re := suite.Require() + leaderServer := cluster.GetLeaderServer() + pdAddr := leaderServer.GetAddr() + urlPrefix := fmt.Sprintf("%s%s/api/v1", pdAddr, apiPrefix) + expectResult := struct { + syncutil.RWMutex + val int + }{} + wg := sync.WaitGroup{} + + for i := 1; i <= 10; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + bundle := genBundle(i) + data, err := json.Marshal(bundle) + suite.NoError(err) + for j := 0; j < 10; j++ { + expectResult.Lock() + err = tu.CheckPostJSON(testDialClient, urlPrefix+"/config/placement-rule", data, tu.StatusOK(re)) + suite.NoError(err) + expectResult.val = i + expectResult.Unlock() + } + }(i) + } + + wg.Wait() + expectResult.RLock() + defer expectResult.RUnlock() + suite.NotZero(expectResult.val) + tu.Eventually(re, func() bool { + respBundle := make([]placement.GroupBundle, 0) + err := tu.CheckGetJSON(testDialClient, urlPrefix+"/config/placement-rule", nil, + tu.StatusOK(re), tu.ExtractJSON(re, &respBundle)) + suite.NoError(err) + suite.Len(respBundle, 1) + return checkBundle(respBundle, expectResult.val) + }) +} + +func (suite *ruleTestSuite) assertBundleEqual(b1, b2 placement.GroupBundle) { + tu.Eventually(suite.Require(), func() bool { + return suite.compareBundle(b1, b2) }) } +func (suite *ruleTestSuite) compareBundle(b1, b2 placement.GroupBundle) bool { + if b2.ID != b1.ID || b2.Index != b1.Index || b2.Override != b1.Override || len(b2.Rules) != len(b1.Rules) { + return false + } + sort.Slice(b1.Rules, func(i, j int) bool { return b1.Rules[i].ID < b1.Rules[j].ID }) + sort.Slice(b2.Rules, func(i, j int) bool { return b2.Rules[i].ID < b2.Rules[j].ID }) + for i := range b1.Rules { + if !suite.compareRule(b1.Rules[i], b2.Rules[i]) { + return false + } + } + return true +} + func (suite *ruleTestSuite) compareRule(r1 *placement.Rule, r2 *placement.Rule) bool { return r2.GroupID == r1.GroupID && r2.ID == r1.ID && From 48dbce1a7f4509c38b2920448a9f37c1963c8093 Mon Sep 17 00:00:00 2001 From: tongjian <1045931706@qq.com> Date: Wed, 20 Dec 2023 14:03:22 +0800 Subject: [PATCH 21/21] check: remove orphan peer only when the peers is greater than the rule count (#7581) close tikv/pd#7584 The healthy orphan peer should be the last one to be removed only if there are extra peers to keep the high availablility. Signed-off-by: bufferflies <1045931706@qq.com> Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- pkg/schedule/checker/rule_checker.go | 4 ++- pkg/schedule/checker/rule_checker_test.go | 36 +++++++++++++++++++++++ pkg/schedule/placement/fit.go | 9 ++++++ 3 files changed, 48 insertions(+), 1 deletion(-) diff --git a/pkg/schedule/checker/rule_checker.go b/pkg/schedule/checker/rule_checker.go index 95cc77ade5d..464f5e97be8 100644 --- a/pkg/schedule/checker/rule_checker.go +++ b/pkg/schedule/checker/rule_checker.go @@ -560,6 +560,7 @@ func (c *RuleChecker) fixOrphanPeers(region *core.RegionInfo, fit *placement.Reg } } + extra := fit.ExtraCount() // If hasUnhealthyFit is true, try to remove unhealthy orphan peers only if number of OrphanPeers is >= 2. // Ref https://github.com/tikv/pd/issues/4045 if len(fit.OrphanPeers) >= 2 { @@ -576,7 +577,8 @@ func (c *RuleChecker) fixOrphanPeers(region *core.RegionInfo, fit *placement.Reg ruleCheckerRemoveOrphanPeerCounter.Inc() return operator.CreateRemovePeerOperator("remove-unhealthy-orphan-peer", c.cluster, 0, region, orphanPeer.StoreId) } - if hasHealthPeer { + // The healthy orphan peer can be removed to keep the high availability only if the peer count is greater than the rule requirement. + if hasHealthPeer && extra > 0 { // there already exists a healthy orphan peer, so we can remove other orphan Peers. ruleCheckerRemoveOrphanPeerCounter.Inc() // if there exists a disconnected orphan peer, we will pick it to remove firstly. diff --git a/pkg/schedule/checker/rule_checker_test.go b/pkg/schedule/checker/rule_checker_test.go index 72d3e7e5ec4..6418483db7f 100644 --- a/pkg/schedule/checker/rule_checker_test.go +++ b/pkg/schedule/checker/rule_checker_test.go @@ -2029,3 +2029,39 @@ func (suite *ruleCheckerTestAdvancedSuite) TestReplaceAnExistingPeerCases() { suite.ruleManager.DeleteGroupBundle(groupName, false) } } + +func (suite *ruleCheckerTestSuite) TestRemoveOrphanPeer() { + suite.cluster.AddLabelsStore(1, 1, map[string]string{"zone": "z1", "host": "h1"}) + suite.cluster.AddLabelsStore(2, 1, map[string]string{"zone": "z1", "host": "h1"}) + suite.cluster.AddLabelsStore(3, 1, map[string]string{"zone": "z1", "host": "h1"}) + suite.cluster.AddLabelsStore(4, 1, map[string]string{"zone": "z2", "host": "h1"}) + suite.cluster.AddLabelsStore(5, 1, map[string]string{"zone": "z2", "host": "h2"}) + suite.cluster.AddLabelsStore(6, 1, map[string]string{"zone": "z2", "host": "h2"}) + rule := &placement.Rule{ + GroupID: "pd", + ID: "test2", + Role: placement.Voter, + Count: 3, + LabelConstraints: []placement.LabelConstraint{ + { + Key: "zone", + Op: placement.In, + Values: []string{"z2"}, + }, + }, + } + suite.ruleManager.SetRule(rule) + suite.ruleManager.DeleteRule("pd", "default") + + // case1: regionA has 3 peers but not extra peer can be removed, so it needs to add peer first + suite.cluster.AddLeaderRegionWithRange(1, "200", "300", 1, 2, 3) + op := suite.rc.Check(suite.cluster.GetRegion(1)) + suite.NotNil(op) + suite.Equal("add-rule-peer", op.Desc()) + + // case2: regionB has 4 peers and one extra peer can be removed, so it needs to remove extra peer first + suite.cluster.AddLeaderRegionWithRange(2, "300", "400", 1, 2, 3, 4) + op = suite.rc.Check(suite.cluster.GetRegion(2)) + suite.NotNil(op) + suite.Equal("remove-orphan-peer", op.Desc()) +} diff --git a/pkg/schedule/placement/fit.go b/pkg/schedule/placement/fit.go index 45afc5bcfa3..d907bcd011a 100644 --- a/pkg/schedule/placement/fit.go +++ b/pkg/schedule/placement/fit.go @@ -93,6 +93,15 @@ func (f *RegionFit) IsSatisfied() bool { return len(f.OrphanPeers) == 0 } +// ExtraCount return the extra count. +func (f *RegionFit) ExtraCount() int { + desired := 0 + for _, r := range f.RuleFits { + desired += r.Rule.Count + } + return len(f.regionStores) - desired +} + // GetRuleFit returns the RuleFit that contains the peer. func (f *RegionFit) GetRuleFit(peerID uint64) *RuleFit { for _, rf := range f.RuleFits {