-
Notifications
You must be signed in to change notification settings - Fork 10
/
memo.nim
199 lines (157 loc) · 5.9 KB
/
memo.nim
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
import tables, macros, options
proc memoize*[A, B](f: proc(a: A): B): proc(a: A): B =
## Returns a memoized version of the given procedure.
var cache = initTable[A, B]()
result = proc(a: A): B =
if cache.hasKey(a):
result = cache[a]
else:
result = f(a)
cache[a] = result
proc getSignature(fun: NimNode): (NimNode, NimNode) =
## Gets a routine's unrolled signature, meaning its return type
## and arguments, one by one (i.e. x,y: int will becone x:int, y:int).
result[0] = fun.params()[0]
result[1] = newTree(nnkArgList)
for i in 1 ..< fun.params.len: # first is for result type
let idents = fun.params[i]
let (typ, default) = (idents[^2], idents[^1])
for j in 0 ..< idents.len-2: # two last are type and default
result[1].add(newTree(nnkIdentDefs, idents[j], typ, default))
proc toIdents(args: NimNode): NimNode =
## Generates arguments' names from argument list.
## If there is only one argument, it is returned as an identifier.
## If there are more, a par is returned.
assert args.kind == nnkArgList
if args.len == 1:
result = args[0][0]
else:
result = newTree(nnkPar)
for arg in args: # two last are type
result.add(arg[0])
proc toTypes(args: NimNode): NimNode =
## Generates arguments' types from argument list.
## If there is only one argument, it's type is returned.
## If there are more, a par is returned.
assert args.kind == nnkArgList
if args.len == 1:
result = args[0][1]
else:
result = newTree(nnkPar)
for arg in args: # two last are type
result.add(arg[1])
## Cache having an owner who can reset it.
type OwnedCache = object
sym: NimNode
decl: NimNode
reset: NimNode
proc declCache(owner, argType, retType: NimNode): OwnedCache =
## Declares a new cache with given argument type and return type.
## Returns cache symbol, cache declaration and cache reset proc.
## For functions with arguments.
## Cache is implemented as a Table that hashes tuples of arguments.
result.sym = genSym(nskVar, "cache")
template cacheImpl(cache, argType, retType) =
var cache = initTable[argType, retType]()
result.decl = getAst(cacheImpl(result.sym, argType, retType))
template declResetCache(cacheName, owner) =
template `resetCache owner`() =
cacheName.clear()
result.reset = getAst(declResetCache(result.sym, owner.name))
proc declCacheNiladic(owner, argType, retType: NimNode): OwnedCache =
## Declares a new cache with given argument type and return type.
## Returns cache symbol, cache declaration and cache reset proc.
## For functions with no arguments.
## Cache is implemented as an Option.
result.sym = genSym(nskVar, "cache")
template cacheImpl(cache, retType) =
var cache: Option[retType] = none(retType)
result.decl = getAst(cacheImpl(result.sym, retType))
template declResetCache(cacheName, owner, retType) =
template `resetCache owner`() =
cacheName = none(retType)
result.reset = getAst(declResetCache(result.sym, owner.name, retType))
proc destructurizedCall(fun, args: NimNode): NimNode =
## For a non-tuple argument, returns a common proc call.
## For a tuple argument, destructurizes it before call.
result = newCall(fun)
if args.kind != nnkPar:
result.add(args)
else:
for arg in args:
result.add(arg)
proc destrTupNode(lhs, rhs: NimNode): NimNode =
## Generates AST for destructurizing a tuple.
## Left argument should be either an ident or par of idents.
## Right argument should be a tuple symbol or literal.
if lhs.kind != nnkPar:
result = newLetStmt(lhs, rhs)
else:
var vartup = newNimNode(nnkVarTuple)
for nam in lhs:
vartup.add(nam)
vartup.add(newEmptyNode())
vartup.add(rhs)
result = newTree(nnkLetSection, vartup)
macro memoized*(e: untyped): auto =
## Rewrites a procedure so that it utilizes memoization.
let (retType, args) = getSignature(e)
let nams = args.toIdents()
let atyp = args.toTypes()
let hasArgs = args.len > 0
let cache = if hasArgs:
declCache(e, atyp, retType)
else:
declCacheNiladic(e, atyp, retType)
# version results from which results will be memoized
let mem = newProc(name = genSym(nskProc, "memoized"))
mem.params = newNimNode(nnkFormalParams).add(e.params[0])
# wrap original implementation
let org = e.copy()
org.name = genSym(nskProc, "impl")
mem.body = newStmtList().add(org)
if hasArgs:
# pack arguments into a tuple
let argSym = genSym(nskParam, "arg")
mem.params.
add(newTree(nnkIdentDefs, argSym, atyp, newEmptyNode()))
let darg = nams.destrTupNode(argSym)
let dcall = org.name.destructurizedCall(nams)
# add implementation wrapping and argument destructurization
mem.body.
add(darg).
add(newAssignment(ident("result"), dcall))
else:
# add implementation wrapping
mem.body.
add(newAssignment(ident("result"),
newCall(org.name)
))
# main procedure implementation:
let fun = newProc(name = e.name)
fun.params = e.params.copy
# build tuple, check it in cache and optionally calculate
template funImpl(impl, cache, fun, lhs, rhs) =
impl
let lhs = rhs
if not cache.hasKey(lhs):
cache[lhs] = fun(lhs)
# check if cache has some and optionally calculate
template funImplNiladic(impl, cache, fun) =
impl
if options.isNone(cache):
cache = some(fun())
if hasArgs:
let packSym = genSym(nskLet, "pack")
fun.body = getAst(funImpl(mem, cache.sym, mem.name, packSym, nams))
fun.body.add(newAssignment(
ident("result"),
newCall(ident("[]"), cache.sym, nams)))
else:
fun.body = getAst(funImplNiladic(mem, cache.sym, mem.name))
fun.body.add(newAssignment(
ident("result"),
newCall(ident("get"), cache.sym)))
# return cache and its owner procedure
result = newStmtList(cache.decl, fun, cache.reset)
export tables.`[]=`, tables.`[]`, options.`get`