From dac3836fa8d88a08afe4bfd24911cdc9efea6e9c Mon Sep 17 00:00:00 2001 From: Yongbo Jiang Date: Tue, 17 May 2022 18:12:38 +0800 Subject: [PATCH] config, api: Add Service Config (#4869) ref tikv/pd#4480 add service config Signed-off-by: Cabinfever_B Co-authored-by: Ti Chi Robot --- server/api/config_test.go | 21 --- server/api/middleware.go | 4 +- server/api/min_resolved_ts.go | 2 +- server/api/router.go | 4 + server/api/service_gc_safepoint.go | 4 +- server/api/service_middleware.go | 131 ++++++++++++++++++ server/api/service_middleware_test.go | 100 +++++++++++++ server/config/config.go | 6 - server/config/persist_options.go | 5 - server/config/service_middleware_config.go | 53 +++++++ .../service_middleware_persist_options.go | 77 ++++++++++ server/server.go | 62 +++++++-- server/storage/endpoint/key_path.go | 1 + server/storage/endpoint/service_middleware.go | 51 +++++++ server/storage/storage.go | 1 + tests/server/api/api_test.go | 46 +++--- 16 files changed, 498 insertions(+), 70 deletions(-) create mode 100644 server/api/service_middleware.go create mode 100644 server/api/service_middleware_test.go create mode 100644 server/config/service_middleware_config.go create mode 100644 server/config/service_middleware_persist_options.go create mode 100644 server/storage/endpoint/service_middleware.go diff --git a/server/api/config_test.go b/server/api/config_test.go index 1d115e714af..271849ce223 100644 --- a/server/api/config_test.go +++ b/server/api/config_test.go @@ -305,27 +305,6 @@ func (s *testConfigSuite) TestConfigPDServer(c *C) { c.Assert(sc.FlowRoundByDigit, Equals, int(3)) c.Assert(sc.MinResolvedTSPersistenceInterval, Equals, typeutil.NewDuration(0)) c.Assert(sc.MaxResetTSGap.Duration, Equals, 24*time.Hour) - c.Assert(sc.EnableAudit, Equals, false) - - // test update enable-audit - ms = map[string]interface{}{ - "enable-audit": true, - } - postData, err = json.Marshal(ms) - c.Assert(err, IsNil) - c.Assert(tu.CheckPostJSON(testDialClient, addrPost, postData, tu.StatusOK(c)), IsNil) - sc = &config.PDServerConfig{} - c.Assert(tu.ReadGetJSON(c, testDialClient, addrGet, sc), IsNil) - c.Assert(sc.EnableAudit, Equals, true) - ms = map[string]interface{}{ - "enable-audit": false, - } - postData, err = json.Marshal(ms) - c.Assert(err, IsNil) - c.Assert(tu.CheckPostJSON(testDialClient, addrPost, postData, tu.StatusOK(c)), IsNil) - sc = &config.PDServerConfig{} - c.Assert(tu.ReadGetJSON(c, testDialClient, addrGet, sc), IsNil) - c.Assert(sc.EnableAudit, Equals, false) } var ttlConfig = map[string]interface{}{ diff --git a/server/api/middleware.go b/server/api/middleware.go index b3022986b48..8cdd25156ca 100644 --- a/server/api/middleware.go +++ b/server/api/middleware.go @@ -56,7 +56,7 @@ func newRequestInfoMiddleware(s *server.Server) negroni.Handler { } func (rm *requestInfoMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { - if !rm.svr.GetPersistOptions().IsAuditEnabled() { + if !rm.svr.GetServiceMiddlewarePersistOptions().IsAuditEnabled() { next(w, r) return } @@ -116,7 +116,7 @@ func newAuditMiddleware(s *server.Server) negroni.Handler { // ServeHTTP is used to implememt negroni.Handler for auditMiddleware func (s *auditMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { - if !s.svr.GetPersistOptions().IsAuditEnabled() { + if !s.svr.GetServiceMiddlewarePersistOptions().IsAuditEnabled() { next(w, r) return } diff --git a/server/api/min_resolved_ts.go b/server/api/min_resolved_ts.go index 741c7da0da1..c717f0a3b42 100644 --- a/server/api/min_resolved_ts.go +++ b/server/api/min_resolved_ts.go @@ -41,7 +41,7 @@ type minResolvedTS struct { PersistInterval typeutil.Duration `json:"persist_interval,omitempty"` } -// @Tags minresolvedts +// @Tags min_resolved_ts // @Summary Get cluster-level min resolved ts. // @Produce json // @Success 200 {array} minResolvedTS diff --git a/server/api/router.go b/server/api/router.go index edef80fdf74..b9eaeb4b090 100644 --- a/server/api/router.go +++ b/server/api/router.go @@ -282,6 +282,10 @@ func createRouter(prefix string, svr *server.Server) *mux.Router { registerFunc(apiRouter, "/admin/persist-file/{file_name}", adminHandler.SavePersistFile, setMethods("POST"), setAuditBackend(localLog)) registerFunc(clusterRouter, "/admin/replication_mode/wait-async", adminHandler.UpdateWaitAsyncTime, setMethods("POST"), setAuditBackend(localLog)) + serviceMiddlewareHandler := newServiceMiddlewareHandler(svr, rd) + registerFunc(apiRouter, "/service-middleware/config", serviceMiddlewareHandler.GetServiceMiddlewareConfig, setMethods("GET")) + registerFunc(apiRouter, "/service-middleware/config", serviceMiddlewareHandler.SetServiceMiddlewareConfig, setMethods("POST"), setAuditBackend(localLog)) + logHandler := newLogHandler(svr, rd) registerFunc(apiRouter, "/admin/log", logHandler.SetLogLevel, setMethods("POST"), setAuditBackend(localLog)) replicationModeHandler := newReplicationModeHandler(svr, rd) diff --git a/server/api/service_gc_safepoint.go b/server/api/service_gc_safepoint.go index f385bb866e8..40c3aff1076 100644 --- a/server/api/service_gc_safepoint.go +++ b/server/api/service_gc_safepoint.go @@ -41,7 +41,7 @@ type listServiceGCSafepoint struct { GCSafePoint uint64 `json:"gc_safe_point"` } -// @Tags servicegcsafepoint +// @Tags service_gc_safepoint // @Summary Get all service GC safepoint. // @Produce json // @Success 200 {array} listServiceGCSafepoint @@ -66,7 +66,7 @@ func (h *serviceGCSafepointHandler) GetGCSafePoint(w http.ResponseWriter, r *htt h.rd.JSON(w, http.StatusOK, list) } -// @Tags servicegcsafepoint +// @Tags service_gc_safepoint // @Summary Delete a service GC safepoint. // @Param service_id path string true "Service ID" // @Produce json diff --git a/server/api/service_middleware.go b/server/api/service_middleware.go new file mode 100644 index 00000000000..c136f8fbf4e --- /dev/null +++ b/server/api/service_middleware.go @@ -0,0 +1,131 @@ +// Copyright 2022 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "reflect" + "strings" + + "github.com/pingcap/errors" + "github.com/tikv/pd/pkg/reflectutil" + "github.com/tikv/pd/server" + "github.com/tikv/pd/server/config" + + "github.com/unrolled/render" +) + +type serviceMiddlewareHandler struct { + svr *server.Server + rd *render.Render +} + +func newServiceMiddlewareHandler(svr *server.Server, rd *render.Render) *serviceMiddlewareHandler { + return &serviceMiddlewareHandler{ + svr: svr, + rd: rd, + } +} + +// @Tags service_middleware +// @Summary Get Service Middleware config. +// @Produce json +// @Success 200 {object} config.Config +// @Router /service-middleware/config [get] +func (h *serviceMiddlewareHandler) GetServiceMiddlewareConfig(w http.ResponseWriter, r *http.Request) { + h.rd.JSON(w, http.StatusOK, h.svr.GetServiceMiddlewareConfig()) +} + +// @Tags service_middleware +// @Summary Update some service-middleware's config items. +// @Accept json +// @Param body body object false "json params" +// @Produce json +// @Success 200 {string} string "The config is updated." +// @Failure 400 {string} string "The input is invalid." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /service-middleware/config [post] +func (h *serviceMiddlewareHandler) SetServiceMiddlewareConfig(w http.ResponseWriter, r *http.Request) { + cfg := h.svr.GetServiceMiddlewareConfig() + data, err := io.ReadAll(r.Body) + r.Body.Close() + if err != nil { + h.rd.JSON(w, http.StatusInternalServerError, err.Error()) + return + } + + conf := make(map[string]interface{}) + if err := json.Unmarshal(data, &conf); err != nil { + h.rd.JSON(w, http.StatusBadRequest, err.Error()) + return + } + + if len(conf) == 0 { + h.rd.JSON(w, http.StatusOK, "The input is empty.") + } + + for k, v := range conf { + if s := strings.Split(k, "."); len(s) > 1 { + if err := h.updateServiceMiddlewareConfig(cfg, k, v); err != nil { + h.rd.JSON(w, http.StatusBadRequest, err.Error()) + return + } + continue + } + key := reflectutil.FindJSONFullTagByChildTag(reflect.TypeOf(config.ServiceMiddlewareConfig{}), k) + if key == "" { + h.rd.JSON(w, http.StatusBadRequest, fmt.Sprintf("config item %s not found", k)) + return + } + if err := h.updateServiceMiddlewareConfig(cfg, key, v); err != nil { + h.rd.JSON(w, http.StatusBadRequest, err.Error()) + return + } + } + + h.rd.JSON(w, http.StatusOK, "The service-middleware config is updated.") +} + +func (h *serviceMiddlewareHandler) updateServiceMiddlewareConfig(cfg *config.ServiceMiddlewareConfig, key string, value interface{}) error { + kp := strings.Split(key, ".") + if kp[0] == "audit" { + return h.updateAudit(cfg, kp[len(kp)-1], value) + } + return errors.Errorf("config prefix %s not found", kp[0]) +} + +func (h *serviceMiddlewareHandler) updateAudit(config *config.ServiceMiddlewareConfig, key string, value interface{}) error { + data, err := json.Marshal(map[string]interface{}{key: value}) + if err != nil { + return err + } + + updated, found, err := mergeConfig(&config.AuditConfig, data) + if err != nil { + return err + } + + if !found { + return errors.Errorf("config item %s not found", key) + } + + if updated { + err = h.svr.SetAuditConfig(config.AuditConfig) + } + return err +} diff --git a/server/api/service_middleware_test.go b/server/api/service_middleware_test.go new file mode 100644 index 00000000000..3d29b23a693 --- /dev/null +++ b/server/api/service_middleware_test.go @@ -0,0 +1,100 @@ +// Copyright 2022 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "encoding/json" + "fmt" + "net/http" + + . "github.com/pingcap/check" + "github.com/pingcap/failpoint" + tu "github.com/tikv/pd/pkg/testutil" + "github.com/tikv/pd/server" + "github.com/tikv/pd/server/config" +) + +var _ = Suite(&testServiceMiddlewareSuite{}) + +type testServiceMiddlewareSuite struct { + svr *server.Server + cleanup cleanUpFunc + urlPrefix string +} + +func (s *testServiceMiddlewareSuite) SetUpSuite(c *C) { + s.svr, s.cleanup = mustNewServer(c, func(cfg *config.Config) { + cfg.Replication.EnablePlacementRules = false + }) + mustWaitLeader(c, []*server.Server{s.svr}) + + addr := s.svr.GetAddr() + s.urlPrefix = fmt.Sprintf("%s%s/api/v1", addr, apiPrefix) +} + +func (s *testServiceMiddlewareSuite) TearDownSuite(c *C) { + s.cleanup() +} + +func (s *testServiceMiddlewareSuite) TestConfigAudit(c *C) { + addr := fmt.Sprintf("%s/service-middleware/config", s.urlPrefix) + ms := map[string]interface{}{ + "enable-audit": "true", + } + postData, err := json.Marshal(ms) + c.Assert(err, IsNil) + c.Assert(tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(c)), IsNil) + sc := &config.ServiceMiddlewareConfig{} + c.Assert(tu.ReadGetJSON(c, testDialClient, addr, sc), IsNil) + c.Assert(sc.EnableAudit, Equals, true) + ms = map[string]interface{}{ + "audit.enable-audit": "false", + } + postData, err = json.Marshal(ms) + c.Assert(err, IsNil) + c.Assert(tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(c)), IsNil) + sc = &config.ServiceMiddlewareConfig{} + c.Assert(tu.ReadGetJSON(c, testDialClient, addr, sc), IsNil) + c.Assert(sc.EnableAudit, Equals, false) + + // test empty + ms = map[string]interface{}{} + postData, err = json.Marshal(ms) + c.Assert(err, IsNil) + c.Assert(tu.CheckPostJSON(testDialClient, addr, postData, tu.StatusOK(c), tu.StringContain(c, "The input is empty.")), IsNil) + + ms = map[string]interface{}{ + "audit": "false", + } + postData, err = json.Marshal(ms) + c.Assert(err, IsNil) + c.Assert(tu.CheckPostJSON(testDialClient, addr, postData, tu.Status(c, http.StatusBadRequest), tu.StringEqual(c, "config item audit not found")), IsNil) + + c.Assert(failpoint.Enable("github.com/tikv/pd/server/config/persistServiceMiddlewareFail", "return(true)"), IsNil) + ms = map[string]interface{}{ + "audit.enable-audit": "true", + } + postData, err = json.Marshal(ms) + c.Assert(err, IsNil) + c.Assert(tu.CheckPostJSON(testDialClient, addr, postData, tu.Status(c, http.StatusBadRequest)), IsNil) + c.Assert(failpoint.Disable("github.com/tikv/pd/server/config/persistServiceMiddlewareFail"), IsNil) + + ms = map[string]interface{}{ + "audit.audit": "false", + } + postData, err = json.Marshal(ms) + c.Assert(err, IsNil) + c.Assert(tu.CheckPostJSON(testDialClient, addr, postData, tu.Status(c, http.StatusBadRequest), tu.StringEqual(c, "config item audit not found")), IsNil) +} diff --git a/server/config/config.go b/server/config/config.go index 3db12b6b1e2..e1d470ffc78 100644 --- a/server/config/config.go +++ b/server/config/config.go @@ -232,7 +232,6 @@ const ( maxTraceFlowRoundByDigit = 5 // 0.1 MB defaultMaxResetTSGap = 24 * time.Hour defaultMinResolvedTSPersistenceInterval = 0 - defaultEnableAuditMiddleware = false defaultKeyType = "table" defaultStrictlyMatchLabel = false @@ -1116,8 +1115,6 @@ type PDServerConfig struct { FlowRoundByDigit int `toml:"flow-round-by-digit" json:"flow-round-by-digit"` // MinResolvedTSPersistenceInterval is the interval to save the min resolved ts. MinResolvedTSPersistenceInterval typeutil.Duration `toml:"min-resolved-ts-persistence-interval" json:"min-resolved-ts-persistence-interval"` - // EnableAudit controls the switch of the audit middleware - EnableAudit bool `toml:"enable-audit" json:"enable-audit"` } func (c *PDServerConfig) adjust(meta *configMetaData) error { @@ -1143,9 +1140,6 @@ func (c *PDServerConfig) adjust(meta *configMetaData) error { if !meta.IsDefined("min-resolved-ts-persistence-interval") { adjustDuration(&c.MinResolvedTSPersistenceInterval, defaultMinResolvedTSPersistenceInterval) } - if !meta.IsDefined("enable-audit") { - c.EnableAudit = defaultEnableAuditMiddleware - } c.migrateConfigurationFromFile(meta) return c.Validate() } diff --git a/server/config/persist_options.go b/server/config/persist_options.go index aa2bcea4abc..6fbb4161b62 100644 --- a/server/config/persist_options.go +++ b/server/config/persist_options.go @@ -456,11 +456,6 @@ func (o *PersistOptions) GetLeaderSchedulePolicy() core.SchedulePolicy { return core.StringToSchedulePolicy(o.GetScheduleConfig().LeaderSchedulePolicy) } -// IsAuditEnabled returns whether audit middleware is enabled -func (o *PersistOptions) IsAuditEnabled() bool { - return o.GetPDServerConfig().EnableAudit -} - // GetKeyType is to get key type. func (o *PersistOptions) GetKeyType() core.KeyType { return core.StringToKeyType(o.GetPDServerConfig().KeyType) diff --git a/server/config/service_middleware_config.go b/server/config/service_middleware_config.go new file mode 100644 index 00000000000..d1b600ccaf2 --- /dev/null +++ b/server/config/service_middleware_config.go @@ -0,0 +1,53 @@ +// Copyright 2022 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +const ( + defaultEnableAuditMiddleware = false +) + +// ServiceMiddlewareConfig is is the configuration for PD Service middleware. +type ServiceMiddlewareConfig struct { + AuditConfig `json:"audit"` +} + +// NewServiceMiddlewareConfig returns a new service middleware config +func NewServiceMiddlewareConfig() *ServiceMiddlewareConfig { + audit := AuditConfig{ + EnableAudit: defaultEnableAuditMiddleware, + } + cfg := &ServiceMiddlewareConfig{ + AuditConfig: audit, + } + return cfg +} + +// Clone returns a cloned service middleware configuration. +func (c *ServiceMiddlewareConfig) Clone() *ServiceMiddlewareConfig { + cfg := *c + return &cfg +} + +// AuditConfig is the configuration for audit +type AuditConfig struct { + // EnableAudit controls the switch of the audit middleware + EnableAudit bool `json:"enable-audit,string"` +} + +// Clone returns a cloned audit config. +func (c *AuditConfig) Clone() *AuditConfig { + cfg := *c + return &cfg +} diff --git a/server/config/service_middleware_persist_options.go b/server/config/service_middleware_persist_options.go new file mode 100644 index 00000000000..7fde025b8c1 --- /dev/null +++ b/server/config/service_middleware_persist_options.go @@ -0,0 +1,77 @@ +// Copyright 2022 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "errors" + "sync/atomic" + + "github.com/pingcap/failpoint" + "github.com/tikv/pd/server/storage/endpoint" +) + +// ServiceMiddlewarePersistOptions wraps all service middleware configurations that need to persist to storage and +// allows to access them safely. +type ServiceMiddlewarePersistOptions struct { + audit atomic.Value +} + +// NewServiceMiddlewarePersistOptions creates a new ServiceMiddlewarePersistOptions instance. +func NewServiceMiddlewarePersistOptions(cfg *ServiceMiddlewareConfig) *ServiceMiddlewarePersistOptions { + o := &ServiceMiddlewarePersistOptions{} + o.audit.Store(&cfg.AuditConfig) + return o +} + +// GetAuditConfig returns pd service middleware configurations. +func (o *ServiceMiddlewarePersistOptions) GetAuditConfig() *AuditConfig { + return o.audit.Load().(*AuditConfig) +} + +// SetAuditConfig sets the PD service middleware configuration. +func (o *ServiceMiddlewarePersistOptions) SetAuditConfig(cfg *AuditConfig) { + o.audit.Store(cfg) +} + +// IsAuditEnabled returns whether audit middleware is enabled +func (o *ServiceMiddlewarePersistOptions) IsAuditEnabled() bool { + return o.GetAuditConfig().EnableAudit +} + +// Persist saves the configuration to the storage. +func (o *ServiceMiddlewarePersistOptions) Persist(storage endpoint.ServiceMiddlewareStorage) error { + cfg := &ServiceMiddlewareConfig{ + AuditConfig: *o.GetAuditConfig(), + } + err := storage.SaveServiceMiddlewareConfig(cfg) + failpoint.Inject("persistServiceMiddlewareFail", func() { + err = errors.New("fail to persist") + }) + return err +} + +// Reload reloads the configuration from the storage. +func (o *ServiceMiddlewarePersistOptions) Reload(storage endpoint.ServiceMiddlewareStorage) error { + cfg := NewServiceMiddlewareConfig() + + isExist, err := storage.LoadServiceMiddlewareConfig(cfg) + if err != nil { + return err + } + if isExist { + o.audit.Store(&cfg.AuditConfig) + } + return nil +} diff --git a/server/server.go b/server/server.go index 6d83faf48f3..bd13193532f 100644 --- a/server/server.go +++ b/server/server.go @@ -100,10 +100,12 @@ type Server struct { startTimestamp int64 // Configs and initial fields. - cfg *config.Config - etcdCfg *embed.Config - persistOptions *config.PersistOptions - handler *Handler + cfg *config.Config + serviceMiddlewareCfg *config.ServiceMiddlewareConfig + etcdCfg *embed.Config + serviceMiddlewarePersistOptions *config.ServiceMiddlewarePersistOptions + persistOptions *config.PersistOptions + handler *Handler ctx context.Context serverLoopCtx context.Context @@ -237,14 +239,17 @@ func combineBuilderServerHTTPService(ctx context.Context, svr *Server, serviceBu func CreateServer(ctx context.Context, cfg *config.Config, serviceBuilders ...HandlerBuilder) (*Server, error) { log.Info("PD Config", zap.Reflect("config", cfg)) rand.Seed(time.Now().UnixNano()) + serviceMiddlewareCfg := config.NewServiceMiddlewareConfig() s := &Server{ - cfg: cfg, - persistOptions: config.NewPersistOptions(cfg), - member: &member.Member{}, - ctx: ctx, - startTimestamp: time.Now().Unix(), - DiagnosticsServer: sysutil.NewDiagnosticsServer(cfg.Log.File.Filename), + cfg: cfg, + persistOptions: config.NewPersistOptions(cfg), + serviceMiddlewareCfg: serviceMiddlewareCfg, + serviceMiddlewarePersistOptions: config.NewServiceMiddlewarePersistOptions(serviceMiddlewareCfg), + member: &member.Member{}, + ctx: ctx, + startTimestamp: time.Now().Unix(), + DiagnosticsServer: sysutil.NewDiagnosticsServer(cfg.Log.File.Filename), } s.handler = newHandler(s) @@ -762,6 +767,11 @@ func (s *Server) GetPersistOptions() *config.PersistOptions { return s.persistOptions } +// GetServiceMiddlewarePersistOptions returns the service middleware persist option. +func (s *Server) GetServiceMiddlewarePersistOptions() *config.ServiceMiddlewarePersistOptions { + return s.serviceMiddlewarePersistOptions +} + // GetHBStreams returns the heartbeat streams. func (s *Server) GetHBStreams() *hbstream.HeartbeatStreams { return s.hbStreams @@ -801,6 +811,13 @@ func (s *Server) GetMembers() ([]*pdpb.Member, error) { return members, err } +// GetServiceMiddlewareConfig gets the service middleware config information. +func (s *Server) GetServiceMiddlewareConfig() *config.ServiceMiddlewareConfig { + cfg := s.serviceMiddlewareCfg.Clone() + cfg.AuditConfig = *s.serviceMiddlewarePersistOptions.GetAuditConfig() + return cfg +} + // GetConfig gets the config information. func (s *Server) GetConfig() *config.Config { cfg := s.cfg.Clone() @@ -948,6 +965,27 @@ func (s *Server) SetReplicationConfig(cfg config.ReplicationConfig) error { return nil } +// GetAuditConfig gets the audit config information. +func (s *Server) GetAuditConfig() *config.AuditConfig { + return s.serviceMiddlewarePersistOptions.GetAuditConfig().Clone() +} + +// SetAuditConfig sets the audit config. +func (s *Server) SetAuditConfig(cfg config.AuditConfig) error { + old := s.serviceMiddlewarePersistOptions.GetAuditConfig() + s.serviceMiddlewarePersistOptions.SetAuditConfig(&cfg) + if err := s.serviceMiddlewarePersistOptions.Persist(s.storage); err != nil { + s.serviceMiddlewarePersistOptions.SetAuditConfig(old) + log.Error("failed to update Audit config", + zap.Reflect("new", cfg), + zap.Reflect("old", old), + errs.ZapError(err)) + return err + } + log.Info("Audit config is updated", zap.Reflect("new", cfg), zap.Reflect("old", old)) + return nil +} + // GetPDServerConfig gets the balance config information. func (s *Server) GetPDServerConfig() *config.PDServerConfig { return s.persistOptions.GetPDServerConfig().Clone() @@ -1416,6 +1454,10 @@ func (s *Server) reloadConfigFromKV() error { if err != nil { return err } + err = s.serviceMiddlewarePersistOptions.Reload(s.storage) + if err != nil { + return err + } switchableStorage, ok := s.storage.(interface { SwitchToRegionStorage() SwitchToDefaultStorage() diff --git a/server/storage/endpoint/key_path.go b/server/storage/endpoint/key_path.go index 1f5e05601cf..01f40ecc74d 100644 --- a/server/storage/endpoint/key_path.go +++ b/server/storage/endpoint/key_path.go @@ -22,6 +22,7 @@ import ( const ( clusterPath = "raft" configPath = "config" + serviceMiddlewarePath = "service_middleware" schedulePath = "schedule" gcPath = "gc" rulesPath = "rules" diff --git a/server/storage/endpoint/service_middleware.go b/server/storage/endpoint/service_middleware.go new file mode 100644 index 00000000000..62cf91c97bf --- /dev/null +++ b/server/storage/endpoint/service_middleware.go @@ -0,0 +1,51 @@ +// Copyright 2022 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package endpoint + +import ( + "encoding/json" + + "github.com/tikv/pd/pkg/errs" +) + +// ServiceMiddlewareStorage defines the storage operations on the service middleware. +type ServiceMiddlewareStorage interface { + LoadServiceMiddlewareConfig(cfg interface{}) (bool, error) + SaveServiceMiddlewareConfig(cfg interface{}) error +} + +var _ ServiceMiddlewareStorage = (*StorageEndpoint)(nil) + +// LoadServiceMiddlewareConfig loads service middleware config from serviceMiddlewarePath then unmarshal it to cfg. +func (se *StorageEndpoint) LoadServiceMiddlewareConfig(cfg interface{}) (bool, error) { + value, err := se.Load(serviceMiddlewarePath) + if err != nil || value == "" { + return false, err + } + err = json.Unmarshal([]byte(value), cfg) + if err != nil { + return false, errs.ErrJSONUnmarshal.Wrap(err).GenWithStackByCause() + } + return true, nil +} + +// SaveServiceMiddlewareConfig stores marshallable cfg to the serviceMiddlewarePath. +func (se *StorageEndpoint) SaveServiceMiddlewareConfig(cfg interface{}) error { + value, err := json.Marshal(cfg) + if err != nil { + return errs.ErrJSONMarshal.Wrap(err).GenWithStackByCause() + } + return se.Save(serviceMiddlewarePath, string(value)) +} diff --git a/server/storage/storage.go b/server/storage/storage.go index 45fc75b7ba9..3c0a959ca7a 100644 --- a/server/storage/storage.go +++ b/server/storage/storage.go @@ -32,6 +32,7 @@ type Storage interface { // Introducing the kv.Base here is to provide // the basic key-value read/write ability for the Storage. kv.Base + endpoint.ServiceMiddlewareStorage endpoint.ConfigStorage endpoint.MetaStorage endpoint.RuleStorage diff --git a/tests/server/api/api_test.go b/tests/server/api/api_test.go index b4a39aa8501..9a48d979fa5 100644 --- a/tests/server/api/api_test.go +++ b/tests/server/api/api_test.go @@ -153,15 +153,15 @@ func (s *testMiddlewareSuite) TestRequestInfoMiddleware(c *C) { leader := s.cluster.GetServer(s.cluster.GetLeader()) input := map[string]interface{}{ - "enable-audit": true, + "enable-audit": "true", } data, err := json.Marshal(input) c.Assert(err, IsNil) - req, _ := http.NewRequest("POST", leader.GetAddr()+"/pd/api/v1/config", bytes.NewBuffer(data)) + req, _ := http.NewRequest("POST", leader.GetAddr()+"/pd/api/v1/service-middleware/config", bytes.NewBuffer(data)) resp, err := dialClient.Do(req) c.Assert(err, IsNil) resp.Body.Close() - c.Assert(leader.GetServer().GetPersistOptions().IsAuditEnabled(), Equals, true) + c.Assert(leader.GetServer().GetServiceMiddlewarePersistOptions().IsAuditEnabled(), Equals, true) labels := make(map[string]interface{}) labels["testkey"] = "testvalue" @@ -181,15 +181,15 @@ func (s *testMiddlewareSuite) TestRequestInfoMiddleware(c *C) { c.Assert(resp.Header.Get("ip"), Equals, "127.0.0.1") input = map[string]interface{}{ - "enable-audit": false, + "enable-audit": "false", } data, err = json.Marshal(input) c.Assert(err, IsNil) - req, _ = http.NewRequest("POST", leader.GetAddr()+"/pd/api/v1/config", bytes.NewBuffer(data)) + req, _ = http.NewRequest("POST", leader.GetAddr()+"/pd/api/v1/service-middleware/config", bytes.NewBuffer(data)) resp, err = dialClient.Do(req) c.Assert(err, IsNil) resp.Body.Close() - c.Assert(leader.GetServer().GetPersistOptions().IsAuditEnabled(), Equals, false) + c.Assert(leader.GetServer().GetServiceMiddlewarePersistOptions().IsAuditEnabled(), Equals, false) header := mustRequestSuccess(c, leader.GetServer()) c.Assert(header.Get("service-label"), Equals, "") @@ -206,10 +206,10 @@ func BenchmarkDoRequestWithServiceMiddleware(b *testing.B) { cluster.WaitLeader() leader := cluster.GetServer(cluster.GetLeader()) input := map[string]interface{}{ - "enable-audit": true, + "enable-audit": "true", } data, _ := json.Marshal(input) - req, _ := http.NewRequest("POST", leader.GetAddr()+"/pd/api/v1/config", bytes.NewBuffer(data)) + req, _ := http.NewRequest("POST", leader.GetAddr()+"/pd/api/v1/service-middleware/config", bytes.NewBuffer(data)) resp, _ := dialClient.Do(req) resp.Body.Close() b.StartTimer() @@ -229,10 +229,10 @@ func BenchmarkDoRequestWithoutServiceMiddleware(b *testing.B) { cluster.WaitLeader() leader := cluster.GetServer(cluster.GetLeader()) input := map[string]interface{}{ - "enable-audit": false, + "enable-audit": "false", } data, _ := json.Marshal(input) - req, _ := http.NewRequest("POST", leader.GetAddr()+"/pd/api/v1/config", bytes.NewBuffer(data)) + req, _ := http.NewRequest("POST", leader.GetAddr()+"/pd/api/v1/service-middleware/config", bytes.NewBuffer(data)) resp, _ := dialClient.Do(req) resp.Body.Close() b.StartTimer() @@ -254,15 +254,15 @@ func doTestRequest(srv *tests.TestServer) { func (s *testMiddlewareSuite) TestAuditPrometheusBackend(c *C) { leader := s.cluster.GetServer(s.cluster.GetLeader()) input := map[string]interface{}{ - "enable-audit": true, + "enable-audit": "true", } data, err := json.Marshal(input) c.Assert(err, IsNil) - req, _ := http.NewRequest("POST", leader.GetAddr()+"/pd/api/v1/config", bytes.NewBuffer(data)) + req, _ := http.NewRequest("POST", leader.GetAddr()+"/pd/api/v1/service-middleware/config", bytes.NewBuffer(data)) resp, err := dialClient.Do(req) c.Assert(err, IsNil) resp.Body.Close() - c.Assert(leader.GetServer().GetPersistOptions().IsAuditEnabled(), Equals, true) + c.Assert(leader.GetServer().GetServiceMiddlewarePersistOptions().IsAuditEnabled(), Equals, true) timeUnix := time.Now().Unix() - 20 req, _ = http.NewRequest("GET", fmt.Sprintf("%s/pd/api/v1/trend?from=%d", leader.GetAddr(), timeUnix), nil) resp, err = dialClient.Do(req) @@ -302,15 +302,15 @@ func (s *testMiddlewareSuite) TestAuditPrometheusBackend(c *C) { c.Assert(strings.Contains(output, "pd_service_audit_handling_seconds_count{component=\"anonymous\",method=\"HTTP\",service=\"GetTrend\"} 2"), Equals, true) input = map[string]interface{}{ - "enable-audit": false, + "enable-audit": "false", } data, err = json.Marshal(input) c.Assert(err, IsNil) - req, _ = http.NewRequest("POST", leader.GetAddr()+"/pd/api/v1/config", bytes.NewBuffer(data)) + req, _ = http.NewRequest("POST", leader.GetAddr()+"/pd/api/v1/service-middleware/config", bytes.NewBuffer(data)) resp, err = dialClient.Do(req) c.Assert(err, IsNil) resp.Body.Close() - c.Assert(leader.GetServer().GetPersistOptions().IsAuditEnabled(), Equals, false) + c.Assert(leader.GetServer().GetServiceMiddlewarePersistOptions().IsAuditEnabled(), Equals, false) } func (s *testMiddlewareSuite) TestAuditLocalLogBackend(c *C) { @@ -322,15 +322,15 @@ func (s *testMiddlewareSuite) TestAuditLocalLogBackend(c *C) { log.ReplaceGlobals(lg, p) leader := s.cluster.GetServer(s.cluster.GetLeader()) input := map[string]interface{}{ - "enable-audit": true, + "enable-audit": "true", } data, err := json.Marshal(input) c.Assert(err, IsNil) - req, _ := http.NewRequest("POST", leader.GetAddr()+"/pd/api/v1/config", bytes.NewBuffer(data)) + req, _ := http.NewRequest("POST", leader.GetAddr()+"/pd/api/v1/service-middleware/config", bytes.NewBuffer(data)) resp, err := dialClient.Do(req) c.Assert(err, IsNil) resp.Body.Close() - c.Assert(leader.GetServer().GetPersistOptions().IsAuditEnabled(), Equals, true) + c.Assert(leader.GetServer().GetServiceMiddlewarePersistOptions().IsAuditEnabled(), Equals, true) req, _ = http.NewRequest("POST", leader.GetAddr()+"/pd/api/v1/admin/log", strings.NewReader("\"info\"")) resp, err = dialClient.Do(req) @@ -354,10 +354,10 @@ func BenchmarkDoRequestWithLocalLogAudit(b *testing.B) { cluster.WaitLeader() leader := cluster.GetServer(cluster.GetLeader()) input := map[string]interface{}{ - "enable-audit": true, + "enable-audit": "true", } data, _ := json.Marshal(input) - req, _ := http.NewRequest("POST", leader.GetAddr()+"/pd/api/v1/config", bytes.NewBuffer(data)) + req, _ := http.NewRequest("POST", leader.GetAddr()+"/pd/api/v1/service-middleware/config", bytes.NewBuffer(data)) resp, _ := dialClient.Do(req) resp.Body.Close() b.StartTimer() @@ -377,10 +377,10 @@ func BenchmarkDoRequestWithoutLocalLogAudit(b *testing.B) { cluster.WaitLeader() leader := cluster.GetServer(cluster.GetLeader()) input := map[string]interface{}{ - "enable-audit": false, + "enable-audit": "false", } data, _ := json.Marshal(input) - req, _ := http.NewRequest("POST", leader.GetAddr()+"/pd/api/v1/config", bytes.NewBuffer(data)) + req, _ := http.NewRequest("POST", leader.GetAddr()+"/pd/api/v1/service-middleware/config", bytes.NewBuffer(data)) resp, _ := dialClient.Do(req) resp.Body.Close() b.StartTimer()