From badaef04f845828652a13493a5f5f237682763f7 Mon Sep 17 00:00:00 2001 From: Lewis Russell Date: Sun, 1 Oct 2023 16:41:58 +0100 Subject: [PATCH] test: refactor --- .luarc.json | 1 + test/client/session.lua | 30 ++-- test/gitsigns_spec.lua | 4 +- test/helpers.lua | 312 +++++++++++----------------------------- test/screen.lua | 1 + 5 files changed, 106 insertions(+), 242 deletions(-) diff --git a/.luarc.json b/.luarc.json index ece73b80..4149a4f2 100644 --- a/.luarc.json +++ b/.luarc.json @@ -8,6 +8,7 @@ "lua", "$VIMRUNTIME", "${3rd}/busted/library", + "${3rd}/luassert/library", "${3rd}/luv/library" ], "checkThirdParty": false diff --git a/test/client/session.lua b/test/client/session.lua index 75179ad8..4b43637e 100644 --- a/test/client/session.lua +++ b/test/client/session.lua @@ -1,11 +1,11 @@ -local uv = require('luv') +local uv = vim.loop local MsgpackRpcStream = require('test.client.msgpack_rpc_stream') ---- @class Session +--- @class NvimSession --- @field private _msgpack_rpc_stream MsgpackRpcStream --- @field private _pending_messages string[] ---- @field private _prepare uv_prepare_t ---- @field private _timer uv_timer_t +--- @field private _prepare uv.uv_prepare_t +--- @field private _timer uv.uv_timer_t --- @field private _is_running boolean local Session = {} Session.__index = Session @@ -57,7 +57,17 @@ function Session.new(stream) }, Session) end +--- @param timeout integer +--- @return string function Session:next_message(timeout) + if self._is_running then + error('Event loop already running') + end + + if #self._pending_messages > 0 then + return table.remove(self._pending_messages, 1) + end + local function on_request(method, args, response) table.insert(self._pending_messages, {'request', method, args, response}) uv.stop() @@ -68,14 +78,6 @@ function Session:next_message(timeout) uv.stop() end - if self._is_running then - error('Event loop already running') - end - - if #self._pending_messages > 0 then - return table.remove(self._pending_messages, 1) - end - self:_run(on_request, on_notification, timeout) return table.remove(self._pending_messages, 1) end @@ -100,6 +102,10 @@ function Session:request(method, ...) return true, result end +---@param request_cb fun()? +---@param notification_cb fun()? +---@param setup_cb fun()? +---@param timeout integer? function Session:run(request_cb, notification_cb, setup_cb, timeout) local function on_request(method, args, response) coroutine_exec(request_cb, method, args, function(status, result, flag) diff --git a/test/gitsigns_spec.lua b/test/gitsigns_spec.lua index eddce370..21755240 100644 --- a/test/gitsigns_spec.lua +++ b/test/gitsigns_spec.lua @@ -35,8 +35,8 @@ local eq = helpers.eq helpers.env() describe('gitsigns', function() - local screen - local config + local screen --- @type NvimScreen + local config --- @type table before_each(function() clear() diff --git a/test/helpers.lua b/test/helpers.lua index d35855df..8e7d048c 100644 --- a/test/helpers.lua +++ b/test/helpers.lua @@ -1,6 +1,5 @@ local assert = require('luassert') -local busted = require('busted') -local luv = require('luv') +local luv = vim.loop local Session = require('test.client.session') local uv_stream = require('test.client.uv_stream') local ChildProcessStream = uv_stream.ChildProcessStream @@ -14,36 +13,16 @@ function M.sleep(ms) luv.sleep(ms) end --- Calls fn() until it succeeds, up to `max` times or until `max_ms` --- milliseconds have passed. -function M.retry(max, max_ms, fn) - assert(max == nil or max > 0) - assert(max_ms == nil or max_ms > 0) - local tries = 1 - local timeout = (max_ms and max_ms or 10000) - local start_time = luv.now() - while true do - local status, result = pcall(fn) - if status then - return result - end - luv.update_time() -- Update cached value of luv.now() (libuv: uv_now()). - if (max and tries >= max) or (luv.now() - start_time > timeout) then - busted.fail(string.format("retry() attempts: %d\n%s", tries, tostring(result)), 2) - end - tries = tries + 1 - luv.sleep(20) -- Avoid hot loop... - end -end - M.eq = assert.are.same M.neq = assert.are_not.same local function epicfail(state, arguments, _) + --- @diagnostic disable-next-line state.failure_message = arguments[1] return false end +--- @diagnostic disable-next-line:missing-parameter assert:register("assertion", "epicfail", epicfail) function M.matches(pat, actual) @@ -53,17 +32,21 @@ function M.matches(pat, actual) error(string.format('Pattern does not match.\nPattern:\n%s\nActual:\n%s', pat, actual)) end --- Reads text lines from `filename` into a table. --- --- filename: path to file --- start: start line (1-indexed), negative means "lines before end" (tail) +--- Reads text lines from `filename` into a table. +--- +--- filename: path to file +--- start: start line (1-indexed), negative means "lines before end" (tail) +--- @param filename string +--- @param start? integer +--- @return string[]? local function read_file_list(filename, start) - local lnum = (start ~= nil and type(start) == 'number') and start or 1 - local tail = (lnum < 0) + local lnum = start or 1 + local tail = lnum < 0 local maxlines = tail and math.abs(lnum) or nil local file = io.open(filename, 'r') + if not file then - return nil + return end -- There is no need to read more than the last 2MB of the log file, so seek @@ -75,10 +58,10 @@ local function read_file_list(filename, start) end file:seek("set", offset) - local lines = {} + local lines = {} --- @type string[] local i = 1 local line = file:read("*l") - while line ~= nil do + while line do if i >= start then table.insert(lines, line) if #lines > maxlines then @@ -92,6 +75,11 @@ local function read_file_list(filename, start) return lines end +--- @generic R +--- @param fn fun(...): R +--- @param ... any arguments +--- @return boolean +--- @return R|string function M.pcall(fn, ...) assert(type(fn) == 'function') local status, rv = pcall(fn, ...) @@ -119,19 +107,23 @@ function M.pcall(fn, ...) return status, errmsg end --- Invokes `fn` and returns the error string (with truncated paths), or raises --- an error if `fn` succeeds. --- --- Replaces line/column numbers with zero: --- shared.lua:0: in function 'gsplit' --- shared.lua:0: in function ' --- --- Usage: --- -- Match exact string. --- eq('e', pcall_err(function(a, b) error('e') end, 'arg1', 'arg2')) --- -- Match Lua pattern. --- matches('e[or]+$', pcall_err(function(a, b) error('some error') end, 'arg1', 'arg2')) --- +--- Invokes `fn` and returns the error string (with truncated paths), or raises +--- an error if `fn` succeeds. +--- +--- Replaces line/column numbers with zero: +--- shared.lua:0: in function 'gsplit' +--- shared.lua:0: in function ' +--- +--- Usage: +--- -- Match exact string. +--- eq('e', pcall_err(function(a, b) error('e') end, 'arg1', 'arg2')) +--- -- Match Lua pattern. +--- matches('e[or]+$', pcall_err(function(a, b) error('some error') end, 'arg1', 'arg2')) +--- +--- @generic R +--- @param fn fun(...): R +--- @param ... any arguments +--- @return R local function pcall_err_withfile(fn, ...) assert(type(fn) == 'function') local status, rv = M.pcall(fn, ...) @@ -171,9 +163,12 @@ function M.concat_tables(...) return ret end +--- @param str string +--- @param leave_indent? integer +--- @return string function M.dedent(str, leave_indent) -- find minimum common indent across lines - local indent = nil + local indent = nil --- @type string? for line in str:gmatch('[^\n]+') do local line_indent = line:match('^%s+') or '' if indent == nil or #line_indent < #indent then @@ -240,31 +235,23 @@ local nvim_argv = { '--cmd', nvim_set, '--cmd', 'mapclear', '--cmd', 'mapclear!', - '--embed' + '--embed', + '--headless' } -local prepend_argv - -if prepend_argv then - local new_nvim_argv = {} - local len = #prepend_argv - for i = 1, len do - new_nvim_argv[i] = prepend_argv[i] - end - for i = 1, #nvim_argv do - new_nvim_argv[i + len] = nvim_argv[i] - end - nvim_argv = new_nvim_argv - M.prepend_argv = prepend_argv -end - -local session, loop_running, last_error +local session --- @type NvimSession? +local loop_running = false +local last_error --- @type string? function M.get_session() return session end +--- @param method string +--- @param ... any +--- @return any[] local function request(method, ...) + assert(session) local status, rv = session:request(method, ...) if not status then if loop_running then @@ -287,8 +274,14 @@ local function call_and_stop_on_error(lsession, ...) return result end +--- @param lsession NvimSession +--- @param request_cb fun()? +--- @param notification_cb fun()? +--- @param timeout integer +--- @return unknown function M.run_session(lsession, request_cb, notification_cb, timeout) - local on_request, on_notification + local on_request --- @type fun() + local on_notification --- @type fun() if request_cb then function on_request(method, args) @@ -314,20 +307,26 @@ function M.run_session(lsession, request_cb, notification_cb, timeout) return lsession.eof_err end ----- Executes an ex-command. VimL errors manifest as client (lua) errors, but ----- v:errmsg will not be updated. +--- Executes an ex-command. VimL errors manifest as client (lua) errors, but +--- v:errmsg will not be updated. +--- @param cmd string function M.command(cmd) request('nvim_command', cmd) end ----- Evaluates a VimL expression. ----- Fails on VimL error, but does not update v:errmsg. +--- Evaluates a VimL expression. +--- Fails on VimL error, but does not update v:errmsg. +--- @param expr string +--- @return any[] function M.eval(expr) return request('nvim_eval', expr) end ----- Executes a VimL function via RPC. ----- Fails on VimL error, but does not update v:errmsg. +--- Executes a VimL function via RPC. +--- Fails on VimL error, but does not update v:errmsg. +--- @param name string +--- @param ... any +--- @return any[] function M.call(name, ...) return request('nvim_call_function', name, {...}) end @@ -337,8 +336,9 @@ local function assert_alive() assert(2 == M.eval('1+1'), 'crash? request failed') end ----- Sends user input to Nvim. ----- Does not fail on VimL error, but v:errmsg will be updated. +--- Sends user input to Nvim. +--- Does not fail on VimL error, but v:errmsg will be updated. +--- @param input string local function nvim_feed(input) while #input > 0 do local written = request('nvim_input', input) @@ -350,77 +350,21 @@ local function nvim_feed(input) end end +--- @param ... string function M.feed(...) for _, v in ipairs({...}) do nvim_feed(M.dedent(v)) end end +--- @param ... string local function rawfeed(...) for _, v in ipairs({...}) do nvim_feed(M.dedent(v)) end end -local function merge_args(...) - local i = 1 - local argv = {} - for anum = 1,select('#', ...) do - local args = select(anum, ...) - if args then - for _, arg in ipairs(args) do - argv[i] = arg - i = i + 1 - end - end - end - return argv -end - --- Removes Nvim startup args from `args` matching items in `args_rm`. --- --- - Special case: "-u", "-i", "--cmd" are treated specially: their "values" are also removed. --- - Special case: "runtimepath" will remove only { '--cmd', 'set runtimepath^=…', } --- --- Example: --- args={'--headless', '-u', 'NONE'} --- args_rm={'--cmd', '-u'} --- Result: --- {'--headless'} --- --- All matching cases are removed. --- --- Example: --- args={'--cmd', 'foo', '-N', '--cmd', 'bar'} --- args_rm={'--cmd', '-u'} --- Result: --- {'-N'} -local function remove_args(args, args_rm) - local new_args = {} - local skip_following = {'-u', '-i', '-c', '--cmd', '-s', '--listen'} - if not args_rm or #args_rm == 0 then - return {unpack(args)} - end - for _, v in ipairs(args_rm) do - assert(type(v) == 'string') - end - local last = '' - for _, arg in ipairs(args) do - if vim.tbl_contains(skip_following, last) then - last = '' - elseif vim.tbl_contains(args_rm, arg) then - last = arg - elseif arg == runtime_set and vim.tbl_contains(args_rm, 'runtimepath') then - table.remove(new_args) -- Remove the preceding "--cmd". - last = '' - else - table.insert(new_args, arg) - end - end - return new_args -end - -function M.check_close() +local function check_close() if not session then return end @@ -437,100 +381,14 @@ function M.check_close() session = nil end ---- @param io_extra used for stdin_fd, see :help ui-option -function M.spawn(argv, merge, env, keep, io_extra) - if not keep then - M.check_close() - end - - local child_stream = ChildProcessStream.spawn( - merge and merge_args(prepend_argv, argv) or argv, - env, io_extra) - return Session.new(child_stream) -end - --- Builds an argument list for use in clear(). --- ----@see clear() for parameters. -local function new_argv(...) - local args = {unpack(nvim_argv)} - table.insert(args, '--headless') - if _G._nvim_test_id then - -- Set the server name to the test-id for logging. #8519 - table.insert(args, '--listen') - table.insert(args, _G._nvim_test_id) - end - local new_args - local io_extra - local env = nil - local opts = select(1, ...) - if type(opts) ~= 'table' then - new_args = {...} - else - args = remove_args(args, opts.args_rm) - if opts.env then - local env_opt = {} - for k, v in pairs(opts.env) do - assert(type(k) == 'string') - assert(type(v) == 'string') - env_opt[k] = v - end - for _, k in ipairs({ - 'HOME', - 'ASAN_OPTIONS', - 'TSAN_OPTIONS', - 'MSAN_OPTIONS', - 'LD_LIBRARY_PATH', - 'PATH', - 'NVIM_LOG_FILE', - 'NVIM_RPLUGIN_MANIFEST', - 'GCOV_ERROR_FILE', - 'XDG_DATA_DIRS', - 'TMPDIR', - 'VIMRUNTIME', - }) do - -- Set these from the environment unless the caller defined them. - if not env_opt[k] then - env_opt[k] = os.getenv(k) - end - end - env = {} - for k, v in pairs(env_opt) do - env[#env + 1] = k .. '=' .. v - end - end - new_args = opts.args or {} - io_extra = opts.io_extra - end - for _, arg in ipairs(new_args) do - table.insert(args, arg) - end - return args, env, io_extra -end - --- same params as clear, but does returns the session instead --- of replacing the default session -local function spawn_argv(keep, ...) - local argv, env, io_extra = new_argv(...) - return M.spawn(argv, nil, env, keep, io_extra) -end - --- Starts a new global Nvim session. --- --- Parameters are interpreted as startup args, OR a map with these keys: --- args: List: Args appended to the default `nvim_argv` set. --- args_rm: List: Args removed from the default set. All cases are --- removed, e.g. args_rm={'--cmd'} removes all cases of "--cmd" --- (and its value) from the default set. --- env: Map: Defines the environment of the new session. --- --- Example: --- clear('-e') --- clear{args={'-e'}, args_rm={'-i'}, env={TERM=term}} -function M.clear(...) - session = spawn_argv(false, ...) +--- Starts a new global Nvim session. +function M.clear() + check_close() + local child_stream = ChildProcessStream.spawn(nvim_argv) + session = Session.new(child_stream) end +---@param ... string function M.insert(...) nvim_feed('i') for _, v in ipairs({...}) do @@ -541,8 +399,7 @@ function M.insert(...) end function M.create_callindex(func) - local table = {} - setmetatable(table, { + return setmetatable({}, { __index = function(tbl, arg1) local ret = function(...) return func(arg1, ...) @@ -551,7 +408,6 @@ function M.create_callindex(func) return ret end, }) - return table end function M.nvim(method, ...) @@ -609,7 +465,7 @@ function M.exec_lua(code, ...) return M.meths.exec_lua(code, {...}) end ---- @param after_each fun(name:string,block:fun()) +--- @param after_each fun(block:fun()) function M.after_each(after_each) after_each(function() if not session then diff --git a/test/screen.lua b/test/screen.lua index 3c173e9f..5be916c8 100644 --- a/test/screen.lua +++ b/test/screen.lua @@ -88,6 +88,7 @@ local function isempty(v) return type(v) == 'table' and next(v) == nil end +--- @class NvimScreen local Screen = {} Screen.__index = Screen