From 154373762723ce5787c9d6a34d4f6ee460621864 Mon Sep 17 00:00:00 2001 From: Lewis Russell Date: Wed, 25 Sep 2024 15:37:00 +0100 Subject: [PATCH] fix: type narrow on fields with multiple literals --- script/vm/tracer.lua | 14 ++++++++---- test/type_inference/common.lua | 42 ++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 5 deletions(-) diff --git a/script/vm/tracer.lua b/script/vm/tracer.lua index ef94cc9e8..bb701fe60 100644 --- a/script/vm/tracer.lua +++ b/script/vm/tracer.lua @@ -261,7 +261,7 @@ end --- @param source parser.object --- @param fieldName string --- @param literal parser.object ---- @return string[]? +--- @return [string, boolean][]? local function getNodeTypesWithLiteralField(uri, source, fieldName, literal) local loc = vm.getVariable(source) if not loc then @@ -279,7 +279,7 @@ local function getNodeTypesWithLiteralField(uri, source, fieldName, literal) for _, t in ipairs(f.extends.types) do if t[1] == literal[1] then tys = tys or {} - table.insert(tys, set.class[1]) + table.insert(tys, {set.class[1], #f.extends.types > 1}) break end end @@ -682,15 +682,19 @@ local lookIntoChild = util.switch() -- TODO: handle more types if tys and #tys == 1 then - local ty = tys[1] + -- If the type is in a union (e.g. 'lit' | foo), then the type + -- cannot be removed from the node. + local ty, tyInUnion = tys[1][1], tys[1][2] topNode = topNode:copy() if action.op.type == '==' then topNode:narrow(tracer.uri, ty) - if outNode then + if not tyInUnion and outNode then outNode:remove(ty) end else - topNode:remove(ty) + if not tyInUnion then + topNode:remove(ty) + end if outNode then outNode:narrow(tracer.uri, ty) end diff --git a/test/type_inference/common.lua b/test/type_inference/common.lua index 13b0f61ec..792967722 100644 --- a/test/type_inference/common.lua +++ b/test/type_inference/common.lua @@ -4529,3 +4529,45 @@ if obj.type == 'a' then local = obj end ]] + +TEST 'A|B' [[ +--- @class A +--- @field mode? 'a' | 'b' + +--- @class B + +local a --- @type A | B + +if a.mode == 'a' then + local b = a +else + local = a +end +]] + +TEST 'A|B' [[ +--- @class A +--- @field mode? 'a' | 'b' + +--- @class B + +local a --- @type A | B + +if a.mode ~= 'a' then + local = a +end +]] + +TEST 'A' [[ +--- @class A +--- @field mode? 'a' | 'b' + +--- @class B + +local a --- @type A | B + +if a.mode ~= 'a' then +else + local = a +end +]]