-
-
Notifications
You must be signed in to change notification settings - Fork 4
/
SaveManager.ttslua
executable file
·236 lines (187 loc) · 8.53 KB
/
SaveManager.ttslua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
require('ge_tts.License')
local Logger = require('ge_tts.Logger')
local TableUtils = require('ge_tts.TableUtils')
local SAVE_STATE_IDENTIFIER = "__ge_tts_save__"
---@class ge_tts__SaveManager
local SaveManager = {}
---@shape __ge_tts__SaveManager_Callbacks
---@field onLoads (fun(savedState: string): void)[]
---@field onSave nil | (fun(): nil | string)
local ORIGINAL_PSEUDO_MODULE_NAME = '__originalSavedState'
---@type table<string, __ge_tts__SaveManager_Callbacks>
local callbacks = {}
local originalOnSave = --[[---@type nil | fun(): string]] _G.onSave
local originalOnLoad = --[[---@type nil | fun(savedState: string): void]] _G.onLoad
---@param moduleName string
---@return __ge_tts__SaveManager_Callbacks
local function getModuleCallbacks(moduleName)
local moduleCallbacks = callbacks[moduleName]
if not moduleCallbacks then
moduleCallbacks = { onLoads = {} }
callbacks[moduleName] = moduleCallbacks
end
return moduleCallbacks
end
---@param moduleName string
---@param savedState string
local function executeOnLoads(moduleName, savedState)
if moduleName == ORIGINAL_PSEUDO_MODULE_NAME and originalOnLoad then
(--[[---@not nil]] originalOnLoad)(savedState)
else
local onLoads = TableUtils.copy(getModuleCallbacks(moduleName).onLoads) -- Copying because callbacks may modify onLoads whilst we iterate.
for _, onLoad in ipairs(onLoads) do
onLoad(savedState)
end
end
end
---
---Registers onSave for the specified moduleName. moduleName must be unique.
---
---Any onLoad registered for the same moduleName will be called with the savedState returned from onSave. This allows
---several Lua modules/files to independently maintain their own savedState.
---
---@param moduleName string
---@param onSave fun(): nil | string
function SaveManager.registerOnSave(moduleName, onSave)
Logger.assert(type(moduleName) == 'string' and moduleName ~= '', 'moduleName must be specified')
local moduleCallbacks = getModuleCallbacks(moduleName)
Logger.assert(moduleCallbacks.onSave == nil, 'onSave is already registered for module: ' .. moduleName)
moduleCallbacks.onSave = onSave
end
---
---Registers onLoad for the specified moduleName. You may have multiple onLoad registered for the same moduleName.
---
---The provided onLoad function will only be called with data pertaining to the provided moduleName. This allows Lua
---modules to independently maintain their own savedState.
---
---If the moduleName argument is omitted, the provided onLoad will be called with an empty string. This is useful if you
---simply want your onLoad callback called when Tabletop Simulator finished loading, but you don't need any saved state.
---
---@overload fun(onLoad: (fun(savedState: string): void)): boolean
---@overload fun(moduleName: string, onLoad: (fun(savedState: string): void)): boolean
---@param moduleNameOrOnLoad string | fun(savedState: string): void
---@param nilOrOnLoad nil | fun(savedState: string): void
function SaveManager.registerOnLoad(moduleNameOrOnLoad, nilOrOnLoad)
if type(moduleNameOrOnLoad) == 'function' then
SaveManager.registerOnLoad('', --[[---@type fun(savedState: string): void]] moduleNameOrOnLoad)
return
end
Logger.assert(type(moduleNameOrOnLoad) == 'string', 'moduleName must be a string')
local moduleName = --[[---@type string]] moduleNameOrOnLoad
local moduleCallbacks = getModuleCallbacks(moduleName)
local onLoad = --[[---@type fun(savedState: string): void]] nilOrOnLoad
table.insert(moduleCallbacks.onLoads, onLoad)
end
---
---Remove the existing onSave callback for moduleName.
---
---Returns true if there was an existing onSave callback and it was removed, or false if there was already no onSave for moduleName.
---
---@param moduleName string
---@return boolean
function SaveManager.removeOnSave(moduleName)
local moduleCallbacks = callbacks[moduleName]
if moduleCallbacks and moduleCallbacks.onSave then
moduleCallbacks.onSave = nil
return true
end
return false
end
---@overload fun(onLoad: (fun(savedState: string): void)): boolean
---@overload fun(moduleName: string, onLoad: (fun(savedState: string): void)): boolean
---@param moduleNameOrOnLoad string | fun(savedState: string): void
---@param nilOrOnLoad nil | fun(savedState: string): void
---@return boolean
function SaveManager.removeOnLoad(moduleNameOrOnLoad, nilOrOnLoad)
if type(moduleNameOrOnLoad) == 'function' then
return SaveManager.removeOnLoad('', --[[---@type fun(savedState: string): void]] moduleNameOrOnLoad)
end
Logger.assert(type(moduleNameOrOnLoad) == 'string', 'SaveManager moduleName must be a string')
local moduleName = --[[---@type string]] moduleNameOrOnLoad
local moduleCallbacks = callbacks[moduleName]
local onLoad = nilOrOnLoad
if moduleCallbacks then
for i, existingOnLoad in ipairs(moduleCallbacks.onLoads) do
if existingOnLoad == onLoad then
table.remove(moduleCallbacks.onLoads, i)
return true
end
end
end
return false
end
---@return string
function onSave()
local savedState = SAVE_STATE_IDENTIFIER
for moduleName, moduleCallbacks in pairs(callbacks) do
if moduleCallbacks.onSave then
local moduleSavedState = (--[[---@not nil]] moduleCallbacks.onSave)()
if moduleSavedState ~= nil then
Logger.assert(type(moduleSavedState) == 'string', moduleName .. "'s onSave returned a " .. type(moduleSavedState) .. ', a string is required.')
savedState = savedState .. moduleName:len() .. ' ' .. moduleName .. ' ' .. (--[[---@not nil]] moduleSavedState):len() .. ' ' .. moduleSavedState
end
end
end
if originalOnSave then
local originalSavedStated = (--[[---@not nil]] originalOnSave)()
savedState = savedState .. ORIGINAL_PSEUDO_MODULE_NAME:len() .. ' ' .. ORIGINAL_PSEUDO_MODULE_NAME .. ' ' .. originalSavedStated:len() .. ' ' .. originalSavedStated
end
return savedState
end
local GE_MODULE_PREFIX = 'ge_tts.'
---@param savedState string
function onLoad(savedState)
savedState = savedState or ''
Logger.assert(savedState == '' or savedState:sub(1, SAVE_STATE_IDENTIFIER:len()) == SAVE_STATE_IDENTIFIER, "When working with ge_tts, you must use ge_tts.SaveManager instead of writing directly to script_state.")
local savedStateLength = savedState:len()
local moduleNameOffset = SAVE_STATE_IDENTIFIER:len() + 1
local i = moduleNameOffset
---@type table<string, {rangeStart: number, rangeEnd: number}>
local moduleStateRanges = {}
repeat
if savedState:sub(i, i) == ' ' then
local moduleNameLength = tonumber(savedState:sub(moduleNameOffset, i - 1))
local moduleName = savedState:sub(i + 1, i + moduleNameLength)
local moduleSizeOffset = i + moduleNameLength + 2
for j = moduleSizeOffset, savedStateLength do
if savedState:sub(j, j) == ' ' then
local moduleStateLength = tonumber(savedState:sub(moduleSizeOffset, j - 1))
local moduleSavedStateEnd = j + moduleStateLength
moduleStateRanges[moduleName] = {
rangeStart = j + 1,
rangeEnd = moduleSavedStateEnd
}
moduleNameOffset = moduleSavedStateEnd + 1
i = moduleSavedStateEnd + 1
break
end
end
else
i = i + 1
end
until i > savedStateLength
-- ge_tts listeners execute first
for moduleName, _ in pairs(callbacks) do
if moduleName:sub(1, GE_MODULE_PREFIX:len()) == GE_MODULE_PREFIX then
local stateRange = moduleStateRanges[moduleName]
if stateRange then
local moduleSavedState = savedState:sub(stateRange.rangeStart, stateRange.rangeEnd)
executeOnLoads(moduleName, moduleSavedState)
else
executeOnLoads(moduleName, '')
end
end
end
for moduleName, _ in pairs(callbacks) do
if moduleName:sub(1, GE_MODULE_PREFIX:len()) ~= GE_MODULE_PREFIX then
local stateRange = moduleStateRanges[moduleName]
if stateRange then
local moduleSavedState = savedState:sub(stateRange.rangeStart, stateRange.rangeEnd)
executeOnLoads(moduleName, moduleSavedState)
else
executeOnLoads(moduleName, '')
end
end
end
end
return SaveManager