diff --git a/client/errs/errno.go b/client/errs/errno.go index 646af81929d..0f93ebf1472 100644 --- a/client/errs/errno.go +++ b/client/errs/errno.go @@ -54,7 +54,7 @@ var ( ErrClientGetMultiResponse = errors.Normalize("get invalid value response %v, must only one", errors.RFCCodeText("PD:client:ErrClientGetMultiResponse")) ErrClientGetServingEndpoint = errors.Normalize("get serving endpoint failed", errors.RFCCodeText("PD:client:ErrClientGetServingEndpoint")) ErrClientFindGroupByKeyspaceID = errors.Normalize("can't find keyspace group by keyspace id", errors.RFCCodeText("PD:client:ErrClientFindGroupByKeyspaceID")) - ErrClientWatchGCSafePointV2Stream = errors.Normalize("watch gc safe point v2 stream failed, %s", errors.RFCCodeText("PD:client:ErrClientWatchGCSafePointV2Stream")) + ErrClientWatchGCSafePointV2Stream = errors.Normalize("watch gc safe point v2 stream failed", errors.RFCCodeText("PD:client:ErrClientWatchGCSafePointV2Stream")) ) // grpcutil errors diff --git a/client/errs/errs.go b/client/errs/errs.go index e715056b055..47f7c29a467 100644 --- a/client/errs/errs.go +++ b/client/errs/errs.go @@ -27,7 +27,7 @@ func ZapError(err error, causeError ...error) zap.Field { } if e, ok := err.(*errors.Error); ok { if len(causeError) >= 1 { - err = e.Wrap(causeError[0]).FastGenWithCause() + err = e.Wrap(causeError[0]) } else { err = e.FastGenByArgs() } diff --git a/client/http/api.go b/client/http/api.go index f744fd0c395..2153cd286e8 100644 --- a/client/http/api.go +++ b/client/http/api.go @@ -38,6 +38,9 @@ const ( store = "/pd/api/v1/store" Stores = "/pd/api/v1/stores" StatsRegion = "/pd/api/v1/stats/region" + membersPrefix = "/pd/api/v1/members" + leaderPrefix = "/pd/api/v1/leader" + transferLeader = "/pd/api/v1/leader/transfer" // Config Config = "/pd/api/v1/config" ClusterVersion = "/pd/api/v1/config/cluster-version" @@ -124,6 +127,16 @@ func StoreLabelByID(id uint64) string { return fmt.Sprintf("%s/%d/label", store, id) } +// LabelByStoreID returns the path of PD HTTP API to set store label. +func LabelByStoreID(storeID int64) string { + return fmt.Sprintf("%s/%d/label", store, storeID) +} + +// TransferLeaderByID returns the path of PD HTTP API to transfer leader by ID. +func TransferLeaderByID(leaderID string) string { + return fmt.Sprintf("%s/%s", transferLeader, leaderID) +} + // ConfigWithTTLSeconds returns the config API with the TTL seconds parameter. func ConfigWithTTLSeconds(ttlSeconds float64) string { return fmt.Sprintf("%s?ttlSecond=%.0f", Config, ttlSeconds) diff --git a/client/http/client.go b/client/http/client.go index b79aa9ca002..d74c77571d6 100644 --- a/client/http/client.go +++ b/client/http/client.go @@ -26,6 +26,7 @@ import ( "time" "github.com/pingcap/errors" + "github.com/pingcap/kvproto/pkg/pdpb" "github.com/pingcap/log" "github.com/prometheus/client_golang/prometheus" "go.uber.org/zap" @@ -54,9 +55,19 @@ type Client interface { GetHistoryHotRegions(context.Context, *HistoryHotRegionsRequest) (*HistoryHotRegions, error) GetRegionStatusByKeyRange(context.Context, *KeyRange, bool) (*RegionStats, error) GetStores(context.Context) (*StoresInfo, error) + GetStore(context.Context, uint64) (*StoreInfo, error) + SetStoreLabels(context.Context, int64, map[string]string) error + GetMembers(context.Context) (*MembersInfo, error) + GetLeader(context.Context) (*pdpb.Member, error) + TransferLeader(context.Context, string) error /* Config-related interfaces */ GetScheduleConfig(context.Context) (map[string]interface{}, error) SetScheduleConfig(context.Context, map[string]interface{}) error + GetClusterVersion(context.Context) (string, error) + /* Scheduler-related interfaces */ + GetSchedulers(context.Context) ([]string, error) + CreateScheduler(ctx context.Context, name string, storeID uint64) error + SetSchedulerDelay(context.Context, string, int64) error /* Rule-related interfaces */ GetAllPlacementRuleBundles(context.Context) ([]*GroupBundle, error) GetPlacementRuleBundleByGroup(context.Context, string) (*GroupBundle, error) @@ -221,7 +232,7 @@ func (c *client) execDuration(name string, duration time.Duration) { // Header key definition constants. const ( pdAllowFollowerHandleKey = "PD-Allow-Follower-Handle" - componentSignatureKey = "component" + xCallerIDKey = "X-Caller-ID" ) // HeaderOption configures the HTTP header. @@ -239,7 +250,7 @@ func WithAllowFollowerHandle() HeaderOption { func (c *client) requestWithRetry( ctx context.Context, name, uri, method string, - body io.Reader, res interface{}, + body []byte, res interface{}, headerOpts ...HeaderOption, ) error { var ( @@ -261,7 +272,7 @@ func (c *client) requestWithRetry( func (c *client) request( ctx context.Context, name, url, method string, - body io.Reader, res interface{}, + body []byte, res interface{}, headerOpts ...HeaderOption, ) error { logFields := []zap.Field{ @@ -271,7 +282,7 @@ func (c *client) request( zap.String("caller-id", c.callerID), } log.Debug("[pd] request the http url", logFields...) - req, err := http.NewRequestWithContext(ctx, method, url, body) + req, err := http.NewRequestWithContext(ctx, method, url, bytes.NewBuffer(body)) if err != nil { log.Error("[pd] create http request failed", append(logFields, zap.Error(err))...) return errors.Trace(err) @@ -279,7 +290,7 @@ func (c *client) request( for _, opt := range headerOpts { opt(req.Header) } - req.Header.Set(componentSignatureKey, c.callerID) + req.Header.Set(xCallerIDKey, c.callerID) start := time.Now() resp, err := c.inner.cli.Do(req) @@ -333,7 +344,7 @@ func (c *client) GetRegionByID(ctx context.Context, regionID uint64) (*RegionInf var region RegionInfo err := c.requestWithRetry(ctx, "GetRegionByID", RegionByID(regionID), - http.MethodGet, http.NoBody, ®ion) + http.MethodGet, nil, ®ion) if err != nil { return nil, err } @@ -345,7 +356,7 @@ func (c *client) GetRegionByKey(ctx context.Context, key []byte) (*RegionInfo, e var region RegionInfo err := c.requestWithRetry(ctx, "GetRegionByKey", RegionByKey(key), - http.MethodGet, http.NoBody, ®ion) + http.MethodGet, nil, ®ion) if err != nil { return nil, err } @@ -357,7 +368,7 @@ func (c *client) GetRegions(ctx context.Context) (*RegionsInfo, error) { var regions RegionsInfo err := c.requestWithRetry(ctx, "GetRegions", Regions, - http.MethodGet, http.NoBody, ®ions) + http.MethodGet, nil, ®ions) if err != nil { return nil, err } @@ -370,7 +381,7 @@ func (c *client) GetRegionsByKeyRange(ctx context.Context, keyRange *KeyRange, l var regions RegionsInfo err := c.requestWithRetry(ctx, "GetRegionsByKeyRange", RegionsByKeyRange(keyRange, limit), - http.MethodGet, http.NoBody, ®ions) + http.MethodGet, nil, ®ions) if err != nil { return nil, err } @@ -382,7 +393,7 @@ func (c *client) GetRegionsByStoreID(ctx context.Context, storeID uint64) (*Regi var regions RegionsInfo err := c.requestWithRetry(ctx, "GetRegionsByStoreID", RegionsByStoreID(storeID), - http.MethodGet, http.NoBody, ®ions) + http.MethodGet, nil, ®ions) if err != nil { return nil, err } @@ -395,7 +406,7 @@ func (c *client) GetRegionsReplicatedStateByKeyRange(ctx context.Context, keyRan var state string err := c.requestWithRetry(ctx, "GetRegionsReplicatedStateByKeyRange", RegionsReplicatedByKeyRange(keyRange), - http.MethodGet, http.NoBody, &state) + http.MethodGet, nil, &state) if err != nil { return "", err } @@ -407,7 +418,7 @@ func (c *client) GetHotReadRegions(ctx context.Context) (*StoreHotPeersInfos, er var hotReadRegions StoreHotPeersInfos err := c.requestWithRetry(ctx, "GetHotReadRegions", HotRead, - http.MethodGet, http.NoBody, &hotReadRegions) + http.MethodGet, nil, &hotReadRegions) if err != nil { return nil, err } @@ -419,7 +430,7 @@ func (c *client) GetHotWriteRegions(ctx context.Context) (*StoreHotPeersInfos, e var hotWriteRegions StoreHotPeersInfos err := c.requestWithRetry(ctx, "GetHotWriteRegions", HotWrite, - http.MethodGet, http.NoBody, &hotWriteRegions) + http.MethodGet, nil, &hotWriteRegions) if err != nil { return nil, err } @@ -435,7 +446,7 @@ func (c *client) GetHistoryHotRegions(ctx context.Context, req *HistoryHotRegion var historyHotRegions HistoryHotRegions err = c.requestWithRetry(ctx, "GetHistoryHotRegions", HotHistory, - http.MethodGet, bytes.NewBuffer(reqJSON), &historyHotRegions, + http.MethodGet, reqJSON, &historyHotRegions, WithAllowFollowerHandle()) if err != nil { return nil, err @@ -450,7 +461,7 @@ func (c *client) GetRegionStatusByKeyRange(ctx context.Context, keyRange *KeyRan var regionStats RegionStats err := c.requestWithRetry(ctx, "GetRegionStatusByKeyRange", RegionStatsByKeyRange(keyRange, onlyCount), - http.MethodGet, http.NoBody, ®ionStats, + http.MethodGet, nil, ®ionStats, ) if err != nil { return nil, err @@ -458,12 +469,50 @@ func (c *client) GetRegionStatusByKeyRange(ctx context.Context, keyRange *KeyRan return ®ionStats, nil } +// SetStoreLabels sets the labels of a store. +func (c *client) SetStoreLabels(ctx context.Context, storeID int64, storeLabels map[string]string) error { + jsonInput, err := json.Marshal(storeLabels) + if err != nil { + return errors.Trace(err) + } + return c.requestWithRetry(ctx, "SetStoreLabel", LabelByStoreID(storeID), + http.MethodPost, jsonInput, nil) +} + +func (c *client) GetMembers(ctx context.Context) (*MembersInfo, error) { + var members MembersInfo + err := c.requestWithRetry(ctx, + "GetMembers", membersPrefix, + http.MethodGet, nil, &members) + if err != nil { + return nil, err + } + return &members, nil +} + +// GetLeader gets the leader of PD cluster. +func (c *client) GetLeader(ctx context.Context) (*pdpb.Member, error) { + var leader pdpb.Member + err := c.requestWithRetry(ctx, "GetLeader", leaderPrefix, + http.MethodGet, nil, &leader) + if err != nil { + return nil, err + } + return &leader, nil +} + +// TransferLeader transfers the PD leader. +func (c *client) TransferLeader(ctx context.Context, newLeader string) error { + return c.requestWithRetry(ctx, "TransferLeader", TransferLeaderByID(newLeader), + http.MethodPost, nil, nil) +} + // GetScheduleConfig gets the schedule configurations. func (c *client) GetScheduleConfig(ctx context.Context) (map[string]interface{}, error) { var config map[string]interface{} err := c.requestWithRetry(ctx, "GetScheduleConfig", ScheduleConfig, - http.MethodGet, http.NoBody, &config) + http.MethodGet, nil, &config) if err != nil { return nil, err } @@ -478,7 +527,7 @@ func (c *client) SetScheduleConfig(ctx context.Context, config map[string]interf } return c.requestWithRetry(ctx, "SetScheduleConfig", ScheduleConfig, - http.MethodPost, bytes.NewBuffer(configJSON), nil) + http.MethodPost, configJSON, nil) } // GetStores gets the stores info. @@ -486,19 +535,43 @@ func (c *client) GetStores(ctx context.Context) (*StoresInfo, error) { var stores StoresInfo err := c.requestWithRetry(ctx, "GetStores", Stores, - http.MethodGet, http.NoBody, &stores) + http.MethodGet, nil, &stores) if err != nil { return nil, err } return &stores, nil } +// GetStore gets the store info by ID. +func (c *client) GetStore(ctx context.Context, storeID uint64) (*StoreInfo, error) { + var store StoreInfo + err := c.requestWithRetry(ctx, + "GetStore", StoreByID(storeID), + http.MethodGet, nil, &store) + if err != nil { + return nil, err + } + return &store, nil +} + +// GetClusterVersion gets the cluster version. +func (c *client) GetClusterVersion(ctx context.Context) (string, error) { + var version string + err := c.requestWithRetry(ctx, + "GetClusterVersion", ClusterVersion, + http.MethodGet, nil, &version) + if err != nil { + return "", err + } + return version, nil +} + // GetAllPlacementRuleBundles gets all placement rules bundles. func (c *client) GetAllPlacementRuleBundles(ctx context.Context) ([]*GroupBundle, error) { var bundles []*GroupBundle err := c.requestWithRetry(ctx, "GetPlacementRuleBundle", PlacementRuleBundle, - http.MethodGet, http.NoBody, &bundles) + http.MethodGet, nil, &bundles) if err != nil { return nil, err } @@ -510,7 +583,7 @@ func (c *client) GetPlacementRuleBundleByGroup(ctx context.Context, group string var bundle GroupBundle err := c.requestWithRetry(ctx, "GetPlacementRuleBundleByGroup", PlacementRuleBundleByGroup(group), - http.MethodGet, http.NoBody, &bundle) + http.MethodGet, nil, &bundle) if err != nil { return nil, err } @@ -522,7 +595,7 @@ func (c *client) GetPlacementRulesByGroup(ctx context.Context, group string) ([] var rules []*Rule err := c.requestWithRetry(ctx, "GetPlacementRulesByGroup", PlacementRulesByGroup(group), - http.MethodGet, http.NoBody, &rules) + http.MethodGet, nil, &rules) if err != nil { return nil, err } @@ -537,7 +610,7 @@ func (c *client) SetPlacementRule(ctx context.Context, rule *Rule) error { } return c.requestWithRetry(ctx, "SetPlacementRule", PlacementRule, - http.MethodPost, bytes.NewBuffer(ruleJSON), nil) + http.MethodPost, ruleJSON, nil) } // SetPlacementRuleInBatch sets the placement rules in batch. @@ -548,7 +621,7 @@ func (c *client) SetPlacementRuleInBatch(ctx context.Context, ruleOps []*RuleOp) } return c.requestWithRetry(ctx, "SetPlacementRuleInBatch", PlacementRulesInBatch, - http.MethodPost, bytes.NewBuffer(ruleOpsJSON), nil) + http.MethodPost, ruleOpsJSON, nil) } // SetPlacementRuleBundles sets the placement rule bundles. @@ -560,14 +633,14 @@ func (c *client) SetPlacementRuleBundles(ctx context.Context, bundles []*GroupBu } return c.requestWithRetry(ctx, "SetPlacementRuleBundles", PlacementRuleBundleWithPartialParameter(partial), - http.MethodPost, bytes.NewBuffer(bundlesJSON), nil) + http.MethodPost, bundlesJSON, nil) } // DeletePlacementRule deletes the placement rule. func (c *client) DeletePlacementRule(ctx context.Context, group, id string) error { return c.requestWithRetry(ctx, "DeletePlacementRule", PlacementRuleByGroupAndID(group, id), - http.MethodDelete, http.NoBody, nil) + http.MethodDelete, nil, nil) } // GetAllPlacementRuleGroups gets all placement rule groups. @@ -575,7 +648,7 @@ func (c *client) GetAllPlacementRuleGroups(ctx context.Context) ([]*RuleGroup, e var ruleGroups []*RuleGroup err := c.requestWithRetry(ctx, "GetAllPlacementRuleGroups", placementRuleGroups, - http.MethodGet, http.NoBody, &ruleGroups) + http.MethodGet, nil, &ruleGroups) if err != nil { return nil, err } @@ -587,7 +660,7 @@ func (c *client) GetPlacementRuleGroupByID(ctx context.Context, id string) (*Rul var ruleGroup RuleGroup err := c.requestWithRetry(ctx, "GetPlacementRuleGroupByID", PlacementRuleGroupByID(id), - http.MethodGet, http.NoBody, &ruleGroup) + http.MethodGet, nil, &ruleGroup) if err != nil { return nil, err } @@ -602,14 +675,14 @@ func (c *client) SetPlacementRuleGroup(ctx context.Context, ruleGroup *RuleGroup } return c.requestWithRetry(ctx, "SetPlacementRuleGroup", placementRuleGroup, - http.MethodPost, bytes.NewBuffer(ruleGroupJSON), nil) + http.MethodPost, ruleGroupJSON, nil) } // DeletePlacementRuleGroupByID deletes the placement rule group by ID. func (c *client) DeletePlacementRuleGroupByID(ctx context.Context, id string) error { return c.requestWithRetry(ctx, "DeletePlacementRuleGroupByID", PlacementRuleGroupByID(id), - http.MethodDelete, http.NoBody, nil) + http.MethodDelete, nil, nil) } // GetAllRegionLabelRules gets all region label rules. @@ -617,7 +690,7 @@ func (c *client) GetAllRegionLabelRules(ctx context.Context) ([]*LabelRule, erro var labelRules []*LabelRule err := c.requestWithRetry(ctx, "GetAllRegionLabelRules", RegionLabelRules, - http.MethodGet, http.NoBody, &labelRules) + http.MethodGet, nil, &labelRules) if err != nil { return nil, err } @@ -633,7 +706,7 @@ func (c *client) GetRegionLabelRulesByIDs(ctx context.Context, ruleIDs []string) var labelRules []*LabelRule err = c.requestWithRetry(ctx, "GetRegionLabelRulesByIDs", RegionLabelRulesByIDs, - http.MethodGet, bytes.NewBuffer(idsJSON), &labelRules) + http.MethodGet, idsJSON, &labelRules) if err != nil { return nil, err } @@ -648,7 +721,7 @@ func (c *client) SetRegionLabelRule(ctx context.Context, labelRule *LabelRule) e } return c.requestWithRetry(ctx, "SetRegionLabelRule", RegionLabelRule, - http.MethodPost, bytes.NewBuffer(labelRuleJSON), nil) + http.MethodPost, labelRuleJSON, nil) } // PatchRegionLabelRules patches the region label rules. @@ -659,7 +732,32 @@ func (c *client) PatchRegionLabelRules(ctx context.Context, labelRulePatch *Labe } return c.requestWithRetry(ctx, "PatchRegionLabelRules", RegionLabelRules, - http.MethodPatch, bytes.NewBuffer(labelRulePatchJSON), nil) + http.MethodPatch, labelRulePatchJSON, nil) +} + +// GetSchedulers gets the schedulers from PD cluster. +func (c *client) GetSchedulers(ctx context.Context) ([]string, error) { + var schedulers []string + err := c.requestWithRetry(ctx, "GetSchedulers", Schedulers, + http.MethodGet, nil, &schedulers) + if err != nil { + return nil, err + } + return schedulers, nil +} + +// CreateScheduler creates a scheduler to PD cluster. +func (c *client) CreateScheduler(ctx context.Context, name string, storeID uint64) error { + inputJSON, err := json.Marshal(map[string]interface{}{ + "name": name, + "store_id": storeID, + }) + if err != nil { + return errors.Trace(err) + } + return c.requestWithRetry(ctx, + "CreateScheduler", Schedulers, + http.MethodPost, inputJSON, nil) } // AccelerateSchedule accelerates the scheduling of the regions within the given key range. @@ -675,7 +773,7 @@ func (c *client) AccelerateSchedule(ctx context.Context, keyRange *KeyRange) err } return c.requestWithRetry(ctx, "AccelerateSchedule", AccelerateSchedule, - http.MethodPost, bytes.NewBuffer(inputJSON), nil) + http.MethodPost, inputJSON, nil) } // AccelerateScheduleInBatch accelerates the scheduling of the regions within the given key ranges in batch. @@ -695,10 +793,27 @@ func (c *client) AccelerateScheduleInBatch(ctx context.Context, keyRanges []*Key } return c.requestWithRetry(ctx, "AccelerateScheduleInBatch", AccelerateScheduleInBatch, - http.MethodPost, bytes.NewBuffer(inputJSON), nil) + http.MethodPost, inputJSON, nil) +} + +// SetSchedulerDelay sets the delay of given scheduler. +func (c *client) SetSchedulerDelay(ctx context.Context, scheduler string, delaySec int64) error { + m := map[string]int64{ + "delay": delaySec, + } + inputJSON, err := json.Marshal(m) + if err != nil { + return errors.Trace(err) + } + return c.requestWithRetry(ctx, + "SetSchedulerDelay", SchedulerByName(scheduler), + http.MethodPost, inputJSON, nil) } // GetMinResolvedTSByStoresIDs get min-resolved-ts by stores IDs. +// - When storeIDs has zero length, it will return (cluster-level's min_resolved_ts, nil, nil) when no error. +// - When storeIDs is {"cluster"}, it will return (cluster-level's min_resolved_ts, stores_min_resolved_ts, nil) when no error. +// - When storeID is specified to ID lists, it will return (min_resolved_ts of given stores, stores_min_resolved_ts, nil) when no error. func (c *client) GetMinResolvedTSByStoresIDs(ctx context.Context, storeIDs []uint64) (uint64, map[uint64]uint64, error) { uri := MinResolvedTSPrefix // scope is an optional parameter, it can be `cluster` or specified store IDs. @@ -720,7 +835,7 @@ func (c *client) GetMinResolvedTSByStoresIDs(ctx context.Context, storeIDs []uin }{} err := c.requestWithRetry(ctx, "GetMinResolvedTSByStoresIDs", uri, - http.MethodGet, http.NoBody, &resp) + http.MethodGet, nil, &resp) if err != nil { return 0, nil, err } diff --git a/client/http/client_test.go b/client/http/client_test.go index 621910e29ea..70c2ddee08b 100644 --- a/client/http/client_test.go +++ b/client/http/client_test.go @@ -59,7 +59,7 @@ func TestCallerID(t *testing.T) { re := require.New(t) expectedVal := defaultCallerID httpClient := newHTTPClientWithRequestChecker(func(req *http.Request) error { - val := req.Header.Get(componentSignatureKey) + val := req.Header.Get(xCallerIDKey) if val != expectedVal { re.Failf("Caller ID header check failed", "should be %s, but got %s", expectedVal, val) diff --git a/client/http/types.go b/client/http/types.go index 1d8db36d100..b05e8e0efba 100644 --- a/client/http/types.go +++ b/client/http/types.go @@ -21,6 +21,7 @@ import ( "time" "github.com/pingcap/kvproto/pkg/encryptionpb" + "github.com/pingcap/kvproto/pkg/pdpb" ) // KeyRange defines a range of keys in bytes. @@ -574,3 +575,12 @@ type LabelRulePatch struct { SetRules []*LabelRule `json:"sets"` DeleteRules []string `json:"deletes"` } + +// MembersInfo is PD members info returned from PD RESTful interface +// type Members map[string][]*pdpb.Member +type MembersInfo struct { + Header *pdpb.ResponseHeader `json:"header,omitempty"` + Members []*pdpb.Member `json:"members,omitempty"` + Leader *pdpb.Member `json:"leader,omitempty"` + EtcdLeader *pdpb.Member `json:"etcd_leader,omitempty"` +} diff --git a/client/tso_dispatcher.go b/client/tso_dispatcher.go index 0de4dc3a49e..6b2c33ca58d 100644 --- a/client/tso_dispatcher.go +++ b/client/tso_dispatcher.go @@ -412,7 +412,7 @@ tsoBatchLoop: } else { log.Error("[tso] fetch pending tso requests error", zap.String("dc-location", dc), - errs.ZapError(errs.ErrClientGetTSO.FastGenByArgs("when fetch pending tso requests"), err)) + errs.ZapError(errs.ErrClientGetTSO, err)) } return } @@ -495,10 +495,10 @@ tsoBatchLoop: default: } c.svcDiscovery.ScheduleCheckMemberChanged() - log.Error("[tso] getTS error", + log.Error("[tso] getTS error after processing requests", zap.String("dc-location", dc), zap.String("stream-addr", streamAddr), - errs.ZapError(errs.ErrClientGetTSO.FastGenByArgs("after processing requests"), err)) + errs.ZapError(errs.ErrClientGetTSO, err)) // Set `stream` to nil and remove this stream from the `connectionCtxs` due to error. connectionCtxs.Delete(streamAddr) cancel() diff --git a/errors.toml b/errors.toml index a318fc32492..69aa29d2dda 100644 --- a/errors.toml +++ b/errors.toml @@ -83,7 +83,7 @@ get min TSO failed, %v ["PD:client:ErrClientGetTSO"] error = ''' -get TSO failed, %v +get TSO failed ''' ["PD:client:ErrClientGetTSOTimeout"] diff --git a/pkg/audit/audit.go b/pkg/audit/audit.go index 3553e5f8377..b971b09ed7e 100644 --- a/pkg/audit/audit.go +++ b/pkg/audit/audit.go @@ -98,7 +98,7 @@ func (b *PrometheusHistogramBackend) ProcessHTTPRequest(req *http.Request) bool if !ok { return false } - b.histogramVec.WithLabelValues(requestInfo.ServiceLabel, "HTTP", requestInfo.Component, requestInfo.IP).Observe(float64(endTime - requestInfo.StartTimeStamp)) + b.histogramVec.WithLabelValues(requestInfo.ServiceLabel, "HTTP", requestInfo.CallerID, requestInfo.IP).Observe(float64(endTime - requestInfo.StartTimeStamp)) return true } diff --git a/pkg/audit/audit_test.go b/pkg/audit/audit_test.go index d59c9627115..8098b36975e 100644 --- a/pkg/audit/audit_test.go +++ b/pkg/audit/audit_test.go @@ -51,7 +51,7 @@ func TestPrometheusHistogramBackend(t *testing.T) { Name: "audit_handling_seconds_test", Help: "PD server service handling audit", Buckets: prometheus.DefBuckets, - }, []string{"service", "method", "component", "ip"}) + }, []string{"service", "method", "caller_id", "ip"}) prometheus.MustRegister(serviceAuditHistogramTest) @@ -62,7 +62,7 @@ func TestPrometheusHistogramBackend(t *testing.T) { req, _ := http.NewRequest(http.MethodGet, "http://127.0.0.1:2379/test?test=test", http.NoBody) info := requestutil.GetRequestInfo(req) info.ServiceLabel = "test" - info.Component = "user1" + info.CallerID = "user1" info.IP = "localhost" req = req.WithContext(requestutil.WithRequestInfo(req.Context(), info)) re.False(backend.ProcessHTTPRequest(req)) @@ -73,7 +73,7 @@ func TestPrometheusHistogramBackend(t *testing.T) { re.True(backend.ProcessHTTPRequest(req)) re.True(backend.ProcessHTTPRequest(req)) - info.Component = "user2" + info.CallerID = "user2" req = req.WithContext(requestutil.WithRequestInfo(req.Context(), info)) re.True(backend.ProcessHTTPRequest(req)) @@ -85,8 +85,8 @@ func TestPrometheusHistogramBackend(t *testing.T) { defer resp.Body.Close() content, _ := io.ReadAll(resp.Body) output := string(content) - re.Contains(output, "pd_service_audit_handling_seconds_test_count{component=\"user1\",ip=\"localhost\",method=\"HTTP\",service=\"test\"} 2") - re.Contains(output, "pd_service_audit_handling_seconds_test_count{component=\"user2\",ip=\"localhost\",method=\"HTTP\",service=\"test\"} 1") + re.Contains(output, "pd_service_audit_handling_seconds_test_count{caller_id=\"user1\",ip=\"localhost\",method=\"HTTP\",service=\"test\"} 2") + re.Contains(output, "pd_service_audit_handling_seconds_test_count{caller_id=\"user2\",ip=\"localhost\",method=\"HTTP\",service=\"test\"} 1") } func TestLocalLogBackendUsingFile(t *testing.T) { @@ -103,7 +103,7 @@ func TestLocalLogBackendUsingFile(t *testing.T) { b, _ := os.ReadFile(fname) output := strings.SplitN(string(b), "]", 4) re.Equal( - fmt.Sprintf(" [\"audit log\"] [service-info=\"{ServiceLabel:, Method:HTTP/1.1/GET:/test, Component:anonymous, IP:, Port:, "+ + fmt.Sprintf(" [\"audit log\"] [service-info=\"{ServiceLabel:, Method:HTTP/1.1/GET:/test, CallerID:anonymous, IP:, Port:, "+ "StartTime:%s, URLParam:{\\\"test\\\":[\\\"test\\\"]}, BodyParam:testBody}\"]\n", time.Unix(info.StartTimeStamp, 0).String()), output[3], diff --git a/pkg/autoscaling/prometheus.go b/pkg/autoscaling/prometheus.go index 43ba768a585..91b813b6ef2 100644 --- a/pkg/autoscaling/prometheus.go +++ b/pkg/autoscaling/prometheus.go @@ -94,7 +94,7 @@ func (prom *PrometheusQuerier) queryMetricsFromPrometheus(query string, timestam resp, warnings, err := prom.api.Query(ctx, query, timestamp) if err != nil { - return nil, errs.ErrPrometheusQuery.Wrap(err).FastGenWithCause() + return nil, errs.ErrPrometheusQuery.Wrap(err) } if len(warnings) > 0 { diff --git a/pkg/errs/errno.go b/pkg/errs/errno.go index a4320238374..03fa0f61158 100644 --- a/pkg/errs/errno.go +++ b/pkg/errs/errno.go @@ -86,7 +86,7 @@ var ( var ( ErrClientCreateTSOStream = errors.Normalize("create TSO stream failed, %s", errors.RFCCodeText("PD:client:ErrClientCreateTSOStream")) ErrClientGetTSOTimeout = errors.Normalize("get TSO timeout", errors.RFCCodeText("PD:client:ErrClientGetTSOTimeout")) - ErrClientGetTSO = errors.Normalize("get TSO failed, %v", errors.RFCCodeText("PD:client:ErrClientGetTSO")) + ErrClientGetTSO = errors.Normalize("get TSO failed", errors.RFCCodeText("PD:client:ErrClientGetTSO")) ErrClientGetLeader = errors.Normalize("get leader failed, %v", errors.RFCCodeText("PD:client:ErrClientGetLeader")) ErrClientGetMember = errors.Normalize("get member failed", errors.RFCCodeText("PD:client:ErrClientGetMember")) ErrClientGetMinTSO = errors.Normalize("get min TSO failed, %v", errors.RFCCodeText("PD:client:ErrClientGetMinTSO")) diff --git a/pkg/errs/errs.go b/pkg/errs/errs.go index acc42637733..5746b282f10 100644 --- a/pkg/errs/errs.go +++ b/pkg/errs/errs.go @@ -27,7 +27,7 @@ func ZapError(err error, causeError ...error) zap.Field { } if e, ok := err.(*errors.Error); ok { if len(causeError) >= 1 { - err = e.Wrap(causeError[0]).FastGenWithCause() + err = e.Wrap(causeError[0]) } else { err = e.FastGenByArgs() } diff --git a/pkg/errs/errs_test.go b/pkg/errs/errs_test.go index c242dd994f5..d76c02dc110 100644 --- a/pkg/errs/errs_test.go +++ b/pkg/errs/errs_test.go @@ -81,9 +81,19 @@ func TestError(t *testing.T) { log.Error("test", zap.Error(ErrEtcdLeaderNotFound.FastGenByArgs())) re.Contains(lg.Message(), rfc) err := errors.New("test error") - log.Error("test", ZapError(ErrEtcdLeaderNotFound, err)) - rfc = `[error="[PD:member:ErrEtcdLeaderNotFound]test error` - re.Contains(lg.Message(), rfc) + // use Info() because of no stack for comparing. + log.Info("test", ZapError(ErrEtcdLeaderNotFound, err)) + rfc = `[error="[PD:member:ErrEtcdLeaderNotFound]etcd leader not found: test error` + m1 := lg.Message() + re.Contains(m1, rfc) + log.Info("test", zap.Error(ErrEtcdLeaderNotFound.Wrap(err))) + m2 := lg.Message() + idx1 := strings.Index(m1, "[error") + idx2 := strings.Index(m2, "[error") + re.Equal(m1[idx1:], m2[idx2:]) + log.Info("test", zap.Error(ErrEtcdLeaderNotFound.Wrap(err).FastGenWithCause())) + m3 := lg.Message() + re.NotContains(m3, rfc) } func TestErrorEqual(t *testing.T) { @@ -94,24 +104,24 @@ func TestErrorEqual(t *testing.T) { re.True(errors.ErrorEqual(err1, err2)) err := errors.New("test") - err1 = ErrSchedulerNotFound.Wrap(err).FastGenWithCause() - err2 = ErrSchedulerNotFound.Wrap(err).FastGenWithCause() + err1 = ErrSchedulerNotFound.Wrap(err) + err2 = ErrSchedulerNotFound.Wrap(err) re.True(errors.ErrorEqual(err1, err2)) err1 = ErrSchedulerNotFound.FastGenByArgs() - err2 = ErrSchedulerNotFound.Wrap(err).FastGenWithCause() + err2 = ErrSchedulerNotFound.Wrap(err) re.False(errors.ErrorEqual(err1, err2)) err3 := errors.New("test") err4 := errors.New("test") - err1 = ErrSchedulerNotFound.Wrap(err3).FastGenWithCause() - err2 = ErrSchedulerNotFound.Wrap(err4).FastGenWithCause() + err1 = ErrSchedulerNotFound.Wrap(err3) + err2 = ErrSchedulerNotFound.Wrap(err4) re.True(errors.ErrorEqual(err1, err2)) err3 = errors.New("test1") err4 = errors.New("test") - err1 = ErrSchedulerNotFound.Wrap(err3).FastGenWithCause() - err2 = ErrSchedulerNotFound.Wrap(err4).FastGenWithCause() + err1 = ErrSchedulerNotFound.Wrap(err3) + err2 = ErrSchedulerNotFound.Wrap(err4) re.False(errors.ErrorEqual(err1, err2)) } @@ -135,11 +145,16 @@ func TestErrorWithStack(t *testing.T) { m1 := lg.Message() log.Error("test", zap.Error(errors.WithStack(err))) m2 := lg.Message() + log.Error("test", ZapError(ErrStrconvParseInt.GenWithStackByCause(), err)) + m3 := lg.Message() // This test is based on line number and the first log is in line 141, the second is in line 142. // So they have the same length stack. Move this test to another place need to change the corresponding length. idx1 := strings.Index(m1, "[stack=") re.GreaterOrEqual(idx1, -1) idx2 := strings.Index(m2, "[stack=") re.GreaterOrEqual(idx2, -1) + idx3 := strings.Index(m3, "[stack=") + re.GreaterOrEqual(idx3, -1) re.Len(m2[idx2:], len(m1[idx1:])) + re.Len(m3[idx3:], len(m1[idx1:])) } diff --git a/pkg/mcs/scheduling/server/server.go b/pkg/mcs/scheduling/server/server.go index a43cbbebd86..8ee8b81ae47 100644 --- a/pkg/mcs/scheduling/server/server.go +++ b/pkg/mcs/scheduling/server/server.go @@ -192,6 +192,10 @@ func (s *Server) updateAPIServerMemberLoop() { continue } for _, ep := range members.Members { + if len(ep.GetClientURLs()) == 0 { // This member is not started yet. + log.Info("member is not started yet", zap.String("member-id", fmt.Sprintf("%x", ep.GetID())), errs.ZapError(err)) + continue + } status, err := s.GetClient().Status(ctx, ep.ClientURLs[0]) if err != nil { log.Info("failed to get status of member", zap.String("member-id", fmt.Sprintf("%x", ep.ID)), zap.String("endpoint", ep.ClientURLs[0]), errs.ZapError(err)) diff --git a/pkg/ratelimit/controller.go b/pkg/ratelimit/controller.go new file mode 100644 index 00000000000..0c95be9b11b --- /dev/null +++ b/pkg/ratelimit/controller.go @@ -0,0 +1,79 @@ +// Copyright 2023 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ratelimit + +import ( + "sync" + + "golang.org/x/time/rate" +) + +var emptyFunc = func() {} + +// Controller is a controller which holds multiple limiters to manage the request rate of different objects. +type Controller struct { + limiters sync.Map + // the label which is in labelAllowList won't be limited, and only inited by hard code. + labelAllowList map[string]struct{} +} + +// NewController returns a global limiter which can be updated in the later. +func NewController() *Controller { + return &Controller{ + labelAllowList: make(map[string]struct{}), + } +} + +// Allow is used to check whether it has enough token. +func (l *Controller) Allow(label string) (DoneFunc, error) { + var ok bool + lim, ok := l.limiters.Load(label) + if ok { + return lim.(*limiter).allow() + } + return emptyFunc, nil +} + +// Update is used to update Ratelimiter with Options +func (l *Controller) Update(label string, opts ...Option) UpdateStatus { + var status UpdateStatus + for _, opt := range opts { + status |= opt(label, l) + } + return status +} + +// GetQPSLimiterStatus returns the status of a given label's QPS limiter. +func (l *Controller) GetQPSLimiterStatus(label string) (limit rate.Limit, burst int) { + if limit, exist := l.limiters.Load(label); exist { + return limit.(*limiter).getQPSLimiterStatus() + } + return 0, 0 +} + +// GetConcurrencyLimiterStatus returns the status of a given label's concurrency limiter. +func (l *Controller) GetConcurrencyLimiterStatus(label string) (limit uint64, current uint64) { + if limit, exist := l.limiters.Load(label); exist { + return limit.(*limiter).getConcurrencyLimiterStatus() + } + return 0, 0 +} + +// IsInAllowList returns whether this label is in allow list. +// If returns true, the given label won't be limited +func (l *Controller) IsInAllowList(label string) bool { + _, allow := l.labelAllowList[label] + return allow +} diff --git a/pkg/ratelimit/controller_test.go b/pkg/ratelimit/controller_test.go new file mode 100644 index 00000000000..a830217cb9f --- /dev/null +++ b/pkg/ratelimit/controller_test.go @@ -0,0 +1,426 @@ +// Copyright 2023 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ratelimit + +import ( + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/tikv/pd/pkg/utils/syncutil" + "golang.org/x/time/rate" +) + +type changeAndResult struct { + opt Option + checkOptionStatus func(string, Option) + totalRequest int + success int + fail int + release int + waitDuration time.Duration + checkStatusFunc func(string) +} + +type labelCase struct { + label string + round []changeAndResult +} + +func runMulitLabelLimiter(t *testing.T, limiter *Controller, testCase []labelCase) { + re := require.New(t) + var caseWG sync.WaitGroup + for _, tempCas := range testCase { + caseWG.Add(1) + cas := tempCas + go func() { + var lock syncutil.Mutex + successCount, failedCount := 0, 0 + var wg sync.WaitGroup + r := &releaseUtil{} + for _, rd := range cas.round { + rd.checkOptionStatus(cas.label, rd.opt) + time.Sleep(rd.waitDuration) + for i := 0; i < rd.totalRequest; i++ { + wg.Add(1) + go func() { + countRateLimiterHandleResult(limiter, cas.label, &successCount, &failedCount, &lock, &wg, r) + }() + } + wg.Wait() + re.Equal(rd.fail, failedCount) + re.Equal(rd.success, successCount) + for i := 0; i < rd.release; i++ { + r.release() + } + rd.checkStatusFunc(cas.label) + failedCount -= rd.fail + successCount -= rd.success + } + caseWG.Done() + }() + } + caseWG.Wait() +} + +func TestControllerWithConcurrencyLimiter(t *testing.T) { + t.Parallel() + re := require.New(t) + limiter := NewController() + testCase := []labelCase{ + { + label: "test1", + round: []changeAndResult{ + { + opt: UpdateConcurrencyLimiter(10), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&ConcurrencyChanged != 0) + }, + totalRequest: 15, + fail: 5, + success: 10, + release: 10, + waitDuration: 0, + checkStatusFunc: func(label string) { + limit, current := limiter.GetConcurrencyLimiterStatus(label) + re.Equal(uint64(10), limit) + re.Equal(uint64(0), current) + }, + }, + { + opt: UpdateConcurrencyLimiter(10), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&ConcurrencyNoChange != 0) + }, + checkStatusFunc: func(label string) {}, + }, + { + opt: UpdateConcurrencyLimiter(5), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&ConcurrencyChanged != 0) + }, + totalRequest: 15, + fail: 10, + success: 5, + release: 5, + waitDuration: 0, + checkStatusFunc: func(label string) { + limit, current := limiter.GetConcurrencyLimiterStatus(label) + re.Equal(uint64(5), limit) + re.Equal(uint64(0), current) + }, + }, + { + opt: UpdateConcurrencyLimiter(0), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&ConcurrencyDeleted != 0) + }, + totalRequest: 15, + fail: 0, + success: 15, + release: 5, + waitDuration: 0, + checkStatusFunc: func(label string) { + limit, current := limiter.GetConcurrencyLimiterStatus(label) + re.Equal(uint64(0), limit) + re.Equal(uint64(0), current) + }, + }, + }, + }, + { + label: "test2", + round: []changeAndResult{ + { + opt: UpdateConcurrencyLimiter(15), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&ConcurrencyChanged != 0) + }, + totalRequest: 10, + fail: 0, + success: 10, + release: 0, + waitDuration: 0, + checkStatusFunc: func(label string) { + limit, current := limiter.GetConcurrencyLimiterStatus(label) + re.Equal(uint64(15), limit) + re.Equal(uint64(10), current) + }, + }, + { + opt: UpdateConcurrencyLimiter(10), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&ConcurrencyChanged != 0) + }, + totalRequest: 10, + fail: 10, + success: 0, + release: 10, + waitDuration: 0, + checkStatusFunc: func(label string) { + limit, current := limiter.GetConcurrencyLimiterStatus(label) + re.Equal(uint64(10), limit) + re.Equal(uint64(0), current) + }, + }, + }, + }, + } + runMulitLabelLimiter(t, limiter, testCase) +} + +func TestBlockList(t *testing.T) { + t.Parallel() + re := require.New(t) + opts := []Option{AddLabelAllowList()} + limiter := NewController() + label := "test" + + re.False(limiter.IsInAllowList(label)) + for _, opt := range opts { + opt(label, limiter) + } + re.True(limiter.IsInAllowList(label)) + + status := UpdateQPSLimiter(float64(rate.Every(time.Second)), 1)(label, limiter) + re.True(status&InAllowList != 0) + for i := 0; i < 10; i++ { + _, err := limiter.Allow(label) + re.NoError(err) + } +} + +func TestControllerWithQPSLimiter(t *testing.T) { + t.Parallel() + re := require.New(t) + limiter := NewController() + testCase := []labelCase{ + { + label: "test1", + round: []changeAndResult{ + { + opt: UpdateQPSLimiter(float64(rate.Every(time.Second)), 1), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&QPSChanged != 0) + }, + totalRequest: 3, + fail: 2, + success: 1, + waitDuration: 0, + checkStatusFunc: func(label string) { + limit, burst := limiter.GetQPSLimiterStatus(label) + re.Equal(rate.Limit(1), limit) + re.Equal(1, burst) + }, + }, + { + opt: UpdateQPSLimiter(float64(rate.Every(time.Second)), 1), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&QPSNoChange != 0) + }, + checkStatusFunc: func(label string) {}, + }, + { + opt: UpdateQPSLimiter(5, 5), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&QPSChanged != 0) + }, + totalRequest: 10, + fail: 5, + success: 5, + waitDuration: time.Second, + checkStatusFunc: func(label string) { + limit, burst := limiter.GetQPSLimiterStatus(label) + re.Equal(rate.Limit(5), limit) + re.Equal(5, burst) + }, + }, + { + opt: UpdateQPSLimiter(0, 0), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&QPSDeleted != 0) + }, + totalRequest: 10, + fail: 0, + success: 10, + release: 0, + waitDuration: 0, + checkStatusFunc: func(label string) { + limit, burst := limiter.GetQPSLimiterStatus(label) + re.Equal(rate.Limit(0), limit) + re.Equal(0, burst) + }, + }, + }, + }, + { + label: "test2", + round: []changeAndResult{ + { + opt: UpdateQPSLimiter(50, 5), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&QPSChanged != 0) + }, + totalRequest: 10, + fail: 5, + success: 5, + waitDuration: time.Second, + checkStatusFunc: func(label string) { + limit, burst := limiter.GetQPSLimiterStatus(label) + re.Equal(rate.Limit(50), limit) + re.Equal(5, burst) + }, + }, + { + opt: UpdateQPSLimiter(0, 0), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&QPSDeleted != 0) + }, + totalRequest: 10, + fail: 0, + success: 10, + release: 0, + waitDuration: 0, + checkStatusFunc: func(label string) { + limit, burst := limiter.GetQPSLimiterStatus(label) + re.Equal(rate.Limit(0), limit) + re.Equal(0, burst) + }, + }, + }, + }, + } + runMulitLabelLimiter(t, limiter, testCase) +} + +func TestControllerWithTwoLimiters(t *testing.T) { + t.Parallel() + re := require.New(t) + limiter := NewController() + testCase := []labelCase{ + { + label: "test1", + round: []changeAndResult{ + { + opt: UpdateDimensionConfig(&DimensionConfig{ + QPS: 100, + QPSBurst: 100, + ConcurrencyLimit: 100, + }), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&QPSChanged != 0) + }, + totalRequest: 200, + fail: 100, + success: 100, + release: 100, + waitDuration: time.Second, + checkStatusFunc: func(label string) { + limit, burst := limiter.GetQPSLimiterStatus(label) + re.Equal(rate.Limit(100), limit) + re.Equal(100, burst) + climit, current := limiter.GetConcurrencyLimiterStatus(label) + re.Equal(uint64(100), climit) + re.Equal(uint64(0), current) + }, + }, + { + opt: UpdateQPSLimiter(float64(rate.Every(time.Second)), 1), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&QPSChanged != 0) + }, + totalRequest: 200, + fail: 199, + success: 1, + release: 0, + waitDuration: time.Second, + checkStatusFunc: func(label string) { + limit, current := limiter.GetConcurrencyLimiterStatus(label) + re.Equal(uint64(100), limit) + re.Equal(uint64(1), current) + }, + }, + }, + }, + { + label: "test2", + round: []changeAndResult{ + { + opt: UpdateQPSLimiter(50, 5), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&QPSChanged != 0) + }, + totalRequest: 10, + fail: 5, + success: 5, + waitDuration: time.Second, + checkStatusFunc: func(label string) { + limit, burst := limiter.GetQPSLimiterStatus(label) + re.Equal(rate.Limit(50), limit) + re.Equal(5, burst) + }, + }, + { + opt: UpdateQPSLimiter(0, 0), + checkOptionStatus: func(label string, o Option) { + status := limiter.Update(label, o) + re.True(status&QPSDeleted != 0) + }, + totalRequest: 10, + fail: 0, + success: 10, + release: 0, + waitDuration: 0, + checkStatusFunc: func(label string) { + limit, burst := limiter.GetQPSLimiterStatus(label) + re.Equal(rate.Limit(0), limit) + re.Equal(0, burst) + }, + }, + }, + }, + } + runMulitLabelLimiter(t, limiter, testCase) +} + +func countRateLimiterHandleResult(limiter *Controller, label string, successCount *int, + failedCount *int, lock *syncutil.Mutex, wg *sync.WaitGroup, r *releaseUtil) { + doneFucn, err := limiter.Allow(label) + lock.Lock() + defer lock.Unlock() + if err == nil { + *successCount++ + r.append(doneFucn) + } else { + *failedCount++ + } + wg.Done() +} diff --git a/pkg/ratelimit/limiter.go b/pkg/ratelimit/limiter.go index 4bf930ed6c5..444b5aa2481 100644 --- a/pkg/ratelimit/limiter.go +++ b/pkg/ratelimit/limiter.go @@ -1,4 +1,4 @@ -// Copyright 2022 TiKV Project Authors. +// Copyright 2023 TiKV Project Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,11 +15,16 @@ package ratelimit import ( - "sync" + "math" + "github.com/tikv/pd/pkg/errs" + "github.com/tikv/pd/pkg/utils/syncutil" "golang.org/x/time/rate" ) +// DoneFunc is done function. +type DoneFunc func() + // DimensionConfig is the limit dimension config of one label type DimensionConfig struct { // qps conifg @@ -29,92 +34,125 @@ type DimensionConfig struct { ConcurrencyLimit uint64 } -// Limiter is a controller for the request rate. -type Limiter struct { - qpsLimiter sync.Map - concurrencyLimiter sync.Map - // the label which is in labelAllowList won't be limited - labelAllowList map[string]struct{} +type limiter struct { + mu syncutil.RWMutex + concurrency *concurrencyLimiter + rate *RateLimiter } -// NewLimiter returns a global limiter which can be updated in the later. -func NewLimiter() *Limiter { - return &Limiter{ - labelAllowList: make(map[string]struct{}), - } +func newLimiter() *limiter { + lim := &limiter{} + return lim } -// Allow is used to check whether it has enough token. -func (l *Limiter) Allow(label string) bool { - var cl *concurrencyLimiter - var ok bool - if limiter, exist := l.concurrencyLimiter.Load(label); exist { - if cl, ok = limiter.(*concurrencyLimiter); ok && !cl.allow() { - return false - } - } +func (l *limiter) getConcurrencyLimiter() *concurrencyLimiter { + l.mu.RLock() + defer l.mu.RUnlock() + return l.concurrency +} - if limiter, exist := l.qpsLimiter.Load(label); exist { - if ql, ok := limiter.(*RateLimiter); ok && !ql.Allow() { - if cl != nil { - cl.release() - } - return false - } - } +func (l *limiter) getRateLimiter() *RateLimiter { + l.mu.RLock() + defer l.mu.RUnlock() + return l.rate +} - return true +func (l *limiter) deleteRateLimiter() bool { + l.mu.Lock() + defer l.mu.Unlock() + l.rate = nil + return l.isEmpty() } -// Release is used to refill token. It may be not uesful for some limiters because they will refill automatically -func (l *Limiter) Release(label string) { - if limiter, exist := l.concurrencyLimiter.Load(label); exist { - if cl, ok := limiter.(*concurrencyLimiter); ok { - cl.release() - } - } +func (l *limiter) deleteConcurrency() bool { + l.mu.Lock() + defer l.mu.Unlock() + l.concurrency = nil + return l.isEmpty() } -// Update is used to update Ratelimiter with Options -func (l *Limiter) Update(label string, opts ...Option) UpdateStatus { - var status UpdateStatus - for _, opt := range opts { - status |= opt(label, l) - } - return status +func (l *limiter) isEmpty() bool { + return l.concurrency == nil && l.rate == nil } -// GetQPSLimiterStatus returns the status of a given label's QPS limiter. -func (l *Limiter) GetQPSLimiterStatus(label string) (limit rate.Limit, burst int) { - if limiter, exist := l.qpsLimiter.Load(label); exist { - return limiter.(*RateLimiter).Limit(), limiter.(*RateLimiter).Burst() +func (l *limiter) getQPSLimiterStatus() (limit rate.Limit, burst int) { + baseLimiter := l.getRateLimiter() + if baseLimiter != nil { + return baseLimiter.Limit(), baseLimiter.Burst() } - return 0, 0 } -// QPSUnlimit deletes QPS limiter of the given label -func (l *Limiter) QPSUnlimit(label string) { - l.qpsLimiter.Delete(label) +func (l *limiter) getConcurrencyLimiterStatus() (limit uint64, current uint64) { + baseLimiter := l.getConcurrencyLimiter() + if baseLimiter != nil { + return baseLimiter.getLimit(), baseLimiter.getCurrent() + } + return 0, 0 } -// GetConcurrencyLimiterStatus returns the status of a given label's concurrency limiter. -func (l *Limiter) GetConcurrencyLimiterStatus(label string) (limit uint64, current uint64) { - if limiter, exist := l.concurrencyLimiter.Load(label); exist { - return limiter.(*concurrencyLimiter).getLimit(), limiter.(*concurrencyLimiter).getCurrent() +func (l *limiter) updateConcurrencyConfig(limit uint64) UpdateStatus { + oldConcurrencyLimit, _ := l.getConcurrencyLimiterStatus() + if oldConcurrencyLimit == limit { + return ConcurrencyNoChange + } + if limit < 1 { + l.deleteConcurrency() + return ConcurrencyDeleted } - return 0, 0 + l.mu.Lock() + defer l.mu.Unlock() + if l.concurrency != nil { + l.concurrency.setLimit(limit) + } else { + l.concurrency = newConcurrencyLimiter(limit) + } + return ConcurrencyChanged +} + +func (l *limiter) updateQPSConfig(limit float64, burst int) UpdateStatus { + oldQPSLimit, oldBurst := l.getQPSLimiterStatus() + if math.Abs(float64(oldQPSLimit)-limit) < eps && oldBurst == burst { + return QPSNoChange + } + if limit <= eps || burst < 1 { + l.deleteRateLimiter() + return QPSDeleted + } + l.mu.Lock() + defer l.mu.Unlock() + if l.rate != nil { + l.rate.SetLimit(rate.Limit(limit)) + l.rate.SetBurst(burst) + } else { + l.rate = NewRateLimiter(limit, burst) + } + return QPSChanged } -// ConcurrencyUnlimit deletes concurrency limiter of the given label -func (l *Limiter) ConcurrencyUnlimit(label string) { - l.concurrencyLimiter.Delete(label) +func (l *limiter) updateDimensionConfig(cfg *DimensionConfig) UpdateStatus { + status := l.updateQPSConfig(cfg.QPS, cfg.QPSBurst) + status |= l.updateConcurrencyConfig(cfg.ConcurrencyLimit) + return status } -// IsInAllowList returns whether this label is in allow list. -// If returns true, the given label won't be limited -func (l *Limiter) IsInAllowList(label string) bool { - _, allow := l.labelAllowList[label] - return allow +func (l *limiter) allow() (DoneFunc, error) { + concurrency := l.getConcurrencyLimiter() + if concurrency != nil && !concurrency.allow() { + return nil, errs.ErrRateLimitExceeded + } + + rate := l.getRateLimiter() + if rate != nil && !rate.Allow() { + if concurrency != nil { + concurrency.release() + } + return nil, errs.ErrRateLimitExceeded + } + return func() { + if concurrency != nil { + concurrency.release() + } + }, nil } diff --git a/pkg/ratelimit/limiter_test.go b/pkg/ratelimit/limiter_test.go index d5d9829816a..8834495f3e9 100644 --- a/pkg/ratelimit/limiter_test.go +++ b/pkg/ratelimit/limiter_test.go @@ -1,4 +1,4 @@ -// Copyright 2022 TiKV Project Authors. +// Copyright 2023 TiKV Project Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -24,162 +24,145 @@ import ( "golang.org/x/time/rate" ) -func TestUpdateConcurrencyLimiter(t *testing.T) { +type releaseUtil struct { + dones []DoneFunc +} + +func (r *releaseUtil) release() { + if len(r.dones) > 0 { + r.dones[0]() + r.dones = r.dones[1:] + } +} + +func (r *releaseUtil) append(d DoneFunc) { + r.dones = append(r.dones, d) +} + +func TestWithConcurrencyLimiter(t *testing.T) { t.Parallel() re := require.New(t) - opts := []Option{UpdateConcurrencyLimiter(10)} - limiter := NewLimiter() - - label := "test" - status := limiter.Update(label, opts...) + limiter := newLimiter() + status := limiter.updateConcurrencyConfig(10) re.True(status&ConcurrencyChanged != 0) var lock syncutil.Mutex successCount, failedCount := 0, 0 var wg sync.WaitGroup + r := &releaseUtil{} for i := 0; i < 15; i++ { wg.Add(1) go func() { - countRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) + countSingleLimiterHandleResult(limiter, &successCount, &failedCount, &lock, &wg, r) }() } wg.Wait() re.Equal(5, failedCount) re.Equal(10, successCount) for i := 0; i < 10; i++ { - limiter.Release(label) + r.release() } - limit, current := limiter.GetConcurrencyLimiterStatus(label) + limit, current := limiter.getConcurrencyLimiterStatus() re.Equal(uint64(10), limit) re.Equal(uint64(0), current) - status = limiter.Update(label, UpdateConcurrencyLimiter(10)) + status = limiter.updateConcurrencyConfig(10) re.True(status&ConcurrencyNoChange != 0) - status = limiter.Update(label, UpdateConcurrencyLimiter(5)) + status = limiter.updateConcurrencyConfig(5) re.True(status&ConcurrencyChanged != 0) failedCount = 0 successCount = 0 for i := 0; i < 15; i++ { wg.Add(1) - go countRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) + go countSingleLimiterHandleResult(limiter, &successCount, &failedCount, &lock, &wg, r) } wg.Wait() re.Equal(10, failedCount) re.Equal(5, successCount) for i := 0; i < 5; i++ { - limiter.Release(label) + r.release() } - status = limiter.Update(label, UpdateConcurrencyLimiter(0)) + status = limiter.updateConcurrencyConfig(0) re.True(status&ConcurrencyDeleted != 0) failedCount = 0 successCount = 0 for i := 0; i < 15; i++ { wg.Add(1) - go countRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) + go countSingleLimiterHandleResult(limiter, &successCount, &failedCount, &lock, &wg, r) } wg.Wait() re.Equal(0, failedCount) re.Equal(15, successCount) - limit, current = limiter.GetConcurrencyLimiterStatus(label) + limit, current = limiter.getConcurrencyLimiterStatus() re.Equal(uint64(0), limit) re.Equal(uint64(0), current) } -func TestBlockList(t *testing.T) { +func TestWithQPSLimiter(t *testing.T) { t.Parallel() re := require.New(t) - opts := []Option{AddLabelAllowList()} - limiter := NewLimiter() - label := "test" - - re.False(limiter.IsInAllowList(label)) - for _, opt := range opts { - opt(label, limiter) - } - re.True(limiter.IsInAllowList(label)) - - status := UpdateQPSLimiter(float64(rate.Every(time.Second)), 1)(label, limiter) - re.True(status&InAllowList != 0) - for i := 0; i < 10; i++ { - re.True(limiter.Allow(label)) - } -} - -func TestUpdateQPSLimiter(t *testing.T) { - t.Parallel() - re := require.New(t) - opts := []Option{UpdateQPSLimiter(float64(rate.Every(time.Second)), 1)} - limiter := NewLimiter() - - label := "test" - status := limiter.Update(label, opts...) + limiter := newLimiter() + status := limiter.updateQPSConfig(float64(rate.Every(time.Second)), 1) re.True(status&QPSChanged != 0) var lock syncutil.Mutex successCount, failedCount := 0, 0 var wg sync.WaitGroup + r := &releaseUtil{} wg.Add(3) for i := 0; i < 3; i++ { - go countRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) + go countSingleLimiterHandleResult(limiter, &successCount, &failedCount, &lock, &wg, r) } wg.Wait() re.Equal(2, failedCount) re.Equal(1, successCount) - limit, burst := limiter.GetQPSLimiterStatus(label) + limit, burst := limiter.getQPSLimiterStatus() re.Equal(rate.Limit(1), limit) re.Equal(1, burst) - status = limiter.Update(label, UpdateQPSLimiter(float64(rate.Every(time.Second)), 1)) + status = limiter.updateQPSConfig(float64(rate.Every(time.Second)), 1) re.True(status&QPSNoChange != 0) - status = limiter.Update(label, UpdateQPSLimiter(5, 5)) + status = limiter.updateQPSConfig(5, 5) re.True(status&QPSChanged != 0) - limit, burst = limiter.GetQPSLimiterStatus(label) + limit, burst = limiter.getQPSLimiterStatus() re.Equal(rate.Limit(5), limit) re.Equal(5, burst) time.Sleep(time.Second) for i := 0; i < 10; i++ { if i < 5 { - re.True(limiter.Allow(label)) + _, err := limiter.allow() + re.NoError(err) } else { - re.False(limiter.Allow(label)) + _, err := limiter.allow() + re.Error(err) } } time.Sleep(time.Second) - status = limiter.Update(label, UpdateQPSLimiter(0, 0)) + status = limiter.updateQPSConfig(0, 0) re.True(status&QPSDeleted != 0) for i := 0; i < 10; i++ { - re.True(limiter.Allow(label)) + _, err := limiter.allow() + re.NoError(err) } - qLimit, qCurrent := limiter.GetQPSLimiterStatus(label) + qLimit, qCurrent := limiter.getQPSLimiterStatus() re.Equal(rate.Limit(0), qLimit) re.Equal(0, qCurrent) -} -func TestQPSLimiter(t *testing.T) { - t.Parallel() - re := require.New(t) - opts := []Option{UpdateQPSLimiter(float64(rate.Every(3*time.Second)), 100)} - limiter := NewLimiter() - - label := "test" - for _, opt := range opts { - opt(label, limiter) - } - - var lock syncutil.Mutex - successCount, failedCount := 0, 0 - var wg sync.WaitGroup + successCount = 0 + failedCount = 0 + status = limiter.updateQPSConfig(float64(rate.Every(3*time.Second)), 100) + re.True(status&QPSChanged != 0) wg.Add(200) for i := 0; i < 200; i++ { - go countRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) + go countSingleLimiterHandleResult(limiter, &successCount, &failedCount, &lock, &wg, r) } wg.Wait() re.Equal(200, failedCount+successCount) @@ -188,12 +171,12 @@ func TestQPSLimiter(t *testing.T) { time.Sleep(4 * time.Second) // 3+1 wg.Add(1) - countRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) + countSingleLimiterHandleResult(limiter, &successCount, &failedCount, &lock, &wg, r) wg.Wait() re.Equal(101, successCount) } -func TestTwoLimiters(t *testing.T) { +func TestWithTwoLimiters(t *testing.T) { t.Parallel() re := require.New(t) cfg := &DimensionConfig{ @@ -201,20 +184,18 @@ func TestTwoLimiters(t *testing.T) { QPSBurst: 100, ConcurrencyLimit: 100, } - opts := []Option{UpdateDimensionConfig(cfg)} - limiter := NewLimiter() - - label := "test" - for _, opt := range opts { - opt(label, limiter) - } + limiter := newLimiter() + status := limiter.updateDimensionConfig(cfg) + re.True(status&QPSChanged != 0) + re.True(status&ConcurrencyChanged != 0) var lock syncutil.Mutex successCount, failedCount := 0, 0 var wg sync.WaitGroup + r := &releaseUtil{} wg.Add(200) for i := 0; i < 200; i++ { - go countRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) + go countSingleLimiterHandleResult(limiter, &successCount, &failedCount, &lock, &wg, r) } wg.Wait() re.Equal(100, failedCount) @@ -223,35 +204,42 @@ func TestTwoLimiters(t *testing.T) { wg.Add(100) for i := 0; i < 100; i++ { - go countRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) + go countSingleLimiterHandleResult(limiter, &successCount, &failedCount, &lock, &wg, r) } wg.Wait() re.Equal(200, failedCount) re.Equal(100, successCount) for i := 0; i < 100; i++ { - limiter.Release(label) + r.release() } - limiter.Update(label, UpdateQPSLimiter(float64(rate.Every(10*time.Second)), 1)) + status = limiter.updateQPSConfig(float64(rate.Every(10*time.Second)), 1) + re.True(status&QPSChanged != 0) wg.Add(100) for i := 0; i < 100; i++ { - go countRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) + go countSingleLimiterHandleResult(limiter, &successCount, &failedCount, &lock, &wg, r) } wg.Wait() re.Equal(101, successCount) re.Equal(299, failedCount) - limit, current := limiter.GetConcurrencyLimiterStatus(label) + limit, current := limiter.getConcurrencyLimiterStatus() re.Equal(uint64(100), limit) re.Equal(uint64(1), current) + + cfg = &DimensionConfig{} + status = limiter.updateDimensionConfig(cfg) + re.True(status&ConcurrencyDeleted != 0) + re.True(status&QPSDeleted != 0) } -func countRateLimiterHandleResult(limiter *Limiter, label string, successCount *int, - failedCount *int, lock *syncutil.Mutex, wg *sync.WaitGroup) { - result := limiter.Allow(label) +func countSingleLimiterHandleResult(limiter *limiter, successCount *int, + failedCount *int, lock *syncutil.Mutex, wg *sync.WaitGroup, r *releaseUtil) { + doneFucn, err := limiter.allow() lock.Lock() defer lock.Unlock() - if result { + if err == nil { *successCount++ + r.append(doneFucn) } else { *failedCount++ } diff --git a/pkg/ratelimit/option.go b/pkg/ratelimit/option.go index 53afb9926d4..b1cc459d786 100644 --- a/pkg/ratelimit/option.go +++ b/pkg/ratelimit/option.go @@ -14,8 +14,6 @@ package ratelimit -import "golang.org/x/time/rate" - // UpdateStatus is flags for updating limiter config. type UpdateStatus uint32 @@ -40,77 +38,46 @@ const ( // Option is used to create a limiter with the optional settings. // these setting is used to add a kind of limiter for a service -type Option func(string, *Limiter) UpdateStatus +type Option func(string, *Controller) UpdateStatus // AddLabelAllowList adds a label into allow list. // It means the given label will not be limited func AddLabelAllowList() Option { - return func(label string, l *Limiter) UpdateStatus { + return func(label string, l *Controller) UpdateStatus { l.labelAllowList[label] = struct{}{} return 0 } } -func updateConcurrencyConfig(l *Limiter, label string, limit uint64) UpdateStatus { - oldConcurrencyLimit, _ := l.GetConcurrencyLimiterStatus(label) - if oldConcurrencyLimit == limit { - return ConcurrencyNoChange - } - if limit < 1 { - l.ConcurrencyUnlimit(label) - return ConcurrencyDeleted - } - if limiter, exist := l.concurrencyLimiter.LoadOrStore(label, newConcurrencyLimiter(limit)); exist { - limiter.(*concurrencyLimiter).setLimit(limit) - } - return ConcurrencyChanged -} - -func updateQPSConfig(l *Limiter, label string, limit float64, burst int) UpdateStatus { - oldQPSLimit, oldBurst := l.GetQPSLimiterStatus(label) - - if (float64(oldQPSLimit)-limit < eps && float64(oldQPSLimit)-limit > -eps) && oldBurst == burst { - return QPSNoChange - } - if limit <= eps || burst < 1 { - l.QPSUnlimit(label) - return QPSDeleted - } - if limiter, exist := l.qpsLimiter.LoadOrStore(label, NewRateLimiter(limit, burst)); exist { - limiter.(*RateLimiter).SetLimit(rate.Limit(limit)) - limiter.(*RateLimiter).SetBurst(burst) - } - return QPSChanged -} - // UpdateConcurrencyLimiter creates a concurrency limiter for a given label if it doesn't exist. func UpdateConcurrencyLimiter(limit uint64) Option { - return func(label string, l *Limiter) UpdateStatus { + return func(label string, l *Controller) UpdateStatus { if _, allow := l.labelAllowList[label]; allow { return InAllowList } - return updateConcurrencyConfig(l, label, limit) + lim, _ := l.limiters.LoadOrStore(label, newLimiter()) + return lim.(*limiter).updateConcurrencyConfig(limit) } } // UpdateQPSLimiter creates a QPS limiter for a given label if it doesn't exist. func UpdateQPSLimiter(limit float64, burst int) Option { - return func(label string, l *Limiter) UpdateStatus { + return func(label string, l *Controller) UpdateStatus { if _, allow := l.labelAllowList[label]; allow { return InAllowList } - return updateQPSConfig(l, label, limit, burst) + lim, _ := l.limiters.LoadOrStore(label, newLimiter()) + return lim.(*limiter).updateQPSConfig(limit, burst) } } // UpdateDimensionConfig creates QPS limiter and concurrency limiter for a given label by config if it doesn't exist. func UpdateDimensionConfig(cfg *DimensionConfig) Option { - return func(label string, l *Limiter) UpdateStatus { + return func(label string, l *Controller) UpdateStatus { if _, allow := l.labelAllowList[label]; allow { return InAllowList } - status := updateQPSConfig(l, label, cfg.QPS, cfg.QPSBurst) - status |= updateConcurrencyConfig(l, label, cfg.ConcurrencyLimit) - return status + lim, _ := l.limiters.LoadOrStore(label, newLimiter()) + return lim.(*limiter).updateDimensionConfig(cfg) } } diff --git a/pkg/schedule/hbstream/heartbeat_streams.go b/pkg/schedule/hbstream/heartbeat_streams.go index e7d7f688035..57a7521c0a7 100644 --- a/pkg/schedule/hbstream/heartbeat_streams.go +++ b/pkg/schedule/hbstream/heartbeat_streams.go @@ -139,7 +139,7 @@ func (s *HeartbeatStreams) run() { if stream, ok := s.streams[storeID]; ok { if err := stream.Send(msg); err != nil { log.Error("send heartbeat message fail", - zap.Uint64("region-id", msg.GetRegionId()), errs.ZapError(errs.ErrGRPCSend.Wrap(err).GenWithStackByArgs())) + zap.Uint64("region-id", msg.GetRegionId()), errs.ZapError(errs.ErrGRPCSend, err)) delete(s.streams, storeID) heartbeatStreamCounter.WithLabelValues(storeAddress, storeLabel, "push", "err").Inc() } else { diff --git a/pkg/schedule/plugin_interface.go b/pkg/schedule/plugin_interface.go index dd1ce3471e6..62ffe2eb900 100644 --- a/pkg/schedule/plugin_interface.go +++ b/pkg/schedule/plugin_interface.go @@ -46,19 +46,19 @@ func (p *PluginInterface) GetFunction(path string, funcName string) (plugin.Symb // open plugin filePath, err := filepath.Abs(path) if err != nil { - return nil, errs.ErrFilePathAbs.Wrap(err).FastGenWithCause() + return nil, errs.ErrFilePathAbs.Wrap(err) } log.Info("open plugin file", zap.String("file-path", filePath)) plugin, err := plugin.Open(filePath) if err != nil { - return nil, errs.ErrLoadPlugin.Wrap(err).FastGenWithCause() + return nil, errs.ErrLoadPlugin.Wrap(err) } p.pluginMap[path] = plugin } // get func from plugin f, err := p.pluginMap[path].Lookup(funcName) if err != nil { - return nil, errs.ErrLookupPluginFunc.Wrap(err).FastGenWithCause() + return nil, errs.ErrLookupPluginFunc.Wrap(err) } return f, nil } diff --git a/pkg/schedule/schedulers/evict_leader.go b/pkg/schedule/schedulers/evict_leader.go index 879aa9869b3..ae8b1ecf1ea 100644 --- a/pkg/schedule/schedulers/evict_leader.go +++ b/pkg/schedule/schedulers/evict_leader.go @@ -80,7 +80,7 @@ func (conf *evictLeaderSchedulerConfig) BuildWithArgs(args []string) error { id, err := strconv.ParseUint(args[0], 10, 64) if err != nil { - return errs.ErrStrconvParseUint.Wrap(err).FastGenWithCause() + return errs.ErrStrconvParseUint.Wrap(err) } ranges, err := getKeyRanges(args[1:]) if err != nil { diff --git a/pkg/schedule/schedulers/grant_leader.go b/pkg/schedule/schedulers/grant_leader.go index 885f81e2442..027350536aa 100644 --- a/pkg/schedule/schedulers/grant_leader.go +++ b/pkg/schedule/schedulers/grant_leader.go @@ -63,7 +63,7 @@ func (conf *grantLeaderSchedulerConfig) BuildWithArgs(args []string) error { id, err := strconv.ParseUint(args[0], 10, 64) if err != nil { - return errs.ErrStrconvParseUint.Wrap(err).FastGenWithCause() + return errs.ErrStrconvParseUint.Wrap(err) } ranges, err := getKeyRanges(args[1:]) if err != nil { diff --git a/pkg/schedule/schedulers/init.go b/pkg/schedule/schedulers/init.go index f60be1e5b06..57eb4b90985 100644 --- a/pkg/schedule/schedulers/init.go +++ b/pkg/schedule/schedulers/init.go @@ -129,7 +129,7 @@ func schedulersRegister() { id, err := strconv.ParseUint(args[0], 10, 64) if err != nil { - return errs.ErrStrconvParseUint.Wrap(err).FastGenWithCause() + return errs.ErrStrconvParseUint.Wrap(err) } ranges, err := getKeyRanges(args[1:]) @@ -180,14 +180,14 @@ func schedulersRegister() { } leaderID, err := strconv.ParseUint(args[0], 10, 64) if err != nil { - return errs.ErrStrconvParseUint.Wrap(err).FastGenWithCause() + return errs.ErrStrconvParseUint.Wrap(err) } storeIDs := make([]uint64, 0) for _, id := range strings.Split(args[1], ",") { storeID, err := strconv.ParseUint(id, 10, 64) if err != nil { - return errs.ErrStrconvParseUint.Wrap(err).FastGenWithCause() + return errs.ErrStrconvParseUint.Wrap(err) } storeIDs = append(storeIDs, storeID) } @@ -248,7 +248,7 @@ func schedulersRegister() { id, err := strconv.ParseUint(args[0], 10, 64) if err != nil { - return errs.ErrStrconvParseUint.Wrap(err).FastGenWithCause() + return errs.ErrStrconvParseUint.Wrap(err) } ranges, err := getKeyRanges(args[1:]) if err != nil { @@ -365,7 +365,7 @@ func schedulersRegister() { if len(args) == 1 { limit, err := strconv.ParseUint(args[0], 10, 64) if err != nil { - return errs.ErrStrconvParseUint.Wrap(err).FastGenWithCause() + return errs.ErrStrconvParseUint.Wrap(err) } conf.Limit = limit } diff --git a/pkg/schedule/schedulers/scheduler.go b/pkg/schedule/schedulers/scheduler.go index 1c788989454..38fc8f5607d 100644 --- a/pkg/schedule/schedulers/scheduler.go +++ b/pkg/schedule/schedulers/scheduler.go @@ -52,7 +52,7 @@ type Scheduler interface { func EncodeConfig(v interface{}) ([]byte, error) { marshaled, err := json.Marshal(v) if err != nil { - return nil, errs.ErrJSONMarshal.Wrap(err).FastGenWithCause() + return nil, errs.ErrJSONMarshal.Wrap(err) } return marshaled, nil } @@ -61,7 +61,7 @@ func EncodeConfig(v interface{}) ([]byte, error) { func DecodeConfig(data []byte, v interface{}) error { err := json.Unmarshal(data, v) if err != nil { - return errs.ErrJSONUnmarshal.Wrap(err).FastGenWithCause() + return errs.ErrJSONUnmarshal.Wrap(err) } return nil } diff --git a/pkg/schedule/schedulers/utils.go b/pkg/schedule/schedulers/utils.go index fea51798d1c..a22f992bda1 100644 --- a/pkg/schedule/schedulers/utils.go +++ b/pkg/schedule/schedulers/utils.go @@ -218,11 +218,11 @@ func getKeyRanges(args []string) ([]core.KeyRange, error) { for len(args) > 1 { startKey, err := url.QueryUnescape(args[0]) if err != nil { - return nil, errs.ErrQueryUnescape.Wrap(err).FastGenWithCause() + return nil, errs.ErrQueryUnescape.Wrap(err) } endKey, err := url.QueryUnescape(args[1]) if err != nil { - return nil, errs.ErrQueryUnescape.Wrap(err).FastGenWithCause() + return nil, errs.ErrQueryUnescape.Wrap(err) } args = args[2:] ranges = append(ranges, core.NewKeyRange(startKey, endKey)) diff --git a/pkg/tso/keyspace_group_manager.go b/pkg/tso/keyspace_group_manager.go index badcb18d5d8..58534de1642 100644 --- a/pkg/tso/keyspace_group_manager.go +++ b/pkg/tso/keyspace_group_manager.go @@ -542,7 +542,7 @@ func (kgm *KeyspaceGroupManager) InitializeGroupWatchLoop() error { putFn := func(kv *mvccpb.KeyValue) error { group := &endpoint.KeyspaceGroup{} if err := json.Unmarshal(kv.Value, group); err != nil { - return errs.ErrJSONUnmarshal.Wrap(err).FastGenWithCause() + return errs.ErrJSONUnmarshal.Wrap(err) } kgm.updateKeyspaceGroup(group) if group.ID == mcsutils.DefaultKeyspaceGroupID { diff --git a/pkg/utils/apiutil/apiutil.go b/pkg/utils/apiutil/apiutil.go index 164b3c0783d..53fab682fcb 100644 --- a/pkg/utils/apiutil/apiutil.go +++ b/pkg/utils/apiutil/apiutil.go @@ -39,15 +39,14 @@ import ( "go.uber.org/zap" ) -var ( - // componentSignatureKey is used for http request header key - // to identify component signature +const ( + // componentSignatureKey is used for http request header key to identify component signature. + // Deprecated: please use `XCallerIDHeader` below to obtain a more granular source identification. + // This is kept for backward compatibility. componentSignatureKey = "component" - // componentAnonymousValue identifies anonymous request source - componentAnonymousValue = "anonymous" -) + // anonymousValue identifies anonymous request source + anonymousValue = "anonymous" -const ( // PDRedirectorHeader is used to mark which PD redirected this request. PDRedirectorHeader = "PD-Redirector" // PDAllowFollowerHandleHeader is used to mark whether this request is allowed to be handled by the follower PD. @@ -58,6 +57,8 @@ const ( XForwardedPortHeader = "X-Forwarded-Port" // XRealIPHeader is used to mark the real client IP. XRealIPHeader = "X-Real-Ip" + // XCallerIDHeader is used to mark the caller ID. + XCallerIDHeader = "X-Caller-ID" // ForwardToMicroServiceHeader is used to mark the request is forwarded to micro service. ForwardToMicroServiceHeader = "Forward-To-Micro-Service" @@ -112,7 +113,7 @@ func ErrorResp(rd *render.Render, w http.ResponseWriter, err error) { // GetIPPortFromHTTPRequest returns http client host IP and port from context. // Because `X-Forwarded-For ` header has been written into RFC 7239(Forwarded HTTP Extension), -// so `X-Forwarded-For` has the higher priority than `X-Real-IP`. +// so `X-Forwarded-For` has the higher priority than `X-Real-Ip`. // And both of them have the higher priority than `RemoteAddr` func GetIPPortFromHTTPRequest(r *http.Request) (ip, port string) { forwardedIPs := strings.Split(r.Header.Get(XForwardedForHeader), ",") @@ -136,32 +137,42 @@ func GetIPPortFromHTTPRequest(r *http.Request) (ip, port string) { return splitIP, splitPort } -// GetComponentNameOnHTTP returns component name from Request Header -func GetComponentNameOnHTTP(r *http.Request) string { +// getComponentNameOnHTTP returns component name from the request header. +func getComponentNameOnHTTP(r *http.Request) string { componentName := r.Header.Get(componentSignatureKey) if len(componentName) == 0 { - componentName = componentAnonymousValue + componentName = anonymousValue } return componentName } -// ComponentSignatureRoundTripper is used to add component signature in HTTP header -type ComponentSignatureRoundTripper struct { - proxied http.RoundTripper - component string +// GetCallerIDOnHTTP returns caller ID from the request header. +func GetCallerIDOnHTTP(r *http.Request) string { + callerID := r.Header.Get(XCallerIDHeader) + if len(callerID) == 0 { + // Fall back to get the component name to keep backward compatibility. + callerID = getComponentNameOnHTTP(r) + } + return callerID +} + +// CallerIDRoundTripper is used to add caller ID in the HTTP header. +type CallerIDRoundTripper struct { + proxied http.RoundTripper + callerID string } -// NewComponentSignatureRoundTripper returns a new ComponentSignatureRoundTripper. -func NewComponentSignatureRoundTripper(roundTripper http.RoundTripper, componentName string) *ComponentSignatureRoundTripper { - return &ComponentSignatureRoundTripper{ - proxied: roundTripper, - component: componentName, +// NewCallerIDRoundTripper returns a new `CallerIDRoundTripper`. +func NewCallerIDRoundTripper(roundTripper http.RoundTripper, callerID string) *CallerIDRoundTripper { + return &CallerIDRoundTripper{ + proxied: roundTripper, + callerID: callerID, } } // RoundTrip is used to implement RoundTripper -func (rt *ComponentSignatureRoundTripper) RoundTrip(req *http.Request) (resp *http.Response, err error) { - req.Header.Add(componentSignatureKey, rt.component) +func (rt *CallerIDRoundTripper) RoundTrip(req *http.Request) (resp *http.Response, err error) { + req.Header.Add(XCallerIDHeader, rt.callerID) // Send the request, get the response and the error resp, err = rt.proxied.RoundTrip(req) return diff --git a/pkg/utils/apiutil/apiutil_test.go b/pkg/utils/apiutil/apiutil_test.go index a4e7b97aa4d..106d3fb21cb 100644 --- a/pkg/utils/apiutil/apiutil_test.go +++ b/pkg/utils/apiutil/apiutil_test.go @@ -101,7 +101,7 @@ func TestGetIPPortFromHTTPRequest(t *testing.T) { ip: "127.0.0.1", port: "5299", }, - // IPv4 "X-Real-IP" with port + // IPv4 "X-Real-Ip" with port { r: &http.Request{ Header: map[string][]string{ @@ -111,7 +111,7 @@ func TestGetIPPortFromHTTPRequest(t *testing.T) { ip: "127.0.0.1", port: "5299", }, - // IPv4 "X-Real-IP" without port + // IPv4 "X-Real-Ip" without port { r: &http.Request{ Header: map[string][]string{ @@ -158,7 +158,7 @@ func TestGetIPPortFromHTTPRequest(t *testing.T) { ip: "::1", port: "", }, - // IPv6 "X-Real-IP" with port + // IPv6 "X-Real-Ip" with port { r: &http.Request{ Header: map[string][]string{ @@ -168,7 +168,7 @@ func TestGetIPPortFromHTTPRequest(t *testing.T) { ip: "::1", port: "5299", }, - // IPv6 "X-Real-IP" without port + // IPv6 "X-Real-Ip" without port { r: &http.Request{ Header: map[string][]string{ diff --git a/pkg/utils/apiutil/serverapi/middleware.go b/pkg/utils/apiutil/serverapi/middleware.go old mode 100644 new mode 100755 index eb0f8a5f8eb..c7979dcc038 --- a/pkg/utils/apiutil/serverapi/middleware.go +++ b/pkg/utils/apiutil/serverapi/middleware.go @@ -182,14 +182,6 @@ func (h *redirector) ServeHTTP(w http.ResponseWriter, r *http.Request, next http return } - // Prevent more than one redirection. - if name := r.Header.Get(apiutil.PDRedirectorHeader); len(name) != 0 { - log.Error("redirect but server is not leader", zap.String("from", name), zap.String("server", h.s.Name()), errs.ZapError(errs.ErrRedirect)) - http.Error(w, errs.ErrRedirectToNotLeader.FastGenByArgs().Error(), http.StatusInternalServerError) - return - } - - r.Header.Set(apiutil.PDRedirectorHeader, h.s.Name()) forwardedIP, forwardedPort := apiutil.GetIPPortFromHTTPRequest(r) if len(forwardedIP) > 0 { r.Header.Add(apiutil.XForwardedForHeader, forwardedIP) @@ -208,9 +200,9 @@ func (h *redirector) ServeHTTP(w http.ResponseWriter, r *http.Request, next http return } clientUrls = append(clientUrls, targetAddr) + // Add a header to the response, this is not a failure injection + // it is used for testing, to check whether the request is forwarded to the micro service failpoint.Inject("checkHeader", func() { - // add a header to the response, this is not a failure injection - // it is used for testing, to check whether the request is forwarded to the micro service w.Header().Set(apiutil.ForwardToMicroServiceHeader, "true") }) } else { @@ -220,7 +212,15 @@ func (h *redirector) ServeHTTP(w http.ResponseWriter, r *http.Request, next http return } clientUrls = leader.GetClientUrls() + // Prevent more than one redirection among PD/API servers. + if name := r.Header.Get(apiutil.PDRedirectorHeader); len(name) != 0 { + log.Error("redirect but server is not leader", zap.String("from", name), zap.String("server", h.s.Name()), errs.ZapError(errs.ErrRedirect)) + http.Error(w, errs.ErrRedirectToNotLeader.FastGenByArgs().Error(), http.StatusInternalServerError) + return + } + r.Header.Set(apiutil.PDRedirectorHeader, h.s.Name()) } + urls := make([]url.URL, 0, len(clientUrls)) for _, item := range clientUrls { u, err := url.Parse(item) diff --git a/pkg/utils/logutil/log.go b/pkg/utils/logutil/log.go index 3dc4430b066..8c0977818fa 100644 --- a/pkg/utils/logutil/log.go +++ b/pkg/utils/logutil/log.go @@ -70,7 +70,7 @@ func StringToZapLogLevel(level string) zapcore.Level { func SetupLogger(logConfig log.Config, logger **zap.Logger, logProps **log.ZapProperties, enabled ...bool) error { lg, p, err := log.InitLogger(&logConfig, zap.AddStacktrace(zapcore.FatalLevel)) if err != nil { - return errs.ErrInitLogger.Wrap(err).FastGenWithCause() + return errs.ErrInitLogger.Wrap(err) } *logger = lg *logProps = p diff --git a/pkg/utils/requestutil/context_test.go b/pkg/utils/requestutil/context_test.go index 475b109e410..298fc1ff8a3 100644 --- a/pkg/utils/requestutil/context_test.go +++ b/pkg/utils/requestutil/context_test.go @@ -34,7 +34,7 @@ func TestRequestInfo(t *testing.T) { RequestInfo{ ServiceLabel: "test label", Method: http.MethodPost, - Component: "pdctl", + CallerID: "pdctl", IP: "localhost", URLParam: "{\"id\"=1}", BodyParam: "{\"state\"=\"Up\"}", @@ -45,7 +45,7 @@ func TestRequestInfo(t *testing.T) { re.True(ok) re.Equal("test label", result.ServiceLabel) re.Equal(http.MethodPost, result.Method) - re.Equal("pdctl", result.Component) + re.Equal("pdctl", result.CallerID) re.Equal("localhost", result.IP) re.Equal("{\"id\"=1}", result.URLParam) re.Equal("{\"state\"=\"Up\"}", result.BodyParam) diff --git a/pkg/utils/requestutil/request_info.go b/pkg/utils/requestutil/request_info.go index 40724bb790f..cc5403f7232 100644 --- a/pkg/utils/requestutil/request_info.go +++ b/pkg/utils/requestutil/request_info.go @@ -27,9 +27,11 @@ import ( // RequestInfo holds service information from http.Request type RequestInfo struct { - ServiceLabel string - Method string - Component string + ServiceLabel string + Method string + // CallerID is used to identify the specific source of a HTTP request, it will be marked in + // the PD HTTP client, with granularity that can be refined to a specific functionality within a component. + CallerID string IP string Port string URLParam string @@ -38,8 +40,8 @@ type RequestInfo struct { } func (info *RequestInfo) String() string { - s := fmt.Sprintf("{ServiceLabel:%s, Method:%s, Component:%s, IP:%s, Port:%s, StartTime:%s, URLParam:%s, BodyParam:%s}", - info.ServiceLabel, info.Method, info.Component, info.IP, info.Port, time.Unix(info.StartTimeStamp, 0), info.URLParam, info.BodyParam) + s := fmt.Sprintf("{ServiceLabel:%s, Method:%s, CallerID:%s, IP:%s, Port:%s, StartTime:%s, URLParam:%s, BodyParam:%s}", + info.ServiceLabel, info.Method, info.CallerID, info.IP, info.Port, time.Unix(info.StartTimeStamp, 0), info.URLParam, info.BodyParam) return s } @@ -49,7 +51,7 @@ func GetRequestInfo(r *http.Request) RequestInfo { return RequestInfo{ ServiceLabel: apiutil.GetRouteName(r), Method: fmt.Sprintf("%s/%s:%s", r.Proto, r.Method, r.URL.Path), - Component: apiutil.GetComponentNameOnHTTP(r), + CallerID: apiutil.GetCallerIDOnHTTP(r), IP: ip, Port: port, URLParam: getURLParam(r), diff --git a/pkg/utils/tempurl/check_env_linux.go b/pkg/utils/tempurl/check_env_linux.go index cf0e686cada..58f902f4bb7 100644 --- a/pkg/utils/tempurl/check_env_linux.go +++ b/pkg/utils/tempurl/check_env_linux.go @@ -36,7 +36,7 @@ func checkAddr(addr string) (bool, error) { return s.RemoteAddr.String() == addr || s.LocalAddr.String() == addr }) if err != nil { - return false, errs.ErrNetstatTCPSocks.Wrap(err).FastGenWithCause() + return false, errs.ErrNetstatTCPSocks.Wrap(err) } return len(tabs) < 1, nil } diff --git a/server/api/middleware.go b/server/api/middleware.go index 627d7fecc92..6536935592f 100644 --- a/server/api/middleware.go +++ b/server/api/middleware.go @@ -69,7 +69,7 @@ func (rm *requestInfoMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Reques w.Header().Add("body-param", requestInfo.BodyParam) w.Header().Add("url-param", requestInfo.URLParam) w.Header().Add("method", requestInfo.Method) - w.Header().Add("component", requestInfo.Component) + w.Header().Add("caller-id", requestInfo.CallerID) w.Header().Add("ip", requestInfo.IP) }) @@ -177,8 +177,8 @@ func (s *rateLimitMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request, // There is no need to check whether rateLimiter is nil. CreateServer ensures that it is created rateLimiter := s.svr.GetServiceRateLimiter() - if rateLimiter.Allow(requestInfo.ServiceLabel) { - defer rateLimiter.Release(requestInfo.ServiceLabel) + if done, err := rateLimiter.Allow(requestInfo.ServiceLabel); err == nil { + defer done() next(w, r) } else { http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests) diff --git a/server/api/store.go b/server/api/store.go index 8537cd45c5b..44e178c23fd 100644 --- a/server/api/store.go +++ b/server/api/store.go @@ -334,7 +334,7 @@ func (h *storeHandler) SetStoreLabel(w http.ResponseWriter, r *http.Request) { // @Param id path integer true "Store Id" // @Param body body object true "Labels in json format" // @Produce json -// @Success 200 {string} string "The store's label is updated." +// @Success 200 {string} string "The label is deleted for store." // @Failure 400 {string} string "The input is invalid." // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /store/{id}/label [delete] @@ -369,7 +369,7 @@ func (h *storeHandler) DeleteStoreLabel(w http.ResponseWriter, r *http.Request) // @Param id path integer true "Store Id" // @Param body body object true "json params" // @Produce json -// @Success 200 {string} string "The store's label is updated." +// @Success 200 {string} string "The store's weight is updated." // @Failure 400 {string} string "The input is invalid." // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /store/{id}/weight [post] @@ -413,7 +413,7 @@ func (h *storeHandler) SetStoreWeight(w http.ResponseWriter, r *http.Request) { return } - h.rd.JSON(w, http.StatusOK, "The store's label is updated.") + h.rd.JSON(w, http.StatusOK, "The store's weight is updated.") } // FIXME: details of input json body params @@ -423,7 +423,7 @@ func (h *storeHandler) SetStoreWeight(w http.ResponseWriter, r *http.Request) { // @Param id path integer true "Store Id" // @Param body body object true "json params" // @Produce json -// @Success 200 {string} string "The store's label is updated." +// @Success 200 {string} string "The store's limit is updated." // @Failure 400 {string} string "The input is invalid." // @Failure 500 {string} string "PD server failed to proceed the request." // @Router /store/{id}/limit [post] @@ -486,7 +486,7 @@ func (h *storeHandler) SetStoreLimit(w http.ResponseWriter, r *http.Request) { return } } - h.rd.JSON(w, http.StatusOK, "The store's label is updated.") + h.rd.JSON(w, http.StatusOK, "The store's limit is updated.") } type storesHandler struct { diff --git a/server/grpc_service.go b/server/grpc_service.go index fa74f1ea8b6..24280f46437 100644 --- a/server/grpc_service.go +++ b/server/grpc_service.go @@ -431,11 +431,11 @@ func (s *GrpcServer) GetMembers(context.Context, *pdpb.GetMembersRequest) (*pdpb if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { fName := currentFunction() limiter := s.GetGRPCRateLimiter() - if limiter.Allow(fName) { - defer limiter.Release(fName) + if done, err := limiter.Allow(fName); err == nil { + defer done() } else { return &pdpb.GetMembersResponse{ - Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, errs.ErrRateLimitExceeded.FastGenByArgs().Error()), + Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), }, nil } } @@ -662,11 +662,11 @@ func (s *GrpcServer) GetStore(ctx context.Context, request *pdpb.GetStoreRequest if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { fName := currentFunction() limiter := s.GetGRPCRateLimiter() - if limiter.Allow(fName) { - defer limiter.Release(fName) + if done, err := limiter.Allow(fName); err == nil { + defer done() } else { return &pdpb.GetStoreResponse{ - Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, errs.ErrRateLimitExceeded.FastGenByArgs().Error()), + Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), }, nil } } @@ -765,11 +765,11 @@ func (s *GrpcServer) GetAllStores(ctx context.Context, request *pdpb.GetAllStore if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { fName := currentFunction() limiter := s.GetGRPCRateLimiter() - if limiter.Allow(fName) { - defer limiter.Release(fName) + if done, err := limiter.Allow(fName); err == nil { + defer done() } else { return &pdpb.GetAllStoresResponse{ - Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, errs.ErrRateLimitExceeded.FastGenByArgs().Error()), + Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, err.Error()), }, nil } } @@ -810,8 +810,8 @@ func (s *GrpcServer) StoreHeartbeat(ctx context.Context, request *pdpb.StoreHear if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { fName := currentFunction() limiter := s.GetGRPCRateLimiter() - if limiter.Allow(fName) { - defer limiter.Release(fName) + if done, err := limiter.Allow(fName); err == nil { + defer done() } else { return &pdpb.StoreHeartbeatResponse{ Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, errs.ErrRateLimitExceeded.FastGenByArgs().Error()), @@ -1286,8 +1286,8 @@ func (s *GrpcServer) GetRegion(ctx context.Context, request *pdpb.GetRegionReque if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { fName := currentFunction() limiter := s.GetGRPCRateLimiter() - if limiter.Allow(fName) { - defer limiter.Release(fName) + if done, err := limiter.Allow(fName); err == nil { + defer done() } else { return &pdpb.GetRegionResponse{ Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, errs.ErrRateLimitExceeded.FastGenByArgs().Error()), @@ -1330,8 +1330,8 @@ func (s *GrpcServer) GetPrevRegion(ctx context.Context, request *pdpb.GetRegionR if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { fName := currentFunction() limiter := s.GetGRPCRateLimiter() - if limiter.Allow(fName) { - defer limiter.Release(fName) + if done, err := limiter.Allow(fName); err == nil { + defer done() } else { return &pdpb.GetRegionResponse{ Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, errs.ErrRateLimitExceeded.FastGenByArgs().Error()), @@ -1375,8 +1375,8 @@ func (s *GrpcServer) GetRegionByID(ctx context.Context, request *pdpb.GetRegionB if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { fName := currentFunction() limiter := s.GetGRPCRateLimiter() - if limiter.Allow(fName) { - defer limiter.Release(fName) + if done, err := limiter.Allow(fName); err == nil { + defer done() } else { return &pdpb.GetRegionResponse{ Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, errs.ErrRateLimitExceeded.FastGenByArgs().Error()), @@ -1419,8 +1419,8 @@ func (s *GrpcServer) ScanRegions(ctx context.Context, request *pdpb.ScanRegionsR if s.GetServiceMiddlewarePersistOptions().IsGRPCRateLimitEnabled() { fName := currentFunction() limiter := s.GetGRPCRateLimiter() - if limiter.Allow(fName) { - defer limiter.Release(fName) + if done, err := limiter.Allow(fName); err == nil { + defer done() } else { return &pdpb.ScanRegionsResponse{ Header: s.wrapErrorToHeader(pdpb.ErrorType_UNKNOWN, errs.ErrRateLimitExceeded.FastGenByArgs().Error()), diff --git a/server/handler.go b/server/handler.go index 6c0679bd9f9..b91c8e368f9 100644 --- a/server/handler.go +++ b/server/handler.go @@ -470,7 +470,7 @@ func (h *Handler) PluginLoad(pluginPath string) error { // make sure path is in data dir filePath, err := filepath.Abs(pluginPath) if err != nil || !isPathInDirectory(filePath, h.s.GetConfig().DataDir) { - return errs.ErrFilePathAbs.Wrap(err).FastGenWithCause() + return errs.ErrFilePathAbs.Wrap(err) } c.LoadPlugin(pluginPath, ch) diff --git a/server/metrics.go b/server/metrics.go index 2d13d02d564..54c5830dc52 100644 --- a/server/metrics.go +++ b/server/metrics.go @@ -151,7 +151,7 @@ var ( Name: "audit_handling_seconds", Help: "PD server service handling audit", Buckets: prometheus.DefBuckets, - }, []string{"service", "method", "component", "ip"}) + }, []string{"service", "method", "caller_id", "ip"}) serverMaxProcs = prometheus.NewGauge( prometheus.GaugeOpts{ Namespace: "pd", diff --git a/server/server.go b/server/server.go index c815e7d50c6..187c30dbf7a 100644 --- a/server/server.go +++ b/server/server.go @@ -216,11 +216,11 @@ type Server struct { // related data structures defined in the PD grpc service pdProtoFactory *tsoutil.PDProtoFactory - serviceRateLimiter *ratelimit.Limiter + serviceRateLimiter *ratelimit.Controller serviceLabels map[string][]apiutil.AccessPath apiServiceLabelMap map[apiutil.AccessPath]string - grpcServiceRateLimiter *ratelimit.Limiter + grpcServiceRateLimiter *ratelimit.Controller grpcServiceLabels map[string]struct{} grpcServer *grpc.Server @@ -273,8 +273,8 @@ func CreateServer(ctx context.Context, cfg *config.Config, services []string, le audit.NewLocalLogBackend(true), audit.NewPrometheusHistogramBackend(serviceAuditHistogram, false), } - s.serviceRateLimiter = ratelimit.NewLimiter() - s.grpcServiceRateLimiter = ratelimit.NewLimiter() + s.serviceRateLimiter = ratelimit.NewController() + s.grpcServiceRateLimiter = ratelimit.NewController() s.serviceAuditBackendLabels = make(map[string]*audit.BackendLabels) s.serviceLabels = make(map[string][]apiutil.AccessPath) s.grpcServiceLabels = make(map[string]struct{}) @@ -1467,7 +1467,7 @@ func (s *Server) SetServiceAuditBackendLabels(serviceLabel string, labels []stri } // GetServiceRateLimiter is used to get rate limiter -func (s *Server) GetServiceRateLimiter() *ratelimit.Limiter { +func (s *Server) GetServiceRateLimiter() *ratelimit.Controller { return s.serviceRateLimiter } @@ -1482,7 +1482,7 @@ func (s *Server) UpdateServiceRateLimiter(serviceLabel string, opts ...ratelimit } // GetGRPCRateLimiter is used to get rate limiter -func (s *Server) GetGRPCRateLimiter() *ratelimit.Limiter { +func (s *Server) GetGRPCRateLimiter() *ratelimit.Controller { return s.grpcServiceRateLimiter } diff --git a/tests/integrations/client/http_client_test.go b/tests/integrations/client/http_client_test.go index 6c636d2a2a1..7c8f66f4826 100644 --- a/tests/integrations/client/http_client_test.go +++ b/tests/integrations/client/http_client_test.go @@ -49,7 +49,7 @@ func (suite *httpClientTestSuite) SetupSuite() { re := suite.Require() var err error suite.ctx, suite.cancelFunc = context.WithCancel(context.Background()) - suite.cluster, err = tests.NewTestCluster(suite.ctx, 1) + suite.cluster, err = tests.NewTestCluster(suite.ctx, 2) re.NoError(err) err = suite.cluster.RunInitialServers() re.NoError(err) @@ -134,6 +134,13 @@ func (suite *httpClientTestSuite) TestMeta() { re.NoError(err) re.Equal(1, store.Count) re.Len(store.Stores, 1) + storeID := uint64(store.Stores[0].Store.ID) // TODO: why type is different? + store2, err := suite.client.GetStore(suite.ctx, storeID) + re.NoError(err) + re.EqualValues(storeID, store2.Store.ID) + version, err := suite.client.GetClusterVersion(suite.ctx) + re.NoError(err) + re.Equal("0.0.0", version) } func (suite *httpClientTestSuite) TestGetMinResolvedTSByStoresIDs() { @@ -384,3 +391,67 @@ func (suite *httpClientTestSuite) TestScheduleConfig() { re.Equal(float64(8), config["leader-schedule-limit"]) re.Equal(float64(2048), config["region-schedule-limit"]) } + +func (suite *httpClientTestSuite) TestSchedulers() { + re := suite.Require() + schedulers, err := suite.client.GetSchedulers(suite.ctx) + re.NoError(err) + re.Len(schedulers, 0) + + err = suite.client.CreateScheduler(suite.ctx, "evict-leader-scheduler", 1) + re.NoError(err) + schedulers, err = suite.client.GetSchedulers(suite.ctx) + re.NoError(err) + re.Len(schedulers, 1) + err = suite.client.SetSchedulerDelay(suite.ctx, "evict-leader-scheduler", 100) + re.NoError(err) + err = suite.client.SetSchedulerDelay(suite.ctx, "not-exist", 100) + re.ErrorContains(err, "500 Internal Server Error") // TODO: should return friendly error message +} + +func (suite *httpClientTestSuite) TestSetStoreLabels() { + re := suite.Require() + resp, err := suite.client.GetStores(suite.ctx) + re.NoError(err) + setStore := resp.Stores[0] + re.Empty(setStore.Store.Labels, nil) + storeLabels := map[string]string{ + "zone": "zone1", + } + err = suite.client.SetStoreLabels(suite.ctx, 1, storeLabels) + re.NoError(err) + + resp, err = suite.client.GetStores(suite.ctx) + re.NoError(err) + for _, store := range resp.Stores { + if store.Store.ID == setStore.Store.ID { + for _, label := range store.Store.Labels { + re.Equal(label.Value, storeLabels[label.Key]) + } + } + } +} + +func (suite *httpClientTestSuite) TestTransferLeader() { + re := suite.Require() + members, err := suite.client.GetMembers(suite.ctx) + re.NoError(err) + re.Len(members.Members, 2) + + oldLeader, err := suite.client.GetLeader(suite.ctx) + re.NoError(err) + + // Transfer leader to another pd + for _, member := range members.Members { + if member.Name != oldLeader.Name { + err = suite.client.TransferLeader(suite.ctx, member.Name) + re.NoError(err) + break + } + } + + newLeader := suite.cluster.WaitLeader() + re.NotEmpty(newLeader) + re.NoError(err) + re.NotEqual(oldLeader.Name, newLeader) +} diff --git a/tests/integrations/mcs/scheduling/api_test.go b/tests/integrations/mcs/scheduling/api_test.go index 8f5d37ee1bb..4c71f8f14a3 100644 --- a/tests/integrations/mcs/scheduling/api_test.go +++ b/tests/integrations/mcs/scheduling/api_test.go @@ -1,6 +1,7 @@ package scheduling_test import ( + "context" "encoding/hex" "encoding/json" "fmt" @@ -40,10 +41,12 @@ func TestAPI(t *testing.T) { } func (suite *apiTestSuite) SetupSuite() { + suite.NoError(failpoint.Enable("github.com/tikv/pd/pkg/utils/apiutil/serverapi/checkHeader", "return(true)")) suite.env = tests.NewSchedulingTestEnvironment(suite.T()) } func (suite *apiTestSuite) TearDownSuite() { + suite.NoError(failpoint.Disable("github.com/tikv/pd/pkg/utils/apiutil/serverapi/checkHeader")) suite.env.Cleanup() } @@ -99,10 +102,6 @@ func (suite *apiTestSuite) TestAPIForward() { func (suite *apiTestSuite) checkAPIForward(cluster *tests.TestCluster) { re := suite.Require() - re.NoError(failpoint.Enable("github.com/tikv/pd/pkg/utils/apiutil/serverapi/checkHeader", "return(true)")) - defer func() { - re.NoError(failpoint.Disable("github.com/tikv/pd/pkg/utils/apiutil/serverapi/checkHeader")) - }() leader := cluster.GetLeaderServer().GetServer() urlPrefix := fmt.Sprintf("%s/pd/api/v1", leader.GetAddr()) @@ -300,7 +299,7 @@ func (suite *apiTestSuite) checkAPIForward(cluster *tests.TestCluster) { rulesArgs, err := json.Marshal(rules) suite.NoError(err) - err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "/config/rules"), &rules, + err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules"), &rules, testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true")) re.NoError(err) err = testutil.CheckPostJSON(testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules"), rulesArgs, @@ -499,3 +498,49 @@ func (suite *apiTestSuite) checkAdminRegionCacheForward(cluster *tests.TestClust re.Equal(0, schedulingServer.GetCluster().GetRegionCount([]byte{}, []byte{})) re.Equal(0, apiServer.GetRaftCluster().GetRegionCount([]byte{}, []byte{}).Count) } + +func (suite *apiTestSuite) TestFollowerForward() { + suite.env.RunTestInTwoModes(suite.checkFollowerForward) +} + +func (suite *apiTestSuite) checkFollowerForward(cluster *tests.TestCluster) { + re := suite.Require() + leaderAddr := cluster.GetLeaderServer().GetAddr() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + follower, err := cluster.JoinAPIServer(ctx) + re.NoError(err) + re.NoError(follower.Run()) + re.NotEmpty(cluster.WaitLeader()) + + followerAddr := follower.GetAddr() + if cluster.GetLeaderServer().GetAddr() != leaderAddr { + followerAddr = leaderAddr + } + + urlPrefix := fmt.Sprintf("%s/pd/api/v1", followerAddr) + rules := []*placement.Rule{} + if sche := cluster.GetSchedulingPrimaryServer(); sche != nil { + // follower will forward to scheduling server directly + re.NotEqual(cluster.GetLeaderServer().GetAddr(), followerAddr) + err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules"), &rules, + testutil.WithHeader(re, apiutil.ForwardToMicroServiceHeader, "true"), + ) + re.NoError(err) + } else { + // follower will forward to leader server + re.NotEqual(cluster.GetLeaderServer().GetAddr(), followerAddr) + err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config/rules"), &rules, + testutil.WithoutHeader(re, apiutil.ForwardToMicroServiceHeader), + ) + re.NoError(err) + } + + // follower will forward to leader server + re.NotEqual(cluster.GetLeaderServer().GetAddr(), followerAddr) + results := make(map[string]interface{}) + err = testutil.ReadGetJSON(re, testDialClient, fmt.Sprintf("%s/%s", urlPrefix, "config"), &results, + testutil.WithoutHeader(re, apiutil.ForwardToMicroServiceHeader), + ) + re.NoError(err) +} diff --git a/tests/pdctl/global_test.go b/tests/pdctl/global_test.go index 7e57f589249..00d31a384d5 100644 --- a/tests/pdctl/global_test.go +++ b/tests/pdctl/global_test.go @@ -30,18 +30,20 @@ import ( "go.uber.org/zap" ) +const pdControlCallerID = "pd-ctl" + func TestSendAndGetComponent(t *testing.T) { re := require.New(t) handler := func(ctx context.Context, s *server.Server) (http.Handler, apiutil.APIServiceGroup, error) { mux := http.NewServeMux() mux.HandleFunc("/pd/api/v1/health", func(w http.ResponseWriter, r *http.Request) { - component := apiutil.GetComponentNameOnHTTP(r) + callerID := apiutil.GetCallerIDOnHTTP(r) for k := range r.Header { log.Info("header", zap.String("key", k)) } - log.Info("component", zap.String("component", component)) - re.Equal("pdctl", component) - fmt.Fprint(w, component) + log.Info("caller id", zap.String("caller-id", callerID)) + re.Equal(pdControlCallerID, callerID) + fmt.Fprint(w, callerID) }) info := apiutil.APIServiceGroup{ IsCore: true, @@ -65,5 +67,5 @@ func TestSendAndGetComponent(t *testing.T) { args := []string{"-u", pdAddr, "health"} output, err := ExecuteCommand(cmd, args...) re.NoError(err) - re.Equal("pdctl\n", string(output)) + re.Equal(fmt.Sprintf("%s\n", pdControlCallerID), string(output)) } diff --git a/tests/pdctl/scheduler/scheduler_test.go b/tests/pdctl/scheduler/scheduler_test.go index d8d54a79d13..fb7c239b431 100644 --- a/tests/pdctl/scheduler/scheduler_test.go +++ b/tests/pdctl/scheduler/scheduler_test.go @@ -15,7 +15,6 @@ package scheduler_test import ( - "context" "encoding/json" "fmt" "reflect" @@ -691,47 +690,3 @@ func mightExec(re *require.Assertions, cmd *cobra.Command, args []string, v inte } json.Unmarshal(output, v) } - -func TestForwardSchedulerRequest(t *testing.T) { - re := require.New(t) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - cluster, err := tests.NewTestAPICluster(ctx, 1) - re.NoError(err) - re.NoError(cluster.RunInitialServers()) - re.NotEmpty(cluster.WaitLeader()) - server := cluster.GetLeaderServer() - re.NoError(server.BootstrapCluster()) - backendEndpoints := server.GetAddr() - tc, err := tests.NewTestSchedulingCluster(ctx, 1, backendEndpoints) - re.NoError(err) - defer tc.Destroy() - tc.WaitForPrimaryServing(re) - - cmd := pdctlCmd.GetRootCmd() - args := []string{"-u", backendEndpoints, "scheduler", "show"} - var sches []string - testutil.Eventually(re, func() bool { - output, err := pdctl.ExecuteCommand(cmd, args...) - re.NoError(err) - re.NoError(json.Unmarshal(output, &sches)) - return slice.Contains(sches, "balance-leader-scheduler") - }) - - mustUsage := func(args []string) { - output, err := pdctl.ExecuteCommand(cmd, args...) - re.NoError(err) - re.Contains(string(output), "Usage") - } - mustUsage([]string{"-u", backendEndpoints, "scheduler", "pause", "balance-leader-scheduler"}) - echo := mustExec(re, cmd, []string{"-u", backendEndpoints, "scheduler", "pause", "balance-leader-scheduler", "60"}, nil) - re.Contains(echo, "Success!") - checkSchedulerWithStatusCommand := func(status string, expected []string) { - var schedulers []string - mustExec(re, cmd, []string{"-u", backendEndpoints, "scheduler", "show", "--status", status}, &schedulers) - re.Equal(expected, schedulers) - } - checkSchedulerWithStatusCommand("paused", []string{ - "balance-leader-scheduler", - }) -} diff --git a/tests/server/api/api_test.go b/tests/server/api/api_test.go index 905fb8ec096..f5db6bb2513 100644 --- a/tests/server/api/api_test.go +++ b/tests/server/api/api_test.go @@ -164,7 +164,7 @@ func (suite *middlewareTestSuite) TestRequestInfoMiddleware() { suite.Equal("{\"force\":[\"true\"]}", resp.Header.Get("url-param")) suite.Equal("{\"testkey\":\"testvalue\"}", resp.Header.Get("body-param")) suite.Equal("HTTP/1.1/POST:/pd/api/v1/debug/pprof/profile", resp.Header.Get("method")) - suite.Equal("anonymous", resp.Header.Get("component")) + suite.Equal("anonymous", resp.Header.Get("caller-id")) suite.Equal("127.0.0.1", resp.Header.Get("ip")) input = map[string]interface{}{ @@ -408,7 +408,7 @@ func (suite *middlewareTestSuite) TestAuditPrometheusBackend() { defer resp.Body.Close() content, _ := io.ReadAll(resp.Body) output := string(content) - suite.Contains(output, "pd_service_audit_handling_seconds_count{component=\"anonymous\",ip=\"127.0.0.1\",method=\"HTTP\",service=\"GetTrend\"} 1") + suite.Contains(output, "pd_service_audit_handling_seconds_count{caller_id=\"anonymous\",ip=\"127.0.0.1\",method=\"HTTP\",service=\"GetTrend\"} 1") // resign to test persist config oldLeaderName := leader.GetServer().Name() @@ -434,7 +434,7 @@ func (suite *middlewareTestSuite) TestAuditPrometheusBackend() { defer resp.Body.Close() content, _ = io.ReadAll(resp.Body) output = string(content) - suite.Contains(output, "pd_service_audit_handling_seconds_count{component=\"anonymous\",ip=\"127.0.0.1\",method=\"HTTP\",service=\"GetTrend\"} 2") + suite.Contains(output, "pd_service_audit_handling_seconds_count{caller_id=\"anonymous\",ip=\"127.0.0.1\",method=\"HTTP\",service=\"GetTrend\"} 2") input = map[string]interface{}{ "enable-audit": "false", @@ -543,7 +543,7 @@ func BenchmarkDoRequestWithoutServiceMiddleware(b *testing.B) { func doTestRequestWithLogAudit(srv *tests.TestServer) { req, _ := http.NewRequest(http.MethodDelete, fmt.Sprintf("%s/pd/api/v1/admin/cache/regions", srv.GetAddr()), http.NoBody) - req.Header.Set("component", "test") + req.Header.Set(apiutil.XCallerIDHeader, "test") resp, _ := dialClient.Do(req) resp.Body.Close() } @@ -551,7 +551,7 @@ func doTestRequestWithLogAudit(srv *tests.TestServer) { func doTestRequestWithPrometheus(srv *tests.TestServer) { timeUnix := time.Now().Unix() - 20 req, _ := http.NewRequest(http.MethodGet, fmt.Sprintf("%s/pd/api/v1/trend?from=%d", srv.GetAddr(), timeUnix), http.NoBody) - req.Header.Set("component", "test") + req.Header.Set(apiutil.XCallerIDHeader, "test") resp, _ := dialClient.Do(req) resp.Body.Close() } diff --git a/tools/pd-ctl/pdctl/command/global.go b/tools/pd-ctl/pdctl/command/global.go index 5d8552da51a..0b1f4b4409a 100644 --- a/tools/pd-ctl/pdctl/command/global.go +++ b/tools/pd-ctl/pdctl/command/global.go @@ -29,14 +29,15 @@ import ( "go.etcd.io/etcd/pkg/transport" ) -var ( - pdControllerComponentName = "pdctl" - dialClient = &http.Client{ - Transport: apiutil.NewComponentSignatureRoundTripper(http.DefaultTransport, pdControllerComponentName), - } - pingPrefix = "pd/api/v1/ping" +const ( + pdControlCallerID = "pd-ctl" + pingPrefix = "pd/api/v1/ping" ) +var dialClient = &http.Client{ + Transport: apiutil.NewCallerIDRoundTripper(http.DefaultTransport, pdControlCallerID), +} + // InitHTTPSClient creates https client with ca file func InitHTTPSClient(caPath, certPath, keyPath string) error { tlsInfo := transport.TLSInfo{ @@ -50,8 +51,8 @@ func InitHTTPSClient(caPath, certPath, keyPath string) error { } dialClient = &http.Client{ - Transport: apiutil.NewComponentSignatureRoundTripper( - &http.Transport{TLSClientConfig: tlsConfig}, pdControllerComponentName), + Transport: apiutil.NewCallerIDRoundTripper( + &http.Transport{TLSClientConfig: tlsConfig}, pdControlCallerID), } return nil