Skip to content

Commit

Permalink
fix: add column v6, v7, update unit tests (#157)
Browse files Browse the repository at this point in the history
  • Loading branch information
greathaoliu authored Apr 16, 2022
1 parent 7b3bc27 commit 34e13d4
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 2 deletions.
63 changes: 61 additions & 2 deletions adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ type CasbinRule struct {
V3 string `gorm:"size:100"`
V4 string `gorm:"size:100"`
V5 string `gorm:"size:100"`
V6 string `gorm:"size:25"`
V7 string `gorm:"size:25"`
}

func (CasbinRule) TableName() string {
Expand All @@ -63,6 +65,8 @@ type Filter struct {
V3 []string
V4 []string
V5 []string
V6 []string
V7 []string
}

// Adapter represents the Gorm adapter for policy storage.
Expand Down Expand Up @@ -378,7 +382,7 @@ func (a *Adapter) createTable() error {
index := strings.ReplaceAll("idx_"+tableName, ".", "_")
hasIndex := a.db.Migrator().HasIndex(t, index)
if !hasIndex {
if err := a.db.Exec(fmt.Sprintf("CREATE UNIQUE INDEX %s ON %s (ptype,v0,v1,v2,v3,v4,v5)", index, tableName)).Error; err != nil {
if err := a.db.Exec(fmt.Sprintf("CREATE UNIQUE INDEX %s ON %s (ptype,v0,v1,v2,v3,v4,v5,v6,v7)", index, tableName)).Error; err != nil {
return err
}
}
Expand All @@ -397,7 +401,8 @@ func (a *Adapter) dropTable() error {
func loadPolicyLine(line CasbinRule, model model.Model) {
var p = []string{line.Ptype,
line.V0, line.V1, line.V2,
line.V3, line.V4, line.V5}
line.V3, line.V4, line.V5,
line.V6, line.V7}

index := len(p) - 1
for p[index] == "" {
Expand Down Expand Up @@ -473,6 +478,12 @@ func (a *Adapter) filterQuery(db *gorm.DB, filter Filter) func(db *gorm.DB) *gor
if len(filter.V5) > 0 {
db = db.Where("v5 in (?)", filter.V5)
}
if len(filter.V6) > 0 {
db = db.Where("v6 in (?)", filter.V6)
}
if len(filter.V7) > 0 {
db = db.Where("v7 in (?)", filter.V7)
}
return db
}
}
Expand All @@ -499,6 +510,12 @@ func (a *Adapter) savePolicyLine(ptype string, rule []string) CasbinRule {
if len(rule) > 5 {
line.V5 = rule[5]
}
if len(rule) > 6 {
line.V6 = rule[6]
}
if len(rule) > 7 {
line.V7 = rule[7]
}

return *line
}
Expand Down Expand Up @@ -616,6 +633,12 @@ func (a *Adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int,
if fieldIndex <= 5 && 5 < fieldIndex+len(fieldValues) {
line.V5 = fieldValues[5-fieldIndex]
}
if fieldIndex <= 6 && 6 < fieldIndex+len(fieldValues) {
line.V6 = fieldValues[6-fieldIndex]
}
if fieldIndex <= 7 && 7 < fieldIndex+len(fieldValues) {
line.V7 = fieldValues[7-fieldIndex]
}
err = a.rawDelete(a.db, *line)
return err
}
Expand Down Expand Up @@ -658,6 +681,14 @@ func (a *Adapter) rawDelete(db *gorm.DB, line CasbinRule) error {
queryStr += " and v5 = ?"
queryArgs = append(queryArgs, line.V5)
}
if line.V6 != "" {
queryStr += " and v6 = ?"
queryArgs = append(queryArgs, line.V6)
}
if line.V7 != "" {
queryStr += " and v7 = ?"
queryArgs = append(queryArgs, line.V7)
}
args := append([]interface{}{queryStr}, queryArgs...)
err := db.Delete(a.getTableInstance(), args...).Error
return err
Expand Down Expand Up @@ -691,6 +722,14 @@ func appendWhere(line CasbinRule) (string, []interface{}) {
queryStr += " and v5 = ?"
queryArgs = append(queryArgs, line.V5)
}
if line.V6 != "" {
queryStr += " and v6 = ?"
queryArgs = append(queryArgs, line.V6)
}
if line.V7 != "" {
queryStr += " and v7 = ?"
queryArgs = append(queryArgs, line.V7)
}
return queryStr, queryArgs
}

Expand Down Expand Up @@ -743,6 +782,12 @@ func (a *Adapter) UpdateFilteredPolicies(sec string, ptype string, newPolicies [
if fieldIndex <= 5 && 5 < fieldIndex+len(fieldValues) {
line.V5 = fieldValues[5-fieldIndex]
}
if fieldIndex <= 6 && 6 < fieldIndex+len(fieldValues) {
line.V6 = fieldValues[6-fieldIndex]
}
if fieldIndex <= 7 && 7 < fieldIndex+len(fieldValues) {
line.V7 = fieldValues[7-fieldIndex]
}

newP := make([]CasbinRule, 0, len(newPolicies))
oldP := make([]CasbinRule, 0)
Expand Down Expand Up @@ -805,6 +850,14 @@ func (c *CasbinRule) queryString() (interface{}, []interface{}) {
queryStr += " and v5 = ?"
queryArgs = append(queryArgs, c.V5)
}
if c.V6 != "" {
queryStr += " and v6 = ?"
queryArgs = append(queryArgs, c.V6)
}
if c.V7 != "" {
queryStr += " and v7 = ?"
queryArgs = append(queryArgs, c.V7)
}

return queryStr, queryArgs
}
Expand Down Expand Up @@ -832,5 +885,11 @@ func (c *CasbinRule) toStringPolicy() []string {
if c.V5 != "" {
policy = append(policy, c.V5)
}
if c.V6 != "" {
policy = append(policy, c.V6)
}
if c.V7 != "" {
policy = append(policy, c.V7)
}
return policy
}
11 changes: 11 additions & 0 deletions adapter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ func initAdapterWithGormInstanceAndCustomTable(t *testing.T, db *gorm.DB) *Adapt
V3 string `gorm:"size:128;uniqueIndex:unique_index"`
V4 string `gorm:"size:128;uniqueIndex:unique_index"`
V5 string `gorm:"size:128;uniqueIndex:unique_index"`
V6 string `gorm:"size:128;uniqueIndex:unique_index"`
V7 string `gorm:"size:128;uniqueIndex:unique_index"`
}

// Create an adapter
Expand Down Expand Up @@ -525,3 +527,12 @@ func TestAddPolicies(t *testing.T) {

testGetPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"jack", "data1", "read"}, {"jack2", "data1", "read"}})
}

func TestAddPoliciesFullColumn(t *testing.T) {
a := initAdapter(t, "mysql", "root:@tcp(127.0.0.1:3306)/", "casbin", "casbin_rule")
e, _ := casbin.NewEnforcer("examples/rbac_model.conf", a)
e.AddPolicies([][]string{{"jack", "data1", "read", "col3", "col4", "col5", "col6", "col7"}, {"jack2", "data1", "read", "col3", "col4", "col5", "col6", "col7"}})
e.LoadPolicy()

testGetPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"jack", "data1", "read", "col3", "col4", "col5", "col6", "col7"}, {"jack2", "data1", "read", "col3", "col4", "col5", "col6", "col7"}})
}

0 comments on commit 34e13d4

Please sign in to comment.