-
Notifications
You must be signed in to change notification settings - Fork 0
/
fiber.lua
127 lines (102 loc) · 3 KB
/
fiber.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
local coroutine = require('coroutine')
local debug = require('debug')
local fiber = {}
-- Map of managed coroutines
local fibers = {}
local function formatError(co, err)
local stack = debug.traceback(co, tostring(err))
if type(err) == "table" then
err.message = stack
return err
end
return stack
end
local function check(co, success, ...)
local fiber = fibers[co]
if not success then
local err = formatError(co, ...)
if fiber and fiber.callback then
return fiber.callback(err)
end
error(err)
end
-- Abort on non-managed coroutines.
if not fiber then
return ...
end
-- If the fiber is done, pass the result to the callback and cleanup.
if not fiber.paused then
fibers[co] = nil
if fiber.callback then
fiber.callback(nil, ...)
end
return ...
end
fiber.paused = false
end
-- Create a managed fiber as a continuable
function fiber.new(fn, ...)
local args = {...}
local nargs = select("#", ...)
return function (callback)
local co = coroutine.create(fn)
local fiber = {
callback = callback
}
fibers[co] = fiber
check(co, coroutine.resume(co, unpack(args, 1, nargs)))
end
end
-- Wait in this coroutine for the continuation to complete
function fiber.wait(continuation)
if type(continuation) ~= "function" then
error("Continuation must be a function.")
end
-- Find out what thread we're running in.
local co, isMain = coroutine.running()
-- When main, Lua 5.1 `co` will be nil, lua 5.2, `isMain` will be true
if not co or isMain then
error("Can't wait from the main thread.")
end
local fiber = fibers[co]
-- Execute the continuation
local async, ret, nret
continuation(function (...)
-- If async hasn't been set yet, that means the callback was called before
-- the continuation returned. We should store the result and wait till it
-- returns later on.
if not async then
async = false
ret = {...}
nret = select("#", ...)
return
end
-- Callback was called we can resume the coroutine.
-- When it yields, check for managed coroutines
check(co, coroutine.resume(co, ...))
end)
-- If the callback was called early, we can just return the value here and
-- not bother suspending the coroutine in the first place.
if async == false then
return unpack(ret, 1, nret)
end
-- Mark that the contination has returned.
async = true
-- Mark the fiber as paused if there is one.
if fiber then fiber.paused = true end
-- Suspend the coroutine and wait for the callback to be called.
return coroutine.yield()
end
-- This is a wrapper around wait that strips off the first result and
-- interprets is as an error to throw.
function fiber.await(...)
-- TODO: find out if there is a way to count the number of return values from
-- fiber.wait while still storing the results in a table.
local results = {fiber.wait(...)}
local nresults = sel
if results[1] then
error(results[1])
end
return unpack(results, 2)
end
return fiber