From ff00a12805e2436e0a8373d6d40f4920d06ab86a Mon Sep 17 00:00:00 2001 From: sthuang <167743503+shaoting-huang@users.noreply.github.com> Date: Sat, 9 Nov 2024 20:40:25 +0800 Subject: [PATCH] enhance: RBAC custom privilege group ut coverage (#37558) issue: https://github.com/milvus-io/milvus/issues/37031 Signed-off-by: shaoting-huang --- .../proxy/httpserver/handler_v2.go | 18 +++++----- .../proxy/httpserver/handler_v2_test.go | 28 +++++++++++++-- .../proxy/httpserver/request_v2.go | 5 ++- internal/rootcoord/meta_table_test.go | 12 +++++++ internal/rootcoord/mock_test.go | 8 +++++ internal/rootcoord/root_coord_test.go | 33 +++++++++++++++++ .../integration/rbac/privilege_group_test.go | 35 +++++++++++++++++++ 7 files changed, 126 insertions(+), 13 deletions(-) diff --git a/internal/distributed/proxy/httpserver/handler_v2.go b/internal/distributed/proxy/httpserver/handler_v2.go index e78e566279772..8fc88382335ac 100644 --- a/internal/distributed/proxy/httpserver/handler_v2.go +++ b/internal/distributed/proxy/httpserver/handler_v2.go @@ -133,11 +133,11 @@ func (h *HandlersV2) RegisterRoutesToV2(router gin.IRouter) { router.POST(RoleCategory+RevokePrivilegeAction, timeoutMiddleware(wrapperPost(func() any { return &GrantReq{} }, wrapperTraceLog(h.removePrivilegeFromRole)))) // privilege group - router.POST(PrivilegeGroupCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &GrantReq{} }, wrapperTraceLog(h.createPrivilegeGroup)))) - router.POST(PrivilegeGroupCategory+DropAction, timeoutMiddleware(wrapperPost(func() any { return &GrantReq{} }, wrapperTraceLog(h.dropPrivilegeGroup)))) - router.POST(PrivilegeGroupCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &GrantReq{} }, wrapperTraceLog(h.listPrivilegeGroups)))) - router.POST(PrivilegeGroupCategory+AddPrivilegesToGroupAction, timeoutMiddleware(wrapperPost(func() any { return &GrantReq{} }, wrapperTraceLog(h.addPrivilegesToGroup)))) - router.POST(PrivilegeGroupCategory+RemovePrivilegesFromGroupAction, timeoutMiddleware(wrapperPost(func() any { return &GrantReq{} }, wrapperTraceLog(h.removePrivilegesFromGroup)))) + router.POST(PrivilegeGroupCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &PrivilegeGroupReq{} }, wrapperTraceLog(h.createPrivilegeGroup)))) + router.POST(PrivilegeGroupCategory+DropAction, timeoutMiddleware(wrapperPost(func() any { return &PrivilegeGroupReq{} }, wrapperTraceLog(h.dropPrivilegeGroup)))) + router.POST(PrivilegeGroupCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &PrivilegeGroupReq{} }, wrapperTraceLog(h.listPrivilegeGroups)))) + router.POST(PrivilegeGroupCategory+AddPrivilegesToGroupAction, timeoutMiddleware(wrapperPost(func() any { return &PrivilegeGroupReq{} }, wrapperTraceLog(h.addPrivilegesToGroup)))) + router.POST(PrivilegeGroupCategory+RemovePrivilegesFromGroupAction, timeoutMiddleware(wrapperPost(func() any { return &PrivilegeGroupReq{} }, wrapperTraceLog(h.removePrivilegesFromGroup)))) router.POST(IndexCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &CollectionNameReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.listIndexes))))) router.POST(IndexCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &IndexReq{} }, wrapperTraceLog(h.wrapperCheckDatabase(h.describeIndex))))) @@ -1857,9 +1857,11 @@ func (h *HandlersV2) removePrivilegesFromGroup(ctx context.Context, c *gin.Conte func (h *HandlersV2) operatePrivilegeGroup(ctx context.Context, c *gin.Context, anyReq any, dbName string, operateType milvuspb.OperatePrivilegeGroupType) (interface{}, error) { httpReq := anyReq.(*PrivilegeGroupReq) req := &milvuspb.OperatePrivilegeGroupRequest{ - GroupName: httpReq.PrivilegeGroupName, - Privileges: httpReq.Privileges, - Type: operateType, + GroupName: httpReq.PrivilegeGroupName, + Privileges: lo.Map(httpReq.Privileges, func(p string, _ int) *milvuspb.PrivilegeEntity { + return &milvuspb.PrivilegeEntity{Name: p} + }), + Type: operateType, } resp, err := wrapperProxy(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/OperatePrivilegeGroup", func(reqCtx context.Context, req any) (interface{}, error) { return h.proxy.OperatePrivilegeGroup(reqCtx, req.(*milvuspb.OperatePrivilegeGroupRequest)) diff --git a/internal/distributed/proxy/httpserver/handler_v2_test.go b/internal/distributed/proxy/httpserver/handler_v2_test.go index d8047bde8b323..70c2f63e627d5 100644 --- a/internal/distributed/proxy/httpserver/handler_v2_test.go +++ b/internal/distributed/proxy/httpserver/handler_v2_test.go @@ -1230,6 +1230,10 @@ func TestMethodGet(t *testing.T) { Status: &StatusSuccess, Alias: DefaultAliasName, }, nil).Once() + mp.EXPECT().ListPrivilegeGroups(mock.Anything, mock.Anything).Return(&milvuspb.ListPrivilegeGroupsResponse{ + Status: &StatusSuccess, + PrivilegeGroups: []*milvuspb.PrivilegeGroupInfo{{GroupName: "group1", Privileges: []*milvuspb.PrivilegeEntity{{Name: "*"}}}}, + }, nil).Once() testEngine := initHTTPServerV2(mp, false) queryTestCases := []rawTestCase{} @@ -1320,6 +1324,9 @@ func TestMethodGet(t *testing.T) { queryTestCases = append(queryTestCases, rawTestCase{ path: versionalV2(AliasCategory, DescribeAction), }) + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(PrivilegeGroupCategory, ListAction), + }) for _, testcase := range queryTestCases { t.Run(testcase.path, func(t *testing.T) { @@ -1329,7 +1336,8 @@ func TestMethodGet(t *testing.T) { `"indexName": "` + DefaultIndexName + `",` + `"userName": "` + util.UserRoot + `",` + `"roleName": "` + util.RoleAdmin + `",` + - `"aliasName": "` + DefaultAliasName + `"` + + `"aliasName": "` + DefaultAliasName + `",` + + `"privilegeGroupName": "pg"` + `}`)) req := httptest.NewRequest(http.MethodPost, testcase.path, bodyReader) w := httptest.NewRecorder() @@ -1369,6 +1377,7 @@ func TestMethodDelete(t *testing.T) { mp.EXPECT().DropRole(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once() mp.EXPECT().DropIndex(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once() mp.EXPECT().DropAlias(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once() + mp.EXPECT().DropPrivilegeGroup(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once() testEngine := initHTTPServerV2(mp, false) queryTestCases := []rawTestCase{} queryTestCases = append(queryTestCases, rawTestCase{ @@ -1389,10 +1398,13 @@ func TestMethodDelete(t *testing.T) { queryTestCases = append(queryTestCases, rawTestCase{ path: versionalV2(AliasCategory, DropAction), }) + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(PrivilegeGroupCategory, DropAction), + }) for _, testcase := range queryTestCases { t.Run(testcase.path, func(t *testing.T) { bodyReader := bytes.NewReader([]byte(`{"collectionName": "` + DefaultCollectionName + `", "partitionName": "` + DefaultPartitionName + - `", "userName": "` + util.UserRoot + `", "roleName": "` + util.RoleAdmin + `", "indexName": "` + DefaultIndexName + `", "aliasName": "` + DefaultAliasName + `"}`)) + `", "userName": "` + util.UserRoot + `", "roleName": "` + util.RoleAdmin + `", "indexName": "` + DefaultIndexName + `", "aliasName": "` + DefaultAliasName + `", "privilegeGroupName": "pg"}`)) req := httptest.NewRequest(http.MethodPost, testcase.path, bodyReader) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) @@ -1431,6 +1443,8 @@ func TestMethodPost(t *testing.T) { mp.EXPECT().CreateIndex(mock.Anything, mock.Anything).Return(commonErrorStatus, nil).Once() mp.EXPECT().CreateAlias(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once() mp.EXPECT().AlterAlias(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once() + mp.EXPECT().CreatePrivilegeGroup(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Once() + mp.EXPECT().OperatePrivilegeGroup(mock.Anything, mock.Anything).Return(commonSuccessStatus, nil).Twice() mp.EXPECT().ImportV2(mock.Anything, mock.Anything).Return(&internalpb.ImportResponse{ Status: commonSuccessStatus, JobID: "1234567890", }, nil).Once() @@ -1523,6 +1537,15 @@ func TestMethodPost(t *testing.T) { queryTestCases = append(queryTestCases, rawTestCase{ path: versionalV2(ImportJobCategory, DescribeAction), }) + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(PrivilegeGroupCategory, CreateAction), + }) + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(PrivilegeGroupCategory, AddPrivilegesToGroupAction), + }) + queryTestCases = append(queryTestCases, rawTestCase{ + path: versionalV2(PrivilegeGroupCategory, RemovePrivilegesFromGroupAction), + }) for _, testcase := range queryTestCases { t.Run(testcase.path, func(t *testing.T) { @@ -1533,6 +1556,7 @@ func TestMethodPost(t *testing.T) { `"indexParams": [{"indexName": "` + DefaultIndexName + `", "fieldName": "book_intro", "metricType": "L2", "params": {"nlist": 30, "index_type": "IVF_FLAT"}}],` + `"userName": "` + util.UserRoot + `", "password": "Milvus", "newPassword": "milvus", "roleName": "` + util.RoleAdmin + `",` + `"roleName": "` + util.RoleAdmin + `", "objectType": "Global", "objectName": "*", "privilege": "*",` + + `"privilegeGroupName": "pg", "privileges": ["create", "drop"],` + `"aliasName": "` + DefaultAliasName + `",` + `"jobId": "1234567890",` + `"files": [["book.json"]]` + diff --git a/internal/distributed/proxy/httpserver/request_v2.go b/internal/distributed/proxy/httpserver/request_v2.go index 2c2ab8f6e79c2..63ed000d2d477 100644 --- a/internal/distributed/proxy/httpserver/request_v2.go +++ b/internal/distributed/proxy/httpserver/request_v2.go @@ -6,7 +6,6 @@ import ( "github.com/gin-gonic/gin" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" - "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/pkg/util/merr" ) @@ -275,8 +274,8 @@ func (req *RoleReq) GetRoleName() string { } type PrivilegeGroupReq struct { - PrivilegeGroupName string `json:"privilegeGroupName" binding:"required"` - Privileges []*milvuspb.PrivilegeEntity `json:"privileges"` + PrivilegeGroupName string `json:"privilegeGroupName" binding:"required"` + Privileges []string `json:"privileges"` } type GrantReq struct { diff --git a/internal/rootcoord/meta_table_test.go b/internal/rootcoord/meta_table_test.go index a9ce1ad78a0cb..4192b655c03be 100644 --- a/internal/rootcoord/meta_table_test.go +++ b/internal/rootcoord/meta_table_test.go @@ -2093,10 +2093,22 @@ func TestMetaTable_PrivilegeGroup(t *testing.T) { } err := mt.CreatePrivilegeGroup("pg1") assert.Error(t, err) + err = mt.CreatePrivilegeGroup("") + assert.Error(t, err) + err = mt.CreatePrivilegeGroup("Insert") + assert.Error(t, err) err = mt.CreatePrivilegeGroup("pg2") assert.NoError(t, err) + err = mt.DropPrivilegeGroup("") + assert.Error(t, err) err = mt.DropPrivilegeGroup("pg1") assert.NoError(t, err) + err = mt.OperatePrivilegeGroup("", []*milvuspb.PrivilegeEntity{}, milvuspb.OperatePrivilegeGroupType_AddPrivilegesToGroup) + assert.Error(t, err) + err = mt.OperatePrivilegeGroup("pg3", []*milvuspb.PrivilegeEntity{}, milvuspb.OperatePrivilegeGroupType_AddPrivilegesToGroup) + assert.Error(t, err) + _, err = mt.GetPrivilegeGroupRoles("") + assert.Error(t, err) _, err = mt.ListPrivilegeGroups() assert.NoError(t, err) } diff --git a/internal/rootcoord/mock_test.go b/internal/rootcoord/mock_test.go index 0bb82803f22c0..b4b4369f379dd 100644 --- a/internal/rootcoord/mock_test.go +++ b/internal/rootcoord/mock_test.go @@ -101,6 +101,7 @@ type mockMetaTable struct { DropPrivilegeGroupFunc func(groupName string) error ListPrivilegeGroupsFunc func() ([]*milvuspb.PrivilegeGroupInfo, error) OperatePrivilegeGroupFunc func(groupName string, privileges []*milvuspb.PrivilegeEntity, operateType milvuspb.OperatePrivilegeGroupType) error + GetPrivilegeGroupRolesFunc func(groupName string) ([]*milvuspb.RoleEntity, error) } func (m mockMetaTable) GetDatabaseByName(ctx context.Context, dbName string, ts Timestamp) (*model.Database, error) { @@ -271,6 +272,10 @@ func (m mockMetaTable) OperatePrivilegeGroup(groupName string, privileges []*mil return m.OperatePrivilegeGroupFunc(groupName, privileges, operateType) } +func (m mockMetaTable) GetPrivilegeGroupRoles(groupName string) ([]*milvuspb.RoleEntity, error) { + return m.GetPrivilegeGroupRolesFunc(groupName) +} + func newMockMetaTable() *mockMetaTable { return &mockMetaTable{} } @@ -559,6 +564,9 @@ func withInvalidMeta() Opt { meta.OperatePrivilegeGroupFunc = func(groupName string, privileges []*milvuspb.PrivilegeEntity, operateType milvuspb.OperatePrivilegeGroupType) error { return errors.New("error mock OperatePrivilegeGroup") } + meta.GetPrivilegeGroupRolesFunc = func(groupName string) ([]*milvuspb.RoleEntity, error) { + return nil, errors.New("error mock GetPrivilegeGroupRoles") + } return withMeta(meta) } diff --git a/internal/rootcoord/root_coord_test.go b/internal/rootcoord/root_coord_test.go index a3ad2b376f8c2..e8a9dc5d1a8cc 100644 --- a/internal/rootcoord/root_coord_test.go +++ b/internal/rootcoord/root_coord_test.go @@ -1787,6 +1787,39 @@ func TestRootCoord_RBACError(t *testing.T) { } }) + t.Run("operate privilege group failed", func(t *testing.T) { + mockMeta := c.meta.(*mockMetaTable) + mockMeta.ListPrivilegeGroupsFunc = func() ([]*milvuspb.PrivilegeGroupInfo, error) { + return nil, errors.New("mock error") + } + mockMeta.CreatePrivilegeGroupFunc = func(groupName string) error { + return errors.New("mock error") + } + mockMeta.GetPrivilegeGroupRolesFunc = func(groupName string) ([]*milvuspb.RoleEntity, error) { + return nil, errors.New("mock error") + } + { + resp, err := c.OperatePrivilegeGroup(ctx, &milvuspb.OperatePrivilegeGroupRequest{}) + assert.NoError(t, err) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetErrorCode()) + } + { + resp, err := c.ListPrivilegeGroups(ctx, &milvuspb.ListPrivilegeGroupsRequest{}) + assert.NoError(t, err) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetStatus().GetErrorCode()) + } + { + resp, err := c.OperatePrivilegeGroup(ctx, &milvuspb.OperatePrivilegeGroupRequest{}) + assert.NoError(t, err) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetErrorCode()) + } + { + resp, err := c.CreatePrivilegeGroup(ctx, &milvuspb.CreatePrivilegeGroupRequest{}) + assert.NoError(t, err) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.GetErrorCode()) + } + }) + t.Run("select grant failed", func(t *testing.T) { { resp, err := c.SelectGrant(ctx, &milvuspb.SelectGrantRequest{}) diff --git a/tests/integration/rbac/privilege_group_test.go b/tests/integration/rbac/privilege_group_test.go index 864565281341c..bea3af0514737 100644 --- a/tests/integration/rbac/privilege_group_test.go +++ b/tests/integration/rbac/privilege_group_test.go @@ -199,6 +199,41 @@ func (s *PrivilegeGroupTestSuite) TestCustomPrivilegeGroup() { s.True(merr.Ok(dropRoleResp)) } +func (s *PrivilegeGroupTestSuite) TestInvalidPrivilegeGroup() { + ctx := GetContext(context.Background(), "root:123456") + + createResp, err := s.Cluster.Proxy.CreatePrivilegeGroup(ctx, &milvuspb.CreatePrivilegeGroupRequest{ + GroupName: "", + }) + s.NoError(err) + s.False(merr.Ok(createResp)) + + dropResp, err := s.Cluster.Proxy.DropPrivilegeGroup(ctx, &milvuspb.DropPrivilegeGroupRequest{ + GroupName: "group1", + }) + s.NoError(err) + s.True(merr.Ok(dropResp)) + + dropResp, err = s.Cluster.Proxy.DropPrivilegeGroup(ctx, &milvuspb.DropPrivilegeGroupRequest{ + GroupName: "", + }) + s.NoError(err) + s.False(merr.Ok(dropResp)) + + operateResp, err := s.Cluster.Proxy.OperatePrivilegeGroup(ctx, &milvuspb.OperatePrivilegeGroupRequest{ + GroupName: "", + }) + s.NoError(err) + s.False(merr.Ok(operateResp)) + + operateResp, err = s.Cluster.Proxy.OperatePrivilegeGroup(ctx, &milvuspb.OperatePrivilegeGroupRequest{ + GroupName: "group1", + Privileges: []*milvuspb.PrivilegeEntity{{Name: "123"}}, + }) + s.NoError(err) + s.False(merr.Ok(operateResp)) +} + func (s *PrivilegeGroupTestSuite) operatePrivilege(ctx context.Context, role, privilege, objectType string, operateType milvuspb.OperatePrivilegeType) { resp, err := s.Cluster.Proxy.OperatePrivilege(ctx, &milvuspb.OperatePrivilegeRequest{ Type: operateType,