From 72a5219aad7c9b807169f74f8954580a36c1d85e Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 20 May 2022 12:20:12 -0500 Subject: [PATCH] [Schedule] Allowed typing.Tuple in tir.schedule._type_checker (#11289) * [Schedule] Allowed typing.Tuple in tir.schedule._type_checker Previously, `typing.Tuple` annotations could not be used with `tir.schedule._type_checker.type_checked` annotations. This allows `Tuple` type annotations to be type-checked. * Revert change, allow tuples input as List arguments * Suppress mypy errors Directly interacting with a type object would otherwise cause some false positives. * Corrected unit test for allowing tuples to be used as typing.List * Represent multi-type lists as List[Union[...]] instead of List[Any] This gives a better error message and plays nicely with _type2str, since `typing.Any` doesn't have a `__name__` field. --- python/tvm/tir/schedule/_type_checker.py | 49 ++++++- .../unittest/test_type_annotation_checker.py | 121 ++++++++++++++++++ 2 files changed, 169 insertions(+), 1 deletion(-) create mode 100644 tests/python/unittest/test_type_annotation_checker.py diff --git a/python/tvm/tir/schedule/_type_checker.py b/python/tvm/tir/schedule/_type_checker.py index 1b86c4aa30db..21ca0c5a922b 100644 --- a/python/tvm/tir/schedule/_type_checker.py +++ b/python/tvm/tir/schedule/_type_checker.py @@ -41,6 +41,13 @@ def list_(type_: Any) -> Any: return [subtype] return None + @staticmethod + def tuple_(type_: Any) -> Optional[List[type]]: + if _Subtype._origin(type_) is tuple: + subtypes = type_.__args__ + return subtypes + return None + @staticmethod def optional(type_: Any) -> Optional[List[type]]: if _Subtype._origin(type_) is Union: @@ -68,6 +75,14 @@ def list_(type_: Any) -> Optional[List[type]]: return [subtype] return None + @staticmethod + def tuple_(type_: Any) -> Optional[List[type]]: + if isinstance(type_, typing.GenericMeta): # type: ignore # pylint: disable=no-member + if type_.__name__ == "Tuple": + subtypes = type_.__args__ # type: ignore # pylint: disable=no-member + return subtypes + return None + @staticmethod def optional(type_: Any) -> Optional[List[type]]: if isinstance(type_, typing._Union): # type: ignore # pylint: disable=no-member,protected-access @@ -93,6 +108,10 @@ def _dispatcher(type_: Any) -> Tuple[str, List[type]]: if subtype is not None: return "list", subtype + subtype = _Subtype.tuple_(type_) + if subtype is not None: + return "tuple", subtype + subtype = _Subtype.optional(type_) if subtype is not None: return "optional", subtype @@ -108,6 +127,7 @@ def _dispatcher(type_: Any) -> Tuple[str, List[type]]: "none": lambda: "None", "atomic": lambda t: str(t.__name__), "list": lambda t: f"List[{_type2str(t)}]", + "tuple": lambda *t: f"Tuple[{', '.join([_type2str(x) for x in t])}]", "optional": lambda t: f"Optional[{_type2str(t)}]", "union": lambda *t: f"Union[{', '.join([_type2str(x) for x in t])}]", } @@ -118,11 +138,26 @@ def _type2str(type_: Any) -> str: return _TYPE2STR[key](*subtypes) +def _val2type(value: Any): + if isinstance(value, list): + types = set(_val2type(x) for x in value) + if len(types) == 1: + return List[types.pop()] # type: ignore + + return List[Union[tuple(types)]] # type: ignore + + if isinstance(value, tuple): + types = tuple(_val2type(x) for x in value) # type: ignore + return Tuple[types] + + return type(value) + + def _type_check_err(x: Any, name: str, expected: Any) -> str: return ( f'"{name}" has wrong type. ' f'Expected "{_type2str(expected)}", ' - f'but gets: "{_type2str(type(x))}"' + f'but gets: "{_type2str(_val2type(x))}"' ) @@ -142,6 +177,17 @@ def _type_check_list(v: List[Any], name: str, type_: Any) -> Optional[str]: return error_msg return None + def _type_check_tuple(v: Any, name: str, *types: Any) -> Optional[str]: + if not isinstance(v, tuple): + return _type_check_err(v, name, Tuple[types]) + if len(types) != len(v): + return _type_check_err(v, name, Tuple[types]) + for i, (x, type_) in enumerate(zip(v, types)): + error_msg = _type_check(x, f"{name}[{i}]", type_) + if error_msg is not None: + return error_msg + return None + def _type_check_optional(v: Any, name: str, type_: Any) -> Optional[str]: return None if v is None else _type_check(v, name, type_) @@ -156,6 +202,7 @@ def _type_check_union(v: Any, name: str, *types: Any) -> Optional[str]: "none": _type_check_none, "atomic": _type_check_atomic, "list": _type_check_list, + "tuple": _type_check_tuple, "optional": _type_check_optional, "union": _type_check_union, } diff --git a/tests/python/unittest/test_type_annotation_checker.py b/tests/python/unittest/test_type_annotation_checker.py new file mode 100644 index 000000000000..7317e05b1a75 --- /dev/null +++ b/tests/python/unittest/test_type_annotation_checker.py @@ -0,0 +1,121 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test type checker based on python's type annotations""" + +from typing import List, Tuple + +import pytest + +from tvm.tir.schedule._type_checker import type_checked + + +test_cases = [ + { + "type_annotation": int, + "positive_cases": [5], + "negative_cases": ["5"], + }, + { + "type_annotation": List[int], + "positive_cases": [ + [5], + [], + # Tuples are allowed to be used as lists, because both are + # represented in FFI as tvm::runtime::Array. + (1, 2, 3), + ], + "negative_cases": [ + None, + 5, + ["5"], + ], + }, + { + "type_annotation": Tuple[int], + "positive_cases": [ + (5,), + ], + "negative_cases": [ + None, + (1, 2, 3), + [1], + 5, + ["5"], + ], + }, + { + "type_annotation": Tuple[str, int], + "positive_cases": [ + ("x", 5), + ], + "negative_cases": [ + 42, + ("x", 5, 6), + ("x", 5, "y"), + ("x", 5.0), + (None, 5), + ], + }, +] + +positive_cases = [ + (config["type_annotation"], case) for config in test_cases for case in config["positive_cases"] +] + +negative_cases = [ + (config["type_annotation"], case) for config in test_cases for case in config["negative_cases"] +] + + +def format_name(type_annotation, case): + try: + name = type_annotation.__name__ + except AttributeError: + name = str(type_annotation).replace("typing.", "") + + return f"{name}_{case}" + + +@pytest.mark.parametrize( + ["type_annotation", "case"], + positive_cases, + ids=[format_name(t, c) for t, c in positive_cases], +) +def test_matches_type(type_annotation, case): + @type_checked + def func(_: type_annotation): + pass + + func(case) + + +@pytest.mark.parametrize( + ["type_annotation", "case"], + negative_cases, + ids=[format_name(t, c) for t, c in negative_cases], +) +def test_not_matches(type_annotation, case): + @type_checked + def func(_: type_annotation): + pass + + with pytest.raises(TypeError): + func(case) + + +if __name__ == "__main__": + sys.exit(pytest.main(sys.argv))