Skip to content

Commit

Permalink
[Schedule] Allowed typing.Tuple in tir.schedule._type_checker (#11289)
Browse files Browse the repository at this point in the history
* [Schedule] Allowed typing.Tuple in tir.schedule._type_checker

Previously, `typing.Tuple` annotations could not be used with
`tir.schedule._type_checker.type_checked` annotations.  This allows
`Tuple` type annotations to be type-checked.

* Revert change, allow tuples input as List arguments

* Suppress mypy errors

Directly interacting with a type object would otherwise cause some
false positives.

* Corrected unit test for allowing tuples to be used as typing.List

* Represent multi-type lists as List[Union[...]] instead of List[Any]

This gives a better error message and plays nicely with _type2str,
since `typing.Any` doesn't have a `__name__` field.
  • Loading branch information
Lunderberg authored May 20, 2022
1 parent 01b472f commit 72a5219
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 1 deletion.
49 changes: 48 additions & 1 deletion python/tvm/tir/schedule/_type_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ def list_(type_: Any) -> Any:
return [subtype]
return None

@staticmethod
def tuple_(type_: Any) -> Optional[List[type]]:
if _Subtype._origin(type_) is tuple:
subtypes = type_.__args__
return subtypes
return None

@staticmethod
def optional(type_: Any) -> Optional[List[type]]:
if _Subtype._origin(type_) is Union:
Expand Down Expand Up @@ -68,6 +75,14 @@ def list_(type_: Any) -> Optional[List[type]]:
return [subtype]
return None

@staticmethod
def tuple_(type_: Any) -> Optional[List[type]]:
if isinstance(type_, typing.GenericMeta): # type: ignore # pylint: disable=no-member
if type_.__name__ == "Tuple":
subtypes = type_.__args__ # type: ignore # pylint: disable=no-member
return subtypes
return None

@staticmethod
def optional(type_: Any) -> Optional[List[type]]:
if isinstance(type_, typing._Union): # type: ignore # pylint: disable=no-member,protected-access
Expand All @@ -93,6 +108,10 @@ def _dispatcher(type_: Any) -> Tuple[str, List[type]]:
if subtype is not None:
return "list", subtype

subtype = _Subtype.tuple_(type_)
if subtype is not None:
return "tuple", subtype

subtype = _Subtype.optional(type_)
if subtype is not None:
return "optional", subtype
Expand All @@ -108,6 +127,7 @@ def _dispatcher(type_: Any) -> Tuple[str, List[type]]:
"none": lambda: "None",
"atomic": lambda t: str(t.__name__),
"list": lambda t: f"List[{_type2str(t)}]",
"tuple": lambda *t: f"Tuple[{', '.join([_type2str(x) for x in t])}]",
"optional": lambda t: f"Optional[{_type2str(t)}]",
"union": lambda *t: f"Union[{', '.join([_type2str(x) for x in t])}]",
}
Expand All @@ -118,11 +138,26 @@ def _type2str(type_: Any) -> str:
return _TYPE2STR[key](*subtypes)


def _val2type(value: Any):
if isinstance(value, list):
types = set(_val2type(x) for x in value)
if len(types) == 1:
return List[types.pop()] # type: ignore

return List[Union[tuple(types)]] # type: ignore

if isinstance(value, tuple):
types = tuple(_val2type(x) for x in value) # type: ignore
return Tuple[types]

return type(value)


def _type_check_err(x: Any, name: str, expected: Any) -> str:
return (
f'"{name}" has wrong type. '
f'Expected "{_type2str(expected)}", '
f'but gets: "{_type2str(type(x))}"'
f'but gets: "{_type2str(_val2type(x))}"'
)


Expand All @@ -142,6 +177,17 @@ def _type_check_list(v: List[Any], name: str, type_: Any) -> Optional[str]:
return error_msg
return None

def _type_check_tuple(v: Any, name: str, *types: Any) -> Optional[str]:
if not isinstance(v, tuple):
return _type_check_err(v, name, Tuple[types])
if len(types) != len(v):
return _type_check_err(v, name, Tuple[types])
for i, (x, type_) in enumerate(zip(v, types)):
error_msg = _type_check(x, f"{name}[{i}]", type_)
if error_msg is not None:
return error_msg
return None

def _type_check_optional(v: Any, name: str, type_: Any) -> Optional[str]:
return None if v is None else _type_check(v, name, type_)

Expand All @@ -156,6 +202,7 @@ def _type_check_union(v: Any, name: str, *types: Any) -> Optional[str]:
"none": _type_check_none,
"atomic": _type_check_atomic,
"list": _type_check_list,
"tuple": _type_check_tuple,
"optional": _type_check_optional,
"union": _type_check_union,
}
Expand Down
121 changes: 121 additions & 0 deletions tests/python/unittest/test_type_annotation_checker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test type checker based on python's type annotations"""

from typing import List, Tuple

import pytest

from tvm.tir.schedule._type_checker import type_checked


test_cases = [
{
"type_annotation": int,
"positive_cases": [5],
"negative_cases": ["5"],
},
{
"type_annotation": List[int],
"positive_cases": [
[5],
[],
# Tuples are allowed to be used as lists, because both are
# represented in FFI as tvm::runtime::Array.
(1, 2, 3),
],
"negative_cases": [
None,
5,
["5"],
],
},
{
"type_annotation": Tuple[int],
"positive_cases": [
(5,),
],
"negative_cases": [
None,
(1, 2, 3),
[1],
5,
["5"],
],
},
{
"type_annotation": Tuple[str, int],
"positive_cases": [
("x", 5),
],
"negative_cases": [
42,
("x", 5, 6),
("x", 5, "y"),
("x", 5.0),
(None, 5),
],
},
]

positive_cases = [
(config["type_annotation"], case) for config in test_cases for case in config["positive_cases"]
]

negative_cases = [
(config["type_annotation"], case) for config in test_cases for case in config["negative_cases"]
]


def format_name(type_annotation, case):
try:
name = type_annotation.__name__
except AttributeError:
name = str(type_annotation).replace("typing.", "")

return f"{name}_{case}"


@pytest.mark.parametrize(
["type_annotation", "case"],
positive_cases,
ids=[format_name(t, c) for t, c in positive_cases],
)
def test_matches_type(type_annotation, case):
@type_checked
def func(_: type_annotation):
pass

func(case)


@pytest.mark.parametrize(
["type_annotation", "case"],
negative_cases,
ids=[format_name(t, c) for t, c in negative_cases],
)
def test_not_matches(type_annotation, case):
@type_checked
def func(_: type_annotation):
pass

with pytest.raises(TypeError):
func(case)


if __name__ == "__main__":
sys.exit(pytest.main(sys.argv))

0 comments on commit 72a5219

Please sign in to comment.