diff --git a/api/internal/handler/global_rule/global_rule.go b/api/internal/handler/global_rule/global_rule.go index 81b1877c75..2e3657b5a7 100644 --- a/api/internal/handler/global_rule/global_rule.go +++ b/api/internal/handler/global_rule/global_rule.go @@ -18,10 +18,12 @@ package global_rule import ( "encoding/json" + "net/http" "reflect" "github.com/gin-gonic/gin" "github.com/shiningrush/droplet" + "github.com/shiningrush/droplet/data" "github.com/shiningrush/droplet/wrapper" wgin "github.com/shiningrush/droplet/wrapper/gin" @@ -49,9 +51,9 @@ func (h *Handler) ApplyRoute(r *gin.Engine) { r.GET("/apisix/admin/global_rules", wgin.Wraps(h.List, wrapper.InputType(reflect.TypeOf(ListInput{})))) r.PUT("/apisix/admin/global_rules/:id", wgin.Wraps(h.Set, - wrapper.InputType(reflect.TypeOf(entity.GlobalPlugins{})))) + wrapper.InputType(reflect.TypeOf(SetInput{})))) r.PUT("/apisix/admin/global_rules", wgin.Wraps(h.Set, - wrapper.InputType(reflect.TypeOf(entity.GlobalPlugins{})))) + wrapper.InputType(reflect.TypeOf(SetInput{})))) r.PATCH("/apisix/admin/global_rules/:id", consts.ErrorWrapper(Patch)) r.PATCH("/apisix/admin/global_rules/:id/*path", consts.ErrorWrapper(Patch)) @@ -121,10 +123,25 @@ func (h *Handler) List(c droplet.Context) (interface{}, error) { return ret, nil } +type SetInput struct { + entity.GlobalPlugins + ID string `auto_read:"id,path"` +} + func (h *Handler) Set(c droplet.Context) (interface{}, error) { - input := c.Input().(*entity.GlobalPlugins) + input := c.Input().(*SetInput) + + // check if ID in body is equal ID in path + if err := handler.IDCompare(input.ID, input.GlobalPlugins.ID); err != nil { + return &data.SpecCodeResponse{StatusCode: http.StatusBadRequest}, err + } + + // if has id in path, use it + if input.ID != "" { + input.GlobalPlugins.ID = input.ID + } - if err := h.globalRuleStore.Create(c.Context(), input); err != nil { + if err := h.globalRuleStore.Update(c.Context(), &input.GlobalPlugins, true); err != nil { return handler.SpecCodeResponse(err), err } diff --git a/api/internal/handler/global_rule/global_rule_test.go b/api/internal/handler/global_rule/global_rule_test.go index 62a719d8fb..9311b77b25 100644 --- a/api/internal/handler/global_rule/global_rule_test.go +++ b/api/internal/handler/global_rule/global_rule_test.go @@ -209,7 +209,7 @@ func TestHandler_List(t *testing.T) { func TestHandler_Set(t *testing.T) { tests := []struct { caseDesc string - giveInput *entity.GlobalPlugins + giveInput *SetInput giveCtx context.Context giveErr error wantErr error @@ -219,10 +219,12 @@ func TestHandler_Set(t *testing.T) { }{ { caseDesc: "normal", - giveInput: &entity.GlobalPlugins{ + giveInput: &SetInput{ ID: "name", - Plugins: map[string]interface{}{ - "jwt-auth": map[string]interface{}{}, + GlobalPlugins: entity.GlobalPlugins{ + Plugins: map[string]interface{}{ + "jwt-auth": map[string]interface{}{}, + }, }, }, giveCtx: context.WithValue(context.Background(), "test", "value"), @@ -237,9 +239,9 @@ func TestHandler_Set(t *testing.T) { }, { caseDesc: "store create failed", - giveInput: &entity.GlobalPlugins{ - ID: "name", - Plugins: nil, + giveInput: &SetInput{ + ID: "name", + GlobalPlugins: entity.GlobalPlugins{}, }, giveErr: fmt.Errorf("create failed"), wantInput: &entity.GlobalPlugins{ @@ -258,10 +260,11 @@ func TestHandler_Set(t *testing.T) { t.Run(tc.caseDesc, func(t *testing.T) { methodCalled := true mStore := &store.MockInterface{} - mStore.On("Create", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + mStore.On("Update", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { methodCalled = true assert.Equal(t, tc.giveCtx, args.Get(0)) assert.Equal(t, tc.wantInput, args.Get(1)) + assert.True(t, args.Bool(2)) }).Return(tc.giveErr) h := Handler{globalRuleStore: mStore} diff --git a/api/test/e2e/global_rule_test.go b/api/test/e2e/global_rule_test.go index 19a59b3961..7d368c4d8e 100644 --- a/api/test/e2e/global_rule_test.go +++ b/api/test/e2e/global_rule_test.go @@ -56,7 +56,6 @@ func TestGlobalRule(t *testing.T) { Path: "/apisix/admin/global_rules/1", Method: http.MethodPut, Body: `{ - "id": "1", "plugins": { "response-rewrite": { "headers": { @@ -182,6 +181,36 @@ func TestGlobalRule(t *testing.T) { ExpectHeaders: map[string]string{"X-VERSION": "2.0"}, Sleep: sleepTime, }, + { + Desc: "update global rule", + Object: ManagerApiExpect(t), + Path: "/apisix/admin/global_rules/1", + Method: http.MethodPut, + Body: `{ + "id": "1", + "plugins": { + "response-rewrite": { + "headers": { + "X-VERSION":"1.0" + } + }, + "uri-blocker": { + "block_rules": ["root.exe", "root.m+"] + } + } + }`, + Headers: map[string]string{"Authorization": token}, + ExpectStatus: http.StatusOK, + }, + { + Desc: "make sure that update succeeded", + Object: APISIXExpect(t), + Method: http.MethodGet, + Path: "/hello", + Query: "file=root.exe", + ExpectStatus: http.StatusForbidden, + Sleep: sleepTime, + }, { Desc: "delete global rule", Object: ManagerApiExpect(t),