From 34e13d4d2dd3d88170d9e5a38c5f671e11328fb8 Mon Sep 17 00:00:00 2001 From: Hao Liu <519555327@qq.com> Date: Sat, 16 Apr 2022 10:56:10 +0800 Subject: [PATCH] fix: add column v6, v7, update unit tests (#157) --- adapter.go | 63 +++++++++++++++++++++++++++++++++++++++++++++++-- adapter_test.go | 11 +++++++++ 2 files changed, 72 insertions(+), 2 deletions(-) diff --git a/adapter.go b/adapter.go index 4142993..ff359c7 100755 --- a/adapter.go +++ b/adapter.go @@ -49,6 +49,8 @@ type CasbinRule struct { V3 string `gorm:"size:100"` V4 string `gorm:"size:100"` V5 string `gorm:"size:100"` + V6 string `gorm:"size:25"` + V7 string `gorm:"size:25"` } func (CasbinRule) TableName() string { @@ -63,6 +65,8 @@ type Filter struct { V3 []string V4 []string V5 []string + V6 []string + V7 []string } // Adapter represents the Gorm adapter for policy storage. @@ -378,7 +382,7 @@ func (a *Adapter) createTable() error { index := strings.ReplaceAll("idx_"+tableName, ".", "_") hasIndex := a.db.Migrator().HasIndex(t, index) if !hasIndex { - if err := a.db.Exec(fmt.Sprintf("CREATE UNIQUE INDEX %s ON %s (ptype,v0,v1,v2,v3,v4,v5)", index, tableName)).Error; err != nil { + if err := a.db.Exec(fmt.Sprintf("CREATE UNIQUE INDEX %s ON %s (ptype,v0,v1,v2,v3,v4,v5,v6,v7)", index, tableName)).Error; err != nil { return err } } @@ -397,7 +401,8 @@ func (a *Adapter) dropTable() error { func loadPolicyLine(line CasbinRule, model model.Model) { var p = []string{line.Ptype, line.V0, line.V1, line.V2, - line.V3, line.V4, line.V5} + line.V3, line.V4, line.V5, + line.V6, line.V7} index := len(p) - 1 for p[index] == "" { @@ -473,6 +478,12 @@ func (a *Adapter) filterQuery(db *gorm.DB, filter Filter) func(db *gorm.DB) *gor if len(filter.V5) > 0 { db = db.Where("v5 in (?)", filter.V5) } + if len(filter.V6) > 0 { + db = db.Where("v6 in (?)", filter.V6) + } + if len(filter.V7) > 0 { + db = db.Where("v7 in (?)", filter.V7) + } return db } } @@ -499,6 +510,12 @@ func (a *Adapter) savePolicyLine(ptype string, rule []string) CasbinRule { if len(rule) > 5 { line.V5 = rule[5] } + if len(rule) > 6 { + line.V6 = rule[6] + } + if len(rule) > 7 { + line.V7 = rule[7] + } return *line } @@ -616,6 +633,12 @@ func (a *Adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, if fieldIndex <= 5 && 5 < fieldIndex+len(fieldValues) { line.V5 = fieldValues[5-fieldIndex] } + if fieldIndex <= 6 && 6 < fieldIndex+len(fieldValues) { + line.V6 = fieldValues[6-fieldIndex] + } + if fieldIndex <= 7 && 7 < fieldIndex+len(fieldValues) { + line.V7 = fieldValues[7-fieldIndex] + } err = a.rawDelete(a.db, *line) return err } @@ -658,6 +681,14 @@ func (a *Adapter) rawDelete(db *gorm.DB, line CasbinRule) error { queryStr += " and v5 = ?" queryArgs = append(queryArgs, line.V5) } + if line.V6 != "" { + queryStr += " and v6 = ?" + queryArgs = append(queryArgs, line.V6) + } + if line.V7 != "" { + queryStr += " and v7 = ?" + queryArgs = append(queryArgs, line.V7) + } args := append([]interface{}{queryStr}, queryArgs...) err := db.Delete(a.getTableInstance(), args...).Error return err @@ -691,6 +722,14 @@ func appendWhere(line CasbinRule) (string, []interface{}) { queryStr += " and v5 = ?" queryArgs = append(queryArgs, line.V5) } + if line.V6 != "" { + queryStr += " and v6 = ?" + queryArgs = append(queryArgs, line.V6) + } + if line.V7 != "" { + queryStr += " and v7 = ?" + queryArgs = append(queryArgs, line.V7) + } return queryStr, queryArgs } @@ -743,6 +782,12 @@ func (a *Adapter) UpdateFilteredPolicies(sec string, ptype string, newPolicies [ if fieldIndex <= 5 && 5 < fieldIndex+len(fieldValues) { line.V5 = fieldValues[5-fieldIndex] } + if fieldIndex <= 6 && 6 < fieldIndex+len(fieldValues) { + line.V6 = fieldValues[6-fieldIndex] + } + if fieldIndex <= 7 && 7 < fieldIndex+len(fieldValues) { + line.V7 = fieldValues[7-fieldIndex] + } newP := make([]CasbinRule, 0, len(newPolicies)) oldP := make([]CasbinRule, 0) @@ -805,6 +850,14 @@ func (c *CasbinRule) queryString() (interface{}, []interface{}) { queryStr += " and v5 = ?" queryArgs = append(queryArgs, c.V5) } + if c.V6 != "" { + queryStr += " and v6 = ?" + queryArgs = append(queryArgs, c.V6) + } + if c.V7 != "" { + queryStr += " and v7 = ?" + queryArgs = append(queryArgs, c.V7) + } return queryStr, queryArgs } @@ -832,5 +885,11 @@ func (c *CasbinRule) toStringPolicy() []string { if c.V5 != "" { policy = append(policy, c.V5) } + if c.V6 != "" { + policy = append(policy, c.V6) + } + if c.V7 != "" { + policy = append(policy, c.V7) + } return policy } diff --git a/adapter_test.go b/adapter_test.go index 8076e75..be5bedc 100755 --- a/adapter_test.go +++ b/adapter_test.go @@ -164,6 +164,8 @@ func initAdapterWithGormInstanceAndCustomTable(t *testing.T, db *gorm.DB) *Adapt V3 string `gorm:"size:128;uniqueIndex:unique_index"` V4 string `gorm:"size:128;uniqueIndex:unique_index"` V5 string `gorm:"size:128;uniqueIndex:unique_index"` + V6 string `gorm:"size:128;uniqueIndex:unique_index"` + V7 string `gorm:"size:128;uniqueIndex:unique_index"` } // Create an adapter @@ -525,3 +527,12 @@ func TestAddPolicies(t *testing.T) { testGetPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"jack", "data1", "read"}, {"jack2", "data1", "read"}}) } + +func TestAddPoliciesFullColumn(t *testing.T) { + a := initAdapter(t, "mysql", "root:@tcp(127.0.0.1:3306)/", "casbin", "casbin_rule") + e, _ := casbin.NewEnforcer("examples/rbac_model.conf", a) + e.AddPolicies([][]string{{"jack", "data1", "read", "col3", "col4", "col5", "col6", "col7"}, {"jack2", "data1", "read", "col3", "col4", "col5", "col6", "col7"}}) + e.LoadPolicy() + + testGetPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"jack", "data1", "read", "col3", "col4", "col5", "col6", "col7"}, {"jack2", "data1", "read", "col3", "col4", "col5", "col6", "col7"}}) +}