Skip to content

Commit

Permalink
feat[venom]: improve liveness computation (vyperlang#4330)
Browse files Browse the repository at this point in the history
traversing in reverse topsort order improves stack scheduling slightly

this commit also adds a topsort method to CFGAnalysis, and speeds it up
by only checking the terminator instruction instead of iterating over
all the instructions in every basic block.

additional refactors:
- move dfs order calculation from domtree to cfg analysis.
- remove unnecessary calculation of domtree in sccp
- remove redundant IRFunction.compute_reachability
- change cfg_out order
- refactor shared phi fixup code
- remove useless `__eq__()` and `__hash__()` for IRBasicBlock
  • Loading branch information
charles-cooper authored Nov 12, 2024
1 parent fee16e6 commit 48cb39b
Show file tree
Hide file tree
Showing 12 changed files with 110 additions and 142 deletions.
4 changes: 0 additions & 4 deletions tests/functional/codegen/features/test_constructor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
import pytest

from tests.evm_backends.base_env import _compile
from vyper.exceptions import StackTooDeep
from vyper.utils import method_id


Expand Down Expand Up @@ -169,7 +166,6 @@ def get_foo() -> uint256:
assert c.get_foo() == 39


@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression")
def test_nested_dynamic_array_constructor_arg_2(env, get_contract):
code = """
foo: int128
Expand Down
2 changes: 0 additions & 2 deletions tests/functional/codegen/types/test_dynamic_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
CompilerPanic,
ImmutableViolation,
OverflowException,
StackTooDeep,
StateAccessViolation,
TypeMismatch,
)
Expand Down Expand Up @@ -737,7 +736,6 @@ def test_array_decimal_return3() -> DynArray[DynArray[decimal, 2], 2]:
]


@pytest.mark.venom_xfail(raises=StackTooDeep, reason="stack scheduler regression")
def test_mult_list(get_contract):
code = """
nest3: DynArray[DynArray[DynArray[uint256, 2], 2], 2]
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/compiler/venom/test_branch_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def test_simple_jump_case():
fn.append_basic_block(br2)

p1 = bb.append_instruction("param")
p2 = bb.append_instruction("param")
op1 = bb.append_instruction("store", p1)
op2 = bb.append_instruction("store", 64)
op3 = bb.append_instruction("add", op1, op2)
Expand All @@ -24,7 +25,7 @@ def test_simple_jump_case():

br1.append_instruction("add", op3, p1)
br1.append_instruction("stop")
br2.append_instruction("add", op3, 10)
br2.append_instruction("add", op3, p2)
br2.append_instruction("stop")

term_inst = bb.instructions[-1]
Expand Down
9 changes: 5 additions & 4 deletions tests/unit/compiler/venom/test_sccp.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ def test_cont_phi_case():
assert sccp.lattice[IRVariable("%2")].value == 32
assert sccp.lattice[IRVariable("%3")].value == 64
assert sccp.lattice[IRVariable("%4")].value == 96
assert sccp.lattice[IRVariable("%5", version=1)].value == 106
assert sccp.lattice[IRVariable("%5", version=2)] == LatticeEnum.BOTTOM
assert sccp.lattice[IRVariable("%5", version=2)].value == 106
assert sccp.lattice[IRVariable("%5", version=1)] == LatticeEnum.BOTTOM
assert sccp.lattice[IRVariable("%5")].value == 2


Expand Down Expand Up @@ -207,8 +207,9 @@ def test_cont_phi_const_case():
assert sccp.lattice[IRVariable("%2")].value == 32
assert sccp.lattice[IRVariable("%3")].value == 64
assert sccp.lattice[IRVariable("%4")].value == 96
assert sccp.lattice[IRVariable("%5", version=1)].value == 106
assert sccp.lattice[IRVariable("%5", version=2)].value == 97
# dependent on cfg traversal order
assert sccp.lattice[IRVariable("%5", version=2)].value == 106
assert sccp.lattice[IRVariable("%5", version=1)].value == 97
assert sccp.lattice[IRVariable("%5")].value == 2


Expand Down
47 changes: 34 additions & 13 deletions vyper/venom/analysis/cfg.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,62 @@
from typing import Iterator

from vyper.utils import OrderedSet
from vyper.venom.analysis import IRAnalysis
from vyper.venom.basicblock import CFG_ALTERING_INSTRUCTIONS
from vyper.venom.basicblock import CFG_ALTERING_INSTRUCTIONS, IRBasicBlock


class CFGAnalysis(IRAnalysis):
"""
Compute control flow graph information for each basic block in the function.
"""

_dfs: OrderedSet[IRBasicBlock]

def analyze(self) -> None:
fn = self.function
self._dfs = OrderedSet()

for bb in fn.get_basic_blocks():
bb.cfg_in = OrderedSet()
bb.cfg_out = OrderedSet()
bb.out_vars = OrderedSet()
bb.is_reachable = False

for bb in fn.get_basic_blocks():
assert len(bb.instructions) > 0, "Basic block should not be empty"
last_inst = bb.instructions[-1]
assert last_inst.is_bb_terminator, f"Last instruction should be a terminator {bb}"
assert bb.is_terminated

for inst in bb.instructions:
if inst.opcode in CFG_ALTERING_INSTRUCTIONS:
ops = inst.get_label_operands()
for op in ops:
fn.get_basic_block(op.value).add_cfg_in(bb)
term = bb.instructions[-1]
if term.opcode in CFG_ALTERING_INSTRUCTIONS:
ops = term.get_label_operands()
# order of cfg_out matters to performance!
for op in reversed(list(ops)):
next_bb = fn.get_basic_block(op.value)
bb.add_cfg_out(next_bb)
next_bb.add_cfg_in(bb)

# Fill in the "out" set for each basic block
for bb in fn.get_basic_blocks():
for in_bb in bb.cfg_in:
in_bb.add_cfg_out(bb)
self._compute_dfs_r(self.function.entry)

def _compute_dfs_r(self, bb):
if bb.is_reachable:
return
bb.is_reachable = True

for out_bb in bb.cfg_out:
self._compute_dfs_r(out_bb)

self._dfs.add(bb)

@property
def dfs_walk(self) -> Iterator[IRBasicBlock]:
return iter(self._dfs)

def invalidate(self):
from vyper.venom.analysis import DFGAnalysis, DominatorTreeAnalysis, LivenessAnalysis

self.analyses_cache.invalidate_analysis(DominatorTreeAnalysis)
self.analyses_cache.invalidate_analysis(LivenessAnalysis)

self._dfs = None

# be conservative - assume cfg invalidation invalidates dfg
self.analyses_cache.invalidate_analysis(DFGAnalysis)
29 changes: 9 additions & 20 deletions vyper/venom/analysis/dominators.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import cached_property

from vyper.exceptions import CompilerPanic
from vyper.utils import OrderedSet
from vyper.venom.analysis import CFGAnalysis, IRAnalysis
Expand All @@ -14,8 +16,6 @@ class DominatorTreeAnalysis(IRAnalysis):

fn: IRFunction
entry_block: IRBasicBlock
dfs_order: dict[IRBasicBlock, int]
dfs_walk: list[IRBasicBlock]
dominators: dict[IRBasicBlock, OrderedSet[IRBasicBlock]]
immediate_dominators: dict[IRBasicBlock, IRBasicBlock]
dominated: dict[IRBasicBlock, OrderedSet[IRBasicBlock]]
Expand All @@ -27,16 +27,13 @@ def analyze(self):
"""
self.fn = self.function
self.entry_block = self.fn.entry
self.dfs_order = {}
self.dfs_walk = []
self.dominators = {}
self.immediate_dominators = {}
self.dominated = {}
self.dominator_frontiers = {}

self.analyses_cache.request_analysis(CFGAnalysis)
self.cfg = self.analyses_cache.request_analysis(CFGAnalysis)

self._compute_dfs(self.entry_block, OrderedSet())
self._compute_dominators()
self._compute_idoms()
self._compute_df()
Expand Down Expand Up @@ -131,21 +128,13 @@ def _intersect(self, bb1, bb2):
bb2 = self.immediate_dominators[bb2]
return bb1

def _compute_dfs(self, entry: IRBasicBlock, visited):
"""
Depth-first search to compute the DFS order of the basic blocks. This
is used to compute the dominator tree. The sequence of basic blocks in
the DFS order is stored in `self.dfs_walk`. The DFS order of each basic
block is stored in `self.dfs_order`.
"""
visited.add(entry)

for bb in entry.cfg_out:
if bb not in visited:
self._compute_dfs(bb, visited)
@cached_property
def dfs_walk(self) -> list[IRBasicBlock]:
return list(self.cfg.dfs_walk)

self.dfs_walk.append(entry)
self.dfs_order[entry] = len(self.dfs_walk)
@cached_property
def dfs_order(self) -> dict[IRBasicBlock, int]:
return {bb: idx for idx, bb in enumerate(self.dfs_walk)}

def as_graph(self) -> str:
"""
Expand Down
14 changes: 7 additions & 7 deletions vyper/venom/analysis/liveness.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,21 @@ class LivenessAnalysis(IRAnalysis):
"""

def analyze(self):
self.analyses_cache.request_analysis(CFGAnalysis)
cfg = self.analyses_cache.request_analysis(CFGAnalysis)
self._reset_liveness()

self._worklist = deque()
self._worklist.extend(self.function.get_basic_blocks())
worklist = deque(cfg.dfs_walk)

while len(self._worklist) > 0:
while len(worklist) > 0:
changed = False
bb = self._worklist.popleft()

bb = worklist.popleft()
changed |= self._calculate_out_vars(bb)
changed |= self._calculate_liveness(bb)
# recompute liveness for basic blocks pointing into
# this basic block
if changed:
self._worklist.extend(bb.cfg_in)
worklist.extend(bb.cfg_in)

def _reset_liveness(self) -> None:
for bb in self.function.get_basic_blocks():
Expand Down Expand Up @@ -64,7 +64,7 @@ def _calculate_out_vars(self, bb: IRBasicBlock) -> bool:
bb.out_vars = OrderedSet()
for out_bb in bb.cfg_out:
target_vars = self.input_vars_from(bb, out_bb)
bb.out_vars = bb.out_vars.union(target_vars)
bb.out_vars.update(target_vars)
return out_vars != bb.out_vars

# calculate the input variables into self from source
Expand Down
37 changes: 26 additions & 11 deletions vyper/venom/basicblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,15 +185,6 @@ def __init__(self, value: str, is_symbol: bool = False) -> None:
self.value = value
self.is_symbol = is_symbol

def __eq__(self, other):
# no need for is_symbol to participate in equality
return super().__eq__(other)

def __hash__(self):
# __hash__ is required when __eq__ is overridden --
# https://docs.python.org/3/reference/datamodel.html#object.__hash__
return super().__hash__()


class IRInstruction:
"""
Expand Down Expand Up @@ -393,7 +384,6 @@ class IRBasicBlock:
# stack items which this basic block produces
out_vars: OrderedSet[IRVariable]

reachable: OrderedSet["IRBasicBlock"]
is_reachable: bool = False

def __init__(self, label: IRLabel, parent: "IRFunction") -> None:
Expand All @@ -404,7 +394,6 @@ def __init__(self, label: IRLabel, parent: "IRFunction") -> None:
self.cfg_in = OrderedSet()
self.cfg_out = OrderedSet()
self.out_vars = OrderedSet()
self.reachable = OrderedSet()
self.is_reachable = False

def add_cfg_in(self, bb: "IRBasicBlock") -> None:
Expand Down Expand Up @@ -495,6 +484,32 @@ def replace_operands(self, replacements: dict) -> None:
for instruction in self.instructions:
instruction.replace_operands(replacements)

def fix_phi_instructions(self):
cfg_in_labels = tuple(bb.label for bb in self.cfg_in)

needs_sort = False
for inst in self.instructions:
if inst.opcode != "phi":
continue

labels = inst.get_label_operands()
for label in labels:
if label not in cfg_in_labels:
needs_sort = True
inst.remove_phi_operand(label)

op_len = len(inst.operands)
if op_len == 2:
inst.opcode = "store"
inst.operands = [inst.operands[1]]
elif op_len == 0:
inst.opcode = "nop"
inst.output = None
inst.operands = []

if needs_sort:
self.instructions.sort(key=lambda inst: inst.opcode != "phi")

def get_assignments(self):
"""
Get all assignments in basic block.
Expand Down
58 changes: 14 additions & 44 deletions vyper/venom/function.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import Iterator, Optional

from vyper.codegen.ir_node import IRnode
from vyper.utils import OrderedSet
from vyper.venom.basicblock import CFG_ALTERING_INSTRUCTIONS, IRBasicBlock, IRLabel, IRVariable
from vyper.venom.basicblock import IRBasicBlock, IRLabel, IRVariable


class IRFunction:
Expand Down Expand Up @@ -89,60 +88,31 @@ def get_last_variable(self) -> str:
return f"%{self.last_variable}"

def remove_unreachable_blocks(self) -> int:
self._compute_reachability()
# Remove unreachable basic blocks
# pre: requires CFG analysis!
# NOTE: should this be a pass?

removed = []
removed = set()

# Remove unreachable basic blocks
for bb in self.get_basic_blocks():
if not bb.is_reachable:
removed.append(bb)
removed.add(bb)

for bb in removed:
self.remove_basic_block(bb)

# Remove phi instructions that reference removed basic blocks
for bb in removed:
for out_bb in bb.cfg_out:
out_bb.remove_cfg_in(bb)
for inst in out_bb.instructions:
if inst.opcode != "phi":
continue
in_labels = inst.get_label_operands()
if bb.label in in_labels:
inst.remove_phi_operand(bb.label)
op_len = len(inst.operands)
if op_len == 2:
inst.opcode = "store"
inst.operands = [inst.operands[1]]
elif op_len == 0:
out_bb.remove_instruction(inst)

return len(removed)

def _compute_reachability(self) -> None:
"""
Compute reachability of basic blocks.
"""
for bb in self.get_basic_blocks():
bb.reachable = OrderedSet()
bb.is_reachable = False
for in_bb in list(bb.cfg_in):
if in_bb not in removed:
continue

self._compute_reachability_from(self.entry)
bb.remove_cfg_in(in_bb)

def _compute_reachability_from(self, bb: IRBasicBlock) -> None:
"""
Compute reachability of basic blocks from bb.
"""
if bb.is_reachable:
return
bb.is_reachable = True
for inst in bb.instructions:
if inst.opcode in CFG_ALTERING_INSTRUCTIONS:
for op in inst.get_label_operands():
out_bb = self.get_basic_block(op.value)
bb.reachable.add(out_bb)
self._compute_reachability_from(out_bb)
# TODO: only run this if cfg_in changed
bb.fix_phi_instructions()

return len(removed)

@property
def normalized(self) -> bool:
Expand Down
Loading

0 comments on commit 48cb39b

Please sign in to comment.