Skip to content

Commit

Permalink
Infer the parameter types of a same-named function in the subclass ba…
Browse files Browse the repository at this point in the history
…sed on the parameter types in the superclass function.
  • Loading branch information
sumneko committed Sep 6, 2024
1 parent 1ea4c04 commit 7c481f5
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 17 deletions.
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
<!-- Add all new changes here. They will be moved under a version at release -->
* `NEW` Custom documentation exporter
* `NEW` Setting: `Lua.docScriptPath`: Path to a script that overrides `cli.doc.export`, allowing user-specified documentation exporting.
* `NEW` Infer the parameter types of a same-named function in the subclass based on the parameter types in the superclass function.
* `FIX` Fix `VM.OnCompileFunctionParam` function in plugins
* `FIX` Lua 5.1: fix incorrect warning when using setfenv with an int as first parameter
* `FIX` Improve type narrow by checking exact match on literal type params
Expand Down
70 changes: 53 additions & 17 deletions script/vm/compiler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -1088,21 +1088,26 @@ end
---@param func parser.object
---@param source parser.object
local function compileFunctionParam(func, source)
local aindex
for index, arg in ipairs(func.args) do
if arg == source then
aindex = index
break
end
end
---@cast aindex integer

-- local call ---@type fun(f: fun(x: number));call(function (x) end) --> x -> number
local funcNode = vm.compileNode(func)
for n in funcNode:eachObject() do
if n.type == 'doc.type.function' then
for index, arg in ipairs(n.args) do
if func.args[index] == source then
local argNode = vm.compileNode(arg)
for an in argNode:eachObject() do
if an.type ~= 'doc.generic.name' then
vm.setNode(source, an)
end
end
return true
local argNode = vm.compileNode(n.args[aindex])
for an in argNode:eachObject() do
if an.type ~= 'doc.generic.name' then
vm.setNode(source, an)
end
end
return true
end
end

Expand All @@ -1118,19 +1123,50 @@ local function compileFunctionParam(func, source)
if not caller.args then
goto continue
end
for index, arg in ipairs(source.parent) do
if arg == source then
local callerArg = caller.args[index]
if callerArg then
vm.setNode(source, vm.compileNode(callerArg))
found = true
end
end
local callerArg = caller.args[aindex]
if callerArg then
vm.setNode(source, vm.compileNode(callerArg))
found = true
end
::continue::
end
return found
end

do
local parent = func.parent
local key = vm.getKeyName(parent)
local classDef = vm.getParentClass(parent)
local suri = guide.getUri(func)
if classDef and key then
local found
for _, set in ipairs(classDef:getSets(suri)) do
if set.type == 'doc.class' and set.extends then
for _, ext in ipairs(set.extends) do
local extClass = vm.getGlobal('type', ext[1])
if extClass then
vm.getClassFields(suri, extClass, key, function (field, isMark)
for n in vm.compileNode(field):eachObject() do
if n.type == 'function' then
local argNode = vm.compileNode(n.args[aindex])
for an in argNode:eachObject() do
if an.type ~= 'doc.generic.name' then
vm.setNode(source, an)
found = true
end
end
end
end
end)
end
end
end
end
if found then
return true
end
end
end
end

---@param source parser.object
Expand Down
13 changes: 13 additions & 0 deletions test/type_inference/common.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4428,3 +4428,16 @@ TEST 'A' [[
local x
local <?y?> = 1 >> x
]]

TEST 'number' [[
---@class A
local A = {}
---@param x number
function A:func(x) end
---@class B: A
local B = {}
function B:func(<?x?>) end
]]

0 comments on commit 7c481f5

Please sign in to comment.