From 458fd9f2ee884c215a661133516462d296363efa Mon Sep 17 00:00:00 2001 From: Michael Martin <3277009+flrgh@users.noreply.github.com> Date: Tue, 21 Mar 2023 09:51:20 -0700 Subject: [PATCH] tests(helpers): add extra context to wait_until() --- .../01-helpers/01-helpers_spec.lua | 324 +++++++++++- spec/helpers.lua | 56 +-- spec/helpers/wait.lua | 466 ++++++++++++++++++ 3 files changed, 791 insertions(+), 55 deletions(-) create mode 100644 spec/helpers/wait.lua diff --git a/spec/02-integration/01-helpers/01-helpers_spec.lua b/spec/02-integration/01-helpers/01-helpers_spec.lua index 7045d930e97..96d089465ec 100644 --- a/spec/02-integration/01-helpers/01-helpers_spec.lua +++ b/spec/02-integration/01-helpers/01-helpers_spec.lua @@ -1,5 +1,6 @@ local helpers = require "spec.helpers" local cjson = require "cjson" +local spy = require "luassert.spy" for _, strategy in helpers.each_strategy() do @@ -101,16 +102,14 @@ for _, strategy in helpers.each_strategy() do assert(helpers.restart_kong(env)) -- ensure we can make at least one successful request after restarting - helpers.wait_until(function() - -- helpers.proxy_client() will throw an error if connect() fails, - -- so we need to wrap the whole thing in pcall - return pcall(function() + assert.eventually(function() local httpc = helpers.proxy_client(1000, 15555) local res = httpc:get("/") - assert(res.status == 200) + res:read_body() httpc:close() + assert.res_status(200, res) end) - end) + .has_no_error("Kong responds to proxy requests after restart") end before_each(function() @@ -589,6 +588,276 @@ for _, strategy in helpers.each_strategy() do end) + describe("eventually()", function() + local function returns(v) + return function() + return v + end + end + + local function raises(e) + return function() + error(e) + end + end + + local function noop() end + + local function after_n(n, callback, before) + local i = 0 + before = before or noop + + return function() + i = i + 1 + + if i > n then + return callback() + end + + return before() + end + end + + -- returns a value on the nth time the function is called + -- otherwise, returns pre_n + local function return_after_n(n, value, pre_n) + return after_n(n, returns(value), returns(pre_n)) + end + + --- raise an error until the function is called n times + local function raise_for_n(n) + return after_n(n, noop, raises("not yet")) + end + + --- raise an error until the function is called n times + local function raise_after(n) + return after_n(n, raises("done"), noop) + end + + local function new_timer() + local start + local timer = {} + timer.start = function() + ngx.update_time() + start = ngx.now() + return timer + end + + timer.elapsed = function() + ngx.update_time() + return ngx.now() - start + end + + return timer + end + + it("calls a function until a condition is met", function() + assert.has_no_error(function() + assert.eventually(return_after_n(10, true)) + .is_truthy() + end) + end) + + it("returns on the first success", function() + local fn = spy.new(returns(true)) + + assert.has_no_error(function() + assert.eventually(fn) + .is_truthy() + end) + + assert.spy(fn).was.called(1) + + + fn = spy.new(return_after_n(5, true, nil)) + + assert.has_no_error(function() + assert.eventually(fn) + .is_truthy() + end) + + assert.spy(fn).was.called(6) + end) + + it("gives up after a timeout", function() + local timer = new_timer().start() + + local timeout = 0.5 + local sleep = timeout / 10 + + assert.has_error(function() + assert.eventually(function() + ngx.sleep(sleep) + end) + .with_timeout(timeout) + .is_truthy() + end) + + assert.near(timeout, timer.elapsed(), 0.1) + end) + + describe("max_tries", function() + it("can limit the number of tries until giving up", function() + local n = 10 + local fn = spy.new(return_after_n(n, true)) + + assert.has_error(function() + assert.eventually(fn) + .with_max_tries(n) + .is_truthy() + end) + + assert.spy(fn).was.called(n) + end) + end) + + describe("ignore_exceptions", function() + it("causes raised errors to be treated as a normal failure", function() + local maybe_raise = raise_for_n(2) + local eventually_succeed = return_after_n(3, true) + + local fn = spy.new(function() + maybe_raise() + return eventually_succeed() + end) + + assert.has_no_error(function() + assert.eventually(fn) + .ignore_exceptions(true) + .is_truthy() + end) + + assert.spy(fn).was.called(6) + end) + + it("is turned off by default", function() + local maybe_raise = raise_for_n(5) + local eventually_succeed = return_after_n(3, true) + + local fn = spy.new(function() + maybe_raise() + return eventually_succeed() + end) + + assert.has_error(function() + assert.eventually(fn) + .is_truthy() + end) + + assert.spy(fn).was.called(1) + end) + end) + + describe("conditions", function() + it("is_truthy() requires a truthy return value", function() + assert.has_no_error(function() + assert.eventually(returns("yup")) + .is_truthy() + end) + + -- it's common knowledge that null is truthy, but let's + -- test it anyways + assert.has_no_error(function() + assert.eventually(returns(ngx.null)) + .is_truthy() + end) + + assert.has_error(function() + assert.eventually(returns(false)) + .with_timeout(0.001) + .is_truthy() + end) + + assert.has_error(function() + assert.eventually(returns(nil)) + .with_timeout(0.001) + .is_truthy() + end) + end) + + it("is_falsy() requires a falsy return value", function() + assert.has_no_error(function() + assert.eventually(returns(nil)) + .is_falsy() + end) + + assert.has_no_error(function() + assert.eventually(returns(false)) + .is_falsy() + end) + + assert.has_error(function() + assert.eventually(returns(true)) + .with_timeout(0.001) + .is_falsy() + end) + + assert.has_error(function() + assert.eventually(returns("yup")) + .with_timeout(0.001) + .is_falsy() + end) + + -- it's common knowledge that null is truthy, but let's + -- test it anyways + assert.has_error(function() + assert.eventually(returns(ngx.null)) + .with_timeout(0.001) + .is_falsy() + end) + end) + + it("has_no_error() requires the function not to raise an error()", function() + assert.has_no_error(function() + assert.eventually(returns(true)) + .has_no_error() + end) + + assert.has_no_error(function() + -- note: raise_until() does not return any value in the success case + assert.eventually(raise_for_n(5)) + .has_no_error() + end) + + assert.has_error(function() + assert.eventually(error) + .with_timeout(0.001) + .has_no_error() + end) + + assert.has_error(function() + assert.eventually(error) + .with_timeout(0.001) + .has_no_error() + end) + end) + + it("has_error() requires the function to raise an error()", function() + assert.has_no_error(function() + assert.eventually(error) + .has_error() + end) + + assert.has_no_error(function() + assert.eventually(raise_after(5)) + .has_error() + end) + + assert.has_error(function() + assert.eventually(returns(true)) + .with_timeout(0.001) + .has_error() + end) + + assert.has_error(function() + assert.eventually(returns(false)) + .with_timeout(0.001) + .has_error() + end) + end) + end) + end) + end) end @@ -602,7 +871,7 @@ describe("helpers: utilities", function() end) describe("wait_until()", function() - it("does not errors out if thing happens", function() + it("does not raise an error when the function returns truth-y", function() assert.has_no_error(function() local i = 0 helpers.wait_until(function() @@ -611,19 +880,48 @@ describe("helpers: utilities", function() end, 3) end) end) - it("errors out after delay", function() + + it("raises an error if the function does not return truth-y within the timeout", function() assert.error_matches(function() helpers.wait_until(function() return false, "thing still not done" end, 1) - end, "timeout: thing still not done") + end, "timed out") end) - it("reports errors in test function", function() - assert.error_matches(function() + + it("fails when test function raises an error()", function() + local e = assert.has_error(function() + helpers.wait_until(function() + error("oops") + end, 1) + end) + assert.is_string(e) + assert.matches("oops", e) + end) + + it("fails when test function raised an assertion error", function() + assert.has_error(function() helpers.wait_until(function() - assert.equal("foo", "bar") + assert.is_true(false) + end, 1) + end) + end) + end) + + describe("pwait_until()", function() + it("succeeds when a function does not raise an error()", function() + assert.has_no_error(function() + local i = 0 + helpers.pwait_until(function() + i = i + 1 + + if i < 5 then + error("i is less than 5") + end + + assert.is_true(i > 6) end, 1) - end, "Expected objects to be equal.", nil, true) + end) end) end) diff --git a/spec/helpers.lua b/spec/helpers.lua index 3be4570b86b..42df24cd80f 100644 --- a/spec/helpers.lua +++ b/spec/helpers.lua @@ -1430,7 +1430,7 @@ end local say = require "say" local luassert = require "luassert.assert" - +require("spec.helpers.wait") --- Waits until a specific condition is met. -- The check function will repeatedly be called (with a fixed interval), until @@ -1449,44 +1449,12 @@ local luassert = require "luassert.assert" -- -- wait 10 seconds for a file "myfilename" to appear -- helpers.wait_until(function() return file_exist("myfilename") end, 10) local function wait_until(f, timeout, step) - if type(f) ~= "function" then - error("arg #1 must be a function", 2) - end - - if timeout ~= nil and type(timeout) ~= "number" then - error("arg #2 must be a number", 2) - end - - if step ~= nil and type(step) ~= "number" then - error("arg #3 must be a number", 2) - end - - ngx.update_time() - - timeout = timeout or 5 - step = step or 0.05 - - local tstart = ngx.now() - local texp = tstart + timeout - local ok, res, err - - repeat - ok, res, err = pcall(f) - ngx.sleep(step) - ngx.update_time() - until not ok or res or ngx.now() >= texp - - if not ok then - -- report error from `f`, such as assert gone wrong - error(tostring(res), 2) - elseif not res and err then - -- report a failure for `f` to meet its condition - -- and eventually an error return value which could be the cause - error("wait_until() timeout: " .. tostring(err) .. " (after delay: " .. timeout .. "s)", 2) - elseif not res then - -- report a failure for `f` to meet its condition - error("wait_until() timeout (after delay " .. timeout .. "s)", 2) - end + luassert.wait_until({ + condition = "truthy", + fn = f, + timeout = timeout, + step = step, + }) end @@ -1503,11 +1471,15 @@ end -- @return nothing. It returns when the condition is met, or throws an error -- when it times out. local function pwait_until(f, timeout, step) - wait_until(function() - return pcall(f) - end, timeout, step) + luassert.wait_until({ + condition = "no_error", + fn = f, + timeout = timeout, + step = step, + }) end + --- Wait for some timers, throws an error on timeout. -- -- NOTE: this is a regular Lua function, not a Luassert assertion. diff --git a/spec/helpers/wait.lua b/spec/helpers/wait.lua new file mode 100644 index 00000000000..fb14b430f59 --- /dev/null +++ b/spec/helpers/wait.lua @@ -0,0 +1,466 @@ +local say = require "say" +local luassert = require "luassert.assert" +local pretty = require "pl.pretty" + +local fmt = string.format +local insert = table.insert + +local E_ARG_COUNT = "assertion.internal.argtolittle" +local E_ARG_TYPE = "assertion.internal.badargtype" + + +---@alias spec.helpers.wait.ctx.result +---| "timeout" +---| "error" +---| "success" +---| "max tries" + +local TIMEOUT = "timeout" +local ERROR = "error" +local SUCCESS = "success" +local MAX_TRIES = "max tries" + + +---@alias spec.helpers.wait.ctx.condition +---| "truthy" +---| "falsy" +---| "error" +---| "no_error" + + +--- helper functions that check the result of pcall() and report if the +--- wait ctx condition has been met +--- +---@type table +local COND = { + truthy = function(pok, ok_or_err) + return (pok and ok_or_err and true) or false + end, + + falsy = function(pok, ok_or_err) + return (pok and not ok_or_err) or false + end, + + error = function(pok) + return not pok + end, + + no_error = function(pok) + return (pok and true) or false + end, +} + + +---@param ... any +---@return any +local function first_non_nil(...) + local n = select("#", ...) + for i = 1, n do + local v = select(i, ...) + if v ~= nil then + return v + end + end +end + + +---@param exp_type string +---@param field string|integer +---@param value any +---@param caller? string +---@param level? integer +---@return any +local function check_type(exp_type, field, value, caller, level) + caller = caller or "wait_until" + level = (level or 1) + 1 + + local got_type = type(value) + + -- accept callable tables + if exp_type == "function" + and got_type == "table" + and type(debug.getmetatable(value)) == "table" + and type(debug.getmetatable(value).__call) == "function" + then + got_type = "function" + end + + if got_type ~= exp_type then + error(say(E_ARG_TYPE, { field, caller, exp_type, type(value) }), + level) + end + + return value +end + + +local DEFAULTS = { + timeout = 5, + step = 0.05, + message = "UNSPECIFIED", + max_tries = 0, + ignore_exceptions = false, + condition = "truthy", +} + + +---@class spec.helpers.wait.ctx +--- +---@field condition "truthy"|"falsy"|"error"|"no_error" +---@field condition_met boolean +---@field debug? boolean +---@field elapsed number +---@field last_raised_error any +---@field error_raised boolean +---@field fn function +---@field ignore_exceptions boolean +---@field last_returned_error any +---@field last_returned_value any +---@field last_error any +---@field message? string +---@field result spec.helpers.wait.ctx.result +---@field step number +---@field timeout number +---@field tries number +local wait_ctx = { + condition = nil, + condition_met = false, + debug = nil, + elapsed = 0, + error = nil, + error_raised = false, + ignore_exceptions = nil, + last_returned_error = nil, + last_returned_value = nil, + max_tries = nil, + message = nil, + result = "timeout", + step = nil, + timeout = nil, + tries = 0, +} + + +local wait_ctx_mt = { __index = wait_ctx } + +function wait_ctx:dd(msg) + if self.debug then + print(fmt("\n\n%s\n\n", pretty.write(msg))) + end +end + + +function wait_ctx:wait() + ngx.update_time() + + local tstart = ngx.now() + local texp = tstart + self.timeout + local ok, res, err + + local is_met = COND[self.condition] + + if self.condition == "no_error" then + self.ignore_exceptions = true + end + + local tries_remain = self.max_tries + + local f = self.fn + + while true do + ok, res, err = pcall(f) + + self.tries = self.tries + 1 + tries_remain = tries_remain - 1 + + self.condition_met = is_met(ok, res) + + self:dd(self) + + -- yay! + if self.condition_met then + self.last_returned_value = res + self.result = SUCCESS + break + + -- non-truthy return value + elseif ok and not res then + self.last_returned_error = first_non_nil(err, self.last_returned_error) + self.last_error = self.last_returned_error + + -- error() + else + self.error_raised = true + self.last_raised_error = first_non_nil(res, "UNKNOWN") + self.last_error = self.last_raised_error + + if not self.ignore_exceptions then + self.result = ERROR + break + end + end + + if tries_remain == 0 then + self.result = MAX_TRIES + break + end + + ngx.update_time() + + if ngx.now() >= texp then + self.result = TIMEOUT + break + end + + ngx.sleep(self.step) + end + + ngx.update_time() + self.elapsed = ngx.now() - tstart + + self:dd(self) + + -- re-raise + if self.error_raised and not self.ignore_exceptions then + error(self.last_raised_error, 2) + end +end + + +local CTX_TYPES = { + condition = "string", + fn = "function", + max_tries = "number", + timeout = "number", + message = "string", + step = "number", + ignore_exceptions = "boolean", +} + + +function wait_ctx:validate(key, value, caller, level) + local typ = CTX_TYPES[key] + + if not typ then + -- we don't care about validating this key + return value + end + + if key == "condition" and type(value) == "string" then + assert(COND[value] ~= nil, + say(E_ARG_TYPE, { "condition", caller or "wait_until", + "one of: 'truthy', 'falsy', 'error', 'no_error'", + value }), level + 1) + end + + + return check_type(typ, key, value, caller, level) +end + + +---@param state table +---@return spec.helpers.wait.ctx +local function get_or_create_ctx(state) + local ctx = rawget(state, "wait_ctx") + + if not ctx then + ctx = setmetatable({}, wait_ctx_mt) + rawset(state, "wait_ctx", ctx) + end + + return ctx +end + + +---@param ctx spec.helpers.wait.ctx +---@param key string +---@param ... any +local function param(ctx, key, ...) + local value = first_non_nil(first_non_nil(...), DEFAULTS[key]) + ctx[key] = ctx:validate(key, value, "wait_until", 3) +end + + +---@param state table +---@param arguments table +---@param level integer +---@return boolean ok +---@return table return_values +local function wait_until(state, arguments, level) + assert(arguments.n > 0, + say(E_ARG_COUNT, { "wait_until", 1, arguments.n }), + level + 1) + + local input = check_type("table", 1, arguments[1]) + local ctx = get_or_create_ctx(state) + + param(ctx, "fn", input.fn) + param(ctx, "timeout", input.timeout) + param(ctx, "step", input.step) + param(ctx, "message", input.message, arguments[2]) + param(ctx, "max_tries", input.max_tries) + param(ctx, "debug", input.debug, ctx.debug, false) + param(ctx, "condition", input.condition) + param(ctx, "ignore_exceptions", input.ignore_exceptions) + + -- reset the state + rawset(state, "wait_ctx", nil) + + ctx:wait() + + if ctx.condition_met then + return true, { ctx.last_returned_value, n = 1 } + end + + local errors = {} + local result + if ctx.result == ERROR then + result = "error() raised" + + elseif ctx.result == MAX_TRIES then + result = ("max tries (%s) reached"):format(ctx.max_tries) + + elseif ctx.result == TIMEOUT then + result = ("timed out after %ss"):format(ctx.elapsed) + end + + if ctx.last_raised_error then + insert(errors, "Last raised error:") + insert(errors, "") + insert(errors, pretty.write(ctx.last_raised_error)) + insert(errors, "") + end + + if ctx.last_returned_error then + insert(errors, "Last returned error:") + insert(errors, "") + insert(errors, pretty.write(ctx.last_returned_error)) + insert(errors, "") + end + + arguments[1] = ctx.message + arguments[2] = result + arguments[3] = table.concat(errors, "\n") + arguments[4] = ctx.timeout + arguments[5] = ctx.step + arguments[6] = ctx.elapsed + arguments[7] = ctx.tries + arguments[8] = ctx.error_raised + arguments.n = 8 + + arguments.nofmt = {} + for i = 1, arguments.n do + arguments.nofmt[i] = true + end + + return false, { ctx.last_error, n = 1 } +end + + +say:set("assertion.wait_until.failed", [[ +Failed to assert eventual condition: + +%q + +Result: %s + +%s +--- + +Timeout = %s +Step = %s +Elapsed = %s +Tries = %s +Raised = %s +]]) + +luassert:register("assertion", "wait_until", wait_until, + "assertion.wait_until.failed") + + +local function wait_until_modifier(key) + return function(state, arguments) + local ctx = get_or_create_ctx(state) + ctx[key] = ctx:validate(key, arguments[1], key, 1) + + return state + end +end + +luassert:register("modifier", "with_timeout", + wait_until_modifier("timeout")) + +luassert:register("modifier", "with_step", + wait_until_modifier("step")) + +luassert:register("modifier", "with_max_tries", + wait_until_modifier("max_tries")) + +-- luassert blows up on us if we try to use 'error' or 'errors' +luassert:register("modifier", "ignore_exceptions", + wait_until_modifier("ignore_exceptions")) + + +---@param ctx spec.helpers.wait.ctx +local function ctx_builder(ctx) + local self = setmetatable({}, { + __index = function(_, key) + error("unknown modifier/assertion: " .. tostring(key), 2) + end + }) + + local function with(field) + return function(value) + ctx[field] = ctx:validate(field, value, "with_" .. field, 2) + return self + end + end + + self.with_timeout = with("timeout") + self.with_step = with("step") + self.with_max_tries = with("max_tries") + + self.ignore_exceptions = function(ignore) + ctx.ignore_exceptions = ctx:validate("ignore_exceptions", ignore, + "ignore_exceptions", 2) + return self + end + + self.is_truthy = function(msg) + ctx.condition = "truthy" + return luassert.wait_until(ctx, msg) + end + + self.is_falsy = function(msg) + ctx.condition = "falsy" + return luassert.wait_until(ctx, msg) + end + + self.has_error = function(msg) + ctx.condition = "error" + return luassert.wait_until(ctx, msg) + end + + self.has_no_error = function(msg) + ctx.condition = "no_error" + return luassert.wait_until(ctx, msg) + end + + return self +end + + +local function eventually(state, arguments) + local ctx = get_or_create_ctx(state) + + ctx.fn = first_non_nil(arguments[1], ctx.fn) + + check_type("function", 1, ctx.fn, "eventually") + + arguments[1] = ctx_builder(ctx) + arguments.n = 1 + + return true, arguments +end + +luassert:register("assertion", "eventually", eventually)