Skip to content

Commit

Permalink
feat(run): augment args
Browse files Browse the repository at this point in the history
Allow users to augment the arguments to all tests being run from a
singular function.

```lua
local nio = require("nio")
neotest.setup({
  run = {
    augment = function(tree, args)
      local name = nio.ui.input({ prompt = "What is your name?" })

      args.env = { USER_NAME = name }

      return args
    end,
  },
})
```

See nvim-neotest#431
  • Loading branch information
rcarriga committed Jul 13, 2024
1 parent 0fe9186 commit 32ff2ac
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 2 deletions.
9 changes: 9 additions & 0 deletions doc/neotest.txt
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ Default values:
elixir = <function 1>,
go = " ;query\n ;Captures imported types\n (qualified_type name: (type_identifier) @symbol)\n ;Captures package-local and built-in types\n (type_identifier)@symbol\n ;Captures imported function calls and variables/constants\n (selector_expression field: (field_identifier) @symbol)\n ;Captures package-local functions calls\n (call_expression function: (identifier) @symbol)\n ",
haskell = " ;query\n ;explicit import\n ((import_item [(variable)]) @symbol)\n ;symbols that may be imported implicitly\n ((type) @symbol)\n (qualified_variable (variable) @symbol)\n (exp_apply (exp_name (variable) @symbol))\n ((constructor) @symbol)\n ((operator) @symbol)\n ",
java = " ;query\n ;captures imported classes\n (import_declaration\n (scoped_identifier name: ((identifier) @symbol))\n )\n ",
javascript = ' ;query\n ;Captures named imports\n (import_specifier name: (identifier) @symbol)\n ;Captures default import\n (import_clause (identifier) @symbol)\n ;Capture require statements\n (variable_declarator \n name: (identifier) @symbol\n value: (call_expression (identifier) @function (#eq? @function "require")))\n ;Capture namespace imports\n (namespace_import (identifier) @symbol)\n',
lua = ' ;query\n ;Captures module names in require calls\n (function_call\n name: ((identifier) @function (#eq? @function "require"))\n arguments: (arguments (string) @symbol))\n ',
python = " ;query\n ;Captures imports and modules they're imported from\n (import_from_statement (_ (identifier) @symbol))\n (import_statement (_ (identifier) @symbol))\n ",
Expand Down Expand Up @@ -234,6 +235,7 @@ Fields~
{highlights} `(table<string, string>)`
{floating} `(neotest.Config.floating)`
{strategies} `(neotest.Config.strategies)`
{run} `(neotest.Config.run)`
{summary} `(neotest.Config.summary)`
{output} `(neotest.Config.output)`
{output_panel} `(neotest.Config.output_panel)`
Expand Down Expand Up @@ -280,6 +282,13 @@ Fields~
Fields~
{integrated} `(neotest.Config.strategies.integrated)`

*neotest.Config.run*
Fields~
{enabled} `(boolean)`
{augment?} `(fun(tree: neotest.Tree, arg:
neotest.run.RunArgs):neotest.run.RunArgs)` A function to augment the arguments
any tests being run

*neotest.Config.summary*
Fields~
{enabled} `(boolean)`
Expand Down
5 changes: 5 additions & 0 deletions lua/neotest/config/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ local js_watch_query = [[
---@field highlights table<string, string>
---@field floating neotest.Config.floating
---@field strategies neotest.Config.strategies
---@field run neotest.Config.run
---@field summary neotest.Config.summary
---@field output neotest.Config.output
---@field output_panel neotest.Config.output_panel
Expand Down Expand Up @@ -87,6 +88,10 @@ local js_watch_query = [[
---@class neotest.Config.strategies
---@field integrated neotest.Config.strategies.integrated

---@class neotest.Config.run
---@field enabled boolean
---@field augment? fun(tree: neotest.Tree, arg: neotest.run.RunArgs):neotest.run.RunArgs A function to augment the arguments any tests being run

---@class neotest.Config.summary
---@field enabled boolean
---@field animated boolean Enable/disable animation of icons
Expand Down
18 changes: 16 additions & 2 deletions lua/neotest/consumers/run.lua
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
local nio = require("nio")
local lib = require("neotest.lib")
local config = require("neotest.config")

---@private
---@type neotest.Client
Expand Down Expand Up @@ -46,6 +47,17 @@ end
---@field [1] string? Position ID to run
---@field suite boolean Run the entire suite instead of a single position

local function augment_args(tree, args)
args = type(args) == "string" and { args } or args
args = args or {}
local aug = config.run.augment
if not aug then
return args
end
nio.scheduler()
return aug(tree, args)
end

--- Run the given position or the nearest position if not given.
--- All arguments are optional
---
Expand All @@ -70,7 +82,7 @@ function neotest.run.run(args)
lib.notify("No tests found")
return
end
client:run_tree(tree, type(args) == "string" and { args } or args)
client:run_tree(tree, augment_args(tree, args))
end

neotest.run.run = nio.create(neotest.run.run, 1)
Expand Down Expand Up @@ -103,7 +115,7 @@ function neotest.run.run_last(args)
lib.notify("Last test run no longer exists")
return
end
client:run_tree(tree, args)
client:run_tree(tree, augment_args(tree, args))
end)
end

Expand Down Expand Up @@ -147,6 +159,7 @@ function neotest.run.stop(args)
end
client:stop(pos, args)
end

neotest.run.stop = nio.create(neotest.run.stop, 1)

---@class neotest.run.AttachArgs : neotest.client.AttachArgs
Expand All @@ -173,6 +186,7 @@ function neotest.run.attach(args)
end
client:attach(pos, args)
end

neotest.run.attach = nio.create(neotest.run.attach, 1)

--- Get the list of all known adapter IDs.
Expand Down

0 comments on commit 32ff2ac

Please sign in to comment.