From b0e1967f877e9c2f2c8fd8a3fdde505329d882a2 Mon Sep 17 00:00:00 2001 From: Ter Date: Sun, 19 Jun 2022 23:17:41 +0800 Subject: [PATCH] feat: support customized table --- adapter.go | 178 +++++++++++++++++++++++++++++++++++++++--------- adapter_test.go | 80 ++++++++++++++++++++++ 2 files changed, 227 insertions(+), 31 deletions(-) diff --git a/adapter.go b/adapter.go index 9c118e5..149094f 100755 --- a/adapter.go +++ b/adapter.go @@ -83,6 +83,7 @@ type Adapter struct { dbSpecified bool db *gorm.DB isFiltered bool + customTable interface{} } // finalizer is the destructor for Adapter. @@ -197,6 +198,7 @@ func NewAdapterByDBUseTableName(db *gorm.DB, prefix string, tableName string) (* a := &Adapter{ tablePrefix: prefix, tableName: tableName, + customTable: db.Statement.Context.Value(customTableKey{}), } a.db = db.Scopes(a.casbinRuleTable()).Session(&gorm.Session{Context: db.Statement.Context}) @@ -267,6 +269,17 @@ func TurnOffAutoMigrate(db *gorm.DB) { } func NewAdapterByDBWithCustomTable(db *gorm.DB, t interface{}, tableName ...string) (*Adapter, error) { + + r := reflect.TypeOf(t) + if r.Kind() == reflect.Ptr { + r = r.Elem() + } + for _, field := range []string{"ID", "Ptype", "V0", "V1", "V2", "V3", "V4", "V5", "V6", "V7"} { + if _, ok := r.FieldByName(field); !ok { + return nil, errors.New(fmt.Sprintf("The custom table has no column named `%s`", field)) + } + } + ctx := db.Statement.Context if ctx == nil { ctx = context.Background() @@ -363,9 +376,113 @@ func (a *Adapter) Close() error { return nil } -// getTableInstance return the dynamic table name -func (a *Adapter) getTableInstance() *CasbinRule { - return &CasbinRule{} +// getTableObject return the dynamic table object +func (a *Adapter) getTableObject() interface{} { + if a.customTable == nil { + return &CasbinRule{} + } + return a.customTable +} + +func convertToTableDefault(lines interface{}) *[]CasbinRule { + cusLines := reflect.ValueOf(lines) + if cusLines.Kind() == reflect.Ptr { + cusLines = cusLines.Elem() + } + + length := cusLines.Len() + retLines := make([]CasbinRule, length) + + for i := 0; i < length; i++ { + tmp := cusLines.Index(i) + + retLines[i].ID = uint(tmp.FieldByName("ID").Uint()) + retLines[i].Ptype = tmp.FieldByName("Ptype").String() + retLines[i].V0 = tmp.FieldByName("V0").String() + retLines[i].V1 = tmp.FieldByName("V1").String() + retLines[i].V2 = tmp.FieldByName("V2").String() + retLines[i].V3 = tmp.FieldByName("V3").String() + retLines[i].V4 = tmp.FieldByName("V4").String() + retLines[i].V5 = tmp.FieldByName("V5").String() + retLines[i].V6 = tmp.FieldByName("V6").String() + retLines[i].V7 = tmp.FieldByName("V7").String() + } + + return &retLines +} + +func convertToTableCustomized(t interface{}, lines *[]CasbinRule) interface{} { + modelType := reflect.TypeOf(t) + if modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + cusLines := reflect.New(reflect.ArrayOf(len(*lines), modelType)) + + for i, line := range *lines { + tmp := reflect.New(modelType).Elem() + + tmp.FieldByName("ID").SetUint(uint64(line.ID)) + tmp.FieldByName("Ptype").SetString(line.Ptype) + tmp.FieldByName("V0").SetString(line.V0) + tmp.FieldByName("V1").SetString(line.V1) + tmp.FieldByName("V2").SetString(line.V2) + tmp.FieldByName("V3").SetString(line.V3) + tmp.FieldByName("V4").SetString(line.V4) + tmp.FieldByName("V5").SetString(line.V5) + tmp.FieldByName("V6").SetString(line.V6) + tmp.FieldByName("V7").SetString(line.V7) + + cusLines.Elem().Index(i).Set(tmp) + } + + return cusLines.Interface() +} + +func (a *Adapter) customizedDBDelete(db *gorm.DB, conds ...interface{}) error { + if a.customTable == nil { + return db.Delete(&CasbinRule{}, conds...).Error + } + + return db.Delete(a.customTable, conds...).Error +} + +func (a *Adapter) customizedDBFind(db *gorm.DB, linesPtr *[]CasbinRule, conds ...interface{}) error { + if a.customTable == nil { + return db.Find(linesPtr, conds...).Error + } + + modelType := reflect.TypeOf(a.customTable) + if modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + cusLinesPtr := reflect.New(reflect.SliceOf(modelType)).Interface() + + err := db.Find(cusLinesPtr, conds...).Error + if err != nil { + return err + } + + *linesPtr = *convertToTableDefault(cusLinesPtr) + + return nil +} + +func (a *Adapter) customizedDBCreate(db *gorm.DB, linePtr *CasbinRule) error { + if a.customTable == nil { + return db.Create(linePtr).Error + } + + lines := convertToTableCustomized(a.customTable, &[]CasbinRule{*linePtr}) + return db.Create(lines).Error +} + +func (a *Adapter) customizedDBCreateMany(db *gorm.DB, linesPtr *[]CasbinRule) error { + if a.customTable == nil { + return db.Create(linesPtr).Error + } + + lines := convertToTableCustomized(a.customTable, linesPtr) + return db.Create(lines).Error } func (a *Adapter) getFullTableName() string { @@ -394,7 +511,7 @@ func (a *Adapter) createTable() error { return a.db.AutoMigrate(t) } - t = a.getTableInstance() + t = a.getTableObject() if err := a.db.AutoMigrate(t); err != nil { return err } @@ -413,7 +530,7 @@ func (a *Adapter) createTable() error { func (a *Adapter) dropTable() error { t := a.db.Statement.Context.Value(customTableKey) if t == nil { - return a.db.Migrator().DropTable(a.getTableInstance()) + return a.db.Migrator().DropTable(a.getTableObject()) } return a.db.Migrator().DropTable(t) @@ -445,7 +562,7 @@ func loadPolicyLine(line CasbinRule, model model.Model) { // LoadPolicy loads policy from database. func (a *Adapter) LoadPolicy(model model.Model) error { var lines []CasbinRule - if err := a.db.Order("ID").Find(&lines).Error; err != nil { + if err := a.customizedDBFind(a.db.Order("ID"), &lines); err != nil { return err } @@ -479,10 +596,11 @@ func (a *Adapter) LoadFilteredPolicy(model model.Model, filter interface{}) erro } for _, f := range batchFilter.filters { - if err := a.db.Scopes(a.filterQuery(a.db, f)).Order("ID").Find(&lines).Error; err != nil { + if err := a.customizedDBFind(a.db.Scopes(a.filterQuery(a.db, f)).Order("ID"), &lines); err != nil { return err } + for _, line := range lines { loadPolicyLine(line, model) } @@ -532,7 +650,7 @@ func (a *Adapter) filterQuery(db *gorm.DB, filter Filter) func(db *gorm.DB) *gor } func (a *Adapter) savePolicyLine(ptype string, rule []string) CasbinRule { - line := a.getTableInstance() + line := &CasbinRule{} line.Ptype = ptype if len(rule) > 0 { @@ -576,7 +694,7 @@ func (a *Adapter) SavePolicy(model model.Model) error { for _, rule := range ast.Policy { lines = append(lines, a.savePolicyLine(ptype, rule)) if len(lines) > flushEvery { - if err := a.db.Create(&lines).Error; err != nil { + if err := a.customizedDBCreateMany(a.db, &lines); err != nil { return err } lines = nil @@ -588,7 +706,7 @@ func (a *Adapter) SavePolicy(model model.Model) error { for _, rule := range ast.Policy { lines = append(lines, a.savePolicyLine(ptype, rule)) if len(lines) > flushEvery { - if err := a.db.Create(&lines).Error; err != nil { + if err := a.customizedDBCreateMany(a.db, &lines); err != nil { return err } lines = nil @@ -596,7 +714,7 @@ func (a *Adapter) SavePolicy(model model.Model) error { } } if len(lines) > 0 { - if err := a.db.Create(&lines).Error; err != nil { + if err := a.customizedDBCreateMany(a.db, &lines); err != nil { return err } } @@ -607,7 +725,7 @@ func (a *Adapter) SavePolicy(model model.Model) error { // AddPolicy adds a policy rule to the storage. func (a *Adapter) AddPolicy(sec string, ptype string, rule []string) error { line := a.savePolicyLine(ptype, rule) - err := a.db.Create(&line).Error + err := a.customizedDBCreate(a.db, &line) return err } @@ -625,7 +743,7 @@ func (a *Adapter) AddPolicies(sec string, ptype string, rules [][]string) error line := a.savePolicyLine(ptype, rule) lines = append(lines, line) } - return a.db.Create(&lines).Error + return a.customizedDBCreateMany(a.db, &lines) } // RemovePolicies removes multiple policy rules from the storage. @@ -643,7 +761,7 @@ func (a *Adapter) RemovePolicies(sec string, ptype string, rules [][]string) err // RemoveFilteredPolicy removes policy rules that match the filter from the storage. func (a *Adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) error { - line := a.getTableInstance() + line := &CasbinRule{} line.Ptype = ptype @@ -731,8 +849,8 @@ func (a *Adapter) rawDelete(db *gorm.DB, line CasbinRule) error { queryArgs = append(queryArgs, line.V7) } args := append([]interface{}{queryStr}, queryArgs...) - err := db.Delete(a.getTableInstance(), args...).Error - return err + + return a.customizedDBDelete(db, args...) } func appendWhere(line CasbinRule) (string, []interface{}) { @@ -802,7 +920,7 @@ func (a *Adapter) UpdatePolicies(sec string, ptype string, oldRules, newRules [] func (a *Adapter) UpdateFilteredPolicies(sec string, ptype string, newPolicies [][]string, fieldIndex int, fieldValues ...string) ([][]string, error) { // UpdateFilteredPolicies deletes old rules and adds new rules. - line := a.getTableInstance() + line := &CasbinRule{} line.Ptype = ptype if fieldIndex <= 0 && 0 < fieldIndex+len(fieldValues) { @@ -837,21 +955,19 @@ func (a *Adapter) UpdateFilteredPolicies(sec string, ptype string, newPolicies [ } tx := a.db.Begin() + str, args := line.queryString() - for i := range newP { - str, args := line.queryString() - if err := tx.Where(str, args...).Find(&oldP).Error; err != nil { - tx.Rollback() - return nil, err - } - if err := tx.Where(str, args...).Delete([]CasbinRule{}).Error; err != nil { - tx.Rollback() - return nil, err - } - if err := tx.Create(&newP[i]).Error; err != nil { - tx.Rollback() - return nil, err - } + if err := a.customizedDBFind(tx.Where(str, args...), &oldP); err != nil { + tx.Rollback() + return nil, err + } + if err := a.customizedDBDelete(tx.Where(str, args...)); err != nil { + tx.Rollback() + return nil, err + } + if err := a.customizedDBCreateMany(tx, &newP); err != nil { + tx.Rollback() + return nil, err } // return deleted rulues diff --git a/adapter_test.go b/adapter_test.go index 0c57eb4..91b0150 100755 --- a/adapter_test.go +++ b/adapter_test.go @@ -111,6 +111,7 @@ func initPolicy(t *testing.T, a *Adapter) { if err != nil { panic(err) } + fmt.Println(e.GetPolicy()) testGetPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}}) } @@ -255,6 +256,85 @@ func initAdapterWithGormInstanceByPrefixAndName(t *testing.T, db *gorm.DB, prefi return a } +func TestSoftDelete(t *testing.T) { + type TestCasbinRule struct { + CasbinRule + DeletedAt gorm.DeletedAt + //CreatedAt time.Time + } + + // Start preparing + db, err := gorm.Open(mysql.Open("root:@tcp(127.0.0.1:3306)/casbin?parseTime=true"), &gorm.Config{}) + if err != nil { + panic(err) + } + + a, err := NewAdapterByDBWithCustomTable(db, &TestCasbinRule{}, "casbin_create_custom") + if err != nil { + panic(err) + } + + initPolicy(t, a) + + e, err := casbin.NewEnforcer("examples/rbac_model.conf", a) + if err != nil { + panic(err) + } + e.EnableAutoSave(true) + // End of preparation. + + // Test Add & delete policy + ok, err := e.AddPolicy("carol", "data1", "read") + assert.Nil(t, err) + assert.Equal(t, ok, true) + + ok, err = e.RemovePolicy("bob", "data2", "write") + assert.Nil(t, err) + assert.Equal(t, ok, true) + + e.ClearPolicy() + err = a.LoadPolicy(e.GetModel()) + assert.Nil(t, err) + + testGetPolicy(t, e, [][]string{ + {"alice", "data1", "read"}, + //{"bob", "data2", "write"}, + {"data2_admin", "data2", "read"}, + {"data2_admin", "data2", "write"}, + {"carol", "data1", "read"}, + }) + + res := TestCasbinRule{} + err = a.db.Unscoped().Find(&res, "ptype = ? and v0 = ? and v1 = ? and v2 = ?", "p", "bob", "data2", "write").Error + assert.Nil(t, err) + assert.NotNil(t, res.DeletedAt) + log.Print("SoftDeletedRecord: ", res) + + // Test LoadFilteredPolicy + e.ClearPolicy() + err = a.LoadFilteredPolicy(e.GetModel(), Filter{ + V0: []string{"bob", "alice", "carol"}, + }) + assert.Nil(t, err) + + testGetPolicy(t, e, [][]string{ + {"alice", "data1", "read"}, + {"carol", "data1", "read"}, + }) + + //Test Update + _, err = e.UpdateFilteredPolicies([][]string{{"alice", "data1", "write"}}, 0, "alice", "data1", "read") + assert.Nil(t, err) + + e.LoadPolicy() + testGetPolicyWithoutOrder(t, e, [][]string{ + {"alice", "data1", "write"}, + {"data2_admin", "data2", "read"}, + {"data2_admin", "data2", "write"}, + {"carol", "data1", "read"}, + }) +} + func TestNilField(t *testing.T) { a, err := NewAdapter("sqlite3", "test.db") assert.Nil(t, err)