From 3883864220116d24b27083a76255dcf9a7aee065 Mon Sep 17 00:00:00 2001 From: Yongbo Jiang Date: Wed, 29 Nov 2023 18:05:48 +0800 Subject: [PATCH] This is an automated cherry-pick of #7471 close tikv/pd#7469 Signed-off-by: ti-chi-bot --- client/client.go | 242 ++++++++++++++++++++++++++++++++++++++ client/gc_client.go | 136 +++++++++++++++++++++ client/keyspace_client.go | 153 ++++++++++++++++++++++++ 3 files changed, 531 insertions(+) create mode 100644 client/gc_client.go create mode 100644 client/keyspace_client.go diff --git a/client/client.go b/client/client.go index a14c6bdc293..17a0f0f91f7 100644 --- a/client/client.go +++ b/client/client.go @@ -581,8 +581,17 @@ func (c *client) GetAllMembers(ctx context.Context) ([]*pdpb.Member, error) { ctx, cancel := context.WithTimeout(ctx, c.option.timeout) req := &pdpb.GetMembersRequest{Header: c.requestHeader()} +<<<<<<< HEAD ctx = grpcutil.BuildForwardContext(ctx, c.GetLeaderAddr()) resp, err := c.getClient().GetMembers(ctx, req) +======= + protoClient, ctx := c.getClientAndContext(ctx) + if protoClient == nil { + cancel() + return nil, errs.ErrClientGetProtoClient + } + resp, err := protoClient.GetMembers(ctx, req) +>>>>>>> 180ff57af (client: avoid to add redundant grpc metadata (#7471)) cancel() if err = c.respForErr(cmdFailDurationGetAllMembers, start, err, resp.GetHeader()); err != nil { return nil, err @@ -1217,6 +1226,7 @@ func (c *client) getClient() pdpb.PDClient { return c.leaderClient() } +<<<<<<< HEAD func (c *client) getAllClients() map[string]pdpb.PDClient { var ( addrs = c.GetURLs() @@ -1249,6 +1259,17 @@ var tsoReqPool = sync.Pool{ logical: 0, } }, +======= +func (c *client) getClientAndContext(ctx context.Context) (pdpb.PDClient, context.Context) { + if c.option.enableForwarding && atomic.LoadInt32(&c.leaderNetworkFailure) == 1 { + backupClientConn, addr := c.backupClientConn() + if backupClientConn != nil { + log.Debug("[pd] use follower client", zap.String("addr", addr)) + return pdpb.NewPDClient(backupClientConn), grpcutil.BuildForwardContext(ctx, c.GetLeaderAddr()) + } + } + return c.leaderClient(), ctx +>>>>>>> 180ff57af (client: avoid to add redundant grpc metadata (#7471)) } func (c *client) GetTSAsync(ctx context.Context) TSFuture { @@ -1345,6 +1366,7 @@ func handleRegionResponse(res *pdpb.GetRegionResponse) *Region { return r } +<<<<<<< HEAD func (c *client) GetRegion(ctx context.Context, key []byte, opts ...GetRegionOption) (*Region, error) { if span := opentracing.SpanFromContext(ctx); span != nil { span = opentracing.StartSpan("pdclient.GetRegion", opentracing.ChildOf(span.Context())) @@ -1373,6 +1395,8 @@ func (c *client) GetRegion(ctx context.Context, key []byte, opts ...GetRegionOpt return handleRegionResponse(resp), nil } +======= +>>>>>>> 180ff57af (client: avoid to add redundant grpc metadata (#7471)) func isNetworkError(code codes.Code) bool { return code == codes.Unavailable || code == codes.DeadlineExceeded } @@ -1415,6 +1439,38 @@ func (c *client) GetRegionFromMember(ctx context.Context, key []byte, memberURLs return handleRegionResponse(resp), nil } +func (c *client) GetRegion(ctx context.Context, key []byte, opts ...GetRegionOption) (*Region, error) { + if span := opentracing.SpanFromContext(ctx); span != nil { + span = opentracing.StartSpan("pdclient.GetRegion", opentracing.ChildOf(span.Context())) + defer span.Finish() + } + start := time.Now() + defer func() { cmdDurationGetRegion.Observe(time.Since(start).Seconds()) }() + ctx, cancel := context.WithTimeout(ctx, c.option.timeout) + + options := &GetRegionOp{} + for _, opt := range opts { + opt(options) + } + req := &pdpb.GetRegionRequest{ + Header: c.requestHeader(), + RegionKey: key, + NeedBuckets: options.needBuckets, + } + protoClient, ctx := c.getClientAndContext(ctx) + if protoClient == nil { + cancel() + return nil, errs.ErrClientGetProtoClient + } + resp, err := protoClient.GetRegion(ctx, req) + cancel() + + if err = c.respForErr(cmdFailDurationGetRegion, start, err, resp.GetHeader()); err != nil { + return nil, err + } + return handleRegionResponse(resp), nil +} + func (c *client) GetPrevRegion(ctx context.Context, key []byte, opts ...GetRegionOption) (*Region, error) { if span := opentracing.SpanFromContext(ctx); span != nil { span = opentracing.StartSpan("pdclient.GetPrevRegion", opentracing.ChildOf(span.Context())) @@ -1433,8 +1489,17 @@ func (c *client) GetPrevRegion(ctx context.Context, key []byte, opts ...GetRegio RegionKey: key, NeedBuckets: options.needBuckets, } +<<<<<<< HEAD ctx = grpcutil.BuildForwardContext(ctx, c.GetLeaderAddr()) resp, err := c.getClient().GetPrevRegion(ctx, req) +======= + protoClient, ctx := c.getClientAndContext(ctx) + if protoClient == nil { + cancel() + return nil, errs.ErrClientGetProtoClient + } + resp, err := protoClient.GetPrevRegion(ctx, req) +>>>>>>> 180ff57af (client: avoid to add redundant grpc metadata (#7471)) cancel() if err = c.respForErr(cmdFailDurationGetPrevRegion, start, err, resp.GetHeader()); err != nil { @@ -1461,8 +1526,17 @@ func (c *client) GetRegionByID(ctx context.Context, regionID uint64, opts ...Get RegionId: regionID, NeedBuckets: options.needBuckets, } +<<<<<<< HEAD ctx = grpcutil.BuildForwardContext(ctx, c.GetLeaderAddr()) resp, err := c.getClient().GetRegionByID(ctx, req) +======= + protoClient, ctx := c.getClientAndContext(ctx) + if protoClient == nil { + cancel() + return nil, errs.ErrClientGetProtoClient + } + resp, err := protoClient.GetRegionByID(ctx, req) +>>>>>>> 180ff57af (client: avoid to add redundant grpc metadata (#7471)) cancel() if err = c.respForErr(cmdFailedDurationGetRegionByID, start, err, resp.GetHeader()); err != nil { @@ -1491,8 +1565,17 @@ func (c *client) ScanRegions(ctx context.Context, key, endKey []byte, limit int) EndKey: endKey, Limit: int32(limit), } +<<<<<<< HEAD scanCtx = grpcutil.BuildForwardContext(scanCtx, c.GetLeaderAddr()) resp, err := c.getClient().ScanRegions(scanCtx, req) +======= + protoClient, scanCtx := c.getClientAndContext(scanCtx) + if protoClient == nil { + cancel() + return nil, errs.ErrClientGetProtoClient + } + resp, err := protoClient.ScanRegions(scanCtx, req) +>>>>>>> 180ff57af (client: avoid to add redundant grpc metadata (#7471)) if err = c.respForErr(cmdFailedDurationScanRegions, start, err, resp.GetHeader()); err != nil { return nil, err @@ -1542,8 +1625,17 @@ func (c *client) GetStore(ctx context.Context, storeID uint64) (*metapb.Store, e Header: c.requestHeader(), StoreId: storeID, } +<<<<<<< HEAD ctx = grpcutil.BuildForwardContext(ctx, c.GetLeaderAddr()) resp, err := c.getClient().GetStore(ctx, req) +======= + protoClient, ctx := c.getClientAndContext(ctx) + if protoClient == nil { + cancel() + return nil, errs.ErrClientGetProtoClient + } + resp, err := protoClient.GetStore(ctx, req) +>>>>>>> 180ff57af (client: avoid to add redundant grpc metadata (#7471)) cancel() if err = c.respForErr(cmdFailedDurationGetStore, start, err, resp.GetHeader()); err != nil { @@ -1582,8 +1674,17 @@ func (c *client) GetAllStores(ctx context.Context, opts ...GetStoreOption) ([]*m Header: c.requestHeader(), ExcludeTombstoneStores: options.excludeTombstone, } +<<<<<<< HEAD ctx = grpcutil.BuildForwardContext(ctx, c.GetLeaderAddr()) resp, err := c.getClient().GetAllStores(ctx, req) +======= + protoClient, ctx := c.getClientAndContext(ctx) + if protoClient == nil { + cancel() + return nil, errs.ErrClientGetProtoClient + } + resp, err := protoClient.GetAllStores(ctx, req) +>>>>>>> 180ff57af (client: avoid to add redundant grpc metadata (#7471)) cancel() if err = c.respForErr(cmdFailedDurationGetAllStores, start, err, resp.GetHeader()); err != nil { @@ -1605,8 +1706,17 @@ func (c *client) UpdateGCSafePoint(ctx context.Context, safePoint uint64) (uint6 Header: c.requestHeader(), SafePoint: safePoint, } +<<<<<<< HEAD ctx = grpcutil.BuildForwardContext(ctx, c.GetLeaderAddr()) resp, err := c.getClient().UpdateGCSafePoint(ctx, req) +======= + protoClient, ctx := c.getClientAndContext(ctx) + if protoClient == nil { + cancel() + return 0, errs.ErrClientGetProtoClient + } + resp, err := protoClient.UpdateGCSafePoint(ctx, req) +>>>>>>> 180ff57af (client: avoid to add redundant grpc metadata (#7471)) cancel() if err = c.respForErr(cmdFailedDurationUpdateGCSafePoint, start, err, resp.GetHeader()); err != nil { @@ -1635,8 +1745,17 @@ func (c *client) UpdateServiceGCSafePoint(ctx context.Context, serviceID string, TTL: ttl, SafePoint: safePoint, } +<<<<<<< HEAD ctx = grpcutil.BuildForwardContext(ctx, c.GetLeaderAddr()) resp, err := c.getClient().UpdateServiceGCSafePoint(ctx, req) +======= + protoClient, ctx := c.getClientAndContext(ctx) + if protoClient == nil { + cancel() + return 0, errs.ErrClientGetProtoClient + } + resp, err := protoClient.UpdateServiceGCSafePoint(ctx, req) +>>>>>>> 180ff57af (client: avoid to add redundant grpc metadata (#7471)) cancel() if err = c.respForErr(cmdFailedDurationUpdateServiceGCSafePoint, start, err, resp.GetHeader()); err != nil { @@ -1663,8 +1782,17 @@ func (c *client) scatterRegionsWithGroup(ctx context.Context, regionID uint64, g RegionId: regionID, Group: group, } +<<<<<<< HEAD ctx = grpcutil.BuildForwardContext(ctx, c.GetLeaderAddr()) resp, err := c.getClient().ScatterRegion(ctx, req) +======= + protoClient, ctx := c.getClientAndContext(ctx) + if protoClient == nil { + cancel() + return errs.ErrClientGetProtoClient + } + resp, err := protoClient.ScatterRegion(ctx, req) +>>>>>>> 180ff57af (client: avoid to add redundant grpc metadata (#7471)) cancel() if err != nil { return err @@ -1703,8 +1831,17 @@ func (c *client) SplitAndScatterRegions(ctx context.Context, splitKeys [][]byte, RetryLimit: options.retryLimit, } +<<<<<<< HEAD ctx = grpcutil.BuildForwardContext(ctx, c.GetLeaderAddr()) return c.getClient().SplitAndScatterRegions(ctx, req) +======= + protoClient, ctx := c.getClientAndContext(ctx) + if protoClient == nil { + cancel() + return nil, errs.ErrClientGetProtoClient + } + return protoClient.SplitAndScatterRegions(ctx, req) +>>>>>>> 180ff57af (client: avoid to add redundant grpc metadata (#7471)) } func (c *client) GetOperator(ctx context.Context, regionID uint64) (*pdpb.GetOperatorResponse, error) { @@ -1721,8 +1858,17 @@ func (c *client) GetOperator(ctx context.Context, regionID uint64) (*pdpb.GetOpe Header: c.requestHeader(), RegionId: regionID, } +<<<<<<< HEAD ctx = grpcutil.BuildForwardContext(ctx, c.GetLeaderAddr()) return c.getClient().GetOperator(ctx, req) +======= + protoClient, ctx := c.getClientAndContext(ctx) + if protoClient == nil { + cancel() + return nil, errs.ErrClientGetProtoClient + } + return protoClient.GetOperator(ctx, req) +>>>>>>> 180ff57af (client: avoid to add redundant grpc metadata (#7471)) } // SplitRegions split regions by given split keys @@ -1744,8 +1890,17 @@ func (c *client) SplitRegions(ctx context.Context, splitKeys [][]byte, opts ...R SplitKeys: splitKeys, RetryLimit: options.retryLimit, } +<<<<<<< HEAD ctx = grpcutil.BuildForwardContext(ctx, c.GetLeaderAddr()) return c.getClient().SplitRegions(ctx, req) +======= + protoClient, ctx := c.getClientAndContext(ctx) + if protoClient == nil { + cancel() + return nil, errs.ErrClientGetProtoClient + } + return protoClient.SplitRegions(ctx, req) +>>>>>>> 180ff57af (client: avoid to add redundant grpc metadata (#7471)) } func (c *client) requestHeader() *pdpb.RequestHeader { @@ -1769,8 +1924,17 @@ func (c *client) scatterRegionsWithOptions(ctx context.Context, regionsID []uint RetryLimit: options.retryLimit, } +<<<<<<< HEAD ctx = grpcutil.BuildForwardContext(ctx, c.GetLeaderAddr()) resp, err := c.getClient().ScatterRegion(ctx, req) +======= + protoClient, ctx := c.getClientAndContext(ctx) + if protoClient == nil { + cancel() + return nil, errs.ErrClientGetProtoClient + } + resp, err := protoClient.ScatterRegion(ctx, req) +>>>>>>> 180ff57af (client: avoid to add redundant grpc metadata (#7471)) cancel() if err != nil { @@ -1807,8 +1971,19 @@ func trimHTTPPrefix(str string) string { return str } +<<<<<<< HEAD func (c *client) LoadGlobalConfig(ctx context.Context, names []string) ([]GlobalConfigItem, error) { resp, err := c.getClient().LoadGlobalConfig(ctx, &pdpb.LoadGlobalConfigRequest{Names: names}) +======= +func (c *client) LoadGlobalConfig(ctx context.Context, names []string, configPath string) ([]GlobalConfigItem, int64, error) { + ctx, cancel := context.WithTimeout(ctx, c.option.timeout) + defer cancel() + protoClient, ctx := c.getClientAndContext(ctx) + if protoClient == nil { + return nil, 0, errs.ErrClientGetProtoClient + } + resp, err := protoClient.LoadGlobalConfig(ctx, &pdpb.LoadGlobalConfigRequest{Names: names, ConfigPath: configPath}) +>>>>>>> 180ff57af (client: avoid to add redundant grpc metadata (#7471)) if err != nil { return nil, err } @@ -1834,7 +2009,17 @@ func (c *client) StoreGlobalConfig(ctx context.Context, items []GlobalConfigItem for i, it := range items { resArr[i] = &pdpb.GlobalConfigItem{Name: it.Name, Value: it.Value} } +<<<<<<< HEAD res, err := c.getClient().StoreGlobalConfig(ctx, &pdpb.StoreGlobalConfigRequest{Changes: resArr}) +======= + ctx, cancel := context.WithTimeout(ctx, c.option.timeout) + defer cancel() + protoClient, ctx := c.getClientAndContext(ctx) + if protoClient == nil { + return errs.ErrClientGetProtoClient + } + _, err := protoClient.StoreGlobalConfig(ctx, &pdpb.StoreGlobalConfigRequest{Changes: resArr, ConfigPath: configPath}) +>>>>>>> 180ff57af (client: avoid to add redundant grpc metadata (#7471)) if err != nil { return err } @@ -1847,7 +2032,20 @@ func (c *client) StoreGlobalConfig(ctx context.Context, items []GlobalConfigItem func (c *client) WatchGlobalConfig(ctx context.Context) (chan []GlobalConfigItem, error) { globalConfigWatcherCh := make(chan []GlobalConfigItem, 16) +<<<<<<< HEAD res, err := c.getClient().WatchGlobalConfig(ctx, &pdpb.WatchGlobalConfigRequest{}) +======= + ctx, cancel := context.WithTimeout(ctx, c.option.timeout) + defer cancel() + protoClient, ctx := c.getClientAndContext(ctx) + if protoClient == nil { + return nil, errs.ErrClientGetProtoClient + } + res, err := protoClient.WatchGlobalConfig(ctx, &pdpb.WatchGlobalConfigRequest{ + ConfigPath: configPath, + Revision: revision, + }) +>>>>>>> 180ff57af (client: avoid to add redundant grpc metadata (#7471)) if err != nil { close(globalConfigWatcherCh) return nil, err @@ -1880,6 +2078,50 @@ func (c *client) WatchGlobalConfig(ctx context.Context) (chan []GlobalConfigItem return globalConfigWatcherCh, err } +<<<<<<< HEAD +======= +func (c *client) GetExternalTimestamp(ctx context.Context) (uint64, error) { + ctx, cancel := context.WithTimeout(ctx, c.option.timeout) + defer cancel() + protoClient, ctx := c.getClientAndContext(ctx) + if protoClient == nil { + return 0, errs.ErrClientGetProtoClient + } + resp, err := protoClient.GetExternalTimestamp(ctx, &pdpb.GetExternalTimestampRequest{ + Header: c.requestHeader(), + }) + if err != nil { + return 0, err + } + resErr := resp.GetHeader().GetError() + if resErr != nil { + return 0, errors.Errorf("[pd]" + resErr.Message) + } + return resp.GetTimestamp(), nil +} + +func (c *client) SetExternalTimestamp(ctx context.Context, timestamp uint64) error { + ctx, cancel := context.WithTimeout(ctx, c.option.timeout) + defer cancel() + protoClient, ctx := c.getClientAndContext(ctx) + if protoClient == nil { + return errs.ErrClientGetProtoClient + } + resp, err := protoClient.SetExternalTimestamp(ctx, &pdpb.SetExternalTimestampRequest{ + Header: c.requestHeader(), + Timestamp: timestamp, + }) + if err != nil { + return err + } + resErr := resp.GetHeader().GetError() + if resErr != nil { + return errors.Errorf("[pd]" + resErr.Message) + } + return nil +} + +>>>>>>> 180ff57af (client: avoid to add redundant grpc metadata (#7471)) func (c *client) respForErr(observer prometheus.Observer, start time.Time, err error, header *pdpb.ResponseHeader) error { if err != nil || header.GetError() != nil { observer.Observe(time.Since(start).Seconds()) diff --git a/client/gc_client.go b/client/gc_client.go new file mode 100644 index 00000000000..fff292405c2 --- /dev/null +++ b/client/gc_client.go @@ -0,0 +1,136 @@ +// Copyright 2023 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pd + +import ( + "context" + "time" + + "github.com/opentracing/opentracing-go" + "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/pingcap/log" + "github.com/tikv/pd/client/errs" + "go.uber.org/zap" +) + +// GCClient is a client for doing GC +type GCClient interface { + UpdateGCSafePointV2(ctx context.Context, keyspaceID uint32, safePoint uint64) (uint64, error) + UpdateServiceSafePointV2(ctx context.Context, keyspaceID uint32, serviceID string, ttl int64, safePoint uint64) (uint64, error) + WatchGCSafePointV2(ctx context.Context, revision int64) (chan []*pdpb.SafePointEvent, error) +} + +// UpdateGCSafePointV2 update gc safe point for the given keyspace. +func (c *client) UpdateGCSafePointV2(ctx context.Context, keyspaceID uint32, safePoint uint64) (uint64, error) { + if span := opentracing.SpanFromContext(ctx); span != nil { + span = opentracing.StartSpan("pdclient.UpdateGCSafePointV2", opentracing.ChildOf(span.Context())) + defer span.Finish() + } + start := time.Now() + defer func() { cmdDurationUpdateGCSafePointV2.Observe(time.Since(start).Seconds()) }() + + ctx, cancel := context.WithTimeout(ctx, c.option.timeout) + req := &pdpb.UpdateGCSafePointV2Request{ + Header: c.requestHeader(), + KeyspaceId: keyspaceID, + SafePoint: safePoint, + } + protoClient, ctx := c.getClientAndContext(ctx) + if protoClient == nil { + cancel() + return 0, errs.ErrClientGetProtoClient + } + resp, err := protoClient.UpdateGCSafePointV2(ctx, req) + cancel() + + if err = c.respForErr(cmdFailedDurationUpdateGCSafePointV2, start, err, resp.GetHeader()); err != nil { + return 0, err + } + return resp.GetNewSafePoint(), nil +} + +// UpdateServiceSafePointV2 update service safe point for the given keyspace. +func (c *client) UpdateServiceSafePointV2(ctx context.Context, keyspaceID uint32, serviceID string, ttl int64, safePoint uint64) (uint64, error) { + if span := opentracing.SpanFromContext(ctx); span != nil { + span = opentracing.StartSpan("pdclient.UpdateServiceSafePointV2", opentracing.ChildOf(span.Context())) + defer span.Finish() + } + start := time.Now() + defer func() { cmdDurationUpdateServiceSafePointV2.Observe(time.Since(start).Seconds()) }() + + ctx, cancel := context.WithTimeout(ctx, c.option.timeout) + req := &pdpb.UpdateServiceSafePointV2Request{ + Header: c.requestHeader(), + KeyspaceId: keyspaceID, + ServiceId: []byte(serviceID), + SafePoint: safePoint, + Ttl: ttl, + } + protoClient, ctx := c.getClientAndContext(ctx) + if protoClient == nil { + cancel() + return 0, errs.ErrClientGetProtoClient + } + resp, err := protoClient.UpdateServiceSafePointV2(ctx, req) + cancel() + if err = c.respForErr(cmdFailedDurationUpdateServiceSafePointV2, start, err, resp.GetHeader()); err != nil { + return 0, err + } + return resp.GetMinSafePoint(), nil +} + +// WatchGCSafePointV2 watch gc safe point change. +func (c *client) WatchGCSafePointV2(ctx context.Context, revision int64) (chan []*pdpb.SafePointEvent, error) { + SafePointEventsChan := make(chan []*pdpb.SafePointEvent) + req := &pdpb.WatchGCSafePointV2Request{ + Header: c.requestHeader(), + Revision: revision, + } + + ctx, cancel := context.WithTimeout(ctx, c.option.timeout) + defer cancel() + protoClient, ctx := c.getClientAndContext(ctx) + if protoClient == nil { + return nil, errs.ErrClientGetProtoClient + } + stream, err := protoClient.WatchGCSafePointV2(ctx, req) + if err != nil { + close(SafePointEventsChan) + return nil, err + } + go func() { + defer func() { + close(SafePointEventsChan) + if r := recover(); r != nil { + log.Error("[pd] panic in gc client `WatchGCSafePointV2`", zap.Any("error", r)) + return + } + }() + for { + select { + case <-ctx.Done(): + return + default: + resp, err := stream.Recv() + if err != nil { + log.Error("watch gc safe point v2 error", errs.ZapError(errs.ErrClientWatchGCSafePointV2Stream, err)) + return + } + SafePointEventsChan <- resp.GetEvents() + } + } + }() + return SafePointEventsChan, err +} diff --git a/client/keyspace_client.go b/client/keyspace_client.go new file mode 100644 index 00000000000..fedb7452412 --- /dev/null +++ b/client/keyspace_client.go @@ -0,0 +1,153 @@ +// Copyright 2022 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pd + +import ( + "context" + "time" + + "github.com/opentracing/opentracing-go" + "github.com/pingcap/errors" + "github.com/pingcap/kvproto/pkg/keyspacepb" +) + +// KeyspaceClient manages keyspace metadata. +type KeyspaceClient interface { + // LoadKeyspace load and return target keyspace's metadata. + LoadKeyspace(ctx context.Context, name string) (*keyspacepb.KeyspaceMeta, error) + // UpdateKeyspaceState updates target keyspace's state. + UpdateKeyspaceState(ctx context.Context, id uint32, state keyspacepb.KeyspaceState) (*keyspacepb.KeyspaceMeta, error) + // WatchKeyspaces watches keyspace meta changes. + WatchKeyspaces(ctx context.Context) (chan []*keyspacepb.KeyspaceMeta, error) + // GetAllKeyspaces get all keyspace's metadata. + GetAllKeyspaces(ctx context.Context, startID uint32, limit uint32) ([]*keyspacepb.KeyspaceMeta, error) +} + +// keyspaceClient returns the KeyspaceClient from current PD leader. +func (c *client) keyspaceClient() keyspacepb.KeyspaceClient { + if client := c.pdSvcDiscovery.GetServingEndpointClientConn(); client != nil { + return keyspacepb.NewKeyspaceClient(client) + } + return nil +} + +// LoadKeyspace loads and returns target keyspace's metadata. +func (c *client) LoadKeyspace(ctx context.Context, name string) (*keyspacepb.KeyspaceMeta, error) { + if span := opentracing.SpanFromContext(ctx); span != nil { + span = opentracing.StartSpan("keyspaceClient.LoadKeyspace", opentracing.ChildOf(span.Context())) + defer span.Finish() + } + start := time.Now() + defer func() { cmdDurationLoadKeyspace.Observe(time.Since(start).Seconds()) }() + ctx, cancel := context.WithTimeout(ctx, c.option.timeout) + req := &keyspacepb.LoadKeyspaceRequest{ + Header: c.requestHeader(), + Name: name, + } + resp, err := c.keyspaceClient().LoadKeyspace(ctx, req) + cancel() + + if err != nil { + cmdFailedDurationLoadKeyspace.Observe(time.Since(start).Seconds()) + c.pdSvcDiscovery.ScheduleCheckMemberChanged() + return nil, err + } + + if resp.Header.GetError() != nil { + cmdFailedDurationLoadKeyspace.Observe(time.Since(start).Seconds()) + return nil, errors.Errorf("Load keyspace %s failed: %s", name, resp.Header.GetError().String()) + } + + return resp.Keyspace, nil +} + +// UpdateKeyspaceState attempts to update the keyspace specified by ID to the target state, +// it will also record StateChangedAt for the given keyspace if a state change took place. +// Currently, legal operations includes: +// +// ENABLED -> {ENABLED, DISABLED} +// DISABLED -> {ENABLED, DISABLED, ARCHIVED} +// ARCHIVED -> {ARCHIVED, TOMBSTONE} +// TOMBSTONE -> {TOMBSTONE} +// +// Updated keyspace meta will be returned. +func (c *client) UpdateKeyspaceState(ctx context.Context, id uint32, state keyspacepb.KeyspaceState) (*keyspacepb.KeyspaceMeta, error) { + if span := opentracing.SpanFromContext(ctx); span != nil { + span = opentracing.StartSpan("keyspaceClient.UpdateKeyspaceState", opentracing.ChildOf(span.Context())) + defer span.Finish() + } + start := time.Now() + defer func() { cmdDurationUpdateKeyspaceState.Observe(time.Since(start).Seconds()) }() + ctx, cancel := context.WithTimeout(ctx, c.option.timeout) + req := &keyspacepb.UpdateKeyspaceStateRequest{ + Header: c.requestHeader(), + Id: id, + State: state, + } + resp, err := c.keyspaceClient().UpdateKeyspaceState(ctx, req) + cancel() + + if err != nil { + cmdFailedDurationUpdateKeyspaceState.Observe(time.Since(start).Seconds()) + c.pdSvcDiscovery.ScheduleCheckMemberChanged() + return nil, err + } + + if resp.Header.GetError() != nil { + cmdFailedDurationUpdateKeyspaceState.Observe(time.Since(start).Seconds()) + return nil, errors.Errorf("Update state for keyspace id %d failed: %s", id, resp.Header.GetError().String()) + } + + return resp.Keyspace, nil +} + +// WatchKeyspaces watches keyspace meta changes. +// It returns a stream of slices of keyspace metadata. +// The first message in stream contains all current keyspaceMeta, +// all subsequent messages contains new put events for all keyspaces. +func (c *client) WatchKeyspaces(ctx context.Context) (chan []*keyspacepb.KeyspaceMeta, error) { + return nil, errors.Errorf("WatchKeyspaces unimplemented") +} + +// GetAllKeyspaces get all keyspaces metadata. +func (c *client) GetAllKeyspaces(ctx context.Context, startID uint32, limit uint32) ([]*keyspacepb.KeyspaceMeta, error) { + if span := opentracing.SpanFromContext(ctx); span != nil { + span = opentracing.StartSpan("keyspaceClient.GetAllKeyspaces", opentracing.ChildOf(span.Context())) + defer span.Finish() + } + start := time.Now() + defer func() { cmdDurationGetAllKeyspaces.Observe(time.Since(start).Seconds()) }() + ctx, cancel := context.WithTimeout(ctx, c.option.timeout) + req := &keyspacepb.GetAllKeyspacesRequest{ + Header: c.requestHeader(), + StartId: startID, + Limit: limit, + } + resp, err := c.keyspaceClient().GetAllKeyspaces(ctx, req) + cancel() + + if err != nil { + cmdDurationGetAllKeyspaces.Observe(time.Since(start).Seconds()) + c.pdSvcDiscovery.ScheduleCheckMemberChanged() + return nil, err + } + + if resp.Header.GetError() != nil { + cmdDurationGetAllKeyspaces.Observe(time.Since(start).Seconds()) + return nil, errors.Errorf("Get all keyspaces metadata failed: %s", resp.Header.GetError().String()) + } + + return resp.Keyspaces, nil +}