This repository has been archived by the owner on Mar 10, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
context.go
135 lines (126 loc) · 2.9 KB
/
context.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
package bean
import (
"fmt"
"sync"
)
// NewApplicationContext 创建应用上下文
func NewApplicationContext(definitions []BeanDefinition) *ApplicationContext {
context := &ApplicationContext{
Definitions: definitions,
}
context.Init()
return context
}
// ApplicationContext 应用上下文
type ApplicationContext struct {
Definitions []BeanDefinition
tidyDefinitions sync.Map
instances sync.Map
}
// Init 初始化
func (t *ApplicationContext) Init() {
t.tidyDefinitions = sync.Map{}
for _, d := range t.Definitions {
ptr := &BeanDefinition{
Name: d.Name,
Reflect: d.Reflect,
Scope: d.Scope,
InitMethod: d.InitMethod,
ConstructorArgs: d.ConstructorArgs,
Fields: d.Fields,
context: t,
}
t.tidyDefinitions.Store(d.Name, ptr)
}
}
// GetBeanDefinition 获取依赖定义
func (t *ApplicationContext) GetBeanDefinition(name string) *BeanDefinition {
var (
inf interface{}
ok bool
)
if inf, ok = t.tidyDefinitions.Load(name); !ok {
panic(fmt.Sprintf("Bean not found: %s", name))
}
return inf.(*BeanDefinition)
}
// GetBean 获取实例
func (t *ApplicationContext) GetBean(name string, fields Fields, args ConstructorArgs) interface{} {
bd := merge(t.GetBeanDefinition(name), fields, args)
if bd.Scope == SINGLETON {
if ins, ok := t.instances.Load(name); ok {
return ins
}
val := bd.instance()
ins, _ := t.instances.LoadOrStore(name, val) // LoadOrStore 处理并发穿透
return ins
}
return bd.instance()
}
// Get 快速获取实例
func (t *ApplicationContext) Get(name string) interface{} {
return t.GetBean(name, Fields{}, ConstructorArgs{})
}
// Has 判断组件是否存在
func (t *ApplicationContext) Has(name string) (ok bool) {
ok = true
defer func() {
if err := recover(); err != nil {
ok = false
}
}()
t.GetBeanDefinition(name)
return ok
}
// 合并
// args | fields 内的字段会替换之前定义的值
// args 内的 nil 值将会忽略,不会替换处理
func merge(bd *BeanDefinition, fields Fields, args ConstructorArgs) *BeanDefinition {
f := len(fields) > 0
a := len(args) > 0
if !(f || a) {
return bd
}
nbd := &BeanDefinition{
Name: bd.Name,
Scope: bd.Scope,
Reflect: bd.Reflect,
InitMethod: bd.InitMethod,
ConstructorArgs: bd.ConstructorArgs,
Fields: bd.Fields,
context: bd.context,
}
if a {
// 合并替换参数,nil 忽略
tmp := ConstructorArgs{}
tmp = append(tmp, bd.ConstructorArgs...)
for k, v := range args {
if v == nil {
continue
}
ok := false
for sk := range bd.ConstructorArgs {
if sk == k {
ok = true
}
}
if ok {
tmp[k] = v
} else {
tmp = append(tmp, v)
}
}
nbd.ConstructorArgs = tmp
}
if f {
tmp := Fields{}
for k, v := range bd.Fields {
tmp[k] = v
}
for k, v := range fields {
tmp[k] = v
}
nbd.Fields = tmp
}
return nbd
}