Skip to content

Commit

Permalink
feat: support customized table
Browse files Browse the repository at this point in the history
  • Loading branch information
JalinWang committed Aug 25, 2022
1 parent 41953cc commit b0e1967
Show file tree
Hide file tree
Showing 2 changed files with 227 additions and 31 deletions.
178 changes: 147 additions & 31 deletions adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ type Adapter struct {
dbSpecified bool
db *gorm.DB
isFiltered bool
customTable interface{}
}

// finalizer is the destructor for Adapter.
Expand Down Expand Up @@ -197,6 +198,7 @@ func NewAdapterByDBUseTableName(db *gorm.DB, prefix string, tableName string) (*
a := &Adapter{
tablePrefix: prefix,
tableName: tableName,
customTable: db.Statement.Context.Value(customTableKey{}),
}

a.db = db.Scopes(a.casbinRuleTable()).Session(&gorm.Session{Context: db.Statement.Context})
Expand Down Expand Up @@ -267,6 +269,17 @@ func TurnOffAutoMigrate(db *gorm.DB) {
}

func NewAdapterByDBWithCustomTable(db *gorm.DB, t interface{}, tableName ...string) (*Adapter, error) {

r := reflect.TypeOf(t)
if r.Kind() == reflect.Ptr {
r = r.Elem()
}
for _, field := range []string{"ID", "Ptype", "V0", "V1", "V2", "V3", "V4", "V5", "V6", "V7"} {
if _, ok := r.FieldByName(field); !ok {
return nil, errors.New(fmt.Sprintf("The custom table has no column named `%s`", field))
}
}

ctx := db.Statement.Context
if ctx == nil {
ctx = context.Background()
Expand Down Expand Up @@ -363,9 +376,113 @@ func (a *Adapter) Close() error {
return nil
}

// getTableInstance return the dynamic table name
func (a *Adapter) getTableInstance() *CasbinRule {
return &CasbinRule{}
// getTableObject return the dynamic table object
func (a *Adapter) getTableObject() interface{} {
if a.customTable == nil {
return &CasbinRule{}
}
return a.customTable
}

func convertToTableDefault(lines interface{}) *[]CasbinRule {
cusLines := reflect.ValueOf(lines)
if cusLines.Kind() == reflect.Ptr {
cusLines = cusLines.Elem()
}

length := cusLines.Len()
retLines := make([]CasbinRule, length)

for i := 0; i < length; i++ {
tmp := cusLines.Index(i)

retLines[i].ID = uint(tmp.FieldByName("ID").Uint())
retLines[i].Ptype = tmp.FieldByName("Ptype").String()
retLines[i].V0 = tmp.FieldByName("V0").String()
retLines[i].V1 = tmp.FieldByName("V1").String()
retLines[i].V2 = tmp.FieldByName("V2").String()
retLines[i].V3 = tmp.FieldByName("V3").String()
retLines[i].V4 = tmp.FieldByName("V4").String()
retLines[i].V5 = tmp.FieldByName("V5").String()
retLines[i].V6 = tmp.FieldByName("V6").String()
retLines[i].V7 = tmp.FieldByName("V7").String()
}

return &retLines
}

func convertToTableCustomized(t interface{}, lines *[]CasbinRule) interface{} {
modelType := reflect.TypeOf(t)
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
cusLines := reflect.New(reflect.ArrayOf(len(*lines), modelType))

for i, line := range *lines {
tmp := reflect.New(modelType).Elem()

tmp.FieldByName("ID").SetUint(uint64(line.ID))
tmp.FieldByName("Ptype").SetString(line.Ptype)
tmp.FieldByName("V0").SetString(line.V0)
tmp.FieldByName("V1").SetString(line.V1)
tmp.FieldByName("V2").SetString(line.V2)
tmp.FieldByName("V3").SetString(line.V3)
tmp.FieldByName("V4").SetString(line.V4)
tmp.FieldByName("V5").SetString(line.V5)
tmp.FieldByName("V6").SetString(line.V6)
tmp.FieldByName("V7").SetString(line.V7)

cusLines.Elem().Index(i).Set(tmp)
}

return cusLines.Interface()
}

func (a *Adapter) customizedDBDelete(db *gorm.DB, conds ...interface{}) error {
if a.customTable == nil {
return db.Delete(&CasbinRule{}, conds...).Error
}

return db.Delete(a.customTable, conds...).Error
}

func (a *Adapter) customizedDBFind(db *gorm.DB, linesPtr *[]CasbinRule, conds ...interface{}) error {
if a.customTable == nil {
return db.Find(linesPtr, conds...).Error
}

modelType := reflect.TypeOf(a.customTable)
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
cusLinesPtr := reflect.New(reflect.SliceOf(modelType)).Interface()

err := db.Find(cusLinesPtr, conds...).Error
if err != nil {
return err
}

*linesPtr = *convertToTableDefault(cusLinesPtr)

return nil
}

func (a *Adapter) customizedDBCreate(db *gorm.DB, linePtr *CasbinRule) error {
if a.customTable == nil {
return db.Create(linePtr).Error
}

lines := convertToTableCustomized(a.customTable, &[]CasbinRule{*linePtr})
return db.Create(lines).Error
}

func (a *Adapter) customizedDBCreateMany(db *gorm.DB, linesPtr *[]CasbinRule) error {
if a.customTable == nil {
return db.Create(linesPtr).Error
}

lines := convertToTableCustomized(a.customTable, linesPtr)
return db.Create(lines).Error
}

func (a *Adapter) getFullTableName() string {
Expand Down Expand Up @@ -394,7 +511,7 @@ func (a *Adapter) createTable() error {
return a.db.AutoMigrate(t)
}

t = a.getTableInstance()
t = a.getTableObject()
if err := a.db.AutoMigrate(t); err != nil {
return err
}
Expand All @@ -413,7 +530,7 @@ func (a *Adapter) createTable() error {
func (a *Adapter) dropTable() error {
t := a.db.Statement.Context.Value(customTableKey)
if t == nil {
return a.db.Migrator().DropTable(a.getTableInstance())
return a.db.Migrator().DropTable(a.getTableObject())
}

return a.db.Migrator().DropTable(t)
Expand Down Expand Up @@ -445,7 +562,7 @@ func loadPolicyLine(line CasbinRule, model model.Model) {
// LoadPolicy loads policy from database.
func (a *Adapter) LoadPolicy(model model.Model) error {
var lines []CasbinRule
if err := a.db.Order("ID").Find(&lines).Error; err != nil {
if err := a.customizedDBFind(a.db.Order("ID"), &lines); err != nil {
return err
}

Expand Down Expand Up @@ -479,10 +596,11 @@ func (a *Adapter) LoadFilteredPolicy(model model.Model, filter interface{}) erro
}

for _, f := range batchFilter.filters {
if err := a.db.Scopes(a.filterQuery(a.db, f)).Order("ID").Find(&lines).Error; err != nil {
if err := a.customizedDBFind(a.db.Scopes(a.filterQuery(a.db, f)).Order("ID"), &lines); err != nil {
return err
}


for _, line := range lines {
loadPolicyLine(line, model)
}
Expand Down Expand Up @@ -532,7 +650,7 @@ func (a *Adapter) filterQuery(db *gorm.DB, filter Filter) func(db *gorm.DB) *gor
}

func (a *Adapter) savePolicyLine(ptype string, rule []string) CasbinRule {
line := a.getTableInstance()
line := &CasbinRule{}

line.Ptype = ptype
if len(rule) > 0 {
Expand Down Expand Up @@ -576,7 +694,7 @@ func (a *Adapter) SavePolicy(model model.Model) error {
for _, rule := range ast.Policy {
lines = append(lines, a.savePolicyLine(ptype, rule))
if len(lines) > flushEvery {
if err := a.db.Create(&lines).Error; err != nil {
if err := a.customizedDBCreateMany(a.db, &lines); err != nil {
return err
}
lines = nil
Expand All @@ -588,15 +706,15 @@ func (a *Adapter) SavePolicy(model model.Model) error {
for _, rule := range ast.Policy {
lines = append(lines, a.savePolicyLine(ptype, rule))
if len(lines) > flushEvery {
if err := a.db.Create(&lines).Error; err != nil {
if err := a.customizedDBCreateMany(a.db, &lines); err != nil {
return err
}
lines = nil
}
}
}
if len(lines) > 0 {
if err := a.db.Create(&lines).Error; err != nil {
if err := a.customizedDBCreateMany(a.db, &lines); err != nil {
return err
}
}
Expand All @@ -607,7 +725,7 @@ func (a *Adapter) SavePolicy(model model.Model) error {
// AddPolicy adds a policy rule to the storage.
func (a *Adapter) AddPolicy(sec string, ptype string, rule []string) error {
line := a.savePolicyLine(ptype, rule)
err := a.db.Create(&line).Error
err := a.customizedDBCreate(a.db, &line)
return err
}

Expand All @@ -625,7 +743,7 @@ func (a *Adapter) AddPolicies(sec string, ptype string, rules [][]string) error
line := a.savePolicyLine(ptype, rule)
lines = append(lines, line)
}
return a.db.Create(&lines).Error
return a.customizedDBCreateMany(a.db, &lines)
}

// RemovePolicies removes multiple policy rules from the storage.
Expand All @@ -643,7 +761,7 @@ func (a *Adapter) RemovePolicies(sec string, ptype string, rules [][]string) err

// RemoveFilteredPolicy removes policy rules that match the filter from the storage.
func (a *Adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) error {
line := a.getTableInstance()
line := &CasbinRule{}

line.Ptype = ptype

Expand Down Expand Up @@ -731,8 +849,8 @@ func (a *Adapter) rawDelete(db *gorm.DB, line CasbinRule) error {
queryArgs = append(queryArgs, line.V7)
}
args := append([]interface{}{queryStr}, queryArgs...)
err := db.Delete(a.getTableInstance(), args...).Error
return err

return a.customizedDBDelete(db, args...)
}

func appendWhere(line CasbinRule) (string, []interface{}) {
Expand Down Expand Up @@ -802,7 +920,7 @@ func (a *Adapter) UpdatePolicies(sec string, ptype string, oldRules, newRules []

func (a *Adapter) UpdateFilteredPolicies(sec string, ptype string, newPolicies [][]string, fieldIndex int, fieldValues ...string) ([][]string, error) {
// UpdateFilteredPolicies deletes old rules and adds new rules.
line := a.getTableInstance()
line := &CasbinRule{}

line.Ptype = ptype
if fieldIndex <= 0 && 0 < fieldIndex+len(fieldValues) {
Expand Down Expand Up @@ -837,21 +955,19 @@ func (a *Adapter) UpdateFilteredPolicies(sec string, ptype string, newPolicies [
}

tx := a.db.Begin()
str, args := line.queryString()

for i := range newP {
str, args := line.queryString()
if err := tx.Where(str, args...).Find(&oldP).Error; err != nil {
tx.Rollback()
return nil, err
}
if err := tx.Where(str, args...).Delete([]CasbinRule{}).Error; err != nil {
tx.Rollback()
return nil, err
}
if err := tx.Create(&newP[i]).Error; err != nil {
tx.Rollback()
return nil, err
}
if err := a.customizedDBFind(tx.Where(str, args...), &oldP); err != nil {
tx.Rollback()
return nil, err
}
if err := a.customizedDBDelete(tx.Where(str, args...)); err != nil {
tx.Rollback()
return nil, err
}
if err := a.customizedDBCreateMany(tx, &newP); err != nil {
tx.Rollback()
return nil, err
}

// return deleted rulues
Expand Down
Loading

0 comments on commit b0e1967

Please sign in to comment.