diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 7de30982..d2c2e7e6 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -38,3 +38,25 @@ jobs: shell: bash -l {0} run: | coverage report + + pre-commit-hook: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Setup Miniconda + uses: conda-incubator/setup-miniconda@v2 + with: + auto-update-conda: true + auto-activate-base: false + + - name: Install dependencies + shell: bash -l {0} + run: | + conda install python=3.11 pre-commit pyyaml graphviz + + - name: Run pre-commit + shell: bash -l {0} + run: | + pre-commit run -a diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 77fd0503..da51c7d3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -39,5 +39,8 @@ repos: hooks: - id: mypy additional_dependencies: + - types-pyyaml - types-filelock - types-setuptools + entry: mypy + diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 00000000..a1f460d9 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,12 @@ +# Global options: + +[mypy] +warn_unused_configs = True +follow_imports = silent +show_error_context = True +namespace_packages = True +strict = True +files = + numba_rvsdg +exclude = + numba_rvsdg/tests|conf.py diff --git a/numba_rvsdg/core/datastructures/basic_block.py b/numba_rvsdg/core/datastructures/basic_block.py index 393b1057..758d932d 100644 --- a/numba_rvsdg/core/datastructures/basic_block.py +++ b/numba_rvsdg/core/datastructures/basic_block.py @@ -1,6 +1,6 @@ import dis -from typing import Tuple, Dict, List -from dataclasses import dataclass, replace +from typing import Tuple, Dict, List, Optional +from dataclasses import dataclass, replace, field from numba_rvsdg.core.utils import _next_inst_offset from numba_rvsdg.core.datastructures import block_names @@ -26,9 +26,9 @@ class BasicBlock: name: str - _jump_targets: Tuple[str] = tuple() + _jump_targets: Tuple[str, ...] = tuple() - backedges: Tuple[str] = tuple() + backedges: Tuple[str, ...] = tuple() @property def is_exiting(self) -> bool: @@ -58,7 +58,7 @@ def fallthrough(self) -> bool: return len(self._jump_targets) == 1 @property - def jump_targets(self) -> Tuple[str]: + def jump_targets(self) -> Tuple[str, ...]: """Retrieves the jump targets for this block, excluding any jump targets that are also backedges. @@ -94,7 +94,9 @@ def declare_backedge(self, target: str) -> "BasicBlock": return replace(self, backedges=(target,)) return self - def replace_jump_targets(self, jump_targets: Tuple) -> "BasicBlock": + def replace_jump_targets( + self, jump_targets: Tuple[str, ...] + ) -> "BasicBlock": """Replaces jump targets of this block by the given tuple. This method replaces the jump targets of the current BasicBlock. @@ -118,7 +120,7 @@ def replace_jump_targets(self, jump_targets: Tuple) -> "BasicBlock": """ return replace(self, _jump_targets=jump_targets) - def replace_backedges(self, backedges: Tuple) -> "BasicBlock": + def replace_backedges(self, backedges: Tuple[str, ...]) -> "BasicBlock": """Replaces back edges of this block by the given tuple. This method replaces the back edges of the current BasicBlock. @@ -153,9 +155,9 @@ class PythonBytecodeBlock(BasicBlock): The bytecode offset immediately after the last bytecode of the block. """ - begin: int = None + begin: int = -1 - end: int = None + end: int = -1 def get_instructions( self, bcmap: Dict[int, dis.Instruction] @@ -234,8 +236,8 @@ class SyntheticFill(SyntheticBlock): @dataclass(frozen=True) class SyntheticAssignment(SyntheticBlock): - """The SyntheticAssignment class represents a artificially added assignment block - in a structured control flow graph (SCFG). + """The SyntheticAssignment class represents a artificially added + assignment block in a structured control flow graph (SCFG). This block is responsible for giving variables their values, once the respective block is executed. @@ -248,7 +250,7 @@ class SyntheticAssignment(SyntheticBlock): the block is executed. """ - variable_assignment: dict = None + variable_assignment: Dict[str, int] = field(default_factory=lambda: {}) @dataclass(frozen=True) @@ -266,10 +268,12 @@ class SyntheticBranch(SyntheticBlock): to be executed on the basis of that value. """ - variable: str = None - branch_value_table: dict = None + variable: str = "" + branch_value_table: Dict[int, str] = field(default_factory=lambda: {}) - def replace_jump_targets(self, jump_targets: Tuple) -> "BasicBlock": + def replace_jump_targets( + self, jump_targets: Tuple[str, ...] + ) -> "BasicBlock": """Replaces jump targets of this block by the given tuple. This method replaces the jump targets of the current BasicBlock. @@ -360,13 +364,13 @@ class RegionBlock(BasicBlock): The exiting node of the region. """ - kind: str = None - parent_region: "RegionBlock" = None - header: str = None - subregion: "SCFG" = None # noqa - exiting: str = None + kind: Optional[str] = None + parent_region: Optional["RegionBlock"] = None + header: Optional[str] = None + subregion: Optional["SCFG"] = None # type: ignore # noqa + exiting: Optional[str] = None - def replace_header(self, new_header): + def replace_header(self, new_header: str) -> None: """This method performs a inplace replacement of the header block. Parameters @@ -376,7 +380,7 @@ def replace_header(self, new_header): """ object.__setattr__(self, "header", new_header) - def replace_exiting(self, new_exiting): + def replace_exiting(self, new_exiting: str) -> None: """This method performs a inplace replacement of the header block. Parameters diff --git a/numba_rvsdg/core/datastructures/byte_flow.py b/numba_rvsdg/core/datastructures/byte_flow.py index 9b1ebb04..fb08a799 100644 --- a/numba_rvsdg/core/datastructures/byte_flow.py +++ b/numba_rvsdg/core/datastructures/byte_flow.py @@ -1,6 +1,7 @@ import dis from copy import deepcopy from dataclasses import dataclass +from typing import Generator, Callable from numba_rvsdg.core.datastructures.scfg import SCFG from numba_rvsdg.core.datastructures.basic_block import RegionBlock @@ -30,10 +31,10 @@ class ByteFlow: """ bc: dis.Bytecode - scfg: "SCFG" + scfg: SCFG @staticmethod - def from_bytecode(code) -> "ByteFlow": + def from_bytecode(code: Callable) -> "ByteFlow": # type: ignore """Creates a ByteFlow object from the given python function. @@ -54,13 +55,13 @@ def from_bytecode(code) -> "ByteFlow": The resulting ByteFlow object. """ bc = dis.Bytecode(code) - _logger.debug("Bytecode\n%s", _LogWrap(lambda: bc.dis())) + _logger.debug("Bytecode\n%s", _LogWrap(lambda: bc.dis())) # type: ignore # noqa E501 flowinfo = FlowInfo.from_bytecode(bc) scfg = flowinfo.build_basicblocks() return ByteFlow(bc=bc, scfg=scfg) - def _join_returns(self): + def _join_returns(self) -> "ByteFlow": """Joins the return blocks within the corresponding SCFG. This method creates a deep copy of the SCFG and performs @@ -76,7 +77,7 @@ def _join_returns(self): scfg.join_returns() return ByteFlow(bc=self.bc, scfg=scfg) - def _restructure_loop(self): + def _restructure_loop(self) -> "ByteFlow": """Restructures the loops within the corresponding SCFG. Creates a deep copy of the SCFG and performs the operation to @@ -97,7 +98,7 @@ def _restructure_loop(self): restructure_loop(region) return ByteFlow(bc=self.bc, scfg=scfg) - def _restructure_branch(self): + def _restructure_branch(self) -> "ByteFlow": """Restructures the branches within the corresponding SCFG. Creates a deep copy of the SCFG and performs the operation to @@ -117,7 +118,7 @@ def _restructure_branch(self): restructure_branch(region) return ByteFlow(bc=self.bc, scfg=scfg) - def restructure(self): + def restructure(self) -> "ByteFlow": """Applies join_returns, restructure_loop and restructure_branch in the respective order on the SCFG. @@ -146,8 +147,9 @@ def restructure(self): return ByteFlow(bc=self.bc, scfg=scfg) -def _iter_subregions(scfg: "SCFG"): +def _iter_subregions(scfg: SCFG) -> Generator[RegionBlock, SCFG, None]: for node in scfg.graph.values(): if isinstance(node, RegionBlock): yield node + assert node.subregion is not None yield from _iter_subregions(node.subregion) diff --git a/numba_rvsdg/core/datastructures/flow_info.py b/numba_rvsdg/core/datastructures/flow_info.py index 562bf864..a6af42d3 100644 --- a/numba_rvsdg/core/datastructures/flow_info.py +++ b/numba_rvsdg/core/datastructures/flow_info.py @@ -1,6 +1,6 @@ import dis -from typing import Set, Tuple, Dict, Sequence +from typing import Set, Tuple, Dict, Sequence, Optional from dataclasses import dataclass, field from numba_rvsdg.core.datastructures.basic_block import PythonBytecodeBlock @@ -36,7 +36,7 @@ class FlowInfo: last_offset: int = field(default=0) - def _add_jump_inst(self, offset: int, targets: Sequence[int]): + def _add_jump_inst(self, offset: int, targets: Sequence[int]) -> None: """Internal method to add a jump instruction to the FlowInfo. This method adds the target offsets of the jump instruction @@ -93,7 +93,9 @@ def from_bytecode(bc: dis.Bytecode) -> "FlowInfo": flowinfo.last_offset = inst.offset return flowinfo - def build_basicblocks(self: "FlowInfo", end_offset=None) -> "SCFG": + def build_basicblocks( + self: "FlowInfo", end_offset: Optional[int] = None + ) -> "SCFG": """Builds a graph of basic blocks based on the flow information. It creates a structured control flow graph (SCFG) object, assigns diff --git a/numba_rvsdg/core/datastructures/scfg.py b/numba_rvsdg/core/datastructures/scfg.py index 4e5276d0..5b4d15e2 100644 --- a/numba_rvsdg/core/datastructures/scfg.py +++ b/numba_rvsdg/core/datastructures/scfg.py @@ -1,10 +1,20 @@ import dis import yaml +from typing import ( + Any, + Set, + Tuple, + Dict, + List, + Iterator, + Optional, + Generator, + Mapping, + Sized, +) from textwrap import indent -from typing import Set, Tuple, Dict, List, Iterator from dataclasses import dataclass, field from collections import deque -from collections.abc import Mapping from numba_rvsdg.core.datastructures.basic_block import ( BasicBlock, @@ -137,7 +147,7 @@ def new_var_name(self, kind: str) -> str: @dataclass(frozen=True) -class SCFG: +class SCFG(Sized): """SCFG (Structured Control Flow Graph) class. The SCFG class represents a map of names to blocks within the control @@ -163,7 +173,7 @@ class SCFG: # This is the top-level region that this SCFG represents. region: RegionBlock = field(init=False, compare=False) - def __post_init__(self): + def __post_init__(self) -> None: name = self.name_gen.new_region_name("meta") new_region = RegionBlock( name=name, @@ -175,7 +185,7 @@ def __post_init__(self): ) object.__setattr__(self, "region", new_region) - def __getitem__(self, index): + def __getitem__(self, index: str) -> BasicBlock: """Access a block from the graph dictionary using the block name. Parameters @@ -190,7 +200,7 @@ def __getitem__(self, index): """ return self.graph[index] - def __contains__(self, index): + def __contains__(self, index: str) -> bool: """Checks if the given index exists in the graph dictionary. Parameters @@ -206,7 +216,15 @@ def __contains__(self, index): """ return index in self.graph - def __iter__(self): + def __len__(self) -> int: + """ + Returns + ------- + Number of nodes in the graph + """ + return len(self.graph) + + def __iter__(self) -> Generator[Tuple[str, BasicBlock], None, None]: """Returns an iterator over the blocks in the SCFG. Returns an iterator that yields the names and corresponding blocks @@ -242,12 +260,13 @@ def __iter__(self): # If this is a region, recursively yield everything from that # specific region. if type(block) == RegionBlock: + assert block.subregion is not None yield from block.subregion # finally add any jump_targets to the list of names to visit to_visit.extend(block.jump_targets) @property - def concealed_region_view(self): + def concealed_region_view(self) -> "ConcealedRegionView": """A property that returns a ConcealedRegionView object, representing a concealed view of the control flow graph. @@ -317,22 +336,22 @@ def compute_scc(self) -> List[Set[str]]: from numba_rvsdg.networkx_vendored.scc import scc class GraphWrap: - def __init__(self, graph): + def __init__(self, graph: Mapping[str, BasicBlock]) -> None: self.graph = graph - def __getitem__(self, vertex): + def __getitem__(self, vertex: str) -> List[str]: out = self.graph[vertex].jump_targets # Exclude node outside of the subgraph return [k for k in out if k in self.graph] - def __iter__(self): + def __iter__(self) -> Iterator[str]: return iter(self.graph.keys()) - return list(scc(GraphWrap(self.graph))) + return list(scc(GraphWrap(self.graph))) # type: ignore def find_headers_and_entries( self, subgraph: Set[str] - ) -> Tuple[Set[str], Set[str]]: + ) -> Tuple[List[str], List[str]]: """Finds entries and headers in a given subgraph. Entries are blocks outside the subgraph that have an edge pointing to @@ -372,6 +391,8 @@ def find_headers_and_entries( # to it's parent region block's graph. if self.region.kind != "meta": parent_region = self.region.parent_region + assert parent_region is not None + assert parent_region.subregion is not None _, entries = parent_region.subregion.find_headers_and_entries( {self.region.name} ) @@ -379,7 +400,7 @@ def find_headers_and_entries( def find_exiting_and_exits( self, subgraph: Set[str] - ) -> Tuple[Set[str], Set[str]]: + ) -> Tuple[List[str], List[str]]: """Finds exiting and exit blocks in a given subgraph. Existing blocks are blocks inside the subgraph that have edges to @@ -412,7 +433,7 @@ def find_exiting_and_exits( exiting.add(inside) return sorted(exiting), sorted(exits) - def is_reachable_dfs(self, begin: str, end: str): # -> TypeGuard: + def is_reachable_dfs(self, begin: str, end: str) -> bool: """Checks if the end block is reachable from the begin block in the SCFG. @@ -451,7 +472,7 @@ def is_reachable_dfs(self, begin: str, end: str): # -> TypeGuard: if block in self.graph: to_vist.extend(self.graph[block].jump_targets) - def add_block(self, basic_block: BasicBlock): + def add_block(self, basic_block: BasicBlock) -> None: """Adds a BasicBlock object to the control flow graph. Parameters @@ -461,7 +482,7 @@ def add_block(self, basic_block: BasicBlock): """ self.graph[basic_block.name] = basic_block - def remove_blocks(self, names: Set[str]): + def remove_blocks(self, names: Set[str]) -> None: """Removes a BasicBlock object from the control flow graph. Parameters @@ -475,10 +496,10 @@ def remove_blocks(self, names: Set[str]): def insert_block( self, new_name: str, - predecessors: Set[str], - successors: Set[str], - block_type: SyntheticBlock, - ): + predecessors: List[str], + successors: List[str], + block_type: type[SyntheticBlock], + ) -> None: """Inserts a new synthetic block into the SCFG between the given successors and predecessors. @@ -494,11 +515,11 @@ def insert_block( ---------- new_name: str The name of the newly created block. - predecessors: Set[str] - The set of names of BasicBlock that act as predecessors + predecessors: List[str] + The list of names of BasicBlock that act as predecessors for the block to be inserted. - successors: Set[str] - The set of names of BasicBlock that act as successors + successors: List[str] + The list of names of BasicBlock that act as successors for the block to be inserted. block_type: SyntheticBlock The type/class of the newly created block. @@ -506,7 +527,7 @@ def insert_block( # TODO: needs a diagram and documentaion # initialize new block new_block = block_type( - name=new_name, _jump_targets=successors, backedges=set() + name=new_name, _jump_targets=tuple(successors), backedges=tuple() ) # add block to self self.add_block(new_block) @@ -529,9 +550,9 @@ def insert_block( def insert_SyntheticExit( self, new_name: str, - predecessors: Set[str], - successors: Set[str], - ): + predecessors: List[str], + successors: List[str], + ) -> None: """Inserts a synthetic exit block into the SCFG. Parameters same as insert_block method. @@ -544,9 +565,9 @@ def insert_SyntheticExit( def insert_SyntheticTail( self, new_name: str, - predecessors: Set[str], - successors: Set[str], - ): + predecessors: List[str], + successors: List[str], + ) -> None: """Inserts a synthetic tail block into the SCFG. Parameters same as insert_block method. @@ -559,9 +580,9 @@ def insert_SyntheticTail( def insert_SyntheticReturn( self, new_name: str, - predecessors: Set[str], - successors: Set[str], - ): + predecessors: List[str], + successors: List[str], + ) -> None: """Inserts a synthetic return block into the SCFG. Parameters same as insert_block method. @@ -574,9 +595,9 @@ def insert_SyntheticReturn( def insert_SyntheticFill( self, new_name: str, - predecessors: Set[str], - successors: Set[str], - ): + predecessors: List[str], + successors: List[str], + ) -> None: """Inserts a synthetic fill block into the SCFG. Parameters same as insert_block method. @@ -587,8 +608,8 @@ def insert_SyntheticFill( self.insert_block(new_name, predecessors, successors, SyntheticFill) def insert_block_and_control_blocks( - self, new_name: str, predecessors: Set[str], successors: Set[str] - ): + self, new_name: str, predecessors: List[str], successors: List[str] + ) -> None: """Inserts a new block along with control blocks into the SCFG. This method is used for branching assignments. Parameters same as insert_block method. @@ -640,14 +661,14 @@ def insert_block_and_control_blocks( new_block = SyntheticHead( name=new_name, _jump_targets=tuple(successors), - backedges=set(), + backedges=tuple(), variable=branch_variable, branch_value_table=branch_value_table, ) # add block to self self.add_block(new_block) - def join_returns(self): + def join_returns(self) -> None: """Close the CFG. A closed CFG is a CFG with a unique entry and exit node that have no @@ -660,19 +681,19 @@ def join_returns(self): # close if more than one is found if len(return_nodes) > 1: return_solo_name = self.name_gen.new_block_name(SYNTH_RETURN) - self.insert_SyntheticReturn( - return_solo_name, return_nodes, tuple() - ) + self.insert_SyntheticReturn(return_solo_name, return_nodes, []) - def join_tails_and_exits(self, tails: Set[str], exits: Set[str]): + def join_tails_and_exits( + self, tails: List[str], exits: List[str] + ) -> Tuple[str, str]: """Joins the tails and exits of the SCFG. Parameters ---------- - tails: Set[str] - The set of names of BasicBlock that act as tails in the SCFG. - exits: Set[str] - The set of names of BasicBlock that act as exits in the SCFG. + tails: List[str] + The list of names of BasicBlock that act as tails in the SCFG. + exits: List[str] + The list of names of BasicBlock that act as exits in the SCFG. Return ------ @@ -706,11 +727,13 @@ def join_tails_and_exits(self, tails: Set[str], exits: Set[str]): solo_tail_name = self.name_gen.new_block_name(SYNTH_TAIL) solo_exit_name = self.name_gen.new_block_name(SYNTH_EXIT) self.insert_SyntheticTail(solo_tail_name, tails, exits) - self.insert_SyntheticExit(solo_exit_name, {solo_tail_name}, exits) + self.insert_SyntheticExit(solo_exit_name, [solo_tail_name], exits) return solo_tail_name, solo_exit_name + assert False, "unreachable" + @staticmethod - def bcmap_from_bytecode(bc: dis.Bytecode): + def bcmap_from_bytecode(bc: dis.Bytecode) -> Dict[int, dis.Instruction]: """Static method that creates a bytecode map from a `dis.Bytecode` object. @@ -727,7 +750,7 @@ def bcmap_from_bytecode(bc: dis.Bytecode): """ return {inst.offset: inst for inst in bc} - def view(self, name: str = None): + def view(self, name: Optional[str] = None) -> None: """View the current SCFG as a external PDF file. This method internally creates a SCFGRenderer corresponding to @@ -744,7 +767,7 @@ def view(self, name: str = None): SCFGRenderer(self).view(name) @staticmethod - def from_yaml(yaml_string: str): + def from_yaml(yaml_string: str) -> "Tuple[SCFG, Dict[str, str]]": """Static method that creates an SCFG object from a YAML representation. @@ -776,7 +799,9 @@ def from_yaml(yaml_string: str): return SCFGIO.from_yaml(yaml_string) @staticmethod - def from_dict(graph_dict: dict): + def from_dict( + graph_dict: Dict[str, Dict[str, List[str]]] + ) -> Tuple["SCFG", Dict[str, str]]: """Static method that creates an SCFG object from a dictionary representation. @@ -808,7 +833,7 @@ def from_dict(graph_dict: dict): """ return SCFGIO.from_dict(graph_dict) - def to_yaml(self): + def to_yaml(self) -> str: """Converts the SCFG object to a YAML string representation. The method returns a YAML string representing the control @@ -830,7 +855,7 @@ def to_yaml(self): """ return SCFGIO.to_yaml(self) - def to_dict(self): + def to_dict(self) -> Dict[str, Dict[str, Any]]: """Converts the SCFG object to a dictionary representation. This method returns a dictionary representing the control flow @@ -859,7 +884,7 @@ class SCFGIO: """ @staticmethod - def from_yaml(yaml_string: str): + def from_yaml(yaml_string: str) -> Tuple["SCFG", Dict[str, str]]: """Static helper method that creates an SCFG object from a YAML representation. @@ -886,8 +911,10 @@ def from_yaml(yaml_string: str): return scfg, block_dict @staticmethod - def from_dict(graph_dict: dict): - """Static helper method that creates an SCFG object from a dictionary + def from_dict( + graph_dict: Dict[str, Dict[str, Any]] + ) -> "Tuple[SCFG, Dict[str, str]]": + """Static method that creates an SCFG object from a dictionary representation. This method takes a dictionary (graph_dict) @@ -926,12 +953,12 @@ def from_dict(graph_dict: dict): @staticmethod def make_scfg( - graph_dict, - curr_heads: set, - block_ref_dict, - name_gen, - exiting: str = None, - ): + graph_dict: Dict[str, Dict[str, Any]], + curr_heads: Set[str], + block_ref_dict: Dict[str, str], + name_gen: NameGenerator, + exiting: Optional[str] = None, + ) -> "SCFG": """Helper method for building a single 'level' of the hierarchical structure in an `SCFG` graph at a time. Recursively calls itself to build the entire graph. @@ -1007,7 +1034,7 @@ def make_scfg( return scfg @staticmethod - def to_yaml(scfg): + def to_yaml(scfg: "SCFG") -> str: """Helper method to convert the SCFG object to a YAML string representation. @@ -1052,7 +1079,7 @@ def to_yaml(scfg): return ys @staticmethod - def to_dict(scfg): + def to_dict(scfg: "SCFG") -> Dict[str, Dict[str, Any]]: """Helper method to convert the SCFG object to a dictionary representation. @@ -1071,9 +1098,10 @@ def to_dict(scfg): graph_dict: Dict[Dict[...]] A dictionary representing the SCFG. """ - blocks, edges, backedges = {}, {}, {} + blocks: Dict[str, Any] = {} + edges, backedges = {}, {} - def reverse_lookup(value: type): + def reverse_lookup(value: type) -> str: for k, v in block_type_names.items(): if v == value: return k @@ -1081,7 +1109,7 @@ def reverse_lookup(value: type): raise TypeError("Block type not found.") seen = set() - q = set() + q: Set[Tuple[str, BasicBlock]] = set() # Order of elements doesn't matter since they're going to # be sorted at the end. q.update(scfg.graph.items()) @@ -1095,6 +1123,8 @@ def reverse_lookup(value: type): block_type = reverse_lookup(type(value)) blocks[key] = {"type": block_type} if isinstance(value, RegionBlock): + assert value.subregion is not None + assert value.parent_region is not None q.update(value.subregion.graph.items()) blocks[key]["kind"] = value.kind blocks[key]["contains"] = sorted( @@ -1119,7 +1149,7 @@ def reverse_lookup(value: type): return graph_dict @staticmethod - def find_outer_graph(graph_dict: dict): + def find_outer_graph(graph_dict: Dict[str, Dict[str, Any]]) -> Set[str]: """Helper method to find the outermost graph components of an `SCFG` object. (i.e. Components that aren't contained in any other region) @@ -1146,8 +1176,12 @@ def find_outer_graph(graph_dict: dict): @staticmethod def extract_block_info( - blocks, current_name, block_ref_dict, edges, backedges - ): + blocks: Dict[str, Dict[str, Any]], + current_name: str, + block_ref_dict: Dict[str, str], + edges: Dict[str, List[str]], + backedges: Dict[str, List[str]], + ) -> Tuple[Dict[str, Any], str, Tuple[str, ...], Tuple[str, ...]]: """Helper method to extract information from various components of an `SCFG` graph. @@ -1191,13 +1225,15 @@ def extract_block_info( return block_info, block_type, block_edges, block_backedges -class AbstractGraphView(Mapping): +class AbstractGraphView( + Mapping[str, BasicBlock] +): # todo: improve this annotation """Abstract Graph View class. The AbstractGraphView class serves as a template for graph views. """ - def __getitem__(self, item): + def __getitem__(self, item: str) -> BasicBlock: """Retrieves the value associated with the given key or name in the respective graph view. @@ -1213,7 +1249,7 @@ def __getitem__(self, item): """ raise NotImplementedError - def __iter__(self): + def __iter__(self) -> Iterator[str]: """Returns an iterator over the name of blocks in the graph view. Returns @@ -1223,7 +1259,7 @@ def __iter__(self): """ raise NotImplementedError - def __len__(self): + def __len__(self) -> int: """Returns the number of elements in the given region view. Return @@ -1253,9 +1289,9 @@ class ConcealedRegionView(AbstractGraphView): is based on. """ - scfg: SCFG = None + scfg: SCFG - def __init__(self, scfg): + def __init__(self, scfg: SCFG) -> None: """Initializes the ConcealedRegionView with the given SCFG. Parameters @@ -1265,7 +1301,7 @@ def __init__(self, scfg): """ self.scfg = scfg - def __getitem__(self, item): + def __getitem__(self, item: str) -> BasicBlock: """Retrieves the value associated with the given key or name in the respective graph view. @@ -1281,7 +1317,7 @@ def __getitem__(self, item): """ return self.scfg[item] - def __iter__(self): + def __iter__(self) -> Iterator[str]: """Returns an iterator over the name of blocks in the concealed graph view. @@ -1292,7 +1328,9 @@ def __iter__(self): """ return self.region_view_iterator() - def region_view_iterator(self, head: str = None) -> Iterator[str]: + def region_view_iterator( + self, head: Optional[str] = None + ) -> Iterator[str]: """Region View Iterator. This iterator is region aware, which means that regions are "concealed" @@ -1338,6 +1376,7 @@ def region_view_iterator(self, head: str = None) -> Iterator[str]: # If this is a region, continue on to the exiting block, i.e. # the region is presented a single fall-through block to the # consumer of this iterator. + assert block.subregion is not None to_visit.extend(block.subregion[block.exiting].jump_targets) else: # otherwise add any jump_targets to the list of names to visit @@ -1346,7 +1385,7 @@ def region_view_iterator(self, head: str = None) -> Iterator[str]: # finally, yield the name yield name - def __len__(self): + def __len__(self) -> int: """Returns the number of elements in the concealed region view. Return diff --git a/numba_rvsdg/core/transformations.py b/numba_rvsdg/core/transformations.py index 812024de..ab1cd135 100644 --- a/numba_rvsdg/core/transformations.py +++ b/numba_rvsdg/core/transformations.py @@ -1,10 +1,11 @@ from collections import defaultdict -from typing import Set, Dict, List +from typing import Set, Dict, List, Tuple, Optional, Mapping, Iterator from numba_rvsdg.core.datastructures.scfg import SCFG from numba_rvsdg.core.datastructures.basic_block import ( BasicBlock, SyntheticAssignment, + SyntheticBranch, SyntheticExitingLatch, SyntheticExitBranch, RegionBlock, @@ -14,7 +15,7 @@ from numba_rvsdg.core.utils import _logger -def loop_restructure_helper(scfg: SCFG, loop: Set[str]): +def loop_restructure_helper(scfg: SCFG, loop: Set[str]) -> None: """Loop Restructuring Applies the algorithm LOOP RESTRUCTURING from section 4.1 of Bahmann2015. @@ -37,14 +38,15 @@ def loop_restructure_helper(scfg: SCFG, loop: Set[str]): # If there are multiple headers, insert assignment and control blocks, # such that only a single loop header remains. + loop_head = None if len(headers) > 1: headers_were_unified = True solo_head_name = scfg.name_gen.new_block_name(block_names.SYNTH_HEAD) scfg.insert_block_and_control_blocks(solo_head_name, entries, headers) loop.add(solo_head_name) - loop_head: str = solo_head_name + loop_head = solo_head_name else: - loop_head: str = next(iter(headers)) + loop_head = next(iter(headers)) # If there is only a single exiting latch (an exiting block that also has a # backedge to the loop header) we can exit early, since the condition for # SCFG is fullfilled. @@ -79,7 +81,9 @@ def loop_restructure_helper(scfg: SCFG, loop: Set[str]): # If there were multiple headers, we must re-use the variable that was used # for looping as the exit variable if headers_were_unified: - exit_variable = scfg[solo_head_name].variable + bb_branch = scfg[solo_head_name] + assert isinstance(bb_branch, SyntheticBranch) + exit_variable = bb_branch.variable else: exit_variable = scfg.name_gen.new_var_name("exit") # This variable denotes the backedge @@ -96,18 +100,20 @@ def loop_restructure_helper(scfg: SCFG, loop: Set[str]): i: j for i, j in enumerate((loop_head, next(iter(exit_blocks)))) } if headers_were_unified: - header_value_table = scfg[solo_head_name].branch_value_table + bb_branch = scfg[solo_head_name] + assert isinstance(bb_branch, SyntheticBranch) + header_value_table = bb_branch.branch_value_table else: header_value_table = {} # This does a dictionary reverse lookup, to determine the key for a given # value. - def reverse_lookup(d, value): + def reverse_lookup(d: Mapping[int, str], value: str) -> int: for k, v in d.items(): if v == value: return k else: - return "UNUSED" + return -1 # Now that everything is in place, we can start to insert blocks, depending # on what is needed @@ -233,10 +239,11 @@ def reverse_lookup(d, value): scfg.add_block(synth_exit_block) -def restructure_loop(parent_region: RegionBlock): +def restructure_loop(parent_region: RegionBlock) -> None: """Inplace restructuring of the given graph to extract loops using strongly-connected components """ + assert parent_region.subregion is not None scfg = parent_region.subregion # obtain a List of Sets of names, where all names in each set are strongly # connected, i.e. all reachable from one another by traversing the subset @@ -276,15 +283,18 @@ def find_head_blocks(scfg: SCFG, begin: str) -> Set[str]: return head_region_blocks -def find_branch_regions(scfg: SCFG, begin: str, end: str) -> Set[str]: +def find_branch_regions( + scfg: SCFG, begin: str, end: str +) -> List[Optional[Tuple[str, Set[str]]]]: # identify branch regions doms = _doms(scfg) - branch_regions = [] + branch_regions: List[Optional[Tuple[str, Set[str]]]] = [] jump_targets = scfg.graph[begin].jump_targets for bra_start in jump_targets: for jt in jump_targets: if jt != bra_start and scfg.is_reachable_dfs(jt, bra_start): - branch_regions.append(tuple()) + # placeholder for empty branch region + branch_regions.append(None) break else: sub_keys: Set[str] = set() @@ -298,22 +308,17 @@ def find_branch_regions(scfg: SCFG, begin: str, end: str) -> Set[str]: return branch_regions -def _find_branch_regions(scfg: SCFG, begin: str, end: str) -> Set[str]: - # identify branch regions - branch_regions = [] - for bra_start in scfg[begin].jump_targets: - region = [] - region.append(bra_start) - return branch_regions - - def find_tail_blocks( - scfg: SCFG, begin: Set[str], head_region_blocks, branch_regions -): + scfg: SCFG, + begin: str, + head_region_blocks: Set[str], + branch_regions: List[Optional[Tuple[str, Set[str]]]], +) -> Set[str]: tail_subregion = {b for b in scfg.graph.keys()} tail_subregion.difference_update(head_region_blocks) for reg in branch_regions: - if not reg: + if reg is None: + # empty branch region continue b, sub = reg tail_subregion.discard(b) @@ -326,9 +331,10 @@ def find_tail_blocks( def update_exiting( region_block: RegionBlock, new_region_header: str, new_region_name: str -): +) -> RegionBlock: # Recursively updates the exiting blocks of a regionblock region_exiting = region_block.exiting + assert region_block.subregion is not None region_exiting_block: BasicBlock = region_block.subregion.graph.pop( region_exiting ) @@ -355,8 +361,11 @@ def update_exiting( def extract_region( - scfg: SCFG, region_blocks, region_kind, parent_region: RegionBlock -): + scfg: SCFG, + region_blocks: Set[str], + region_kind: str, + parent_region: RegionBlock, +) -> None: headers, entries = scfg.find_headers_and_entries(region_blocks) exiting_blocks, exit_blocks = scfg.find_exiting_and_exits(region_blocks) assert len(headers) == 1 @@ -424,12 +433,14 @@ def extract_region( parent_region.replace_exiting(region_name) # For every region block inside the newly created region, # update the parent region + assert region.subregion is not None for k, v in region.subregion.graph.items(): if isinstance(v, RegionBlock): object.__setattr__(v, "parent_region", region) -def restructure_branch(parent_region: RegionBlock): +def restructure_branch(parent_region: RegionBlock) -> None: + assert parent_region.subregion is not None scfg: SCFG = parent_region.subregion print("restructure_branch", scfg.graph) doms = _doms(scfg) @@ -485,7 +496,7 @@ def restructure_branch(parent_region: RegionBlock): block_names.SYNTH_FILL ) scfg.insert_SyntheticFill( - synthetic_branch_block_name, (begin,), tail_headers + synthetic_branch_block_name, [begin], tail_headers ) # Recompute regions. @@ -507,7 +518,7 @@ def restructure_branch(parent_region: RegionBlock): def _iter_branch_regions( scfg: SCFG, immdoms: Dict[str, str], postimmdoms: Dict[str, str] -): +) -> Iterator[Tuple[str, str]]: for begin, node in scfg.concealed_region_view.items(): if len(node.jump_targets) > 1: # found branch @@ -537,7 +548,7 @@ def _imm_doms(doms: Dict[str, Set[str]]) -> Dict[str, str]: return out -def _doms(scfg: SCFG): +def _doms(scfg: SCFG) -> Dict[str, Set[str]]: # compute dom entries = set() preds_table = defaultdict(set) @@ -559,7 +570,7 @@ def _doms(scfg: SCFG): ) -def _post_doms(scfg: SCFG): +def _post_doms(scfg: SCFG) -> Dict[str, Set[str]]: # compute post dom entries = set() for k, v in scfg.graph.items(): @@ -582,7 +593,12 @@ def _post_doms(scfg: SCFG): ) -def _find_dominators_internal(entries, nodes, preds_table, succs_table): +def _find_dominators_internal( + entries: Set[str], + nodes: List[str], + preds_table: Dict[str, Set[str]], + succs_table: Dict[str, Set[str]], +) -> Dict[str, Set[str]]: # From NUMBA # See theoretical description in # http://en.wikipedia.org/wiki/Dominator_%28graph_theory%29 @@ -623,7 +639,7 @@ def _find_dominators_internal(entries, nodes, preds_table, succs_table): preds = preds_table[n] if preds: new_doms |= functools.reduce( - set.intersection, [doms[p] for p in preds] + set.intersection, [doms[p] for p in preds] # type: ignore ) if new_doms != doms[n]: assert len(new_doms) < len(doms[n]) diff --git a/numba_rvsdg/core/utils.py b/numba_rvsdg/core/utils.py index a951233a..08940db8 100644 --- a/numba_rvsdg/core/utils.py +++ b/numba_rvsdg/core/utils.py @@ -4,10 +4,10 @@ class _LogWrap: - def __init__(self, fn): + def __init__(self, fn): # type: ignore self._fn = fn - def __str__(self): + def __str__(self): # type: ignore return self._fn() diff --git a/numba_rvsdg/networkx_vendored/scc.py b/numba_rvsdg/networkx_vendored/scc.py index 1352e4ad..befb44cf 100644 --- a/numba_rvsdg/networkx_vendored/scc.py +++ b/numba_rvsdg/networkx_vendored/scc.py @@ -5,6 +5,9 @@ """ +# Ignore all mypy errors since this file has been vendored. +# mypy: ignore-errors + def scc(G): preorder = {} diff --git a/numba_rvsdg/rendering/rendering.py b/numba_rvsdg/rendering/rendering.py index d0671c8b..a8c1ba0f 100644 --- a/numba_rvsdg/rendering/rendering.py +++ b/numba_rvsdg/rendering/rendering.py @@ -1,4 +1,5 @@ import logging +from abc import abstractmethod from numba_rvsdg.core.datastructures.basic_block import ( BasicBlock, RegionBlock, @@ -10,7 +11,8 @@ from numba_rvsdg.core.datastructures.scfg import SCFG from numba_rvsdg.core.datastructures.byte_flow import ByteFlow import dis -from typing import Dict +from typing import Dict, Optional +from graphviz import Digraph class BaseRenderer: @@ -21,9 +23,35 @@ class BaseRenderer: edges of the graph are rendered respectively. """ + g: "Digraph" + + @abstractmethod + def render_basic_block( + self, digraph: "Digraph", name: str, block: BasicBlock + ) -> None: + """ """ + + @abstractmethod + def render_control_variable_block( + self, digraph: "Digraph", name: str, block: SyntheticAssignment + ) -> None: + """ """ + + @abstractmethod + def render_branching_block( + self, digraph: "Digraph", name: str, block: SyntheticBranch + ) -> None: + """ """ + + @abstractmethod + def render_region_block( + self, digraph: "Digraph", name: str, regionblock: RegionBlock + ) -> None: + """ """ + def render_block( - self, digraph: "Digraph", name: str, block: BasicBlock # noqa - ): + self, digraph: "Digraph", name: str, block: BasicBlock + ) -> None: """Function that defines how the BasicBlocks in a graph should be rendered. @@ -40,20 +68,20 @@ def render_block( """ if type(block) == BasicBlock: self.render_basic_block(digraph, name, block) - elif type(block) == PythonBytecodeBlock: + if type(block) == PythonBytecodeBlock: self.render_basic_block(digraph, name, block) elif type(block) == SyntheticAssignment: self.render_control_variable_block(digraph, name, block) elif isinstance(block, SyntheticBranch): self.render_branching_block(digraph, name, block) - elif isinstance(block, SyntheticBlock): - self.render_basic_block(digraph, name, block) elif type(block) == RegionBlock: self.render_region_block(digraph, name, block) + elif isinstance(block, SyntheticBlock): + self.render_basic_block(digraph, name, block) else: raise Exception("unreachable") - def render_edges(self, scfg: SCFG): + def render_edges(self, scfg: SCFG) -> None: """Function that renders the edges in an SCFG. Parameters @@ -64,9 +92,9 @@ def render_edges(self, scfg: SCFG): """ blocks = dict(scfg) - def find_base_header(block: BasicBlock): + def find_base_header(block: BasicBlock) -> BasicBlock: if isinstance(block, RegionBlock): - block = blocks[block.header] + block = blocks[block.header] # type: ignore block = find_base_header(block) return block @@ -108,14 +136,14 @@ class ByteFlowRenderer(BaseRenderer): """ - def __init__(self): + def __init__(self) -> None: from graphviz import Digraph self.g = Digraph() def render_region_block( - self, digraph: "Digraph", name: str, regionblock: RegionBlock # noqa - ): + self, digraph: "Digraph", name: str, regionblock: RegionBlock + ) -> None: # render subgraph with digraph.subgraph(name=f"cluster_{name}") as subg: color = "blue" @@ -126,13 +154,16 @@ def render_region_block( if regionblock.kind == "head": color = "red" subg.attr(color=color, label=regionblock.name) + assert regionblock.subregion is not None for name, block in regionblock.subregion.graph.items(): self.render_block(subg, name, block) def render_basic_block( - self, digraph: "Digraph", name: str, block: BasicBlock # noqa - ): - if name.startswith("python_bytecode"): + self, digraph: "Digraph", name: str, block: BasicBlock + ) -> None: + if name.startswith("python_bytecode") and isinstance( + block, PythonBytecodeBlock + ): instlist = block.get_instructions(self.bcmap) body = name + r"\l" body += r"\l".join( @@ -144,8 +175,8 @@ def render_basic_block( digraph.node(str(name), shape="rect", label=body) def render_control_variable_block( - self, digraph: "Digraph", name: str, block: BasicBlock # noqa - ): + self, digraph: "Digraph", name: str, block: SyntheticAssignment + ) -> None: if isinstance(name, str): body = name + r"\l" body += r"\l".join( @@ -156,8 +187,8 @@ def render_control_variable_block( digraph.node(str(name), shape="rect", label=body) def render_branching_block( - self, digraph: "Digraph", name: str, block: BasicBlock # noqa - ): + self, digraph: "Digraph", name: str, block: SyntheticBranch + ) -> None: if isinstance(name, str): body = name + r"\l" body += rf"variable: {block.variable}\l" @@ -168,9 +199,8 @@ def render_branching_block( raise Exception("Unknown name type: " + name) digraph.node(str(name), shape="rect", label=body) - def render_byteflow(self, byteflow: ByteFlow): - """Renders the provided `ByteFlow` object. - """ + def render_byteflow(self, byteflow: ByteFlow) -> "Digraph": + """Renders the provided `ByteFlow` object.""" self.bcmap_from_bytecode(byteflow.bc) # render nodes for name, block in byteflow.scfg.graph.items(): @@ -178,7 +208,7 @@ def render_byteflow(self, byteflow: ByteFlow): self.render_edges(byteflow.scfg) return self.g - def bcmap_from_bytecode(self, bc: dis.Bytecode): + def bcmap_from_bytecode(self, bc: dis.Bytecode) -> None: self.bcmap: Dict[int, dis.Instruction] = SCFG.bcmap_from_bytecode(bc) @@ -204,8 +234,8 @@ def __init__(self, scfg: SCFG): self.render_edges(scfg) def render_region_block( - self, digraph: "Digraph", name: str, regionblock: RegionBlock # noqa - ): + self, digraph: "Digraph", name: str, regionblock: RegionBlock + ) -> None: # render subgraph with digraph.subgraph(name=f"cluster_{name}") as subg: color = "blue" @@ -224,12 +254,13 @@ def render_region_block( ) subg.attr(color=color, label=label) + assert regionblock.subregion is not None for name, block in regionblock.subregion.graph.items(): self.render_block(subg, name, block) def render_basic_block( - self, digraph: "Digraph", name: str, block: BasicBlock # noqa - ): + self, digraph: "Digraph", name: str, block: BasicBlock + ) -> None: body = ( name + r"\l" @@ -242,8 +273,8 @@ def render_basic_block( digraph.node(str(name), shape="rect", label=body) def render_control_variable_block( - self, digraph: "Digraph", name: str, block: BasicBlock # noqa - ): + self, digraph: "Digraph", name: str, block: SyntheticAssignment + ) -> None: if isinstance(name, str): body = name + r"\l" body += r"\l".join( @@ -261,8 +292,8 @@ def render_control_variable_block( digraph.node(str(name), shape="rect", label=body) def render_branching_block( - self, digraph: "Digraph", name: str, block: BasicBlock # noqa - ): + self, digraph: "Digraph", name: str, block: SyntheticBranch + ) -> None: if isinstance(name, str): body = name + r"\l" body += rf"variable: {block.variable}\l" @@ -280,7 +311,7 @@ def render_branching_block( raise Exception("Unknown name type: " + name) digraph.node(str(name), shape="rect", label=body) - def view(self, name: str): + def view(self, name: Optional[str] = None) -> None: """Method used to view the current SCFG as an external graphviz generated PDF file. @@ -295,7 +326,7 @@ def view(self, name: str): logging.basicConfig(level=logging.DEBUG) -def render_func(func): +def render_func(func) -> None: # type: ignore """The `render_func`` function takes a `func` parameter as the Python function to be transformed and rendered and renders the byte flow representation of the bytecode of the function. @@ -308,7 +339,7 @@ def render_func(func): render_flow(ByteFlow.from_bytecode(func)) -def render_flow(flow): +def render_flow(flow: ByteFlow) -> None: """Renders multiple ByteFlow representations across various SCFG transformations. @@ -344,7 +375,7 @@ def render_flow(flow): ByteFlowRenderer().render_byteflow(bflow).view("branch restructured") -def render_scfg(scfg): +def render_scfg(scfg: SCFG) -> None: """The `render_scfg` function takes a `scfg` parameter as the SCFG object to be transformed and rendered and renders the graphviz representation of the SCFG. @@ -354,4 +385,5 @@ def render_scfg(scfg): scfg: SCFG The structured control flow graph (SCFG) to be rendered. """ - ByteFlowRenderer().render_scfg(scfg).view("scfg") + # is this function used?? + ByteFlowRenderer().render_scfg(scfg).view("scfg") # type: ignore diff --git a/numba_rvsdg/tests/mock_asm.py b/numba_rvsdg/tests/mock_asm.py index 76c0c92a..6904bedd 100644 --- a/numba_rvsdg/tests/mock_asm.py +++ b/numba_rvsdg/tests/mock_asm.py @@ -1,3 +1,5 @@ +# mypy: ignore-errors + """ Defines a mock assembly with a minimal operation semantic for testing control flow transformation. @@ -6,7 +8,7 @@ from dataclasses import dataclass from enum import IntEnum -from typing import IO +from typing import IO, List, Tuple import random @@ -57,22 +59,22 @@ class Inst: @dataclass(frozen=True) -class PrintOperands: +class PrintOperands(Operands): text: str @dataclass(frozen=True) -class GotoOperands: +class GotoOperands(Operands): jump_target: int @dataclass(frozen=True) -class CtrOperands: +class CtrOperands(Operands): counter: int @dataclass(frozen=True) -class BrCtrOperands: +class BrCtrOperands(Operands): true_target: int false_target: int @@ -84,7 +86,7 @@ def parse(asm: str) -> list[Inst]: """ # pass 1: scan for labels labelmap: dict[str, int] = {} - todos = [] + todos: List[Tuple[str, List[str]]] = [] for line in asm.splitlines(): line = line.strip() if not line: diff --git a/numba_rvsdg/tests/simulator.py b/numba_rvsdg/tests/simulator.py index 984c3fd7..16e9e77c 100644 --- a/numba_rvsdg/tests/simulator.py +++ b/numba_rvsdg/tests/simulator.py @@ -1,3 +1,5 @@ +# mypy: ignore-errors + from collections import ChainMap from dis import Instruction from numba_rvsdg.core.datastructures.byte_flow import ByteFlow diff --git a/numba_rvsdg/tests/test_byteflow.py b/numba_rvsdg/tests/test_byteflow.py index a92db249..46aea65c 100644 --- a/numba_rvsdg/tests/test_byteflow.py +++ b/numba_rvsdg/tests/test_byteflow.py @@ -1,5 +1,5 @@ +# mypy: ignore-errors from dis import Bytecode, Instruction, Positions - import unittest from numba_rvsdg.core.datastructures.basic_block import PythonBytecodeBlock from numba_rvsdg.core.datastructures.byte_flow import ByteFlow diff --git a/numba_rvsdg/tests/test_figures.py b/numba_rvsdg/tests/test_figures.py index 82419f1d..3309c048 100644 --- a/numba_rvsdg/tests/test_figures.py +++ b/numba_rvsdg/tests/test_figures.py @@ -1,3 +1,5 @@ +# mypy: ignore-errors + from numba_rvsdg.core.datastructures.byte_flow import ByteFlow from numba_rvsdg.core.datastructures.flow_info import FlowInfo from numba_rvsdg.core.datastructures.scfg import SCFG diff --git a/numba_rvsdg/tests/test_mock_asm.py b/numba_rvsdg/tests/test_mock_asm.py index 6773a5f4..ed0941fe 100644 --- a/numba_rvsdg/tests/test_mock_asm.py +++ b/numba_rvsdg/tests/test_mock_asm.py @@ -1,3 +1,5 @@ +# mypy: ignore-errors + from io import StringIO import random import textwrap diff --git a/numba_rvsdg/tests/test_scc.py b/numba_rvsdg/tests/test_scc.py index b3ba2d75..f0de80d5 100644 --- a/numba_rvsdg/tests/test_scc.py +++ b/numba_rvsdg/tests/test_scc.py @@ -1,3 +1,5 @@ +# mypy: ignore-errors + from numba_rvsdg.core.datastructures.byte_flow import ByteFlow from numba_rvsdg.rendering.rendering import render_flow diff --git a/numba_rvsdg/tests/test_scfg.py b/numba_rvsdg/tests/test_scfg.py index cbcdcb7a..71e167e9 100644 --- a/numba_rvsdg/tests/test_scfg.py +++ b/numba_rvsdg/tests/test_scfg.py @@ -1,3 +1,5 @@ +# mypy: ignore-errors + from unittest import main, TestCase from textwrap import dedent from numba_rvsdg.core.datastructures.scfg import SCFG, NameGenerator diff --git a/numba_rvsdg/tests/test_simulate.py b/numba_rvsdg/tests/test_simulate.py index 68f8f44b..9986e716 100644 --- a/numba_rvsdg/tests/test_simulate.py +++ b/numba_rvsdg/tests/test_simulate.py @@ -1,3 +1,5 @@ +# mypy: ignore-errors + from numba_rvsdg.core.datastructures.byte_flow import ByteFlow from numba_rvsdg.tests.simulator import Simulator import unittest diff --git a/numba_rvsdg/tests/test_transforms.py b/numba_rvsdg/tests/test_transforms.py index d42c2e0b..746c7ac2 100644 --- a/numba_rvsdg/tests/test_transforms.py +++ b/numba_rvsdg/tests/test_transforms.py @@ -1,3 +1,5 @@ +# mypy: ignore-errors + from unittest import main from numba_rvsdg.core.datastructures.scfg import SCFG diff --git a/numba_rvsdg/tests/test_utils.py b/numba_rvsdg/tests/test_utils.py index 488a9713..ffd405a8 100644 --- a/numba_rvsdg/tests/test_utils.py +++ b/numba_rvsdg/tests/test_utils.py @@ -1,3 +1,5 @@ +# mypy: ignore-errors + from unittest import TestCase import yaml @@ -78,14 +80,14 @@ def assertSCFGEqual( stack.append(be1) def assertYAMLEqual( - self, first_yaml: SCFG, second_yaml: SCFG, head_map: dict + self, first_yaml: str, second_yaml: str, head_map: dict ): self.assertDictEqual( yaml.safe_load(first_yaml), yaml.safe_load(second_yaml), head_map ) - def assertDictEqual( - self, first_yaml: str, second_yaml: str, head_map: dict + def assertDictEqual( # type: ignore + self, first_yaml: dict, second_yaml: dict, head_map: dict ): block_mapping = head_map stack = list(block_mapping.keys()) @@ -96,7 +98,7 @@ def assertDictEqual( seen = set() while stack: - node_name: BasicBlock = stack.pop() + node_name = stack.pop() if node_name in seen: continue seen.add(node_name)