Skip to content

Commit

Permalink
Add static type check with mypy (#2195)
Browse files Browse the repository at this point in the history
* add mypy config

* fix syntax error

* fix annotations in torchvision/utils.py

* add mypy type check to CircleCI

* add mypy cache to ignore files

* try fix CI

* ignore flake8 F821 since it interferes with mypy

* add mypy type check to config generator

* explicitly set config files
  • Loading branch information
pmeier authored May 11, 2020
1 parent f71316f commit a81d99b
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 10 deletions.
16 changes: 15 additions & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,19 @@ jobs:
- run:
command: |
pip install --user --progress-bar off flake8 typing
flake8 .
flake8 --config=setup.cfg .
python_type_check:
docker:
- image: circleci/python:3.7
steps:
- checkout
- run:
command: |
pip install --user --progress-bar off numpy mypy
pip install --user --progress-bar off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
pip install --user --progress-bar off .
mypy --config-file mypy.ini
clang_format:
docker:
Expand Down Expand Up @@ -702,12 +714,14 @@ workflows:
python_version: "3.6"
cu_version: "cu101"
- python_lint
- python_type_check
- clang_format

nightly:
jobs:
- circleci_consistency
- python_lint
- python_type_check
- clang_format
- binary_linux_wheel:
cu_version: cpu
Expand Down
16 changes: 15 additions & 1 deletion .circleci/config.yml.in
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,19 @@ jobs:
- run:
command: |
pip install --user --progress-bar off flake8 typing
flake8 .
flake8 --config=setup.cfg .

python_type_check:
docker:
- image: circleci/python:3.7
steps:
- checkout
- run:
command: |
pip install --user --progress-bar off numpy mypy
pip install --user --progress-bar off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
pip install --user --progress-bar off .
mypy --config-file mypy.ini

clang_format:
docker:
Expand Down Expand Up @@ -398,12 +410,14 @@ workflows:
python_version: "3.6"
cu_version: "cu101"
- python_lint
- python_type_check
- clang_format

nightly:
{%- endif %}
jobs:
- circleci_consistency
- python_lint
- python_type_check
- clang_format
{{ workflows(prefix="nightly_", filter_branch="nightly", upload=True) }}
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ htmlcov
*.swp
*.swo
gen.yml
.mypy_cache
30 changes: 30 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
[mypy]

files = torchvision
show_error_codes = True
pretty = True

[mypy-torchvision.datasets.*]

ignore_errors = True

[mypy-torchvision.io.*]

ignore_errors = True

[mypy-torchvision.models.*]

ignore_errors = True

[mypy-torchvision.ops.*]

ignore_errors = True

[mypy-torchvision.transforms.*]

ignore_errors = True

[mypy-PIL]

ignore_missing_imports = True

2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ max-line-length = 120

[flake8]
max-line-length = 120
ignore = F401,E402,F403,W503,W504
ignore = F401,E402,F403,W503,W504,F821
exclude = venv
3 changes: 2 additions & 1 deletion torchvision/io/_video_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ def __init__(self):


def _validate_pts(pts_range):
# type: (List[int])
# type: (List[int]) -> None

if pts_range[1] > 0:
assert (
pts_range[0] <= pts_range[1]
Expand Down
14 changes: 8 additions & 6 deletions torchvision/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union, Optional, Sequence, Tuple, Text, BinaryIO
from typing import Union, Optional, List, Tuple, Text, BinaryIO
import io
import pathlib
import torch
Expand All @@ -7,7 +7,7 @@


def make_grid(
tensor: Union[torch.Tensor, Sequence[torch.Tensor]],
tensor: Union[torch.Tensor, List[torch.Tensor]],
nrow: int = 8,
padding: int = 2,
normalize: bool = False,
Expand Down Expand Up @@ -91,15 +91,17 @@ def norm_range(t, range):
for x in irange(xmaps):
if k >= nmaps:
break
grid.narrow(1, y * height + padding, height - padding)\
.narrow(2, x * width + padding, width - padding)\
.copy_(tensor[k])
# Tensor.copy_() is a valid method but seems to be missing from the stubs
# https://pytorch.org/docs/stable/tensors.html#torch.Tensor.copy_
grid.narrow(1, y * height + padding, height - padding).narrow( # type: ignore[attr-defined]
2, x * width + padding, width - padding
).copy_(tensor[k])
k = k + 1
return grid


def save_image(
tensor: Union[torch.Tensor, Sequence[torch.Tensor]],
tensor: Union[torch.Tensor, List[torch.Tensor]],
fp: Union[Text, pathlib.Path, BinaryIO],
nrow: int = 8,
padding: int = 2,
Expand Down

0 comments on commit a81d99b

Please sign in to comment.