-
Notifications
You must be signed in to change notification settings - Fork 3
/
casbin.go
92 lines (86 loc) · 2.46 KB
/
casbin.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
package middleware
import (
_ "embed"
"fmt"
"strings"
"github.com/MQEnergy/go-skeleton/configs"
"github.com/MQEnergy/go-skeleton/internal/vars"
"github.com/MQEnergy/go-skeleton/pkg/helper"
"github.com/MQEnergy/go-skeleton/pkg/response"
"github.com/casbin/casbin/v2"
"github.com/casbin/casbin/v2/model"
"github.com/casbin/casbin/v2/util"
gormadapter "github.com/casbin/gorm-adapter/v3"
"github.com/gofiber/fiber/v2"
"github.com/spf13/cast"
"gorm.io/gorm"
)
// CasbinMiddleware casbin middleware
func CasbinMiddleware(db *gorm.DB, prefix, tableName string) fiber.Handler {
return func(ctx *fiber.Ctx) error {
if db == nil {
return ctx.Next()
}
if tableName == "" {
tableName = "casbin_rule"
}
adapter, _ := gormadapter.NewFilteredAdapterByDB(db, prefix, tableName)
rc, _ := model.NewModelFromString(configs.RbacModelConf)
e, _ := casbin.NewEnforcer(rc, adapter)
e.AddFunction("ParamsObjMatch", ParamsObjMatchFunc)
e.AddFunction("ParamsActMatch", ParamsActMatchFunc)
_ = e.LoadPolicy()
// 获取当前请求的url
obj := ctx.Path()
act := ctx.Method()
roleIds := ctx.GetRespHeader("role_ids")
if roleIds == "" {
return response.UnauthorizedException(ctx, "该用户还未分配权限")
}
roleList := strings.Split(roleIds, ",")
if helper.InAnySlice[string](roleList, vars.Config.GetString("server.superRoleId")) {
return ctx.Next()
}
flag := false
for _, sub := range roleList {
// 判断策略中是否存在
if ok, _ := e.Enforce(sub, obj, act); ok {
flag = true
break
}
}
if !flag {
return response.ForbiddenException(ctx, "该用户无此权限")
}
return ctx.Next()
}
}
// ParamsActMatchFunc 自定义规则函数 method
func ParamsActMatchFunc(args ...interface{}) (interface{}, error) {
if len(args) != 2 {
return nil, fmt.Errorf("must be 2 arguments")
}
rAct := cast.ToString(args[0])
pAct := cast.ToString(args[1])
pActArr := strings.Split(pAct, ",")
if len(pActArr) == 1 {
return pActArr[0] == rAct, nil
}
if len(pActArr) > 1 {
return helper.InAnySlice[string](pActArr, rAct), nil
}
return false, nil
}
// ParamsObjMatchFunc 自定义规则函数 path
func ParamsObjMatchFunc(args ...interface{}) (interface{}, error) {
if len(args) != 2 {
return nil, fmt.Errorf("must be 2 arguments")
}
rObj := cast.ToString(args[0])
pObj := cast.ToString(args[1])
rObjArr := strings.Split(rObj, "?")
if len(rObjArr) == 0 {
return false, nil
}
return util.KeyMatch2(rObjArr[0], pObj), nil
}