-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Schedule] Allowed typing.Tuple in tir.schedule._type_checker (#11289)
* [Schedule] Allowed typing.Tuple in tir.schedule._type_checker Previously, `typing.Tuple` annotations could not be used with `tir.schedule._type_checker.type_checked` annotations. This allows `Tuple` type annotations to be type-checked. * Revert change, allow tuples input as List arguments * Suppress mypy errors Directly interacting with a type object would otherwise cause some false positives. * Corrected unit test for allowing tuples to be used as typing.List * Represent multi-type lists as List[Union[...]] instead of List[Any] This gives a better error message and plays nicely with _type2str, since `typing.Any` doesn't have a `__name__` field.
- Loading branch information
1 parent
01b472f
commit 72a5219
Showing
2 changed files
with
169 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""Test type checker based on python's type annotations""" | ||
|
||
from typing import List, Tuple | ||
|
||
import pytest | ||
|
||
from tvm.tir.schedule._type_checker import type_checked | ||
|
||
|
||
test_cases = [ | ||
{ | ||
"type_annotation": int, | ||
"positive_cases": [5], | ||
"negative_cases": ["5"], | ||
}, | ||
{ | ||
"type_annotation": List[int], | ||
"positive_cases": [ | ||
[5], | ||
[], | ||
# Tuples are allowed to be used as lists, because both are | ||
# represented in FFI as tvm::runtime::Array. | ||
(1, 2, 3), | ||
], | ||
"negative_cases": [ | ||
None, | ||
5, | ||
["5"], | ||
], | ||
}, | ||
{ | ||
"type_annotation": Tuple[int], | ||
"positive_cases": [ | ||
(5,), | ||
], | ||
"negative_cases": [ | ||
None, | ||
(1, 2, 3), | ||
[1], | ||
5, | ||
["5"], | ||
], | ||
}, | ||
{ | ||
"type_annotation": Tuple[str, int], | ||
"positive_cases": [ | ||
("x", 5), | ||
], | ||
"negative_cases": [ | ||
42, | ||
("x", 5, 6), | ||
("x", 5, "y"), | ||
("x", 5.0), | ||
(None, 5), | ||
], | ||
}, | ||
] | ||
|
||
positive_cases = [ | ||
(config["type_annotation"], case) for config in test_cases for case in config["positive_cases"] | ||
] | ||
|
||
negative_cases = [ | ||
(config["type_annotation"], case) for config in test_cases for case in config["negative_cases"] | ||
] | ||
|
||
|
||
def format_name(type_annotation, case): | ||
try: | ||
name = type_annotation.__name__ | ||
except AttributeError: | ||
name = str(type_annotation).replace("typing.", "") | ||
|
||
return f"{name}_{case}" | ||
|
||
|
||
@pytest.mark.parametrize( | ||
["type_annotation", "case"], | ||
positive_cases, | ||
ids=[format_name(t, c) for t, c in positive_cases], | ||
) | ||
def test_matches_type(type_annotation, case): | ||
@type_checked | ||
def func(_: type_annotation): | ||
pass | ||
|
||
func(case) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
["type_annotation", "case"], | ||
negative_cases, | ||
ids=[format_name(t, c) for t, c in negative_cases], | ||
) | ||
def test_not_matches(type_annotation, case): | ||
@type_checked | ||
def func(_: type_annotation): | ||
pass | ||
|
||
with pytest.raises(TypeError): | ||
func(case) | ||
|
||
|
||
if __name__ == "__main__": | ||
sys.exit(pytest.main(sys.argv)) |