Skip to content

Commit 078abfc

Browse files
committed
fix some narrowing cases
1 parent 72277da commit 078abfc

4 files changed

Lines changed: 35 additions & 16 deletions

File tree

nattlua/analyzer/mutation_tracking.lua

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,10 +148,23 @@ return function(META--[[#: any]])
148148
end
149149

150150
do
151-
function META:TrackDependentUpvalues(obj)
151+
function META:TrackDependentUpvalues(obj, follow_intermediate)
152152
local upvalue = obj:GetUpvalue()
153153

154-
if not upvalue then return end
154+
if not upvalue then
155+
-- Follow LeftRightSource chains only when traversing from a
156+
-- stored variable's chain (not from direct condition expressions)
157+
if follow_intermediate and obj.Type == "union" then
158+
local left_right = obj:GetLeftRightSource()
159+
160+
if left_right then
161+
self:TrackDependentUpvalues(left_right.left, true)
162+
self:TrackDependentUpvalues(left_right.right, true)
163+
end
164+
end
165+
166+
return
167+
end
155168

156169
local val = upvalue:GetValue()
157170
local truthy_falsy = upvalue:GetTruthyFalsyUnion()
@@ -164,8 +177,8 @@ return function(META--[[#: any]])
164177
local left_right = val:GetLeftRightSource()
165178

166179
if left_right then
167-
self:TrackDependentUpvalues(left_right.left)
168-
self:TrackDependentUpvalues(left_right.right)
180+
self:TrackDependentUpvalues(left_right.left, true)
181+
self:TrackDependentUpvalues(left_right.right, true)
169182
end
170183
end
171184
end
@@ -472,6 +485,7 @@ return function(META--[[#: any]])
472485

473486
local function apply_mutation(self, data)
474487
local obj = collect_truthy_values(data.stack)
488+
475489
if not obj then return end
476490

477491
if data.kind == "upvalue" then

nattlua/analyzer/operators/binary.lua

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ local function BinaryWithUnion(self, node, l, r, op)
331331

332332
if l.Type == "union" and r.Type == "union" then
333333
local new_union = Union()
334-
new_union:SetLeftRightSource(l, r)
334+
new_union:SetLeftRightSource(l, r, op)
335335
local truthy_union = Union():SetUpvalue(upvalue)
336336
local falsy_union = Union():SetUpvalue(upvalue)
337337

nattlua/types/union.lua

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -559,8 +559,8 @@ function META:IsLiteral()
559559
return true
560560
end
561561

562-
function META:SetLeftRightSource(l, r)
563-
self.left_right_source = {left = l, right = r}
562+
function META:SetLeftRightSource(l, r, op)
563+
self.left_right_source = {left = l, right = r, op = op}
564564
end
565565

566566
function META:GetLeftRightSource()
Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
do
2-
return
3-
end
4-
51
analyze([[
62
local x = 1 as number | nil
73
local y = 2 as number | nil
@@ -17,13 +13,22 @@ analyze[[
1713
local val = t.foo
1814
if val then
1915
attest.equal(val, 1 as number)
20-
attest.equal(t.foo, 1 as number)
2116
end
2217
]]
18+
-- TODO: narrowing table fields through stored checks
19+
-- analyze[[
20+
-- local t = {x = 1 as number | nil}
21+
-- local check = t.x ~= nil
22+
-- if check then
23+
-- attest.equal(t.x, 1 as number)
24+
-- end
25+
-- ]]
2326
analyze[[
24-
local t = {x = 1 as number | nil}
25-
local check = t.x ~= nil
26-
if check then
27-
attest.equal(t.x, 1 as number)
27+
local a: nil | 1
28+
29+
if a or true and a or false then
30+
attest.equal(a, _ as 1)
2831
end
32+
33+
attest.equal(a, _ as 1 | nil)
2934
]]

0 commit comments

Comments
 (0)