From 793e7febd7388755d78c7fe62625890e6acac789 Mon Sep 17 00:00:00 2001 From: Lewis Russell Date: Mon, 9 Jan 2023 16:55:45 +0000 Subject: [PATCH] feat: yet another async implementation Migrate changes discussed in https://github.com/neovim/neovim/issues/19624 --- lua/gitsigns/async.lua | 158 ++++++++++++++++++++++++++++++++------ teal/gitsigns/async.tl | 170 ++++++++++++++++++++++++++++++++++------- types/vim.d.tl | 2 + 3 files changed, 278 insertions(+), 52 deletions(-) diff --git a/lua/gitsigns/async.lua b/lua/gitsigns/async.lua index 99cd2727..f573b49b 100644 --- a/lua/gitsigns/async.lua +++ b/lua/gitsigns/async.lua @@ -16,19 +16,139 @@ + + local M = {} +local Async_T = {} + + + + + + + + + + + + + + + + + + + + + + + + +local handles = setmetatable({}, { __mode = 'k' }) + + +function M.running() + local current = coroutine.running() + if current and handles[current] then + return true + end +end + + +local function maxn(x) + return ((table).maxn)(x) +end + +local function is_Async_T(handle) + if handle and + type(handle) == 'table' and + vim.is_callable(handle.cancel) and + vim.is_callable(handle.is_cancelled) then + return true + end +end + + +function Async_T:cancel(cb) + + if self._current and not self._current:is_cancelled() then + self._current:cancel(cb) + end +end + +function Async_T.new(co) + local handle = setmetatable({}, { __index = Async_T }) + handles[co] = handle + return handle +end + + +function Async_T:is_cancelled() + return self._current and self._current:is_cancelled() +end + +local function run(func, callback, ...) + local co = coroutine.create(func) + local handle = Async_T.new(co) + + local function step(...) + local ret = { coroutine.resume(co, ...) } + local stat, nargs, err_or_fn = unpack(ret) + + if not stat then + error(string.format("The coroutine failed with this message: %s\n%s", + err_or_fn, debug.traceback(co))) + end + + if coroutine.status(co) == 'dead' then + if callback then + callback(unpack(ret, 4, maxn(ret))) + end + return + end + + assert(type(err_or_fn) == 'function', "type error :: expected func") + + local args = { select(4, unpack(ret)) } + args[nargs] = step + local r = err_or_fn(unpack(args, 1, nargs)) + if is_Async_T(r) then + handle._current = r + end + end + step(...) + return handle +end +local function wait(argc, func, ...) + local function pfunc(...) + local args = { ... } + local cb = args[argc] + args[argc] = function(...) + cb(true, ...) + end + xpcall(func, function(err) + cb(false, err, debug.traceback()) + end, unpack(args, 1, argc)) + end + local ret = { coroutine.yield(argc, pfunc, ...) } + local ok = ret[1] + if not ok then + local _, err, traceback = unpack(ret) + error(string.format("Wrapped function failed: %s\n%s", err, traceback)) + end -local main_co_or_nil = coroutine.running() + return unpack(ret, 2, maxn(ret)) +end @@ -37,10 +157,10 @@ local main_co_or_nil = coroutine.running() function M.wrap(func, argc) assert(argc) return function(...) - if coroutine.running() == main_co_or_nil then + if not M.running() then return func(...) end - return coroutine.yield(func, argc, ...) + return wait(argc, func, ...) end end @@ -48,35 +168,27 @@ end -function M.void(func) +function M.create(func, argc) + argc = argc or 0 return function(...) - if coroutine.running() ~= main_co_or_nil then + if M.running() then return func(...) end + local callback = select(argc + 1, ...) + return run(func, callback, unpack({ ... }, 1, argc)) + end +end - local co = coroutine.create(func) - local function step(...) - local ret = { coroutine.resume(co, ...) } - local stat, err_or_fn, nargs = unpack(ret) - if not stat then - error(string.format("The coroutine failed with this message: %s\n%s", - err_or_fn, debug.traceback(co))) - end - - if coroutine.status(co) == 'dead' then - return - end - assert(type(err_or_fn) == "function", "type error :: expected func") - local args = { select(4, unpack(ret)) } - args[nargs] = step - err_or_fn(unpack(args, 1, nargs)) +function M.void(func) + return function(...) + if M.running() then + return func(...) end - - step(...) + return run(func, nil, ...) end end diff --git a/teal/gitsigns/async.tl b/teal/gitsigns/async.tl index 655b63fe..68858f7e 100644 --- a/teal/gitsigns/async.tl +++ b/teal/gitsigns/async.tl @@ -1,6 +1,8 @@ local record Async -- Order by highest number of return types + create: function(T, integer): T + void: function (function() ): function() void: function (function(A1) ): function(A1) void: function(function(A1,A2)): function(A1,A2) @@ -20,6 +22,20 @@ local record M scheduler: function() end +local record Async_T + + -- Handle for an object currently running on the event loop. + -- The coroutine is paused while this is active. + -- Must provide methods cancel() and is_cancelled() + -- + -- Handle gets updated on each call to a wrapped functions, so provide access + -- to it via a proxy + _current: Async_T + + cancel: function(Async_T, function) + is_cancelled: function(Async_T) +end + -- Coroutine.running() was changed between Lua 5.1 and 5.2: -- - 5.1: Returns the running coroutine, or nil when called by the main thread. -- - 5.2: Returns the running coroutine plus a boolean, true when the running @@ -28,7 +44,111 @@ end -- For LuaJIT, 5.2 behaviour is enabled with LUAJIT_ENABLE_LUA52COMPAT -- -- We need to handle both. -local main_co_or_nil = coroutine.running() + +-- Store all the async threads in a weak table so we don't prevent them from +-- being garbage collected +local handles = setmetatable({}, { __mode = 'k' }) + +--- Returns whether the current execution context is async. +function M.running(): boolean + local current = coroutine.running() + if current and handles[current] then + return true + end +end + +-- hack: teal doesn't know table.maxn exists +local function maxn(x: table): integer + return ((table as table).maxn as function)(x) as integer +end + +local function is_Async_T(handle: Async_T): boolean + if handle + and type(handle) == 'table' + and vim.is_callable(handle.cancel) + and vim.is_callable(handle.is_cancelled) then + return true + end +end + +-- Analogous to uv.close +function Async_T:cancel(cb: function) + -- Cancel anything running on the event loop + if self._current and not self._current:is_cancelled() then + self._current:cancel(cb) + end +end + +function Async_T.new(co: thread): Async_T + local handle = setmetatable({} as Async_T, { __index = Async_T }) + handles[co] = handle + return handle +end + +-- Analogous to uv.is_closing +function Async_T:is_cancelled(): boolean + return self._current and self._current:is_cancelled() +end + +local function run(func: function, callback: function, ...: any): Async_T + local co = coroutine.create(func) + local handle = Async_T.new(co) + + local function step(...: any) + local ret = {coroutine.resume(co, ...)} + local stat, nargs, err_or_fn = unpack(ret) as (boolean, integer, (function(...:any): Async_T)) + + if not stat then + error(string.format("The coroutine failed with this message: %s\n%s", + err_or_fn, debug.traceback(co))) + end + + if coroutine.status(co) == 'dead' then + if callback then + callback(unpack(ret, 4, maxn(ret))) + end + return + end + + assert(type(err_or_fn) == 'function', "type error :: expected func") + + local args = {select(4, unpack(ret))} + args[nargs] = step + + local r = err_or_fn(unpack(args, 1, nargs)) + if is_Async_T(r) then + handle._current = r + end + end + + step(...) + return handle +end + +local function wait(argc: integer, func: function, ...): any... + -- Always run the wrapped functions in xpcall and re-raise the error in the + -- coroutine. This makes pcall work as normal. + local function pfunc(...: any) + local args = { ... } + local cb = args[argc] as function + args[argc] = function(...: any) + cb(true, ...) + end + xpcall(func, function(err) + cb(false, err, debug.traceback()) + end, unpack(args, 1, argc)) + end + + local ret = {coroutine.yield(argc, pfunc, ...)} + + local ok = ret[1] + if not ok then + local _, err, traceback = unpack(ret) + error(string.format("Wrapped function failed: %s\n%s", err, traceback)) + end + + return unpack(ret, 2, maxn(ret)) +end ---Creates an async function with a callback style function. ---@param func function: A callback style function to be converted. The last argument must be the callback. @@ -37,10 +157,10 @@ local main_co_or_nil = coroutine.running() function M.wrap(func: function, argc: integer): function assert(argc) return function(...): any... - if coroutine.running() == main_co_or_nil then + if not M.running() then return func(...) end - return coroutine.yield(func, argc, ...) + return wait(argc, func, ...) end end @@ -48,35 +168,27 @@ end ---called from a non-async context. Inherently this cannot return anything ---since it is non-blocking ---@param func function -function M.void(func: function): function - return function(...): any - if coroutine.running() ~= main_co_or_nil then - return func(...) +function M.create(func: function, argc: integer): function(...: any): Async_T + argc = argc or 0 + return function(...: any): Async_T + if M.running() then + return func(...) as Async_T end + local callback = select(argc+1, ...) as function + return run(func, callback, unpack({...}, 1, argc)) + end +end - local co = coroutine.create(func) - - local function step(...) - local ret = {coroutine.resume(co, ...)} as {any} - local stat, err_or_fn, nargs = unpack(ret) as (boolean, function, integer) - - if not stat then - error(string.format("The coroutine failed with this message: %s\n%s", - err_or_fn, debug.traceback(co))) - end - - if coroutine.status(co) == 'dead' then - return - end - - assert(err_or_fn is function, "type error :: expected func") - - local args = {select(4, unpack(ret))} - args[nargs] = step - err_or_fn(unpack(args, 1, nargs)) +---Use this to create a function which executes in an async context but +---called from a non-async context. Inherently this cannot return anything +---since it is non-blocking +---@param func function +function M.void(func: function(...:any)): function(...:any): Async_T + return function(...: any): Async_T + if M.running() then + return func(...) end - - step(...) + return run(func, nil, ...) end end diff --git a/types/vim.d.tl b/types/vim.d.tl index f8fdd98e..92e149e1 100644 --- a/types/vim.d.tl +++ b/types/vim.d.tl @@ -22,6 +22,8 @@ local record M iconv: function(string, string, string): string + is_callable: function(any): boolean + record go operatorfunc: string end