diff --git a/.circleci/config.yml b/.circleci/config.yml index f1dabc9d7a9..9aee7eddc25 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -201,7 +201,7 @@ jobs: pip install --user --progress-bar off types-requests pip install --user --progress-bar off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html pip install --user --progress-bar off git+https://github.com/pytorch/data.git - pip install --user --progress-bar off --editable . + pip install --user --progress-bar off --no-build-isolation --editable . mypy --config-file mypy.ini docstring_parameters_sync: @@ -235,7 +235,7 @@ jobs: command: | pip install --user --progress-bar off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html # need to install torchvision dependencies due to transitive imports - pip install --user --progress-bar off . + pip install --user --progress-bar off --no-build-isolation . pip install pytest python test/test_hub.py @@ -248,7 +248,7 @@ jobs: command: | pip install --user --progress-bar off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html # need to install torchvision dependencies due to transitive imports - pip install --user --progress-bar off . + pip install --user --progress-bar off --no-build-isolation . pip install --user onnx pip install --user onnxruntime pip install --user pytest diff --git a/.circleci/config.yml.in b/.circleci/config.yml.in index c8dd8a5cbc7..7854d7497f0 100644 --- a/.circleci/config.yml.in +++ b/.circleci/config.yml.in @@ -201,7 +201,7 @@ jobs: pip install --user --progress-bar off types-requests pip install --user --progress-bar off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html pip install --user --progress-bar off git+https://github.com/pytorch/data.git - pip install --user --progress-bar off --editable . + pip install --user --progress-bar off --no-build-isolation --editable . mypy --config-file mypy.ini docstring_parameters_sync: @@ -235,7 +235,7 @@ jobs: command: | pip install --user --progress-bar off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html # need to install torchvision dependencies due to transitive imports - pip install --user --progress-bar off . + pip install --user --progress-bar off --no-build-isolation . pip install pytest python test/test_hub.py @@ -248,7 +248,7 @@ jobs: command: | pip install --user --progress-bar off --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html # need to install torchvision dependencies due to transitive imports - pip install --user --progress-bar off . + pip install --user --progress-bar off --no-build-isolation . pip install --user onnx pip install --user onnxruntime pip install --user pytest diff --git a/.circleci/regenerate.py b/.circleci/regenerate.py index 7e2fa25cb9d..3a1fac4bb86 100755 --- a/.circleci/regenerate.py +++ b/.circleci/regenerate.py @@ -14,10 +14,11 @@ https://github.com/pytorch/vision/pull/1321#issuecomment-531033978 """ +import os.path + import jinja2 -from jinja2 import select_autoescape import yaml -import os.path +from jinja2 import select_autoescape PYTHON_VERSIONS = ["3.6", "3.7", "3.8", "3.9"] @@ -25,57 +26,66 @@ RC_PATTERN = r"/v[0-9]+(\.[0-9]+)*-rc[0-9]+/" -def build_workflows(prefix='', filter_branch=None, upload=False, indentation=6, windows_latest_only=False): +def build_workflows(prefix="", filter_branch=None, upload=False, indentation=6, windows_latest_only=False): w = [] for btype in ["wheel", "conda"]: for os_type in ["linux", "macos", "win"]: python_versions = PYTHON_VERSIONS - cu_versions_dict = {"linux": ["cpu", "cu102", "cu111", "cu113", "rocm4.1", "rocm4.2"], - "win": ["cpu", "cu102", "cu111", "cu113"], - "macos": ["cpu"]} + cu_versions_dict = { + "linux": ["cpu", "cu102", "cu111", "cu113", "rocm4.1", "rocm4.2"], + "win": ["cpu", "cu102", "cu111", "cu113"], + "macos": ["cpu"], + } cu_versions = cu_versions_dict[os_type] for python_version in python_versions: for cu_version in cu_versions: # ROCm conda packages not yet supported - if cu_version.startswith('rocm') and btype == "conda": + if cu_version.startswith("rocm") and btype == "conda": continue for unicode in [False]: fb = filter_branch - if windows_latest_only and os_type == "win" and filter_branch is None and \ - (python_version != python_versions[-1] or - (cu_version not in [cu_versions[0], cu_versions[-1]])): + if ( + windows_latest_only + and os_type == "win" + and filter_branch is None + and ( + python_version != python_versions[-1] + or (cu_version not in [cu_versions[0], cu_versions[-1]]) + ) + ): fb = "main" - if not fb and (os_type == 'linux' and - cu_version == 'cpu' and - btype == 'wheel' and - python_version == '3.7'): + if not fb and ( + os_type == "linux" and cu_version == "cpu" and btype == "wheel" and python_version == "3.7" + ): # the fields must match the build_docs "requires" dependency fb = "/.*/" w += workflow_pair( - btype, os_type, python_version, cu_version, - unicode, prefix, upload, filter_branch=fb) + btype, os_type, python_version, cu_version, unicode, prefix, upload, filter_branch=fb + ) if not filter_branch: # Build on every pull request, but upload only on nightly and tags - w += build_doc_job('/.*/') - w += upload_doc_job('nightly') + w += build_doc_job("/.*/") + w += upload_doc_job("nightly") return indent(indentation, w) -def workflow_pair(btype, os_type, python_version, cu_version, unicode, prefix='', upload=False, *, filter_branch=None): +def workflow_pair(btype, os_type, python_version, cu_version, unicode, prefix="", upload=False, *, filter_branch=None): w = [] unicode_suffix = "u" if unicode else "" base_workflow_name = f"{prefix}binary_{os_type}_{btype}_py{python_version}{unicode_suffix}_{cu_version}" - w.append(generate_base_workflow( - base_workflow_name, python_version, cu_version, - unicode, os_type, btype, filter_branch=filter_branch)) + w.append( + generate_base_workflow( + base_workflow_name, python_version, cu_version, unicode, os_type, btype, filter_branch=filter_branch + ) + ) if upload: w.append(generate_upload_workflow(base_workflow_name, os_type, btype, cu_version, filter_branch=filter_branch)) - if filter_branch == 'nightly' and os_type in ['linux', 'win']: - pydistro = 'pip' if btype == 'wheel' else 'conda' + if filter_branch == "nightly" and os_type in ["linux", "win"]: + pydistro = "pip" if btype == "wheel" else "conda" w.append(generate_smoketest_workflow(pydistro, base_workflow_name, filter_branch, python_version, os_type)) return w @@ -85,12 +95,13 @@ def build_doc_job(filter_branch): job = { "name": "build_docs", "python_version": "3.7", - "requires": ["binary_linux_wheel_py3.7_cpu", ], + "requires": [ + "binary_linux_wheel_py3.7_cpu", + ], } if filter_branch: - job["filters"] = gen_filter_branch_tree(filter_branch, - tags_list=RC_PATTERN) + job["filters"] = gen_filter_branch_tree(filter_branch, tags_list=RC_PATTERN) return [{"build_docs": job}] @@ -99,12 +110,13 @@ def upload_doc_job(filter_branch): "name": "upload_docs", "context": "org-member", "python_version": "3.7", - "requires": ["build_docs", ], + "requires": [ + "build_docs", + ], } if filter_branch: - job["filters"] = gen_filter_branch_tree(filter_branch, - tags_list=RC_PATTERN) + job["filters"] = gen_filter_branch_tree(filter_branch, tags_list=RC_PATTERN) return [{"upload_docs": job}] @@ -122,24 +134,25 @@ def upload_doc_job(filter_branch): def get_manylinux_image(cu_version): if cu_version == "cpu": return "pytorch/manylinux-cuda102" - elif cu_version.startswith('cu'): - cu_suffix = cu_version[len('cu'):] + elif cu_version.startswith("cu"): + cu_suffix = cu_version[len("cu") :] return f"pytorch/manylinux-cuda{cu_suffix}" - elif cu_version.startswith('rocm'): - rocm_suffix = cu_version[len('rocm'):] + elif cu_version.startswith("rocm"): + rocm_suffix = cu_version[len("rocm") :] return f"pytorch/manylinux-rocm:{rocm_suffix}" def get_conda_image(cu_version): if cu_version == "cpu": return "pytorch/conda-builder:cpu" - elif cu_version.startswith('cu'): - cu_suffix = cu_version[len('cu'):] + elif cu_version.startswith("cu"): + cu_suffix = cu_version[len("cu") :] return f"pytorch/conda-builder:cuda{cu_suffix}" -def generate_base_workflow(base_workflow_name, python_version, cu_version, - unicode, os_type, btype, *, filter_branch=None): +def generate_base_workflow( + base_workflow_name, python_version, cu_version, unicode, os_type, btype, *, filter_branch=None +): d = { "name": base_workflow_name, @@ -148,7 +161,7 @@ def generate_base_workflow(base_workflow_name, python_version, cu_version, } if os_type != "win" and unicode: - d["unicode_abi"] = '1' + d["unicode_abi"] = "1" if os_type != "win": d["wheel_docker_image"] = get_manylinux_image(cu_version) @@ -158,14 +171,12 @@ def generate_base_workflow(base_workflow_name, python_version, cu_version, if filter_branch is not None: d["filters"] = { - "branches": { - "only": filter_branch - }, + "branches": {"only": filter_branch}, "tags": { # Using a raw string here to avoid having to escape # anything "only": r"/v[0-9]+(\.[0-9]+)*-rc[0-9]+/" - } + }, } w = f"binary_{os_type}_{btype}" @@ -186,19 +197,17 @@ def generate_upload_workflow(base_workflow_name, os_type, btype, cu_version, *, "requires": [base_workflow_name], } - if btype == 'wheel': - d["subfolder"] = "" if os_type == 'macos' else cu_version + "/" + if btype == "wheel": + d["subfolder"] = "" if os_type == "macos" else cu_version + "/" if filter_branch is not None: d["filters"] = { - "branches": { - "only": filter_branch - }, + "branches": {"only": filter_branch}, "tags": { # Using a raw string here to avoid having to escape # anything "only": r"/v[0-9]+(\.[0-9]+)*-rc[0-9]+/" - } + }, } return {f"binary_{btype}_upload": d} @@ -223,8 +232,7 @@ def generate_smoketest_workflow(pydistro, base_workflow_name, filter_branch, pyt def indent(indentation, data_list): - return ("\n" + " " * indentation).join( - yaml.dump(data_list, default_flow_style=False).splitlines()) + return ("\n" + " " * indentation).join(yaml.dump(data_list, default_flow_style=False).splitlines()) def unittest_workflows(indentation=6): @@ -239,12 +247,12 @@ def unittest_workflows(indentation=6): "python_version": python_version, } - if device_type == 'gpu': + if device_type == "gpu": if python_version != "3.8": - job['filters'] = gen_filter_branch_tree('main', 'nightly') - job['cu_version'] = 'cu102' + job["filters"] = gen_filter_branch_tree("main", "nightly") + job["cu_version"] = "cu102" else: - job['cu_version'] = 'cpu' + job["cu_version"] = "cpu" jobs.append({f"unittest_{os_type}_{device_type}": job}) @@ -253,20 +261,17 @@ def unittest_workflows(indentation=6): def cmake_workflows(indentation=6): jobs = [] - python_version = '3.8' - for os_type in ['linux', 'windows', 'macos']: + python_version = "3.8" + for os_type in ["linux", "windows", "macos"]: # Skip OSX CUDA - device_types = ['cpu', 'gpu'] if os_type != 'macos' else ['cpu'] + device_types = ["cpu", "gpu"] if os_type != "macos" else ["cpu"] for device in device_types: - job = { - 'name': f'cmake_{os_type}_{device}', - 'python_version': python_version - } + job = {"name": f"cmake_{os_type}_{device}", "python_version": python_version} - job['cu_version'] = 'cu102' if device == 'gpu' else 'cpu' - if device == 'gpu' and os_type == 'linux': - job['wheel_docker_image'] = 'pytorch/manylinux-cuda102' - jobs.append({f'cmake_{os_type}_{device}': job}) + job["cu_version"] = "cu102" if device == "gpu" else "cpu" + if device == "gpu" and os_type == "linux": + job["wheel_docker_image"] = "pytorch/manylinux-cuda102" + jobs.append({f"cmake_{os_type}_{device}": job}) return indent(indentation, jobs) @@ -275,27 +280,27 @@ def ios_workflows(indentation=6, nightly=False): build_job_names = [] name_prefix = "nightly_" if nightly else "" env_prefix = "nightly-" if nightly else "" - for arch, platform in [('x86_64', 'SIMULATOR'), ('arm64', 'OS')]: - name = f'{name_prefix}binary_libtorchvision_ops_ios_12.0.0_{arch}' + for arch, platform in [("x86_64", "SIMULATOR"), ("arm64", "OS")]: + name = f"{name_prefix}binary_libtorchvision_ops_ios_12.0.0_{arch}" build_job_names.append(name) build_job = { - 'build_environment': f'{env_prefix}binary-libtorchvision_ops-ios-12.0.0-{arch}', - 'ios_arch': arch, - 'ios_platform': platform, - 'name': name, + "build_environment": f"{env_prefix}binary-libtorchvision_ops-ios-12.0.0-{arch}", + "ios_arch": arch, + "ios_platform": platform, + "name": name, } if nightly: - build_job['filters'] = gen_filter_branch_tree('nightly') - jobs.append({'binary_ios_build': build_job}) + build_job["filters"] = gen_filter_branch_tree("nightly") + jobs.append({"binary_ios_build": build_job}) if nightly: upload_job = { - 'build_environment': f'{env_prefix}binary-libtorchvision_ops-ios-12.0.0-upload', - 'context': 'org-member', - 'filters': gen_filter_branch_tree('nightly'), - 'requires': build_job_names, + "build_environment": f"{env_prefix}binary-libtorchvision_ops-ios-12.0.0-upload", + "context": "org-member", + "filters": gen_filter_branch_tree("nightly"), + "requires": build_job_names, } - jobs.append({'binary_ios_upload': upload_job}) + jobs.append({"binary_ios_upload": upload_job}) return indent(indentation, jobs) @@ -305,23 +310,23 @@ def android_workflows(indentation=6, nightly=False): name_prefix = "nightly_" if nightly else "" env_prefix = "nightly-" if nightly else "" - name = f'{name_prefix}binary_libtorchvision_ops_android' + name = f"{name_prefix}binary_libtorchvision_ops_android" build_job_names.append(name) build_job = { - 'build_environment': f'{env_prefix}binary-libtorchvision_ops-android', - 'name': name, + "build_environment": f"{env_prefix}binary-libtorchvision_ops-android", + "name": name, } if nightly: upload_job = { - 'build_environment': f'{env_prefix}binary-libtorchvision_ops-android-upload', - 'context': 'org-member', - 'filters': gen_filter_branch_tree('nightly'), - 'name': f'{name_prefix}binary_libtorchvision_ops_android_upload' + "build_environment": f"{env_prefix}binary-libtorchvision_ops-android-upload", + "context": "org-member", + "filters": gen_filter_branch_tree("nightly"), + "name": f"{name_prefix}binary_libtorchvision_ops_android_upload", } - jobs.append({'binary_android_upload': upload_job}) + jobs.append({"binary_android_upload": upload_job}) else: - jobs.append({'binary_android_build': build_job}) + jobs.append({"binary_android_build": build_job}) return indent(indentation, jobs) @@ -330,15 +335,17 @@ def android_workflows(indentation=6, nightly=False): env = jinja2.Environment( loader=jinja2.FileSystemLoader(d), lstrip_blocks=True, - autoescape=select_autoescape(enabled_extensions=('html', 'xml')), + autoescape=select_autoescape(enabled_extensions=("html", "xml")), keep_trailing_newline=True, ) - with open(os.path.join(d, 'config.yml'), 'w') as f: - f.write(env.get_template('config.yml.in').render( - build_workflows=build_workflows, - unittest_workflows=unittest_workflows, - cmake_workflows=cmake_workflows, - ios_workflows=ios_workflows, - android_workflows=android_workflows, - )) + with open(os.path.join(d, "config.yml"), "w") as f: + f.write( + env.get_template("config.yml.in").render( + build_workflows=build_workflows, + unittest_workflows=unittest_workflows, + cmake_workflows=cmake_workflows, + ios_workflows=ios_workflows, + android_workflows=android_workflows, + ) + ) diff --git a/.circleci/unittest/linux/scripts/run-clang-format.py b/.circleci/unittest/linux/scripts/run-clang-format.py index fad1dc57e56..b72246a19b5 100755 --- a/.circleci/unittest/linux/scripts/run-clang-format.py +++ b/.circleci/unittest/linux/scripts/run-clang-format.py @@ -42,7 +42,6 @@ import subprocess import sys import traceback - from functools import partial try: @@ -51,7 +50,7 @@ DEVNULL = open(os.devnull, "wb") -DEFAULT_EXTENSIONS = 'c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu' +DEFAULT_EXTENSIONS = "c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu" class ExitStatus: @@ -75,14 +74,8 @@ def list_files(files, recursive=False, extensions=None, exclude=None): # os.walk() supports trimming down the dnames list # by modifying it in-place, # to avoid unnecessary directory listings. - dnames[:] = [ - x for x in dnames - if - not fnmatch.fnmatch(os.path.join(dirpath, x), pattern) - ] - fpaths = [ - x for x in fpaths if not fnmatch.fnmatch(x, pattern) - ] + dnames[:] = [x for x in dnames if not fnmatch.fnmatch(os.path.join(dirpath, x), pattern)] + fpaths = [x for x in fpaths if not fnmatch.fnmatch(x, pattern)] for f in fpaths: ext = os.path.splitext(f)[1][1:] if ext in extensions: @@ -95,11 +88,9 @@ def list_files(files, recursive=False, extensions=None, exclude=None): def make_diff(file, original, reformatted): return list( difflib.unified_diff( - original, - reformatted, - fromfile='{}\t(original)'.format(file), - tofile='{}\t(reformatted)'.format(file), - n=3)) + original, reformatted, fromfile="{}\t(original)".format(file), tofile="{}\t(reformatted)".format(file), n=3 + ) + ) class DiffError(Exception): @@ -122,13 +113,12 @@ def run_clang_format_diff_wrapper(args, file): except DiffError: raise except Exception as e: - raise UnexpectedError('{}: {}: {}'.format(file, e.__class__.__name__, - e), e) + raise UnexpectedError("{}: {}: {}".format(file, e.__class__.__name__, e), e) def run_clang_format_diff(args, file): try: - with io.open(file, 'r', encoding='utf-8') as f: + with io.open(file, "r", encoding="utf-8") as f: original = f.readlines() except IOError as exc: raise DiffError(str(exc)) @@ -153,17 +143,10 @@ def run_clang_format_diff(args, file): try: proc = subprocess.Popen( - invocation, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - universal_newlines=True, - encoding='utf-8') - except OSError as exc: - raise DiffError( - "Command '{}' failed to start: {}".format( - subprocess.list2cmdline(invocation), exc - ) + invocation, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True, encoding="utf-8" ) + except OSError as exc: + raise DiffError("Command '{}' failed to start: {}".format(subprocess.list2cmdline(invocation), exc)) proc_stdout = proc.stdout proc_stderr = proc.stderr @@ -182,30 +165,30 @@ def run_clang_format_diff(args, file): def bold_red(s): - return '\x1b[1m\x1b[31m' + s + '\x1b[0m' + return "\x1b[1m\x1b[31m" + s + "\x1b[0m" def colorize(diff_lines): def bold(s): - return '\x1b[1m' + s + '\x1b[0m' + return "\x1b[1m" + s + "\x1b[0m" def cyan(s): - return '\x1b[36m' + s + '\x1b[0m' + return "\x1b[36m" + s + "\x1b[0m" def green(s): - return '\x1b[32m' + s + '\x1b[0m' + return "\x1b[32m" + s + "\x1b[0m" def red(s): - return '\x1b[31m' + s + '\x1b[0m' + return "\x1b[31m" + s + "\x1b[0m" for line in diff_lines: - if line[:4] in ['--- ', '+++ ']: + if line[:4] in ["--- ", "+++ "]: yield bold(line) - elif line.startswith('@@ '): + elif line.startswith("@@ "): yield cyan(line) - elif line.startswith('+'): + elif line.startswith("+"): yield green(line) - elif line.startswith('-'): + elif line.startswith("-"): yield red(line) else: yield line @@ -218,7 +201,7 @@ def print_diff(diff_lines, use_color): def print_trouble(prog, message, use_colors): - error_text = 'error:' + error_text = "error:" if use_colors: error_text = bold_red(error_text) print("{}: {} {}".format(prog, error_text, message), file=sys.stderr) @@ -227,45 +210,37 @@ def print_trouble(prog, message, use_colors): def main(): parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( - '--clang-format-executable', - metavar='EXECUTABLE', - help='path to the clang-format executable', - default='clang-format') - parser.add_argument( - '--extensions', - help='comma separated list of file extensions (default: {})'.format( - DEFAULT_EXTENSIONS), - default=DEFAULT_EXTENSIONS) + "--clang-format-executable", + metavar="EXECUTABLE", + help="path to the clang-format executable", + default="clang-format", + ) parser.add_argument( - '-r', - '--recursive', - action='store_true', - help='run recursively over directories') - parser.add_argument('files', metavar='file', nargs='+') + "--extensions", + help="comma separated list of file extensions (default: {})".format(DEFAULT_EXTENSIONS), + default=DEFAULT_EXTENSIONS, + ) + parser.add_argument("-r", "--recursive", action="store_true", help="run recursively over directories") + parser.add_argument("files", metavar="file", nargs="+") + parser.add_argument("-q", "--quiet", action="store_true") parser.add_argument( - '-q', - '--quiet', - action='store_true') - parser.add_argument( - '-j', - metavar='N', + "-j", + metavar="N", type=int, default=0, - help='run N clang-format jobs in parallel' - ' (default number of cpus + 1)') + help="run N clang-format jobs in parallel" " (default number of cpus + 1)", + ) parser.add_argument( - '--color', - default='auto', - choices=['auto', 'always', 'never'], - help='show colored diff (default: auto)') + "--color", default="auto", choices=["auto", "always", "never"], help="show colored diff (default: auto)" + ) parser.add_argument( - '-e', - '--exclude', - metavar='PATTERN', - action='append', + "-e", + "--exclude", + metavar="PATTERN", + action="append", default=[], - help='exclude paths matching the given glob-like pattern(s)' - ' from recursive search') + help="exclude paths matching the given glob-like pattern(s)" " from recursive search", + ) args = parser.parse_args() @@ -282,10 +257,10 @@ def main(): colored_stdout = False colored_stderr = False - if args.color == 'always': + if args.color == "always": colored_stdout = True colored_stderr = True - elif args.color == 'auto': + elif args.color == "auto": colored_stdout = sys.stdout.isatty() colored_stderr = sys.stderr.isatty() @@ -298,19 +273,15 @@ def main(): except OSError as e: print_trouble( parser.prog, - "Command '{}' failed to start: {}".format( - subprocess.list2cmdline(version_invocation), e - ), + "Command '{}' failed to start: {}".format(subprocess.list2cmdline(version_invocation), e), use_colors=colored_stderr, ) return ExitStatus.TROUBLE retcode = ExitStatus.SUCCESS files = list_files( - args.files, - recursive=args.recursive, - exclude=args.exclude, - extensions=args.extensions.split(',')) + args.files, recursive=args.recursive, exclude=args.exclude, extensions=args.extensions.split(",") + ) if not files: return @@ -327,8 +298,7 @@ def main(): pool = None else: pool = multiprocessing.Pool(njobs) - it = pool.imap_unordered( - partial(run_clang_format_diff_wrapper, args), files) + it = pool.imap_unordered(partial(run_clang_format_diff_wrapper, args), files) while True: try: outs, errs = next(it) @@ -359,5 +329,5 @@ def main(): return retcode -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(main()) diff --git a/.github/workflows/tests-schedule.yml b/.github/workflows/tests-schedule.yml index 65f805ce471..f4ac3a2fbdd 100644 --- a/.github/workflows/tests-schedule.yml +++ b/.github/workflows/tests-schedule.yml @@ -20,8 +20,8 @@ jobs: with: python-version: 3.6 - - name: Upgrade pip - run: python -m pip install --upgrade pip + - name: Upgrade system packages + run: python -m pip install --upgrade pip setuptools wheel - name: Checkout repository uses: actions/checkout@v2 @@ -30,7 +30,7 @@ jobs: run: pip install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html - name: Install torchvision - run: pip install -e . + run: pip install --no-build-isolation --editable . - name: Install all optional dataset requirements run: pip install scipy pandas pycocotools lmdb requests diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 86887d3b2ef..0024d0243d2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,4 +1,11 @@ repos: + - repo: https://github.com/omnilib/ufmt + rev: v1.3.0 + hooks: + - id: ufmt + additional_dependencies: + - black == 21.9b0 + - usort == 0.6.4 - repo: https://gitlab.com/pycqa/flake8 rev: 3.9.2 hooks: diff --git a/android/test_app/make_assets.py b/android/test_app/make_assets.py index 7860c759a57..fedee39fc52 100644 --- a/android/test_app/make_assets.py +++ b/android/test_app/make_assets.py @@ -5,11 +5,8 @@ print(torch.__version__) model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn( - pretrained=True, - box_score_thresh=0.7, - rpn_post_nms_top_n_test=100, - rpn_score_thresh=0.4, - rpn_pre_nms_top_n_test=150) + pretrained=True, box_score_thresh=0.7, rpn_post_nms_top_n_test=100, rpn_score_thresh=0.4, rpn_pre_nms_top_n_test=150 +) model.eval() script_model = torch.jit.script(model) diff --git a/docs/source/conf.py b/docs/source/conf.py index e8e17edf283..4c2f3faec75 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -21,8 +21,8 @@ # import sys # sys.path.insert(0, os.path.abspath('.')) -import torchvision import pytorch_sphinx_theme +import torchvision # -- General configuration ------------------------------------------------ @@ -33,24 +33,24 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.autosummary', - 'sphinx.ext.doctest', - 'sphinx.ext.intersphinx', - 'sphinx.ext.todo', - 'sphinx.ext.mathjax', - 'sphinx.ext.napoleon', - 'sphinx.ext.viewcode', - 'sphinx.ext.duration', - 'sphinx_gallery.gen_gallery', - 'sphinx_copybutton', + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.doctest", + "sphinx.ext.intersphinx", + "sphinx.ext.todo", + "sphinx.ext.mathjax", + "sphinx.ext.napoleon", + "sphinx.ext.viewcode", + "sphinx.ext.duration", + "sphinx_gallery.gen_gallery", + "sphinx_copybutton", ] sphinx_gallery_conf = { - 'examples_dirs': '../../gallery/', # path to your example scripts - 'gallery_dirs': 'auto_examples', # path to where to save gallery generated output - 'backreferences_dir': 'gen_modules/backreferences', - 'doc_module': ('torchvision',), + "examples_dirs": "../../gallery/", # path to your example scripts + "gallery_dirs": "auto_examples", # path to where to save gallery generated output + "backreferences_dir": "gen_modules/backreferences", + "doc_module": ("torchvision",), } napoleon_use_ivar = True @@ -59,22 +59,22 @@ # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # source_suffix = { - '.rst': 'restructuredtext', + ".rst": "restructuredtext", } # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = 'Torchvision' -copyright = '2017-present, Torch Contributors' -author = 'Torch Contributors' +project = "Torchvision" +copyright = "2017-present, Torch Contributors" +author = "Torch Contributors" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the @@ -82,10 +82,10 @@ # # The short X.Y version. # TODO: change to [:2] at v1.0 -version = 'main (' + torchvision.__version__ + ' )' +version = "main (" + torchvision.__version__ + " )" # The full version, including alpha/beta/rc tags. # TODO: verify this works as expected -release = 'main' +release = "main" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -100,7 +100,7 @@ exclude_patterns = [] # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = True @@ -111,7 +111,7 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'pytorch_sphinx_theme' +html_theme = "pytorch_sphinx_theme" html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()] # Theme options are theme-specific and customize the look and feel of a theme @@ -119,30 +119,30 @@ # documentation. # html_theme_options = { - 'collapse_navigation': False, - 'display_version': True, - 'logo_only': True, - 'pytorch_project': 'docs', - 'navigation_with_keys': True, - 'analytics_id': 'UA-117752657-2', + "collapse_navigation": False, + "display_version": True, + "logo_only": True, + "pytorch_project": "docs", + "navigation_with_keys": True, + "analytics_id": "UA-117752657-2", } -html_logo = '_static/img/pytorch-logo-dark.svg' +html_logo = "_static/img/pytorch-logo-dark.svg" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # TODO: remove this once https://github.com/pytorch/pytorch_sphinx_theme/issues/125 is fixed html_css_files = [ - 'css/custom_torchvision.css', + "css/custom_torchvision.css", ] # -- Options for HTMLHelp output ------------------------------------------ # Output file base name for HTML help builder. -htmlhelp_basename = 'PyTorchdoc' +htmlhelp_basename = "PyTorchdoc" # -- Options for LaTeX output --------------------------------------------- @@ -150,15 +150,12 @@ # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', - # The font size ('10pt', '11pt' or '12pt'). # # 'pointsize': '10pt', - # Additional stuff for the LaTeX preamble. # # 'preamble': '', - # Latex figure (float) alignment # # 'figure_align': 'htbp', @@ -169,8 +166,7 @@ # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'pytorch.tex', 'torchvision Documentation', - 'Torch Contributors', 'manual'), + (master_doc, "pytorch.tex", "torchvision Documentation", "Torch Contributors", "manual"), ] @@ -178,10 +174,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, 'torchvision', 'torchvision Documentation', - [author], 1) -] +man_pages = [(master_doc, "torchvision", "torchvision Documentation", [author], 1)] # -- Options for Texinfo output ------------------------------------------- @@ -190,27 +183,33 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'torchvision', 'torchvision Documentation', - author, 'torchvision', 'One line description of project.', - 'Miscellaneous'), + ( + master_doc, + "torchvision", + "torchvision Documentation", + author, + "torchvision", + "One line description of project.", + "Miscellaneous", + ), ] # Example configuration for intersphinx: refer to the Python standard library. intersphinx_mapping = { - 'python': ('https://docs.python.org/', None), - 'torch': ('https://pytorch.org/docs/stable/', None), - 'numpy': ('http://docs.scipy.org/doc/numpy/', None), - 'PIL': ('https://pillow.readthedocs.io/en/stable/', None), - 'matplotlib': ('https://matplotlib.org/stable/', None), + "python": ("https://docs.python.org/", None), + "torch": ("https://pytorch.org/docs/stable/", None), + "numpy": ("http://docs.scipy.org/doc/numpy/", None), + "PIL": ("https://pillow.readthedocs.io/en/stable/", None), + "matplotlib": ("https://matplotlib.org/stable/", None), } # -- A patch that prevents Sphinx from cross-referencing ivar tags ------- # See http://stackoverflow.com/a/41184353/3343043 from docutils import nodes -from sphinx.util.docfields import TypedField from sphinx import addnodes +from sphinx.util.docfields import TypedField def patched_make_field(self, types, domain, items, **kw): @@ -220,40 +219,39 @@ def patched_make_field(self, types, domain, items, **kw): # type: (list, unicode, tuple) -> nodes.field # noqa: F821 def handle_item(fieldarg, content): par = nodes.paragraph() - par += addnodes.literal_strong('', fieldarg) # Patch: this line added + par += addnodes.literal_strong("", fieldarg) # Patch: this line added # par.extend(self.make_xrefs(self.rolename, domain, fieldarg, # addnodes.literal_strong)) if fieldarg in types: - par += nodes.Text(' (') + par += nodes.Text(" (") # NOTE: using .pop() here to prevent a single type node to be # inserted twice into the doctree, which leads to # inconsistencies later when references are resolved fieldtype = types.pop(fieldarg) if len(fieldtype) == 1 and isinstance(fieldtype[0], nodes.Text): - typename = u''.join(n.astext() for n in fieldtype) - typename = typename.replace('int', 'python:int') - typename = typename.replace('long', 'python:long') - typename = typename.replace('float', 'python:float') - typename = typename.replace('type', 'python:type') - par.extend(self.make_xrefs(self.typerolename, domain, typename, - addnodes.literal_emphasis, **kw)) + typename = "".join(n.astext() for n in fieldtype) + typename = typename.replace("int", "python:int") + typename = typename.replace("long", "python:long") + typename = typename.replace("float", "python:float") + typename = typename.replace("type", "python:type") + par.extend(self.make_xrefs(self.typerolename, domain, typename, addnodes.literal_emphasis, **kw)) else: par += fieldtype - par += nodes.Text(')') - par += nodes.Text(' -- ') + par += nodes.Text(")") + par += nodes.Text(" -- ") par += content return par - fieldname = nodes.field_name('', self.label) + fieldname = nodes.field_name("", self.label) if len(items) == 1 and self.can_collapse: fieldarg, content = items[0] bodynode = handle_item(fieldarg, content) else: bodynode = self.list_type() for fieldarg, content in items: - bodynode += nodes.list_item('', handle_item(fieldarg, content)) - fieldbody = nodes.field_body('', bodynode) - return nodes.field('', fieldname, fieldbody) + bodynode += nodes.list_item("", handle_item(fieldarg, content)) + fieldbody = nodes.field_body("", bodynode) + return nodes.field("", fieldname, fieldbody) TypedField.make_field = patched_make_field @@ -286,4 +284,4 @@ def inject_minigalleries(app, what, name, obj, options, lines): def setup(app): - app.connect('autodoc-process-docstring', inject_minigalleries) + app.connect("autodoc-process-docstring", inject_minigalleries) diff --git a/hubconf.py b/hubconf.py index 8412e9e6e6b..1e79a89f426 100644 --- a/hubconf.py +++ b/hubconf.py @@ -1,27 +1,61 @@ # Optional list of dependencies required by the package -dependencies = ['torch'] +dependencies = ["torch"] # classification from torchvision.models.alexnet import alexnet from torchvision.models.densenet import densenet121, densenet169, densenet201, densenet161 -from torchvision.models.inception import inception_v3 -from torchvision.models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152,\ - resnext50_32x4d, resnext101_32x8d, wide_resnet50_2, wide_resnet101_2 -from torchvision.models.squeezenet import squeezenet1_0, squeezenet1_1 -from torchvision.models.vgg import vgg11, vgg13, vgg16, vgg19, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn +from torchvision.models.efficientnet import ( + efficientnet_b0, + efficientnet_b1, + efficientnet_b2, + efficientnet_b3, + efficientnet_b4, + efficientnet_b5, + efficientnet_b6, + efficientnet_b7, +) from torchvision.models.googlenet import googlenet -from torchvision.models.shufflenetv2 import shufflenet_v2_x0_5, shufflenet_v2_x1_0 +from torchvision.models.inception import inception_v3 +from torchvision.models.mnasnet import mnasnet0_5, mnasnet0_75, mnasnet1_0, mnasnet1_3 from torchvision.models.mobilenetv2 import mobilenet_v2 from torchvision.models.mobilenetv3 import mobilenet_v3_large, mobilenet_v3_small -from torchvision.models.mnasnet import mnasnet0_5, mnasnet0_75, mnasnet1_0, \ - mnasnet1_3 -from torchvision.models.efficientnet import efficientnet_b0, efficientnet_b1, efficientnet_b2, \ - efficientnet_b3, efficientnet_b4, efficientnet_b5, efficientnet_b6, efficientnet_b7 -from torchvision.models.regnet import regnet_y_400mf, regnet_y_800mf, \ - regnet_y_1_6gf, regnet_y_3_2gf, regnet_y_8gf, regnet_y_16gf, regnet_y_32gf, \ - regnet_x_400mf, regnet_x_800mf, regnet_x_1_6gf, regnet_x_3_2gf, regnet_x_8gf, \ - regnet_x_16gf, regnet_x_32gf +from torchvision.models.regnet import ( + regnet_y_400mf, + regnet_y_800mf, + regnet_y_1_6gf, + regnet_y_3_2gf, + regnet_y_8gf, + regnet_y_16gf, + regnet_y_32gf, + regnet_x_400mf, + regnet_x_800mf, + regnet_x_1_6gf, + regnet_x_3_2gf, + regnet_x_8gf, + regnet_x_16gf, + regnet_x_32gf, +) +from torchvision.models.resnet import ( + resnet18, + resnet34, + resnet50, + resnet101, + resnet152, + resnext50_32x4d, + resnext101_32x8d, + wide_resnet50_2, + wide_resnet101_2, +) # segmentation -from torchvision.models.segmentation import fcn_resnet50, fcn_resnet101, \ - deeplabv3_resnet50, deeplabv3_resnet101, deeplabv3_mobilenet_v3_large, lraspp_mobilenet_v3_large +from torchvision.models.segmentation import ( + fcn_resnet50, + fcn_resnet101, + deeplabv3_resnet50, + deeplabv3_resnet101, + deeplabv3_mobilenet_v3_large, + lraspp_mobilenet_v3_large, +) +from torchvision.models.shufflenetv2 import shufflenet_v2_x0_5, shufflenet_v2_x1_0 +from torchvision.models.squeezenet import squeezenet1_0, squeezenet1_1 +from torchvision.models.vgg import vgg11, vgg13, vgg16, vgg19, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn diff --git a/ios/VisionTestApp/make_assets.py b/ios/VisionTestApp/make_assets.py index 122094b3547..0f46364569b 100644 --- a/ios/VisionTestApp/make_assets.py +++ b/ios/VisionTestApp/make_assets.py @@ -5,11 +5,8 @@ print(torch.__version__) model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn( - pretrained=True, - box_score_thresh=0.7, - rpn_post_nms_top_n_test=100, - rpn_score_thresh=0.4, - rpn_pre_nms_top_n_test=150) + pretrained=True, box_score_thresh=0.7, rpn_post_nms_top_n_test=100, rpn_score_thresh=0.4, rpn_pre_nms_top_n_test=150 +) model.eval() script_model = torch.jit.script(model) diff --git a/packaging/wheel/relocate.py b/packaging/wheel/relocate.py index dd2c5d2a4ce..3dde79f08b4 100644 --- a/packaging/wheel/relocate.py +++ b/packaging/wheel/relocate.py @@ -2,46 +2,63 @@ """Helper script to package wheels and relocate binaries.""" -# Standard library imports -import os -import io -import sys import glob -import shutil -import zipfile import hashlib +import io + +# Standard library imports +import os +import os.path as osp import platform +import shutil import subprocess -import os.path as osp +import sys +import zipfile from base64 import urlsafe_b64encode # Third party imports -if sys.platform == 'linux': +if sys.platform == "linux": from auditwheel.lddtree import lddtree from wheel.bdist_wheel import get_abi_tag ALLOWLIST = { - 'libgcc_s.so.1', 'libstdc++.so.6', 'libm.so.6', - 'libdl.so.2', 'librt.so.1', 'libc.so.6', - 'libnsl.so.1', 'libutil.so.1', 'libpthread.so.0', - 'libresolv.so.2', 'libX11.so.6', 'libXext.so.6', - 'libXrender.so.1', 'libICE.so.6', 'libSM.so.6', - 'libGL.so.1', 'libgobject-2.0.so.0', 'libgthread-2.0.so.0', - 'libglib-2.0.so.0', 'ld-linux-x86-64.so.2', 'ld-2.17.so' + "libgcc_s.so.1", + "libstdc++.so.6", + "libm.so.6", + "libdl.so.2", + "librt.so.1", + "libc.so.6", + "libnsl.so.1", + "libutil.so.1", + "libpthread.so.0", + "libresolv.so.2", + "libX11.so.6", + "libXext.so.6", + "libXrender.so.1", + "libICE.so.6", + "libSM.so.6", + "libGL.so.1", + "libgobject-2.0.so.0", + "libgthread-2.0.so.0", + "libglib-2.0.so.0", + "ld-linux-x86-64.so.2", + "ld-2.17.so", } WINDOWS_ALLOWLIST = { - 'MSVCP140.dll', 'KERNEL32.dll', - 'VCRUNTIME140_1.dll', 'VCRUNTIME140.dll', - 'api-ms-win-crt-heap-l1-1-0.dll', - 'api-ms-win-crt-runtime-l1-1-0.dll', - 'api-ms-win-crt-stdio-l1-1-0.dll', - 'api-ms-win-crt-filesystem-l1-1-0.dll', - 'api-ms-win-crt-string-l1-1-0.dll', - 'api-ms-win-crt-environment-l1-1-0.dll', - 'api-ms-win-crt-math-l1-1-0.dll', - 'api-ms-win-crt-convert-l1-1-0.dll' + "MSVCP140.dll", + "KERNEL32.dll", + "VCRUNTIME140_1.dll", + "VCRUNTIME140.dll", + "api-ms-win-crt-heap-l1-1-0.dll", + "api-ms-win-crt-runtime-l1-1-0.dll", + "api-ms-win-crt-stdio-l1-1-0.dll", + "api-ms-win-crt-filesystem-l1-1-0.dll", + "api-ms-win-crt-string-l1-1-0.dll", + "api-ms-win-crt-environment-l1-1-0.dll", + "api-ms-win-crt-math-l1-1-0.dll", + "api-ms-win-crt-convert-l1-1-0.dll", } @@ -64,20 +81,18 @@ def rehash(path, blocksize=1 << 20): """Return (hash, length) for path using hashlib.sha256()""" h = hashlib.sha256() length = 0 - with open(path, 'rb') as f: + with open(path, "rb") as f: for block in read_chunks(f, size=blocksize): length += len(block) h.update(block) - digest = 'sha256=' + urlsafe_b64encode( - h.digest() - ).decode('latin1').rstrip('=') + digest = "sha256=" + urlsafe_b64encode(h.digest()).decode("latin1").rstrip("=") # unicode/str python2 issues return (digest, str(length)) # type: ignore def unzip_file(file, dest): """Decompress zip `file` into directory `dest`.""" - with zipfile.ZipFile(file, 'r') as zip_ref: + with zipfile.ZipFile(file, "r") as zip_ref: zip_ref.extractall(dest) @@ -88,8 +103,7 @@ def is_program_installed(basename): On macOS systems, a .app is considered installed if it exists. """ - if (sys.platform == 'darwin' and basename.endswith('.app') and - osp.exists(basename)): + if sys.platform == "darwin" and basename.endswith(".app") and osp.exists(basename): return basename for path in os.environ["PATH"].split(os.pathsep): @@ -105,9 +119,9 @@ def find_program(basename): (return None if not found) """ names = [basename] - if os.name == 'nt': + if os.name == "nt": # Windows platforms - extensions = ('.exe', '.bat', '.cmd', '.dll') + extensions = (".exe", ".bat", ".cmd", ".dll") if not basename.endswith(extensions): names = [basename + ext for ext in extensions] + [basename] for name in names: @@ -118,19 +132,18 @@ def find_program(basename): def patch_new_path(library_path, new_dir): library = osp.basename(library_path) - name, *rest = library.split('.') - rest = '.'.join(rest) - hash_id = hashlib.sha256(library_path.encode('utf-8')).hexdigest()[:8] - new_name = '.'.join([name, hash_id, rest]) + name, *rest = library.split(".") + rest = ".".join(rest) + hash_id = hashlib.sha256(library_path.encode("utf-8")).hexdigest()[:8] + new_name = ".".join([name, hash_id, rest]) return osp.join(new_dir, new_name) def find_dll_dependencies(dumpbin, binary): - out = subprocess.run([dumpbin, "/dependents", binary], - stdout=subprocess.PIPE) - out = out.stdout.strip().decode('utf-8') - start_index = out.find('dependencies:') + len('dependencies:') - end_index = out.find('Summary') + out = subprocess.run([dumpbin, "/dependents", binary], stdout=subprocess.PIPE) + out = out.stdout.strip().decode("utf-8") + start_index = out.find("dependencies:") + len("dependencies:") + end_index = out.find("Summary") dlls = out[start_index:end_index].strip() dlls = dlls.split(os.linesep) dlls = [dll.strip() for dll in dlls] @@ -145,13 +158,13 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary): rename and copy them into the wheel while updating their respective rpaths. """ - print('Relocating {0}'.format(binary)) + print("Relocating {0}".format(binary)) binary_path = osp.join(output_library, binary) ld_tree = lddtree(binary_path) - tree_libs = ld_tree['libs'] + tree_libs = ld_tree["libs"] - binary_queue = [(n, binary) for n in ld_tree['needed']] + binary_queue = [(n, binary) for n in ld_tree["needed"]] binary_paths = {binary: binary_path} binary_dependencies = {} @@ -160,13 +173,13 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary): library_info = tree_libs[library] print(library) - if library_info['path'] is None: - print('Omitting {0}'.format(library)) + if library_info["path"] is None: + print("Omitting {0}".format(library)) continue if library in ALLOWLIST: # Omit glibc/gcc/system libraries - print('Omitting {0}'.format(library)) + print("Omitting {0}".format(library)) continue parent_dependencies = binary_dependencies.get(parent, []) @@ -176,11 +189,11 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary): if library in binary_paths: continue - binary_paths[library] = library_info['path'] - binary_queue += [(n, library) for n in library_info['needed']] + binary_paths[library] = library_info["path"] + binary_queue += [(n, library) for n in library_info["needed"]] - print('Copying dependencies to wheel directory') - new_libraries_path = osp.join(output_dir, 'torchvision.libs') + print("Copying dependencies to wheel directory") + new_libraries_path = osp.join(output_dir, "torchvision.libs") os.makedirs(new_libraries_path) new_names = {binary: binary_path} @@ -189,11 +202,11 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary): if library != binary: library_path = binary_paths[library] new_library_path = patch_new_path(library_path, new_libraries_path) - print('{0} -> {1}'.format(library, new_library_path)) + print("{0} -> {1}".format(library, new_library_path)) shutil.copyfile(library_path, new_library_path) new_names[library] = new_library_path - print('Updating dependency names by new files') + print("Updating dependency names by new files") for library in binary_paths: if library != binary: if library not in binary_dependencies: @@ -202,59 +215,26 @@ def relocate_elf_library(patchelf, output_dir, output_library, binary): new_library_name = new_names[library] for dep in library_dependencies: new_dep = osp.basename(new_names[dep]) - print('{0}: {1} -> {2}'.format(library, dep, new_dep)) + print("{0}: {1} -> {2}".format(library, dep, new_dep)) subprocess.check_output( - [ - patchelf, - '--replace-needed', - dep, - new_dep, - new_library_name - ], - cwd=new_libraries_path) - - print('Updating library rpath') - subprocess.check_output( - [ - patchelf, - '--set-rpath', - "$ORIGIN", - new_library_name - ], - cwd=new_libraries_path) - - subprocess.check_output( - [ - patchelf, - '--print-rpath', - new_library_name - ], - cwd=new_libraries_path) + [patchelf, "--replace-needed", dep, new_dep, new_library_name], cwd=new_libraries_path + ) + + print("Updating library rpath") + subprocess.check_output([patchelf, "--set-rpath", "$ORIGIN", new_library_name], cwd=new_libraries_path) + + subprocess.check_output([patchelf, "--print-rpath", new_library_name], cwd=new_libraries_path) print("Update library dependencies") library_dependencies = binary_dependencies[binary] for dep in library_dependencies: new_dep = osp.basename(new_names[dep]) - print('{0}: {1} -> {2}'.format(binary, dep, new_dep)) - subprocess.check_output( - [ - patchelf, - '--replace-needed', - dep, - new_dep, - binary - ], - cwd=output_library) - - print('Update library rpath') + print("{0}: {1} -> {2}".format(binary, dep, new_dep)) + subprocess.check_output([patchelf, "--replace-needed", dep, new_dep, binary], cwd=output_library) + + print("Update library rpath") subprocess.check_output( - [ - patchelf, - '--set-rpath', - "$ORIGIN:$ORIGIN/../torchvision.libs", - binary_path - ], - cwd=output_library + [patchelf, "--set-rpath", "$ORIGIN:$ORIGIN/../torchvision.libs", binary_path], cwd=output_library ) @@ -265,7 +245,7 @@ def relocate_dll_library(dumpbin, output_dir, output_library, binary): Given a shared library, find the transitive closure of its dependencies, rename and copy them into the wheel. """ - print('Relocating {0}'.format(binary)) + print("Relocating {0}".format(binary)) binary_path = osp.join(output_library, binary) library_dlls = find_dll_dependencies(dumpbin, binary_path) @@ -275,19 +255,19 @@ def relocate_dll_library(dumpbin, output_dir, output_library, binary): while binary_queue != []: library, parent = binary_queue.pop(0) - if library in WINDOWS_ALLOWLIST or library.startswith('api-ms-win'): - print('Omitting {0}'.format(library)) + if library in WINDOWS_ALLOWLIST or library.startswith("api-ms-win"): + print("Omitting {0}".format(library)) continue library_path = find_program(library) if library_path is None: - print('{0} not found'.format(library)) + print("{0} not found".format(library)) continue - if osp.basename(osp.dirname(library_path)) == 'system32': + if osp.basename(osp.dirname(library_path)) == "system32": continue - print('{0}: {1}'.format(library, library_path)) + print("{0}: {1}".format(library, library_path)) parent_dependencies = binary_dependencies.get(parent, []) parent_dependencies.append(library) binary_dependencies[parent] = parent_dependencies @@ -299,55 +279,56 @@ def relocate_dll_library(dumpbin, output_dir, output_library, binary): downstream_dlls = find_dll_dependencies(dumpbin, library_path) binary_queue += [(n, library) for n in downstream_dlls] - print('Copying dependencies to wheel directory') - package_dir = osp.join(output_dir, 'torchvision') + print("Copying dependencies to wheel directory") + package_dir = osp.join(output_dir, "torchvision") for library in binary_paths: if library != binary: library_path = binary_paths[library] new_library_path = osp.join(package_dir, library) - print('{0} -> {1}'.format(library, new_library_path)) + print("{0} -> {1}".format(library, new_library_path)) shutil.copyfile(library_path, new_library_path) def compress_wheel(output_dir, wheel, wheel_dir, wheel_name): """Create RECORD file and compress wheel distribution.""" - print('Update RECORD file in wheel') - dist_info = glob.glob(osp.join(output_dir, '*.dist-info'))[0] - record_file = osp.join(dist_info, 'RECORD') + print("Update RECORD file in wheel") + dist_info = glob.glob(osp.join(output_dir, "*.dist-info"))[0] + record_file = osp.join(dist_info, "RECORD") - with open(record_file, 'w') as f: + with open(record_file, "w") as f: for root, _, files in os.walk(output_dir): for this_file in files: full_file = osp.join(root, this_file) rel_file = osp.relpath(full_file, output_dir) if full_file == record_file: - f.write('{0},,\n'.format(rel_file)) + f.write("{0},,\n".format(rel_file)) else: digest, size = rehash(full_file) - f.write('{0},{1},{2}\n'.format(rel_file, digest, size)) + f.write("{0},{1},{2}\n".format(rel_file, digest, size)) - print('Compressing wheel') + print("Compressing wheel") base_wheel_name = osp.join(wheel_dir, wheel_name) - shutil.make_archive(base_wheel_name, 'zip', output_dir) + shutil.make_archive(base_wheel_name, "zip", output_dir) os.remove(wheel) - shutil.move('{0}.zip'.format(base_wheel_name), wheel) + shutil.move("{0}.zip".format(base_wheel_name), wheel) shutil.rmtree(output_dir) def patch_linux(): # Get patchelf location - patchelf = find_program('patchelf') + patchelf = find_program("patchelf") if patchelf is None: - raise FileNotFoundError('Patchelf was not found in the system, please' - ' make sure that is available on the PATH.') + raise FileNotFoundError( + "Patchelf was not found in the system, please" " make sure that is available on the PATH." + ) # Find wheel - print('Finding wheels...') - wheels = glob.glob(osp.join(PACKAGE_ROOT, 'dist', '*.whl')) - output_dir = osp.join(PACKAGE_ROOT, 'dist', '.wheel-process') + print("Finding wheels...") + wheels = glob.glob(osp.join(PACKAGE_ROOT, "dist", "*.whl")) + output_dir = osp.join(PACKAGE_ROOT, "dist", ".wheel-process") - image_binary = 'image.so' - video_binary = 'video_reader.so' + image_binary = "image.so" + video_binary = "video_reader.so" torchvision_binaries = [image_binary, video_binary] for wheel in wheels: if osp.exists(output_dir): @@ -355,37 +336,37 @@ def patch_linux(): os.makedirs(output_dir) - print('Unzipping wheel...') + print("Unzipping wheel...") wheel_file = osp.basename(wheel) wheel_dir = osp.dirname(wheel) - print('{0}'.format(wheel_file)) + print("{0}".format(wheel_file)) wheel_name, _ = osp.splitext(wheel_file) unzip_file(wheel, output_dir) - print('Finding ELF dependencies...') - output_library = osp.join(output_dir, 'torchvision') + print("Finding ELF dependencies...") + output_library = osp.join(output_dir, "torchvision") for binary in torchvision_binaries: if osp.exists(osp.join(output_library, binary)): - relocate_elf_library( - patchelf, output_dir, output_library, binary) + relocate_elf_library(patchelf, output_dir, output_library, binary) compress_wheel(output_dir, wheel, wheel_dir, wheel_name) def patch_win(): # Get dumpbin location - dumpbin = find_program('dumpbin') + dumpbin = find_program("dumpbin") if dumpbin is None: - raise FileNotFoundError('Dumpbin was not found in the system, please' - ' make sure that is available on the PATH.') + raise FileNotFoundError( + "Dumpbin was not found in the system, please" " make sure that is available on the PATH." + ) # Find wheel - print('Finding wheels...') - wheels = glob.glob(osp.join(PACKAGE_ROOT, 'dist', '*.whl')) - output_dir = osp.join(PACKAGE_ROOT, 'dist', '.wheel-process') + print("Finding wheels...") + wheels = glob.glob(osp.join(PACKAGE_ROOT, "dist", "*.whl")) + output_dir = osp.join(PACKAGE_ROOT, "dist", ".wheel-process") - image_binary = 'image.pyd' - video_binary = 'video_reader.pyd' + image_binary = "image.pyd" + video_binary = "video_reader.pyd" torchvision_binaries = [image_binary, video_binary] for wheel in wheels: if osp.exists(output_dir): @@ -393,25 +374,24 @@ def patch_win(): os.makedirs(output_dir) - print('Unzipping wheel...') + print("Unzipping wheel...") wheel_file = osp.basename(wheel) wheel_dir = osp.dirname(wheel) - print('{0}'.format(wheel_file)) + print("{0}".format(wheel_file)) wheel_name, _ = osp.splitext(wheel_file) unzip_file(wheel, output_dir) - print('Finding DLL/PE dependencies...') - output_library = osp.join(output_dir, 'torchvision') + print("Finding DLL/PE dependencies...") + output_library = osp.join(output_dir, "torchvision") for binary in torchvision_binaries: if osp.exists(osp.join(output_library, binary)): - relocate_dll_library( - dumpbin, output_dir, output_library, binary) + relocate_dll_library(dumpbin, output_dir, output_library, binary) compress_wheel(output_dir, wheel, wheel_dir, wheel_name) -if __name__ == '__main__': - if sys.platform == 'linux': +if __name__ == "__main__": + if sys.platform == "linux": patch_linux() - elif sys.platform == 'win32': + elif sys.platform == "win32": patch_win() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000000..c891329458e --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,13 @@ +[tool.usort] +first_party_detection = false + +[tool.black] + +line-length = 120 +target-version = ["py36"] + +[tool.ufmt] + +excludes = [ + "gallery", +] diff --git a/references/classification/presets.py b/references/classification/presets.py index 2eb60fe2e98..27ce486207d 100644 --- a/references/classification/presets.py +++ b/references/classification/presets.py @@ -4,8 +4,15 @@ class ClassificationPresetTrain: - def __init__(self, crop_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), hflip_prob=0.5, - auto_augment_policy=None, random_erase_prob=0.0): + def __init__( + self, + crop_size, + mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225), + hflip_prob=0.5, + auto_augment_policy=None, + random_erase_prob=0.0, + ): trans = [transforms.RandomResizedCrop(crop_size)] if hflip_prob > 0: trans.append(transforms.RandomHorizontalFlip(hflip_prob)) @@ -17,11 +24,13 @@ def __init__(self, crop_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.2 else: aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) trans.append(autoaugment.AutoAugment(policy=aa_policy)) - trans.extend([ - transforms.PILToTensor(), - transforms.ConvertImageDtype(torch.float), - transforms.Normalize(mean=mean, std=std), - ]) + trans.extend( + [ + transforms.PILToTensor(), + transforms.ConvertImageDtype(torch.float), + transforms.Normalize(mean=mean, std=std), + ] + ) if random_erase_prob > 0: trans.append(transforms.RandomErasing(p=random_erase_prob)) @@ -32,16 +41,24 @@ def __call__(self, img): class ClassificationPresetEval: - def __init__(self, crop_size, resize_size=256, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), - interpolation=InterpolationMode.BILINEAR): - - self.transforms = transforms.Compose([ - transforms.Resize(resize_size, interpolation=interpolation), - transforms.CenterCrop(crop_size), - transforms.PILToTensor(), - transforms.ConvertImageDtype(torch.float), - transforms.Normalize(mean=mean, std=std), - ]) + def __init__( + self, + crop_size, + resize_size=256, + mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225), + interpolation=InterpolationMode.BILINEAR, + ): + + self.transforms = transforms.Compose( + [ + transforms.Resize(resize_size, interpolation=interpolation), + transforms.CenterCrop(crop_size), + transforms.PILToTensor(), + transforms.ConvertImageDtype(torch.float), + transforms.Normalize(mean=mean, std=std), + ] + ) def __call__(self, img): return self.transforms(img) diff --git a/references/classification/train.py b/references/classification/train.py index 48ab75bc2c1..38ac592237a 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -2,16 +2,15 @@ import os import time +import presets import torch import torch.utils.data -from torch.utils.data.dataloader import default_collate -from torch import nn import torchvision -from torchvision.transforms.functional import InterpolationMode - -import presets import transforms import utils +from torch import nn +from torch.utils.data.dataloader import default_collate +from torchvision.transforms.functional import InterpolationMode try: from apex import amp @@ -19,14 +18,13 @@ amp = None -def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, - print_freq, apex=False, model_ema=None): +def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, print_freq, apex=False, model_ema=None): model.train() metric_logger = utils.MetricLogger(delimiter=" ") - metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}')) - metric_logger.add_meter('img/s', utils.SmoothedValue(window_size=10, fmt='{value}')) + metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}")) + metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}")) - header = 'Epoch: [{}]'.format(epoch) + header = "Epoch: [{}]".format(epoch) for image, target in metric_logger.log_every(data_loader, print_freq, header): start_time = time.time() image, target = image.to(device), target.to(device) @@ -44,18 +42,18 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) batch_size = image.shape[0] metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"]) - metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) - metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) - metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time)) + metric_logger.meters["acc1"].update(acc1.item(), n=batch_size) + metric_logger.meters["acc5"].update(acc5.item(), n=batch_size) + metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time)) if model_ema: model_ema.update_parameters(model) -def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=''): +def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=""): model.eval() metric_logger = utils.MetricLogger(delimiter=" ") - header = f'Test: {log_suffix}' + header = f"Test: {log_suffix}" with torch.no_grad(): for image, target in metric_logger.log_every(data_loader, print_freq, header): image = image.to(device, non_blocking=True) @@ -68,17 +66,18 @@ def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=' # could have been padded in distributed setup batch_size = image.shape[0] metric_logger.update(loss=loss.item()) - metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) - metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) + metric_logger.meters["acc1"].update(acc1.item(), n=batch_size) + metric_logger.meters["acc5"].update(acc5.item(), n=batch_size) # gather the stats from all processes metric_logger.synchronize_between_processes() - print(f'{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}') + print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}") return metric_logger.acc1.global_avg def _get_cache_path(filepath): import hashlib + h = hashlib.sha1(filepath.encode()).hexdigest() cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt") cache_path = os.path.expanduser(cache_path) @@ -90,14 +89,20 @@ def load_data(traindir, valdir, args): print("Loading data") resize_size, crop_size = 256, 224 interpolation = InterpolationMode.BILINEAR - if args.model == 'inception_v3': + if args.model == "inception_v3": resize_size, crop_size = 342, 299 - elif args.model.startswith('efficientnet_'): + elif args.model.startswith("efficientnet_"): sizes = { - 'b0': (256, 224), 'b1': (256, 240), 'b2': (288, 288), 'b3': (320, 300), - 'b4': (384, 380), 'b5': (456, 456), 'b6': (528, 528), 'b7': (600, 600), + "b0": (256, 224), + "b1": (256, 240), + "b2": (288, 288), + "b3": (320, 300), + "b4": (384, 380), + "b5": (456, 456), + "b6": (528, 528), + "b7": (600, 600), } - e_type = args.model.replace('efficientnet_', '') + e_type = args.model.replace("efficientnet_", "") resize_size, crop_size = sizes[e_type] interpolation = InterpolationMode.BICUBIC @@ -113,8 +118,10 @@ def load_data(traindir, valdir, args): random_erase_prob = getattr(args, "random_erase", 0.0) dataset = torchvision.datasets.ImageFolder( traindir, - presets.ClassificationPresetTrain(crop_size=crop_size, auto_augment_policy=auto_augment_policy, - random_erase_prob=random_erase_prob)) + presets.ClassificationPresetTrain( + crop_size=crop_size, auto_augment_policy=auto_augment_policy, random_erase_prob=random_erase_prob + ), + ) if args.cache_dataset: print("Saving dataset_train to {}".format(cache_path)) utils.mkdir(os.path.dirname(cache_path)) @@ -130,8 +137,8 @@ def load_data(traindir, valdir, args): else: dataset_test = torchvision.datasets.ImageFolder( valdir, - presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size, - interpolation=interpolation)) + presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size, interpolation=interpolation), + ) if args.cache_dataset: print("Saving dataset_test to {}".format(cache_path)) utils.mkdir(os.path.dirname(cache_path)) @@ -150,8 +157,10 @@ def load_data(traindir, valdir, args): def main(args): if args.apex and amp is None: - raise RuntimeError("Failed to import apex. Please install apex from https://www.github.com/nvidia/apex " - "to enable mixed-precision training.") + raise RuntimeError( + "Failed to import apex. Please install apex from https://www.github.com/nvidia/apex " + "to enable mixed-precision training." + ) if args.output_dir: utils.mkdir(args.output_dir) @@ -163,8 +172,8 @@ def main(args): torch.backends.cudnn.benchmark = True - train_dir = os.path.join(args.data_path, 'train') - val_dir = os.path.join(args.data_path, 'val') + train_dir = os.path.join(args.data_path, "train") + val_dir = os.path.join(args.data_path, "val") dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args) collate_fn = None @@ -178,12 +187,16 @@ def main(args): mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms) collate_fn = lambda batch: mixupcutmix(*default_collate(batch)) # noqa: E731 data_loader = torch.utils.data.DataLoader( - dataset, batch_size=args.batch_size, - sampler=train_sampler, num_workers=args.workers, pin_memory=True, - collate_fn=collate_fn) + dataset, + batch_size=args.batch_size, + sampler=train_sampler, + num_workers=args.workers, + pin_memory=True, + collate_fn=collate_fn, + ) data_loader_test = torch.utils.data.DataLoader( - dataset_test, batch_size=args.batch_size, - sampler=test_sampler, num_workers=args.workers, pin_memory=True) + dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True + ) print("Creating model") model = torchvision.models.__dict__[args.model](pretrained=args.pretrained, num_classes=num_classes) @@ -197,45 +210,57 @@ def main(args): opt_name = args.opt.lower() if opt_name.startswith("sgd"): optimizer = torch.optim.SGD( - model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, - nesterov="nesterov" in opt_name) - elif opt_name == 'rmsprop': - optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum, - weight_decay=args.weight_decay, eps=0.0316, alpha=0.9) + model.parameters(), + lr=args.lr, + momentum=args.momentum, + weight_decay=args.weight_decay, + nesterov="nesterov" in opt_name, + ) + elif opt_name == "rmsprop": + optimizer = torch.optim.RMSprop( + model.parameters(), + lr=args.lr, + momentum=args.momentum, + weight_decay=args.weight_decay, + eps=0.0316, + alpha=0.9, + ) else: raise RuntimeError("Invalid optimizer {}. Only SGD and RMSprop are supported.".format(args.opt)) if args.apex: - model, optimizer = amp.initialize(model, optimizer, - opt_level=args.apex_opt_level - ) + model, optimizer = amp.initialize(model, optimizer, opt_level=args.apex_opt_level) args.lr_scheduler = args.lr_scheduler.lower() - if args.lr_scheduler == 'steplr': + if args.lr_scheduler == "steplr": main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) - elif args.lr_scheduler == 'cosineannealinglr': - main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, - T_max=args.epochs - args.lr_warmup_epochs) - elif args.lr_scheduler == 'exponentiallr': + elif args.lr_scheduler == "cosineannealinglr": + main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, T_max=args.epochs - args.lr_warmup_epochs + ) + elif args.lr_scheduler == "exponentiallr": main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma) else: - raise RuntimeError("Invalid lr scheduler '{}'. Only StepLR, CosineAnnealingLR and ExponentialLR " - "are supported.".format(args.lr_scheduler)) + raise RuntimeError( + "Invalid lr scheduler '{}'. Only StepLR, CosineAnnealingLR and ExponentialLR " + "are supported.".format(args.lr_scheduler) + ) if args.lr_warmup_epochs > 0: - if args.lr_warmup_method == 'linear': - warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=args.lr_warmup_decay, - total_iters=args.lr_warmup_epochs) - elif args.lr_warmup_method == 'constant': - warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=args.lr_warmup_decay, - total_iters=args.lr_warmup_epochs) + if args.lr_warmup_method == "linear": + warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR( + optimizer, start_factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs + ) + elif args.lr_warmup_method == "constant": + warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR( + optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs + ) else: - raise RuntimeError(f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant " - "are supported.") + raise RuntimeError( + f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant " "are supported." + ) lr_scheduler = torch.optim.lr_scheduler.SequentialLR( - optimizer, - schedulers=[warmup_lr_scheduler, main_lr_scheduler], - milestones=[args.lr_warmup_epochs] + optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs] ) else: lr_scheduler = main_lr_scheduler @@ -250,13 +275,13 @@ def main(args): model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=args.model_ema_decay) if args.resume: - checkpoint = torch.load(args.resume, map_location='cpu') - model_without_ddp.load_state_dict(checkpoint['model']) - optimizer.load_state_dict(checkpoint['optimizer']) - lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) - args.start_epoch = checkpoint['epoch'] + 1 + checkpoint = torch.load(args.resume, map_location="cpu") + model_without_ddp.load_state_dict(checkpoint["model"]) + optimizer.load_state_dict(checkpoint["optimizer"]) + lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) + args.start_epoch = checkpoint["epoch"] + 1 if model_ema: - model_ema.load_state_dict(checkpoint['model_ema']) + model_ema.load_state_dict(checkpoint["model_ema"]) if args.test_only: evaluate(model, criterion, data_loader_test, device=device) @@ -271,64 +296,67 @@ def main(args): lr_scheduler.step() evaluate(model, criterion, data_loader_test, device=device) if model_ema: - evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix='EMA') + evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA") if args.output_dir: checkpoint = { - 'model': model_without_ddp.state_dict(), - 'optimizer': optimizer.state_dict(), - 'lr_scheduler': lr_scheduler.state_dict(), - 'epoch': epoch, - 'args': args} + "model": model_without_ddp.state_dict(), + "optimizer": optimizer.state_dict(), + "lr_scheduler": lr_scheduler.state_dict(), + "epoch": epoch, + "args": args, + } if model_ema: - checkpoint['model_ema'] = model_ema.state_dict() - utils.save_on_master( - checkpoint, - os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) - utils.save_on_master( - checkpoint, - os.path.join(args.output_dir, 'checkpoint.pth')) + checkpoint["model_ema"] = model_ema.state_dict() + utils.save_on_master(checkpoint, os.path.join(args.output_dir, "model_{}.pth".format(epoch))) + utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth")) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) - print('Training time {}'.format(total_time_str)) + print("Training time {}".format(total_time_str)) def get_args_parser(add_help=True): import argparse - parser = argparse.ArgumentParser(description='PyTorch Classification Training', add_help=add_help) - - parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', help='dataset') - parser.add_argument('--model', default='resnet18', help='model') - parser.add_argument('--device', default='cuda', help='device') - parser.add_argument('-b', '--batch-size', default=32, type=int) - parser.add_argument('--epochs', default=90, type=int, metavar='N', - help='number of total epochs to run') - parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', - help='number of data loading workers (default: 16)') - parser.add_argument('--opt', default='sgd', type=str, help='optimizer') - parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate') - parser.add_argument('--momentum', default=0.9, type=float, metavar='M', - help='momentum') - parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, - metavar='W', help='weight decay (default: 1e-4)', - dest='weight_decay') - parser.add_argument('--label-smoothing', default=0.0, type=float, - help='label smoothing (default: 0.0)', - dest='label_smoothing') - parser.add_argument('--mixup-alpha', default=0.0, type=float, help='mixup alpha (default: 0.0)') - parser.add_argument('--cutmix-alpha', default=0.0, type=float, help='cutmix alpha (default: 0.0)') - parser.add_argument('--lr-scheduler', default="steplr", help='the lr scheduler (default: steplr)') - parser.add_argument('--lr-warmup-epochs', default=0, type=int, help='the number of epochs to warmup (default: 0)') - parser.add_argument('--lr-warmup-method', default="constant", type=str, - help='the warmup method (default: constant)') - parser.add_argument('--lr-warmup-decay', default=0.01, type=float, help='the decay for lr') - parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs') - parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma') - parser.add_argument('--print-freq', default=10, type=int, help='print frequency') - parser.add_argument('--output-dir', default='.', help='path where to save') - parser.add_argument('--resume', default='', help='resume from checkpoint') - parser.add_argument('--start-epoch', default=0, type=int, metavar='N', - help='start epoch') + + parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help) + + parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", help="dataset") + parser.add_argument("--model", default="resnet18", help="model") + parser.add_argument("--device", default="cuda", help="device") + parser.add_argument("-b", "--batch-size", default=32, type=int) + parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run") + parser.add_argument( + "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)" + ) + parser.add_argument("--opt", default="sgd", type=str, help="optimizer") + parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate") + parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum") + parser.add_argument( + "--wd", + "--weight-decay", + default=1e-4, + type=float, + metavar="W", + help="weight decay (default: 1e-4)", + dest="weight_decay", + ) + parser.add_argument( + "--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing" + ) + parser.add_argument("--mixup-alpha", default=0.0, type=float, help="mixup alpha (default: 0.0)") + parser.add_argument("--cutmix-alpha", default=0.0, type=float, help="cutmix alpha (default: 0.0)") + parser.add_argument("--lr-scheduler", default="steplr", help="the lr scheduler (default: steplr)") + parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)") + parser.add_argument( + "--lr-warmup-method", default="constant", type=str, help="the warmup method (default: constant)" + ) + parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr") + parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs") + parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma") + parser.add_argument("--print-freq", default=10, type=int, help="print frequency") + parser.add_argument("--output-dir", default=".", help="path where to save") + parser.add_argument("--resume", default="", help="resume from checkpoint") + parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch") parser.add_argument( "--cache-dataset", dest="cache_dataset", @@ -353,28 +381,32 @@ def get_args_parser(add_help=True): help="Use pre-trained models from the modelzoo", action="store_true", ) - parser.add_argument('--auto-augment', default=None, help='auto augment policy (default: None)') - parser.add_argument('--random-erase', default=0.0, type=float, help='random erasing probability (default: 0.0)') + parser.add_argument("--auto-augment", default=None, help="auto augment policy (default: None)") + parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)") # Mixed precision training parameters - parser.add_argument('--apex', action='store_true', - help='Use apex for mixed precision training') - parser.add_argument('--apex-opt-level', default='O1', type=str, - help='For apex mixed precision training' - 'O0 for FP32 training, O1 for mixed precision training.' - 'For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet' - ) + parser.add_argument("--apex", action="store_true", help="Use apex for mixed precision training") + parser.add_argument( + "--apex-opt-level", + default="O1", + type=str, + help="For apex mixed precision training" + "O0 for FP32 training, O1 for mixed precision training." + "For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet", + ) # distributed training parameters - parser.add_argument('--world-size', default=1, type=int, - help='number of distributed processes') - parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training') + parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") + parser.add_argument("--dist-url", default="env://", help="url used to set up distributed training") parser.add_argument( - '--model-ema', action='store_true', - help='enable tracking Exponential Moving Average of model parameters') + "--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters" + ) parser.add_argument( - '--model-ema-decay', type=float, default=0.9, - help='decay factor for Exponential Moving Average of model parameters(default: 0.9)') + "--model-ema-decay", + type=float, + default=0.9, + help="decay factor for Exponential Moving Average of model parameters(default: 0.9)", + ) return parser diff --git a/references/classification/train_quantization.py b/references/classification/train_quantization.py index ec945f4f58f..5bf64aea721 100644 --- a/references/classification/train_quantization.py +++ b/references/classification/train_quantization.py @@ -1,14 +1,14 @@ +import copy import datetime import os import time -import copy import torch +import torch.quantization import torch.utils.data -from torch import nn import torchvision -import torch.quantization import utils +from torch import nn from train import train_one_epoch, evaluate, load_data @@ -20,8 +20,7 @@ def main(args): print(args) if args.post_training_quantize and args.distributed: - raise RuntimeError("Post training quantization example should not be performed " - "on distributed mode") + raise RuntimeError("Post training quantization example should not be performed " "on distributed mode") # Set backend engine to ensure that quantized model runs on the correct kernels if args.backend not in torch.backends.quantized.supported_engines: @@ -33,17 +32,17 @@ def main(args): # Data loading code print("Loading data") - train_dir = os.path.join(args.data_path, 'train') - val_dir = os.path.join(args.data_path, 'val') + train_dir = os.path.join(args.data_path, "train") + val_dir = os.path.join(args.data_path, "val") dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args) data_loader = torch.utils.data.DataLoader( - dataset, batch_size=args.batch_size, - sampler=train_sampler, num_workers=args.workers, pin_memory=True) + dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True + ) data_loader_test = torch.utils.data.DataLoader( - dataset_test, batch_size=args.eval_batch_size, - sampler=test_sampler, num_workers=args.workers, pin_memory=True) + dataset_test, batch_size=args.eval_batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True + ) print("Creating model", args.model) # when training quantized models, we always start from a pre-trained fp32 reference model @@ -59,12 +58,10 @@ def main(args): model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) optimizer = torch.optim.SGD( - model.parameters(), lr=args.lr, momentum=args.momentum, - weight_decay=args.weight_decay) + model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay + ) - lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, - step_size=args.lr_step_size, - gamma=args.lr_gamma) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) criterion = nn.CrossEntropyLoss() model_without_ddp = model @@ -73,21 +70,19 @@ def main(args): model_without_ddp = model.module if args.resume: - checkpoint = torch.load(args.resume, map_location='cpu') - model_without_ddp.load_state_dict(checkpoint['model']) - optimizer.load_state_dict(checkpoint['optimizer']) - lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) - args.start_epoch = checkpoint['epoch'] + 1 + checkpoint = torch.load(args.resume, map_location="cpu") + model_without_ddp.load_state_dict(checkpoint["model"]) + optimizer.load_state_dict(checkpoint["optimizer"]) + lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) + args.start_epoch = checkpoint["epoch"] + 1 if args.post_training_quantize: # perform calibration on a subset of the training dataset # for that, create a subset of the training dataset - ds = torch.utils.data.Subset( - dataset, - indices=list(range(args.batch_size * args.num_calibration_batches))) + ds = torch.utils.data.Subset(dataset, indices=list(range(args.batch_size * args.num_calibration_batches))) data_loader_calibration = torch.utils.data.DataLoader( - ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, - pin_memory=True) + ds, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True + ) model.eval() model.fuse_model() model.qconfig = torch.quantization.get_default_qconfig(args.backend) @@ -97,10 +92,9 @@ def main(args): evaluate(model, criterion, data_loader_calibration, device=device, print_freq=1) torch.quantization.convert(model, inplace=True) if args.output_dir: - print('Saving quantized model') + print("Saving quantized model") if utils.is_main_process(): - torch.save(model.state_dict(), os.path.join(args.output_dir, - 'quantized_post_train_model.pth')) + torch.save(model.state_dict(), os.path.join(args.output_dir, "quantized_post_train_model.pth")) print("Evaluating post-training quantized model") evaluate(model, criterion, data_loader_test, device=device) return @@ -115,107 +109,103 @@ def main(args): for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) - print('Starting training for epoch', epoch) - train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, - args.print_freq) + print("Starting training for epoch", epoch) + train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq) lr_scheduler.step() with torch.no_grad(): if epoch >= args.num_observer_update_epochs: - print('Disabling observer for subseq epochs, epoch = ', epoch) + print("Disabling observer for subseq epochs, epoch = ", epoch) model.apply(torch.quantization.disable_observer) if epoch >= args.num_batch_norm_update_epochs: - print('Freezing BN for subseq epochs, epoch = ', epoch) + print("Freezing BN for subseq epochs, epoch = ", epoch) model.apply(torch.nn.intrinsic.qat.freeze_bn_stats) - print('Evaluate QAT model') + print("Evaluate QAT model") evaluate(model, criterion, data_loader_test, device=device) quantized_eval_model = copy.deepcopy(model_without_ddp) quantized_eval_model.eval() - quantized_eval_model.to(torch.device('cpu')) + quantized_eval_model.to(torch.device("cpu")) torch.quantization.convert(quantized_eval_model, inplace=True) - print('Evaluate Quantized model') - evaluate(quantized_eval_model, criterion, data_loader_test, - device=torch.device('cpu')) + print("Evaluate Quantized model") + evaluate(quantized_eval_model, criterion, data_loader_test, device=torch.device("cpu")) model.train() if args.output_dir: checkpoint = { - 'model': model_without_ddp.state_dict(), - 'eval_model': quantized_eval_model.state_dict(), - 'optimizer': optimizer.state_dict(), - 'lr_scheduler': lr_scheduler.state_dict(), - 'epoch': epoch, - 'args': args} - utils.save_on_master( - checkpoint, - os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) - utils.save_on_master( - checkpoint, - os.path.join(args.output_dir, 'checkpoint.pth')) - print('Saving models after epoch ', epoch) + "model": model_without_ddp.state_dict(), + "eval_model": quantized_eval_model.state_dict(), + "optimizer": optimizer.state_dict(), + "lr_scheduler": lr_scheduler.state_dict(), + "epoch": epoch, + "args": args, + } + utils.save_on_master(checkpoint, os.path.join(args.output_dir, "model_{}.pth".format(epoch))) + utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth")) + print("Saving models after epoch ", epoch) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) - print('Training time {}'.format(total_time_str)) + print("Training time {}".format(total_time_str)) def get_args_parser(add_help=True): import argparse - parser = argparse.ArgumentParser(description='PyTorch Quantized Classification Training', add_help=add_help) - - parser.add_argument('--data-path', - default='/datasets01/imagenet_full_size/061417/', - help='dataset') - parser.add_argument('--model', - default='mobilenet_v2', - help='model') - parser.add_argument('--backend', - default='qnnpack', - help='fbgemm or qnnpack') - parser.add_argument('--device', - default='cuda', - help='device') - - parser.add_argument('-b', '--batch-size', default=32, type=int, - help='batch size for calibration/training') - parser.add_argument('--eval-batch-size', default=128, type=int, - help='batch size for evaluation') - parser.add_argument('--epochs', default=90, type=int, metavar='N', - help='number of total epochs to run') - parser.add_argument('--num-observer-update-epochs', - default=4, type=int, metavar='N', - help='number of total epochs to update observers') - parser.add_argument('--num-batch-norm-update-epochs', default=3, - type=int, metavar='N', - help='number of total epochs to update batch norm stats') - parser.add_argument('--num-calibration-batches', - default=32, type=int, metavar='N', - help='number of batches of training set for \ - observer calibration ') - - parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', - help='number of data loading workers (default: 16)') - parser.add_argument('--lr', - default=0.0001, type=float, - help='initial learning rate') - parser.add_argument('--momentum', - default=0.9, type=float, metavar='M', - help='momentum') - parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, - metavar='W', help='weight decay (default: 1e-4)', - dest='weight_decay') - parser.add_argument('--lr-step-size', default=30, type=int, - help='decrease lr every step-size epochs') - parser.add_argument('--lr-gamma', default=0.1, type=float, - help='decrease lr by a factor of lr-gamma') - parser.add_argument('--print-freq', default=10, type=int, - help='print frequency') - parser.add_argument('--output-dir', default='.', help='path where to save') - parser.add_argument('--resume', default='', help='resume from checkpoint') - parser.add_argument('--start-epoch', default=0, type=int, metavar='N', - help='start epoch') + + parser = argparse.ArgumentParser(description="PyTorch Quantized Classification Training", add_help=add_help) + + parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", help="dataset") + parser.add_argument("--model", default="mobilenet_v2", help="model") + parser.add_argument("--backend", default="qnnpack", help="fbgemm or qnnpack") + parser.add_argument("--device", default="cuda", help="device") + + parser.add_argument("-b", "--batch-size", default=32, type=int, help="batch size for calibration/training") + parser.add_argument("--eval-batch-size", default=128, type=int, help="batch size for evaluation") + parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run") + parser.add_argument( + "--num-observer-update-epochs", + default=4, + type=int, + metavar="N", + help="number of total epochs to update observers", + ) + parser.add_argument( + "--num-batch-norm-update-epochs", + default=3, + type=int, + metavar="N", + help="number of total epochs to update batch norm stats", + ) + parser.add_argument( + "--num-calibration-batches", + default=32, + type=int, + metavar="N", + help="number of batches of training set for \ + observer calibration ", + ) + + parser.add_argument( + "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)" + ) + parser.add_argument("--lr", default=0.0001, type=float, help="initial learning rate") + parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum") + parser.add_argument( + "--wd", + "--weight-decay", + default=1e-4, + type=float, + metavar="W", + help="weight decay (default: 1e-4)", + dest="weight_decay", + ) + parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs") + parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma") + parser.add_argument("--print-freq", default=10, type=int, help="print frequency") + parser.add_argument("--output-dir", default=".", help="path where to save") + parser.add_argument("--resume", default="", help="resume from checkpoint") + parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch") parser.add_argument( "--cache-dataset", dest="cache_dataset", @@ -243,11 +233,8 @@ def get_args_parser(add_help=True): ) # distributed training parameters - parser.add_argument('--world-size', default=1, type=int, - help='number of distributed processes') - parser.add_argument('--dist-url', - default='env://', - help='url used to set up distributed training') + parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") + parser.add_argument("--dist-url", default="env://", help="url used to set up distributed training") return parser diff --git a/references/classification/transforms.py b/references/classification/transforms.py index c4d83ce410c..7788c9e5c3f 100644 --- a/references/classification/transforms.py +++ b/references/classification/transforms.py @@ -1,7 +1,7 @@ import math -import torch - from typing import Tuple + +import torch from torch import Tensor from torchvision.transforms import functional as F @@ -19,9 +19,7 @@ class RandomMixup(torch.nn.Module): inplace (bool): boolean to make this transform inplace. Default set to False. """ - def __init__(self, num_classes: int, - p: float = 0.5, alpha: float = 1.0, - inplace: bool = False) -> None: + def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None: super().__init__() assert num_classes > 0, "Please provide a valid positive value for the num_classes." assert alpha > 0, "Alpha param can't be zero." @@ -45,7 +43,7 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: elif target.ndim != 1: raise ValueError("Target ndim should be 1. Got {}".format(target.ndim)) elif not batch.is_floating_point(): - raise TypeError('Batch dtype should be a float tensor. Got {}.'.format(batch.dtype)) + raise TypeError("Batch dtype should be a float tensor. Got {}.".format(batch.dtype)) elif target.dtype != torch.int64: raise TypeError("Target dtype should be torch.int64. Got {}".format(target.dtype)) @@ -74,12 +72,12 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: return batch, target def __repr__(self) -> str: - s = self.__class__.__name__ + '(' - s += 'num_classes={num_classes}' - s += ', p={p}' - s += ', alpha={alpha}' - s += ', inplace={inplace}' - s += ')' + s = self.__class__.__name__ + "(" + s += "num_classes={num_classes}" + s += ", p={p}" + s += ", alpha={alpha}" + s += ", inplace={inplace}" + s += ")" return s.format(**self.__dict__) @@ -97,9 +95,7 @@ class RandomCutmix(torch.nn.Module): inplace (bool): boolean to make this transform inplace. Default set to False. """ - def __init__(self, num_classes: int, - p: float = 0.5, alpha: float = 1.0, - inplace: bool = False) -> None: + def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None: super().__init__() assert num_classes > 0, "Please provide a valid positive value for the num_classes." assert alpha > 0, "Alpha param can't be zero." @@ -123,7 +119,7 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: elif target.ndim != 1: raise ValueError("Target ndim should be 1. Got {}".format(target.ndim)) elif not batch.is_floating_point(): - raise TypeError('Batch dtype should be a float tensor. Got {}.'.format(batch.dtype)) + raise TypeError("Batch dtype should be a float tensor. Got {}.".format(batch.dtype)) elif target.dtype != torch.int64: raise TypeError("Target dtype should be torch.int64. Got {}".format(target.dtype)) @@ -166,10 +162,10 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: return batch, target def __repr__(self) -> str: - s = self.__class__.__name__ + '(' - s += 'num_classes={num_classes}' - s += ', p={p}' - s += ', alpha={alpha}' - s += ', inplace={inplace}' - s += ')' + s = self.__class__.__name__ + "(" + s += "num_classes={num_classes}" + s += ", p={p}" + s += ", alpha={alpha}" + s += ", inplace={inplace}" + s += ")" return s.format(**self.__dict__) diff --git a/references/classification/utils.py b/references/classification/utils.py index fad607636e5..707418ee217 100644 --- a/references/classification/utils.py +++ b/references/classification/utils.py @@ -1,14 +1,14 @@ -from collections import defaultdict, deque, OrderedDict import copy import datetime +import errno import hashlib +import os import time +from collections import defaultdict, deque, OrderedDict + import torch import torch.distributed as dist -import errno -import os - class SmoothedValue(object): """Track a series of values and provide access to smoothed values over a @@ -34,7 +34,7 @@ def synchronize_between_processes(self): """ if not is_dist_avail_and_initialized(): return - t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") dist.barrier() dist.all_reduce(t) t = t.tolist() @@ -65,11 +65,8 @@ def value(self): def __str__(self): return self.fmt.format( - median=self.median, - avg=self.avg, - global_avg=self.global_avg, - max=self.max, - value=self.value) + median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value + ) class MetricLogger(object): @@ -89,15 +86,12 @@ def __getattr__(self, attr): return self.meters[attr] if attr in self.__dict__: return self.__dict__[attr] - raise AttributeError("'{}' object has no attribute '{}'".format( - type(self).__name__, attr)) + raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr)) def __str__(self): loss_str = [] for name, meter in self.meters.items(): - loss_str.append( - "{}: {}".format(name, str(meter)) - ) + loss_str.append("{}: {}".format(name, str(meter))) return self.delimiter.join(loss_str) def synchronize_between_processes(self): @@ -110,31 +104,28 @@ def add_meter(self, name, meter): def log_every(self, iterable, print_freq, header=None): i = 0 if not header: - header = '' + header = "" start_time = time.time() end = time.time() - iter_time = SmoothedValue(fmt='{avg:.4f}') - data_time = SmoothedValue(fmt='{avg:.4f}') - space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + iter_time = SmoothedValue(fmt="{avg:.4f}") + data_time = SmoothedValue(fmt="{avg:.4f}") + space_fmt = ":" + str(len(str(len(iterable)))) + "d" if torch.cuda.is_available(): - log_msg = self.delimiter.join([ - header, - '[{0' + space_fmt + '}/{1}]', - 'eta: {eta}', - '{meters}', - 'time: {time}', - 'data: {data}', - 'max mem: {memory:.0f}' - ]) + log_msg = self.delimiter.join( + [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + "max mem: {memory:.0f}", + ] + ) else: - log_msg = self.delimiter.join([ - header, - '[{0' + space_fmt + '}/{1}]', - 'eta: {eta}', - '{meters}', - 'time: {time}', - 'data: {data}' - ]) + log_msg = self.delimiter.join( + [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"] + ) MB = 1024.0 * 1024.0 for obj in iterable: data_time.update(time.time() - end) @@ -144,21 +135,28 @@ def log_every(self, iterable, print_freq, header=None): eta_seconds = iter_time.global_avg * (len(iterable) - i) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) if torch.cuda.is_available(): - print(log_msg.format( - i, len(iterable), eta=eta_string, - meters=str(self), - time=str(iter_time), data=str(data_time), - memory=torch.cuda.max_memory_allocated() / MB)) + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB, + ) + ) else: - print(log_msg.format( - i, len(iterable), eta=eta_string, - meters=str(self), - time=str(iter_time), data=str(data_time))) + print( + log_msg.format( + i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time) + ) + ) i += 1 end = time.time() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) - print('{} Total time: {}'.format(header, total_time_str)) + print("{} Total time: {}".format(header, total_time_str)) class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel): @@ -167,9 +165,9 @@ class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel): `torch.optim.swa_utils.AveragedModel `_ is used to compute the EMA. """ - def __init__(self, model, decay, device='cpu'): - ema_avg = (lambda avg_model_param, model_param, num_averaged: - decay * avg_model_param + (1 - decay) * model_param) + + def __init__(self, model, decay, device="cpu"): + ema_avg = lambda avg_model_param, model_param, num_averaged: decay * avg_model_param + (1 - decay) * model_param super().__init__(model, device, ema_avg) def update_parameters(self, model): @@ -179,8 +177,7 @@ def update_parameters(self, model): if self.n_averaged == 0: p_swa.detach().copy_(p_model_) else: - p_swa.detach().copy_(self.avg_fn(p_swa.detach(), p_model_, - self.n_averaged.to(device))) + p_swa.detach().copy_(self.avg_fn(p_swa.detach(), p_model_, self.n_averaged.to(device))) self.n_averaged += 1 @@ -216,10 +213,11 @@ def setup_for_distributed(is_master): This function disables printing when not in master process """ import builtins as __builtin__ + builtin_print = __builtin__.print def print(*args, **kwargs): - force = kwargs.pop('force', False) + force = kwargs.pop("force", False) if is_master or force: builtin_print(*args, **kwargs) @@ -256,28 +254,28 @@ def save_on_master(*args, **kwargs): def init_distributed_mode(args): - if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: args.rank = int(os.environ["RANK"]) - args.world_size = int(os.environ['WORLD_SIZE']) - args.gpu = int(os.environ['LOCAL_RANK']) - elif 'SLURM_PROCID' in os.environ: - args.rank = int(os.environ['SLURM_PROCID']) + args.world_size = int(os.environ["WORLD_SIZE"]) + args.gpu = int(os.environ["LOCAL_RANK"]) + elif "SLURM_PROCID" in os.environ: + args.rank = int(os.environ["SLURM_PROCID"]) args.gpu = args.rank % torch.cuda.device_count() elif hasattr(args, "rank"): pass else: - print('Not using distributed mode') + print("Not using distributed mode") args.distributed = False return args.distributed = True torch.cuda.set_device(args.gpu) - args.dist_backend = 'nccl' - print('| distributed init (rank {}): {}'.format( - args.rank, args.dist_url), flush=True) - torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, - world_size=args.world_size, rank=args.rank) + args.dist_backend = "nccl" + print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True) + torch.distributed.init_process_group( + backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank + ) setup_for_distributed(args.rank == 0) @@ -300,9 +298,7 @@ def average_checkpoints(inputs): with open(fpath, "rb") as f: state = torch.load( f, - map_location=( - lambda s, _: torch.serialization.default_restore_location(s, "cpu") - ), + map_location=(lambda s, _: torch.serialization.default_restore_location(s, "cpu")), ) # Copies over the settings from the first checkpoint if new_state is None: @@ -336,7 +332,7 @@ def average_checkpoints(inputs): return new_state -def store_model_weights(model, checkpoint_path, checkpoint_key='model', strict=True): +def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=True): """ This method can be used to prepare weights files for new models. It receives as input a model architecture and a checkpoint from the training script and produces @@ -382,7 +378,7 @@ def store_model_weights(model, checkpoint_path, checkpoint_key='model', strict=T # Deep copy to avoid side-effects on the model object. model = copy.deepcopy(model) - checkpoint = torch.load(checkpoint_path, map_location='cpu') + checkpoint = torch.load(checkpoint_path, map_location="cpu") # Load the weights to the model to validate that everything works # and remove unnecessary weights (such as auxiliaries, etc) diff --git a/references/detection/coco_eval.py b/references/detection/coco_eval.py index 14a3d5457a0..ec0709c5d91 100644 --- a/references/detection/coco_eval.py +++ b/references/detection/coco_eval.py @@ -5,10 +5,9 @@ import numpy as np import pycocotools.mask as mask_util import torch -from pycocotools.cocoeval import COCOeval -from pycocotools.coco import COCO - import utils +from pycocotools.coco import COCO +from pycocotools.cocoeval import COCOeval class CocoEvaluator: @@ -104,8 +103,7 @@ def prepare_for_coco_segmentation(self, predictions): labels = prediction["labels"].tolist() rles = [ - mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] - for mask in masks + mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] for mask in masks ] for rle in rles: rle["counts"] = rle["counts"].decode("utf-8") @@ -141,7 +139,7 @@ def prepare_for_coco_keypoint(self, predictions): { "image_id": original_id, "category_id": labels[k], - 'keypoints': keypoint, + "keypoints": keypoint, "score": scores[k], } for k, keypoint in enumerate(keypoints) diff --git a/references/detection/coco_utils.py b/references/detection/coco_utils.py index 26701a2cbee..ad657068252 100644 --- a/references/detection/coco_utils.py +++ b/references/detection/coco_utils.py @@ -1,16 +1,14 @@ import copy import os -from PIL import Image import torch import torch.utils.data import torchvision - +import transforms as T +from PIL import Image from pycocotools import mask as coco_mask from pycocotools.coco import COCO -import transforms as T - class FilterAndRemapCocoCategories(object): def __init__(self, categories, remap=True): @@ -56,7 +54,7 @@ def __call__(self, image, target): anno = target["annotations"] - anno = [obj for obj in anno if obj['iscrowd'] == 0] + anno = [obj for obj in anno if obj["iscrowd"] == 0] boxes = [obj["bbox"] for obj in anno] # guard against no boxes via resizing @@ -147,7 +145,7 @@ def convert_to_coco_api(ds): coco_ds = COCO() # annotation IDs need to start at 1, not 0, see torchvision issue #1530 ann_id = 1 - dataset = {'images': [], 'categories': [], 'annotations': []} + dataset = {"images": [], "categories": [], "annotations": []} categories = set() for img_idx in range(len(ds)): # find better way to get target @@ -155,41 +153,41 @@ def convert_to_coco_api(ds): img, targets = ds[img_idx] image_id = targets["image_id"].item() img_dict = {} - img_dict['id'] = image_id - img_dict['height'] = img.shape[-2] - img_dict['width'] = img.shape[-1] - dataset['images'].append(img_dict) + img_dict["id"] = image_id + img_dict["height"] = img.shape[-2] + img_dict["width"] = img.shape[-1] + dataset["images"].append(img_dict) bboxes = targets["boxes"] bboxes[:, 2:] -= bboxes[:, :2] bboxes = bboxes.tolist() - labels = targets['labels'].tolist() - areas = targets['area'].tolist() - iscrowd = targets['iscrowd'].tolist() - if 'masks' in targets: - masks = targets['masks'] + labels = targets["labels"].tolist() + areas = targets["area"].tolist() + iscrowd = targets["iscrowd"].tolist() + if "masks" in targets: + masks = targets["masks"] # make masks Fortran contiguous for coco_mask masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1) - if 'keypoints' in targets: - keypoints = targets['keypoints'] + if "keypoints" in targets: + keypoints = targets["keypoints"] keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist() num_objs = len(bboxes) for i in range(num_objs): ann = {} - ann['image_id'] = image_id - ann['bbox'] = bboxes[i] - ann['category_id'] = labels[i] + ann["image_id"] = image_id + ann["bbox"] = bboxes[i] + ann["category_id"] = labels[i] categories.add(labels[i]) - ann['area'] = areas[i] - ann['iscrowd'] = iscrowd[i] - ann['id'] = ann_id - if 'masks' in targets: + ann["area"] = areas[i] + ann["iscrowd"] = iscrowd[i] + ann["id"] = ann_id + if "masks" in targets: ann["segmentation"] = coco_mask.encode(masks[i].numpy()) - if 'keypoints' in targets: - ann['keypoints'] = keypoints[i] - ann['num_keypoints'] = sum(k != 0 for k in keypoints[i][2::3]) - dataset['annotations'].append(ann) + if "keypoints" in targets: + ann["keypoints"] = keypoints[i] + ann["num_keypoints"] = sum(k != 0 for k in keypoints[i][2::3]) + dataset["annotations"].append(ann) ann_id += 1 - dataset['categories'] = [{'id': i} for i in sorted(categories)] + dataset["categories"] = [{"id": i} for i in sorted(categories)] coco_ds.dataset = dataset coco_ds.createIndex() return coco_ds @@ -220,7 +218,7 @@ def __getitem__(self, idx): return img, target -def get_coco(root, image_set, transforms, mode='instances'): +def get_coco(root, image_set, transforms, mode="instances"): anno_file_template = "{}_{}2017.json" PATHS = { "train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))), diff --git a/references/detection/engine.py b/references/detection/engine.py index 82c23c178b1..2ca7df808ef 100644 --- a/references/detection/engine.py +++ b/references/detection/engine.py @@ -1,28 +1,28 @@ import math import sys import time -import torch +import torch import torchvision.models.detection.mask_rcnn - -from coco_utils import get_coco_api_from_dataset -from coco_eval import CocoEvaluator import utils +from coco_eval import CocoEvaluator +from coco_utils import get_coco_api_from_dataset def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq): model.train() metric_logger = utils.MetricLogger(delimiter=" ") - metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) - header = 'Epoch: [{}]'.format(epoch) + metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}")) + header = "Epoch: [{}]".format(epoch) lr_scheduler = None if epoch == 0: - warmup_factor = 1. / 1000 + warmup_factor = 1.0 / 1000 warmup_iters = min(1000, len(data_loader) - 1) - lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=warmup_factor, - total_iters=warmup_iters) + lr_scheduler = torch.optim.lr_scheduler.LinearLR( + optimizer, start_factor=warmup_factor, total_iters=warmup_iters + ) for images, targets in metric_logger.log_every(data_loader, print_freq, header): images = list(image.to(device) for image in images) @@ -76,7 +76,7 @@ def evaluate(model, data_loader, device): cpu_device = torch.device("cpu") model.eval() metric_logger = utils.MetricLogger(delimiter=" ") - header = 'Test:' + header = "Test:" coco = get_coco_api_from_dataset(data_loader.dataset) iou_types = _get_iou_types(model) diff --git a/references/detection/group_by_aspect_ratio.py b/references/detection/group_by_aspect_ratio.py index 1b76f4c64f7..8d680f1b18b 100644 --- a/references/detection/group_by_aspect_ratio.py +++ b/references/detection/group_by_aspect_ratio.py @@ -1,17 +1,16 @@ import bisect -from collections import defaultdict import copy -from itertools import repeat, chain import math -import numpy as np +from collections import defaultdict +from itertools import repeat, chain +import numpy as np import torch import torch.utils.data -from torch.utils.data.sampler import BatchSampler, Sampler -from torch.utils.model_zoo import tqdm import torchvision - from PIL import Image +from torch.utils.data.sampler import BatchSampler, Sampler +from torch.utils.model_zoo import tqdm def _repeat_to_at_least(iterable, n): @@ -34,11 +33,11 @@ class GroupedBatchSampler(BatchSampler): 0, i.e. they must be in the range [0, num_groups). batch_size (int): Size of mini-batch. """ + def __init__(self, sampler, group_ids, batch_size): if not isinstance(sampler, Sampler): raise ValueError( - "sampler should be an instance of " - "torch.utils.data.Sampler, but got sampler={}".format(sampler) + "sampler should be an instance of " "torch.utils.data.Sampler, but got sampler={}".format(sampler) ) self.sampler = sampler self.group_ids = group_ids @@ -68,8 +67,7 @@ def __iter__(self): if num_remaining > 0: # for the remaining batches, take first the buffers with largest number # of elements - for group_id, _ in sorted(buffer_per_group.items(), - key=lambda x: len(x[1]), reverse=True): + for group_id, _ in sorted(buffer_per_group.items(), key=lambda x: len(x[1]), reverse=True): remaining = self.batch_size - len(buffer_per_group[group_id]) samples_from_group_id = _repeat_to_at_least(samples_per_group[group_id], remaining) buffer_per_group[group_id].extend(samples_from_group_id[:remaining]) @@ -85,10 +83,12 @@ def __len__(self): def _compute_aspect_ratios_slow(dataset, indices=None): - print("Your dataset doesn't support the fast path for " - "computing the aspect ratios, so will iterate over " - "the full dataset and load every image instead. " - "This might take some time...") + print( + "Your dataset doesn't support the fast path for " + "computing the aspect ratios, so will iterate over " + "the full dataset and load every image instead. " + "This might take some time..." + ) if indices is None: indices = range(len(dataset)) @@ -104,9 +104,12 @@ def __len__(self): sampler = SubsetSampler(indices) data_loader = torch.utils.data.DataLoader( - dataset, batch_size=1, sampler=sampler, + dataset, + batch_size=1, + sampler=sampler, num_workers=14, # you might want to increase it for faster processing - collate_fn=lambda x: x[0]) + collate_fn=lambda x: x[0], + ) aspect_ratios = [] with tqdm(total=len(dataset)) as pbar: for _i, (img, _) in enumerate(data_loader): diff --git a/references/detection/presets.py b/references/detection/presets.py index 04e0680043a..88d8c697d2a 100644 --- a/references/detection/presets.py +++ b/references/detection/presets.py @@ -1,32 +1,37 @@ import torch - import transforms as T class DetectionPresetTrain: - def __init__(self, data_augmentation, hflip_prob=0.5, mean=(123., 117., 104.)): - if data_augmentation == 'hflip': - self.transforms = T.Compose([ - T.RandomHorizontalFlip(p=hflip_prob), - T.PILToTensor(), - T.ConvertImageDtype(torch.float), - ]) - elif data_augmentation == 'ssd': - self.transforms = T.Compose([ - T.RandomPhotometricDistort(), - T.RandomZoomOut(fill=list(mean)), - T.RandomIoUCrop(), - T.RandomHorizontalFlip(p=hflip_prob), - T.PILToTensor(), - T.ConvertImageDtype(torch.float), - ]) - elif data_augmentation == 'ssdlite': - self.transforms = T.Compose([ - T.RandomIoUCrop(), - T.RandomHorizontalFlip(p=hflip_prob), - T.PILToTensor(), - T.ConvertImageDtype(torch.float), - ]) + def __init__(self, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104.0)): + if data_augmentation == "hflip": + self.transforms = T.Compose( + [ + T.RandomHorizontalFlip(p=hflip_prob), + T.PILToTensor(), + T.ConvertImageDtype(torch.float), + ] + ) + elif data_augmentation == "ssd": + self.transforms = T.Compose( + [ + T.RandomPhotometricDistort(), + T.RandomZoomOut(fill=list(mean)), + T.RandomIoUCrop(), + T.RandomHorizontalFlip(p=hflip_prob), + T.PILToTensor(), + T.ConvertImageDtype(torch.float), + ] + ) + elif data_augmentation == "ssdlite": + self.transforms = T.Compose( + [ + T.RandomIoUCrop(), + T.RandomHorizontalFlip(p=hflip_prob), + T.PILToTensor(), + T.ConvertImageDtype(torch.float), + ] + ) else: raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"') diff --git a/references/detection/train.py b/references/detection/train.py index cd4148e9bf7..e86762342cc 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -21,26 +21,20 @@ import os import time +import presets import torch import torch.utils.data import torchvision import torchvision.models.detection import torchvision.models.detection.mask_rcnn - +import utils from coco_utils import get_coco, get_coco_kp - -from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups from engine import train_one_epoch, evaluate - -import presets -import utils +from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups def get_dataset(name, image_set, transform, data_path): - paths = { - "coco": (data_path, get_coco, 91), - "coco_kp": (data_path, get_coco_kp, 2) - } + paths = {"coco": (data_path, get_coco, 91), "coco_kp": (data_path, get_coco_kp, 2)} p, ds_fn, num_classes = paths[name] ds = ds_fn(p, image_set=image_set, transforms=transform) @@ -53,42 +47,60 @@ def get_transform(train, data_augmentation): def get_args_parser(add_help=True): import argparse - parser = argparse.ArgumentParser(description='PyTorch Detection Training', add_help=add_help) - - parser.add_argument('--data-path', default='/datasets01/COCO/022719/', help='dataset') - parser.add_argument('--dataset', default='coco', help='dataset') - parser.add_argument('--model', default='maskrcnn_resnet50_fpn', help='model') - parser.add_argument('--device', default='cuda', help='device') - parser.add_argument('-b', '--batch-size', default=2, type=int, - help='images per gpu, the total batch size is $NGPU x batch_size') - parser.add_argument('--epochs', default=26, type=int, metavar='N', - help='number of total epochs to run') - parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', - help='number of data loading workers (default: 4)') - parser.add_argument('--lr', default=0.02, type=float, - help='initial learning rate, 0.02 is the default value for training ' - 'on 8 gpus and 2 images_per_gpu') - parser.add_argument('--momentum', default=0.9, type=float, metavar='M', - help='momentum') - parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, - metavar='W', help='weight decay (default: 1e-4)', - dest='weight_decay') - parser.add_argument('--lr-scheduler', default="multisteplr", help='the lr scheduler (default: multisteplr)') - parser.add_argument('--lr-step-size', default=8, type=int, - help='decrease lr every step-size epochs (multisteplr scheduler only)') - parser.add_argument('--lr-steps', default=[16, 22], nargs='+', type=int, - help='decrease lr every step-size epochs (multisteplr scheduler only)') - parser.add_argument('--lr-gamma', default=0.1, type=float, - help='decrease lr by a factor of lr-gamma (multisteplr scheduler only)') - parser.add_argument('--print-freq', default=20, type=int, help='print frequency') - parser.add_argument('--output-dir', default='.', help='path where to save') - parser.add_argument('--resume', default='', help='resume from checkpoint') - parser.add_argument('--start_epoch', default=0, type=int, help='start epoch') - parser.add_argument('--aspect-ratio-group-factor', default=3, type=int) - parser.add_argument('--rpn-score-thresh', default=None, type=float, help='rpn score threshold for faster-rcnn') - parser.add_argument('--trainable-backbone-layers', default=None, type=int, - help='number of trainable layers of backbone') - parser.add_argument('--data-augmentation', default="hflip", help='data augmentation policy (default: hflip)') + + parser = argparse.ArgumentParser(description="PyTorch Detection Training", add_help=add_help) + + parser.add_argument("--data-path", default="/datasets01/COCO/022719/", help="dataset") + parser.add_argument("--dataset", default="coco", help="dataset") + parser.add_argument("--model", default="maskrcnn_resnet50_fpn", help="model") + parser.add_argument("--device", default="cuda", help="device") + parser.add_argument( + "-b", "--batch-size", default=2, type=int, help="images per gpu, the total batch size is $NGPU x batch_size" + ) + parser.add_argument("--epochs", default=26, type=int, metavar="N", help="number of total epochs to run") + parser.add_argument( + "-j", "--workers", default=4, type=int, metavar="N", help="number of data loading workers (default: 4)" + ) + parser.add_argument( + "--lr", + default=0.02, + type=float, + help="initial learning rate, 0.02 is the default value for training " "on 8 gpus and 2 images_per_gpu", + ) + parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum") + parser.add_argument( + "--wd", + "--weight-decay", + default=1e-4, + type=float, + metavar="W", + help="weight decay (default: 1e-4)", + dest="weight_decay", + ) + parser.add_argument("--lr-scheduler", default="multisteplr", help="the lr scheduler (default: multisteplr)") + parser.add_argument( + "--lr-step-size", default=8, type=int, help="decrease lr every step-size epochs (multisteplr scheduler only)" + ) + parser.add_argument( + "--lr-steps", + default=[16, 22], + nargs="+", + type=int, + help="decrease lr every step-size epochs (multisteplr scheduler only)", + ) + parser.add_argument( + "--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma (multisteplr scheduler only)" + ) + parser.add_argument("--print-freq", default=20, type=int, help="print frequency") + parser.add_argument("--output-dir", default=".", help="path where to save") + parser.add_argument("--resume", default="", help="resume from checkpoint") + parser.add_argument("--start_epoch", default=0, type=int, help="start epoch") + parser.add_argument("--aspect-ratio-group-factor", default=3, type=int) + parser.add_argument("--rpn-score-thresh", default=None, type=float, help="rpn score threshold for faster-rcnn") + parser.add_argument( + "--trainable-backbone-layers", default=None, type=int, help="number of trainable layers of backbone" + ) + parser.add_argument("--data-augmentation", default="hflip", help="data augmentation policy (default: hflip)") parser.add_argument( "--sync-bn", dest="sync_bn", @@ -109,9 +121,8 @@ def get_args_parser(add_help=True): ) # distributed training parameters - parser.add_argument('--world-size', default=1, type=int, - help='number of distributed processes') - parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training') + parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") + parser.add_argument("--dist-url", default="env://", help="url used to set up distributed training") return parser @@ -128,8 +139,9 @@ def main(args): # Data loading code print("Loading data") - dataset, num_classes = get_dataset(args.dataset, "train", get_transform(True, args.data_augmentation), - args.data_path) + dataset, num_classes = get_dataset( + args.dataset, "train", get_transform(True, args.data_augmentation), args.data_path + ) dataset_test, _ = get_dataset(args.dataset, "val", get_transform(False, args.data_augmentation), args.data_path) print("Creating data loaders") @@ -144,27 +156,24 @@ def main(args): group_ids = create_aspect_ratio_groups(dataset, k=args.aspect_ratio_group_factor) train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size) else: - train_batch_sampler = torch.utils.data.BatchSampler( - train_sampler, args.batch_size, drop_last=True) + train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, args.batch_size, drop_last=True) data_loader = torch.utils.data.DataLoader( - dataset, batch_sampler=train_batch_sampler, num_workers=args.workers, - collate_fn=utils.collate_fn) + dataset, batch_sampler=train_batch_sampler, num_workers=args.workers, collate_fn=utils.collate_fn + ) data_loader_test = torch.utils.data.DataLoader( - dataset_test, batch_size=1, - sampler=test_sampler, num_workers=args.workers, - collate_fn=utils.collate_fn) + dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn + ) print("Creating model") - kwargs = { - "trainable_backbone_layers": args.trainable_backbone_layers - } + kwargs = {"trainable_backbone_layers": args.trainable_backbone_layers} if "rcnn" in args.model: if args.rpn_score_thresh is not None: kwargs["rpn_score_thresh"] = args.rpn_score_thresh - model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes, pretrained=args.pretrained, - **kwargs) + model = torchvision.models.detection.__dict__[args.model]( + num_classes=num_classes, pretrained=args.pretrained, **kwargs + ) model.to(device) if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) @@ -175,24 +184,25 @@ def main(args): model_without_ddp = model.module params = [p for p in model.parameters() if p.requires_grad] - optimizer = torch.optim.SGD( - params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) + optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) args.lr_scheduler = args.lr_scheduler.lower() - if args.lr_scheduler == 'multisteplr': + if args.lr_scheduler == "multisteplr": lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma) - elif args.lr_scheduler == 'cosineannealinglr': + elif args.lr_scheduler == "cosineannealinglr": lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) else: - raise RuntimeError("Invalid lr scheduler '{}'. Only MultiStepLR and CosineAnnealingLR " - "are supported.".format(args.lr_scheduler)) + raise RuntimeError( + "Invalid lr scheduler '{}'. Only MultiStepLR and CosineAnnealingLR " + "are supported.".format(args.lr_scheduler) + ) if args.resume: - checkpoint = torch.load(args.resume, map_location='cpu') - model_without_ddp.load_state_dict(checkpoint['model']) - optimizer.load_state_dict(checkpoint['optimizer']) - lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) - args.start_epoch = checkpoint['epoch'] + 1 + checkpoint = torch.load(args.resume, map_location="cpu") + model_without_ddp.load_state_dict(checkpoint["model"]) + optimizer.load_state_dict(checkpoint["optimizer"]) + lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) + args.start_epoch = checkpoint["epoch"] + 1 if args.test_only: evaluate(model, data_loader_test, device=device) @@ -207,25 +217,21 @@ def main(args): lr_scheduler.step() if args.output_dir: checkpoint = { - 'model': model_without_ddp.state_dict(), - 'optimizer': optimizer.state_dict(), - 'lr_scheduler': lr_scheduler.state_dict(), - 'args': args, - 'epoch': epoch + "model": model_without_ddp.state_dict(), + "optimizer": optimizer.state_dict(), + "lr_scheduler": lr_scheduler.state_dict(), + "args": args, + "epoch": epoch, } - utils.save_on_master( - checkpoint, - os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) - utils.save_on_master( - checkpoint, - os.path.join(args.output_dir, 'checkpoint.pth')) + utils.save_on_master(checkpoint, os.path.join(args.output_dir, "model_{}.pth".format(epoch))) + utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth")) # evaluate after every epoch evaluate(model, data_loader_test, device=device) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) - print('Training time {}'.format(total_time_str)) + print("Training time {}".format(total_time_str)) if __name__ == "__main__": diff --git a/references/detection/transforms.py b/references/detection/transforms.py index c65535750b5..787bb75a5c5 100644 --- a/references/detection/transforms.py +++ b/references/detection/transforms.py @@ -28,8 +28,9 @@ def __call__(self, image, target): class RandomHorizontalFlip(T.RandomHorizontalFlip): - def forward(self, image: Tensor, - target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: + def forward( + self, image: Tensor, target: Optional[Dict[str, Tensor]] = None + ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: if torch.rand(1) < self.p: image = F.hflip(image) if target is not None: @@ -45,16 +46,18 @@ def forward(self, image: Tensor, class ToTensor(nn.Module): - def forward(self, image: Tensor, - target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: + def forward( + self, image: Tensor, target: Optional[Dict[str, Tensor]] = None + ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: image = F.pil_to_tensor(image) image = F.convert_image_dtype(image) return image, target class PILToTensor(nn.Module): - def forward(self, image: Tensor, - target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: + def forward( + self, image: Tensor, target: Optional[Dict[str, Tensor]] = None + ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: image = F.pil_to_tensor(image) return image, target @@ -64,15 +67,23 @@ def __init__(self, dtype: torch.dtype) -> None: super().__init__() self.dtype = dtype - def forward(self, image: Tensor, - target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: + def forward( + self, image: Tensor, target: Optional[Dict[str, Tensor]] = None + ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: image = F.convert_image_dtype(image, self.dtype) return image, target class RandomIoUCrop(nn.Module): - def __init__(self, min_scale: float = 0.3, max_scale: float = 1.0, min_aspect_ratio: float = 0.5, - max_aspect_ratio: float = 2.0, sampler_options: Optional[List[float]] = None, trials: int = 40): + def __init__( + self, + min_scale: float = 0.3, + max_scale: float = 1.0, + min_aspect_ratio: float = 0.5, + max_aspect_ratio: float = 2.0, + sampler_options: Optional[List[float]] = None, + trials: int = 40, + ): super().__init__() # Configuration similar to https://github.com/weiliu89/caffe/blob/ssd/examples/ssd/ssd_coco.py#L89-L174 self.min_scale = min_scale @@ -84,14 +95,15 @@ def __init__(self, min_scale: float = 0.3, max_scale: float = 1.0, min_aspect_ra self.options = sampler_options self.trials = trials - def forward(self, image: Tensor, - target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: + def forward( + self, image: Tensor, target: Optional[Dict[str, Tensor]] = None + ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: if target is None: raise ValueError("The targets can't be None for this transform.") if isinstance(image, torch.Tensor): if image.ndimension() not in {2, 3}: - raise ValueError('image should be 2/3 dimensional. Got {} dimensions.'.format(image.ndimension())) + raise ValueError("image should be 2/3 dimensional. Got {} dimensions.".format(image.ndimension())) elif image.ndimension() == 2: image = image.unsqueeze(0) @@ -131,8 +143,9 @@ def forward(self, image: Tensor, # check at least 1 box with jaccard limitations boxes = target["boxes"][is_within_crop_area] - ious = torchvision.ops.boxes.box_iou(boxes, torch.tensor([[left, top, right, bottom]], - dtype=boxes.dtype, device=boxes.device)) + ious = torchvision.ops.boxes.box_iou( + boxes, torch.tensor([[left, top, right, bottom]], dtype=boxes.dtype, device=boxes.device) + ) if ious.max() < min_jaccard_overlap: continue @@ -149,13 +162,15 @@ def forward(self, image: Tensor, class RandomZoomOut(nn.Module): - def __init__(self, fill: Optional[List[float]] = None, side_range: Tuple[float, float] = (1., 4.), p: float = 0.5): + def __init__( + self, fill: Optional[List[float]] = None, side_range: Tuple[float, float] = (1.0, 4.0), p: float = 0.5 + ): super().__init__() if fill is None: - fill = [0., 0., 0.] + fill = [0.0, 0.0, 0.0] self.fill = fill self.side_range = side_range - if side_range[0] < 1. or side_range[0] > side_range[1]: + if side_range[0] < 1.0 or side_range[0] > side_range[1]: raise ValueError("Invalid canvas side range provided {}.".format(side_range)) self.p = p @@ -165,11 +180,12 @@ def _get_fill_value(self, is_pil): # We fake the type to make it work on JIT return tuple(int(x) for x in self.fill) if is_pil else 0 - def forward(self, image: Tensor, - target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: + def forward( + self, image: Tensor, target: Optional[Dict[str, Tensor]] = None + ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: if isinstance(image, torch.Tensor): if image.ndimension() not in {2, 3}: - raise ValueError('image should be 2/3 dimensional. Got {} dimensions.'.format(image.ndimension())) + raise ValueError("image should be 2/3 dimensional. Got {} dimensions.".format(image.ndimension())) elif image.ndimension() == 2: image = image.unsqueeze(0) @@ -196,8 +212,9 @@ def forward(self, image: Tensor, image = F.pad(image, [left, top, right, bottom], fill=fill) if isinstance(image, torch.Tensor): v = torch.tensor(self.fill, device=image.device, dtype=image.dtype).view(-1, 1, 1) - image[..., :top, :] = image[..., :, :left] = image[..., (top + orig_h):, :] = \ - image[..., :, (left + orig_w):] = v + image[..., :top, :] = image[..., :, :left] = image[..., (top + orig_h) :, :] = image[ + ..., :, (left + orig_w) : + ] = v if target is not None: target["boxes"][:, 0::2] += left @@ -207,8 +224,14 @@ def forward(self, image: Tensor, class RandomPhotometricDistort(nn.Module): - def __init__(self, contrast: Tuple[float] = (0.5, 1.5), saturation: Tuple[float] = (0.5, 1.5), - hue: Tuple[float] = (-0.05, 0.05), brightness: Tuple[float] = (0.875, 1.125), p: float = 0.5): + def __init__( + self, + contrast: Tuple[float] = (0.5, 1.5), + saturation: Tuple[float] = (0.5, 1.5), + hue: Tuple[float] = (-0.05, 0.05), + brightness: Tuple[float] = (0.875, 1.125), + p: float = 0.5, + ): super().__init__() self._brightness = T.ColorJitter(brightness=brightness) self._contrast = T.ColorJitter(contrast=contrast) @@ -216,11 +239,12 @@ def __init__(self, contrast: Tuple[float] = (0.5, 1.5), saturation: Tuple[float] self._saturation = T.ColorJitter(saturation=saturation) self.p = p - def forward(self, image: Tensor, - target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: + def forward( + self, image: Tensor, target: Optional[Dict[str, Tensor]] = None + ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: if isinstance(image, torch.Tensor): if image.ndimension() not in {2, 3}: - raise ValueError('image should be 2/3 dimensional. Got {} dimensions.'.format(image.ndimension())) + raise ValueError("image should be 2/3 dimensional. Got {} dimensions.".format(image.ndimension())) elif image.ndimension() == 2: image = image.unsqueeze(0) diff --git a/references/detection/utils.py b/references/detection/utils.py index 11fcd3060e4..c708ca05413 100644 --- a/references/detection/utils.py +++ b/references/detection/utils.py @@ -1,8 +1,8 @@ -from collections import defaultdict, deque import datetime import errno import os import time +from collections import defaultdict, deque import torch import torch.distributed as dist @@ -32,7 +32,7 @@ def synchronize_between_processes(self): """ if not is_dist_avail_and_initialized(): return - t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") dist.barrier() dist.all_reduce(t) t = t.tolist() @@ -63,11 +63,8 @@ def value(self): def __str__(self): return self.fmt.format( - median=self.median, - avg=self.avg, - global_avg=self.global_avg, - max=self.max, - value=self.value) + median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value + ) def all_gather(data): @@ -130,15 +127,12 @@ def __getattr__(self, attr): return self.meters[attr] if attr in self.__dict__: return self.__dict__[attr] - raise AttributeError("'{}' object has no attribute '{}'".format( - type(self).__name__, attr)) + raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr)) def __str__(self): loss_str = [] for name, meter in self.meters.items(): - loss_str.append( - "{}: {}".format(name, str(meter)) - ) + loss_str.append("{}: {}".format(name, str(meter))) return self.delimiter.join(loss_str) def synchronize_between_processes(self): @@ -151,31 +145,28 @@ def add_meter(self, name, meter): def log_every(self, iterable, print_freq, header=None): i = 0 if not header: - header = '' + header = "" start_time = time.time() end = time.time() - iter_time = SmoothedValue(fmt='{avg:.4f}') - data_time = SmoothedValue(fmt='{avg:.4f}') - space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + iter_time = SmoothedValue(fmt="{avg:.4f}") + data_time = SmoothedValue(fmt="{avg:.4f}") + space_fmt = ":" + str(len(str(len(iterable)))) + "d" if torch.cuda.is_available(): - log_msg = self.delimiter.join([ - header, - '[{0' + space_fmt + '}/{1}]', - 'eta: {eta}', - '{meters}', - 'time: {time}', - 'data: {data}', - 'max mem: {memory:.0f}' - ]) + log_msg = self.delimiter.join( + [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + "max mem: {memory:.0f}", + ] + ) else: - log_msg = self.delimiter.join([ - header, - '[{0' + space_fmt + '}/{1}]', - 'eta: {eta}', - '{meters}', - 'time: {time}', - 'data: {data}' - ]) + log_msg = self.delimiter.join( + [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"] + ) MB = 1024.0 * 1024.0 for obj in iterable: data_time.update(time.time() - end) @@ -185,22 +176,28 @@ def log_every(self, iterable, print_freq, header=None): eta_seconds = iter_time.global_avg * (len(iterable) - i) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) if torch.cuda.is_available(): - print(log_msg.format( - i, len(iterable), eta=eta_string, - meters=str(self), - time=str(iter_time), data=str(data_time), - memory=torch.cuda.max_memory_allocated() / MB)) + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB, + ) + ) else: - print(log_msg.format( - i, len(iterable), eta=eta_string, - meters=str(self), - time=str(iter_time), data=str(data_time))) + print( + log_msg.format( + i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time) + ) + ) i += 1 end = time.time() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) - print('{} Total time: {} ({:.4f} s / it)'.format( - header, total_time_str, total_time / len(iterable))) + print("{} Total time: {} ({:.4f} s / it)".format(header, total_time_str, total_time / len(iterable))) def collate_fn(batch): @@ -220,10 +217,11 @@ def setup_for_distributed(is_master): This function disables printing when not in master process """ import builtins as __builtin__ + builtin_print = __builtin__.print def print(*args, **kwargs): - force = kwargs.pop('force', False) + force = kwargs.pop("force", False) if is_master or force: builtin_print(*args, **kwargs) @@ -260,25 +258,25 @@ def save_on_master(*args, **kwargs): def init_distributed_mode(args): - if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: args.rank = int(os.environ["RANK"]) - args.world_size = int(os.environ['WORLD_SIZE']) - args.gpu = int(os.environ['LOCAL_RANK']) - elif 'SLURM_PROCID' in os.environ: - args.rank = int(os.environ['SLURM_PROCID']) + args.world_size = int(os.environ["WORLD_SIZE"]) + args.gpu = int(os.environ["LOCAL_RANK"]) + elif "SLURM_PROCID" in os.environ: + args.rank = int(os.environ["SLURM_PROCID"]) args.gpu = args.rank % torch.cuda.device_count() else: - print('Not using distributed mode') + print("Not using distributed mode") args.distributed = False return args.distributed = True torch.cuda.set_device(args.gpu) - args.dist_backend = 'nccl' - print('| distributed init (rank {}): {}'.format( - args.rank, args.dist_url), flush=True) - torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, - world_size=args.world_size, rank=args.rank) + args.dist_backend = "nccl" + print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True) + torch.distributed.init_process_group( + backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank + ) torch.distributed.barrier() setup_for_distributed(args.rank == 0) diff --git a/references/segmentation/coco_utils.py b/references/segmentation/coco_utils.py index c86d5495247..83091c75d95 100644 --- a/references/segmentation/coco_utils.py +++ b/references/segmentation/coco_utils.py @@ -1,13 +1,11 @@ import copy +import os + import torch import torch.utils.data import torchvision from PIL import Image - -import os - from pycocotools import mask as coco_mask - from transforms import Compose @@ -90,14 +88,9 @@ def get_coco(root, image_set, transforms): "val": ("val2017", os.path.join("annotations", "instances_val2017.json")), # "train": ("val2017", os.path.join("annotations", "instances_val2017.json")) } - CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, - 1, 64, 20, 63, 7, 72] - - transforms = Compose([ - FilterAndRemapCocoCategories(CAT_LIST, remap=True), - ConvertCocoPolysToMask(), - transforms - ]) + CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 1, 64, 20, 63, 7, 72] + + transforms = Compose([FilterAndRemapCocoCategories(CAT_LIST, remap=True), ConvertCocoPolysToMask(), transforms]) img_folder, ann_file = PATHS[image_set] img_folder = os.path.join(root, img_folder) diff --git a/references/segmentation/presets.py b/references/segmentation/presets.py index 96334356fcb..8cada98ac95 100644 --- a/references/segmentation/presets.py +++ b/references/segmentation/presets.py @@ -1,5 +1,4 @@ import torch - import transforms as T @@ -11,12 +10,14 @@ def __init__(self, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.4 trans = [T.RandomResize(min_size, max_size)] if hflip_prob > 0: trans.append(T.RandomHorizontalFlip(hflip_prob)) - trans.extend([ - T.RandomCrop(crop_size), - T.PILToTensor(), - T.ConvertImageDtype(torch.float), - T.Normalize(mean=mean, std=std), - ]) + trans.extend( + [ + T.RandomCrop(crop_size), + T.PILToTensor(), + T.ConvertImageDtype(torch.float), + T.Normalize(mean=mean, std=std), + ] + ) self.transforms = T.Compose(trans) def __call__(self, img, target): @@ -25,12 +26,14 @@ def __call__(self, img, target): class SegmentationPresetEval: def __init__(self, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): - self.transforms = T.Compose([ - T.RandomResize(base_size, base_size), - T.PILToTensor(), - T.ConvertImageDtype(torch.float), - T.Normalize(mean=mean, std=std), - ]) + self.transforms = T.Compose( + [ + T.RandomResize(base_size, base_size), + T.PILToTensor(), + T.ConvertImageDtype(torch.float), + T.Normalize(mean=mean, std=std), + ] + ) def __call__(self, img, target): return self.transforms(img, target) diff --git a/references/segmentation/train.py b/references/segmentation/train.py index 83277de9c2c..3a41f86ba87 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -2,23 +2,23 @@ import os import time +import presets import torch import torch.utils.data -from torch import nn import torchvision - -from coco_utils import get_coco -import presets import utils +from coco_utils import get_coco +from torch import nn def get_dataset(dir_path, name, image_set, transform): def sbd(*args, **kwargs): - return torchvision.datasets.SBDataset(*args, mode='segmentation', **kwargs) + return torchvision.datasets.SBDataset(*args, mode="segmentation", **kwargs) + paths = { "voc": (dir_path, torchvision.datasets.VOCSegmentation, 21), "voc_aug": (dir_path, sbd, 21), - "coco": (dir_path, get_coco, 21) + "coco": (dir_path, get_coco, 21), } p, ds_fn, num_classes = paths[name] @@ -39,21 +39,21 @@ def criterion(inputs, target): losses[name] = nn.functional.cross_entropy(x, target, ignore_index=255) if len(losses) == 1: - return losses['out'] + return losses["out"] - return losses['out'] + 0.5 * losses['aux'] + return losses["out"] + 0.5 * losses["aux"] def evaluate(model, data_loader, device, num_classes): model.eval() confmat = utils.ConfusionMatrix(num_classes) metric_logger = utils.MetricLogger(delimiter=" ") - header = 'Test:' + header = "Test:" with torch.no_grad(): for image, target in metric_logger.log_every(data_loader, 100, header): image, target = image.to(device), target.to(device) output = model(image) - output = output['out'] + output = output["out"] confmat.update(target.flatten(), output.argmax(1).flatten()) @@ -65,8 +65,8 @@ def evaluate(model, data_loader, device, num_classes): def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, print_freq): model.train() metric_logger = utils.MetricLogger(delimiter=" ") - metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}')) - header = 'Epoch: [{}]'.format(epoch) + metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}")) + header = "Epoch: [{}]".format(epoch) for image, target in metric_logger.log_every(data_loader, print_freq, header): image, target = image.to(device), target.to(device) output = model(image) @@ -101,18 +101,21 @@ def main(args): test_sampler = torch.utils.data.SequentialSampler(dataset_test) data_loader = torch.utils.data.DataLoader( - dataset, batch_size=args.batch_size, - sampler=train_sampler, num_workers=args.workers, - collate_fn=utils.collate_fn, drop_last=True) + dataset, + batch_size=args.batch_size, + sampler=train_sampler, + num_workers=args.workers, + collate_fn=utils.collate_fn, + drop_last=True, + ) data_loader_test = torch.utils.data.DataLoader( - dataset_test, batch_size=1, - sampler=test_sampler, num_workers=args.workers, - collate_fn=utils.collate_fn) + dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn + ) - model = torchvision.models.segmentation.__dict__[args.model](num_classes=num_classes, - aux_loss=args.aux_loss, - pretrained=args.pretrained) + model = torchvision.models.segmentation.__dict__[args.model]( + num_classes=num_classes, aux_loss=args.aux_loss, pretrained=args.pretrained + ) model.to(device) if args.distributed: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) @@ -129,42 +132,42 @@ def main(args): if args.aux_loss: params = [p for p in model_without_ddp.aux_classifier.parameters() if p.requires_grad] params_to_optimize.append({"params": params, "lr": args.lr * 10}) - optimizer = torch.optim.SGD( - params_to_optimize, - lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) + optimizer = torch.optim.SGD(params_to_optimize, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) iters_per_epoch = len(data_loader) main_lr_scheduler = torch.optim.lr_scheduler.LambdaLR( - optimizer, - lambda x: (1 - x / (iters_per_epoch * (args.epochs - args.lr_warmup_epochs))) ** 0.9) + optimizer, lambda x: (1 - x / (iters_per_epoch * (args.epochs - args.lr_warmup_epochs))) ** 0.9 + ) if args.lr_warmup_epochs > 0: warmup_iters = iters_per_epoch * args.lr_warmup_epochs args.lr_warmup_method = args.lr_warmup_method.lower() - if args.lr_warmup_method == 'linear': - warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=args.lr_warmup_decay, - total_iters=warmup_iters) - elif args.lr_warmup_method == 'constant': - warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=args.lr_warmup_decay, - total_iters=warmup_iters) + if args.lr_warmup_method == "linear": + warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR( + optimizer, start_factor=args.lr_warmup_decay, total_iters=warmup_iters + ) + elif args.lr_warmup_method == "constant": + warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR( + optimizer, factor=args.lr_warmup_decay, total_iters=warmup_iters + ) else: - raise RuntimeError("Invalid warmup lr method '{}'. Only linear and constant " - "are supported.".format(args.lr_warmup_method)) + raise RuntimeError( + "Invalid warmup lr method '{}'. Only linear and constant " + "are supported.".format(args.lr_warmup_method) + ) lr_scheduler = torch.optim.lr_scheduler.SequentialLR( - optimizer, - schedulers=[warmup_lr_scheduler, main_lr_scheduler], - milestones=[warmup_iters] + optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[warmup_iters] ) else: lr_scheduler = main_lr_scheduler if args.resume: - checkpoint = torch.load(args.resume, map_location='cpu') - model_without_ddp.load_state_dict(checkpoint['model'], strict=not args.test_only) + checkpoint = torch.load(args.resume, map_location="cpu") + model_without_ddp.load_state_dict(checkpoint["model"], strict=not args.test_only) if not args.test_only: - optimizer.load_state_dict(checkpoint['optimizer']) - lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) - args.start_epoch = checkpoint['epoch'] + 1 + optimizer.load_state_dict(checkpoint["optimizer"]) + lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) + args.start_epoch = checkpoint["epoch"] + 1 if args.test_only: confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes) @@ -179,53 +182,54 @@ def main(args): confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes) print(confmat) checkpoint = { - 'model': model_without_ddp.state_dict(), - 'optimizer': optimizer.state_dict(), - 'lr_scheduler': lr_scheduler.state_dict(), - 'epoch': epoch, - 'args': args + "model": model_without_ddp.state_dict(), + "optimizer": optimizer.state_dict(), + "lr_scheduler": lr_scheduler.state_dict(), + "epoch": epoch, + "args": args, } - utils.save_on_master( - checkpoint, - os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) - utils.save_on_master( - checkpoint, - os.path.join(args.output_dir, 'checkpoint.pth')) + utils.save_on_master(checkpoint, os.path.join(args.output_dir, "model_{}.pth".format(epoch))) + utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth")) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) - print('Training time {}'.format(total_time_str)) + print("Training time {}".format(total_time_str)) def get_args_parser(add_help=True): import argparse - parser = argparse.ArgumentParser(description='PyTorch Segmentation Training', add_help=add_help) - - parser.add_argument('--data-path', default='/datasets01/COCO/022719/', help='dataset path') - parser.add_argument('--dataset', default='coco', help='dataset name') - parser.add_argument('--model', default='fcn_resnet101', help='model') - parser.add_argument('--aux-loss', action='store_true', help='auxiliar loss') - parser.add_argument('--device', default='cuda', help='device') - parser.add_argument('-b', '--batch-size', default=8, type=int) - parser.add_argument('--epochs', default=30, type=int, metavar='N', - help='number of total epochs to run') - - parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', - help='number of data loading workers (default: 16)') - parser.add_argument('--lr', default=0.01, type=float, help='initial learning rate') - parser.add_argument('--momentum', default=0.9, type=float, metavar='M', - help='momentum') - parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, - metavar='W', help='weight decay (default: 1e-4)', - dest='weight_decay') - parser.add_argument('--lr-warmup-epochs', default=0, type=int, help='the number of epochs to warmup (default: 0)') - parser.add_argument('--lr-warmup-method', default="linear", type=str, help='the warmup method (default: linear)') - parser.add_argument('--lr-warmup-decay', default=0.01, type=float, help='the decay for lr') - parser.add_argument('--print-freq', default=10, type=int, help='print frequency') - parser.add_argument('--output-dir', default='.', help='path where to save') - parser.add_argument('--resume', default='', help='resume from checkpoint') - parser.add_argument('--start-epoch', default=0, type=int, metavar='N', - help='start epoch') + + parser = argparse.ArgumentParser(description="PyTorch Segmentation Training", add_help=add_help) + + parser.add_argument("--data-path", default="/datasets01/COCO/022719/", help="dataset path") + parser.add_argument("--dataset", default="coco", help="dataset name") + parser.add_argument("--model", default="fcn_resnet101", help="model") + parser.add_argument("--aux-loss", action="store_true", help="auxiliar loss") + parser.add_argument("--device", default="cuda", help="device") + parser.add_argument("-b", "--batch-size", default=8, type=int) + parser.add_argument("--epochs", default=30, type=int, metavar="N", help="number of total epochs to run") + + parser.add_argument( + "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)" + ) + parser.add_argument("--lr", default=0.01, type=float, help="initial learning rate") + parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum") + parser.add_argument( + "--wd", + "--weight-decay", + default=1e-4, + type=float, + metavar="W", + help="weight decay (default: 1e-4)", + dest="weight_decay", + ) + parser.add_argument("--lr-warmup-epochs", default=0, type=int, help="the number of epochs to warmup (default: 0)") + parser.add_argument("--lr-warmup-method", default="linear", type=str, help="the warmup method (default: linear)") + parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr") + parser.add_argument("--print-freq", default=10, type=int, help="print frequency") + parser.add_argument("--output-dir", default=".", help="path where to save") + parser.add_argument("--resume", default="", help="resume from checkpoint") + parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch") parser.add_argument( "--test-only", dest="test_only", @@ -239,9 +243,8 @@ def get_args_parser(add_help=True): action="store_true", ) # distributed training parameters - parser.add_argument('--world-size', default=1, type=int, - help='number of distributed processes') - parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training') + parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") + parser.add_argument("--dist-url", default="env://", help="url used to set up distributed training") return parser diff --git a/references/segmentation/utils.py b/references/segmentation/utils.py index b67c18052fb..2bb5451289a 100644 --- a/references/segmentation/utils.py +++ b/references/segmentation/utils.py @@ -1,12 +1,12 @@ -from collections import defaultdict, deque import datetime +import errno +import os import time +from collections import defaultdict, deque + import torch import torch.distributed as dist -import errno -import os - class SmoothedValue(object): """Track a series of values and provide access to smoothed values over a @@ -32,7 +32,7 @@ def synchronize_between_processes(self): """ if not is_dist_avail_and_initialized(): return - t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") dist.barrier() dist.all_reduce(t) t = t.tolist() @@ -63,11 +63,8 @@ def value(self): def __str__(self): return self.fmt.format( - median=self.median, - avg=self.avg, - global_avg=self.global_avg, - max=self.max, - value=self.value) + median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value + ) class ConfusionMatrix(object): @@ -82,7 +79,7 @@ def update(self, a, b): with torch.no_grad(): k = (a >= 0) & (a < n) inds = n * a[k].to(torch.int64) + b[k] - self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n) + self.mat += torch.bincount(inds, minlength=n ** 2).reshape(n, n) def reset(self): self.mat.zero_() @@ -104,15 +101,12 @@ def reduce_from_all_processes(self): def __str__(self): acc_global, acc, iu = self.compute() - return ( - 'global correct: {:.1f}\n' - 'average row correct: {}\n' - 'IoU: {}\n' - 'mean IoU: {:.1f}').format( - acc_global.item() * 100, - ['{:.1f}'.format(i) for i in (acc * 100).tolist()], - ['{:.1f}'.format(i) for i in (iu * 100).tolist()], - iu.mean().item() * 100) + return ("global correct: {:.1f}\n" "average row correct: {}\n" "IoU: {}\n" "mean IoU: {:.1f}").format( + acc_global.item() * 100, + ["{:.1f}".format(i) for i in (acc * 100).tolist()], + ["{:.1f}".format(i) for i in (iu * 100).tolist()], + iu.mean().item() * 100, + ) class MetricLogger(object): @@ -132,15 +126,12 @@ def __getattr__(self, attr): return self.meters[attr] if attr in self.__dict__: return self.__dict__[attr] - raise AttributeError("'{}' object has no attribute '{}'".format( - type(self).__name__, attr)) + raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr)) def __str__(self): loss_str = [] for name, meter in self.meters.items(): - loss_str.append( - "{}: {}".format(name, str(meter)) - ) + loss_str.append("{}: {}".format(name, str(meter))) return self.delimiter.join(loss_str) def synchronize_between_processes(self): @@ -153,31 +144,28 @@ def add_meter(self, name, meter): def log_every(self, iterable, print_freq, header=None): i = 0 if not header: - header = '' + header = "" start_time = time.time() end = time.time() - iter_time = SmoothedValue(fmt='{avg:.4f}') - data_time = SmoothedValue(fmt='{avg:.4f}') - space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + iter_time = SmoothedValue(fmt="{avg:.4f}") + data_time = SmoothedValue(fmt="{avg:.4f}") + space_fmt = ":" + str(len(str(len(iterable)))) + "d" if torch.cuda.is_available(): - log_msg = self.delimiter.join([ - header, - '[{0' + space_fmt + '}/{1}]', - 'eta: {eta}', - '{meters}', - 'time: {time}', - 'data: {data}', - 'max mem: {memory:.0f}' - ]) + log_msg = self.delimiter.join( + [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + "max mem: {memory:.0f}", + ] + ) else: - log_msg = self.delimiter.join([ - header, - '[{0' + space_fmt + '}/{1}]', - 'eta: {eta}', - '{meters}', - 'time: {time}', - 'data: {data}' - ]) + log_msg = self.delimiter.join( + [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"] + ) MB = 1024.0 * 1024.0 for obj in iterable: data_time.update(time.time() - end) @@ -187,21 +175,28 @@ def log_every(self, iterable, print_freq, header=None): eta_seconds = iter_time.global_avg * (len(iterable) - i) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) if torch.cuda.is_available(): - print(log_msg.format( - i, len(iterable), eta=eta_string, - meters=str(self), - time=str(iter_time), data=str(data_time), - memory=torch.cuda.max_memory_allocated() / MB)) + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB, + ) + ) else: - print(log_msg.format( - i, len(iterable), eta=eta_string, - meters=str(self), - time=str(iter_time), data=str(data_time))) + print( + log_msg.format( + i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time) + ) + ) i += 1 end = time.time() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) - print('{} Total time: {}'.format(header, total_time_str)) + print("{} Total time: {}".format(header, total_time_str)) def cat_list(images, fill_value=0): @@ -209,7 +204,7 @@ def cat_list(images, fill_value=0): batch_shape = (len(images),) + max_size batched_imgs = images[0].new(*batch_shape).fill_(fill_value) for img, pad_img in zip(images, batched_imgs): - pad_img[..., :img.shape[-2], :img.shape[-1]].copy_(img) + pad_img[..., : img.shape[-2], : img.shape[-1]].copy_(img) return batched_imgs @@ -233,10 +228,11 @@ def setup_for_distributed(is_master): This function disables printing when not in master process """ import builtins as __builtin__ + builtin_print = __builtin__.print def print(*args, **kwargs): - force = kwargs.pop('force', False) + force = kwargs.pop("force", False) if is_master or force: builtin_print(*args, **kwargs) @@ -273,26 +269,26 @@ def save_on_master(*args, **kwargs): def init_distributed_mode(args): - if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: args.rank = int(os.environ["RANK"]) - args.world_size = int(os.environ['WORLD_SIZE']) - args.gpu = int(os.environ['LOCAL_RANK']) - elif 'SLURM_PROCID' in os.environ: - args.rank = int(os.environ['SLURM_PROCID']) + args.world_size = int(os.environ["WORLD_SIZE"]) + args.gpu = int(os.environ["LOCAL_RANK"]) + elif "SLURM_PROCID" in os.environ: + args.rank = int(os.environ["SLURM_PROCID"]) args.gpu = args.rank % torch.cuda.device_count() elif hasattr(args, "rank"): pass else: - print('Not using distributed mode') + print("Not using distributed mode") args.distributed = False return args.distributed = True torch.cuda.set_device(args.gpu) - args.dist_backend = 'nccl' - print('| distributed init (rank {}): {}'.format( - args.rank, args.dist_url), flush=True) - torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, - world_size=args.world_size, rank=args.rank) + args.dist_backend = "nccl" + print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True) + torch.distributed.init_process_group( + backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank + ) setup_for_distributed(args.rank == 0) diff --git a/references/similarity/loss.py b/references/similarity/loss.py index 1fa4a89c762..237ad8e9e11 100644 --- a/references/similarity/loss.py +++ b/references/similarity/loss.py @@ -1,21 +1,21 @@ -''' +""" Pytorch adaptation of https://omoindrot.github.io/triplet-loss https://github.com/omoindrot/tensorflow-triplet-loss -''' +""" import torch import torch.nn as nn class TripletMarginLoss(nn.Module): - def __init__(self, margin=1.0, p=2., mining='batch_all'): + def __init__(self, margin=1.0, p=2.0, mining="batch_all"): super(TripletMarginLoss, self).__init__() self.margin = margin self.p = p self.mining = mining - if mining == 'batch_all': + if mining == "batch_all": self.loss_fn = batch_all_triplet_loss - if mining == 'batch_hard': + if mining == "batch_hard": self.loss_fn = batch_hard_triplet_loss def forward(self, embeddings, labels): diff --git a/references/similarity/sampler.py b/references/similarity/sampler.py index 0ae6d07a77c..591155fb449 100644 --- a/references/similarity/sampler.py +++ b/references/similarity/sampler.py @@ -1,7 +1,8 @@ +import random +from collections import defaultdict + import torch from torch.utils.data.sampler import Sampler -from collections import defaultdict -import random def create_groups(groups, k): diff --git a/references/similarity/test.py b/references/similarity/test.py index 8381e02e740..b7de7a405e1 100644 --- a/references/similarity/test.py +++ b/references/similarity/test.py @@ -1,15 +1,13 @@ import unittest from collections import defaultdict -from torch.utils.data import DataLoader -from torchvision.datasets import FakeData import torchvision.transforms as transforms - from sampler import PKSampler +from torch.utils.data import DataLoader +from torchvision.datasets import FakeData class Tester(unittest.TestCase): - def test_pksampler(self): p, k = 16, 4 @@ -19,8 +17,7 @@ def test_pksampler(self): self.assertRaises(AssertionError, PKSampler, targets, p, k) # Ensure p, k constraints on batch - dataset = FakeData(size=1000, num_classes=100, image_size=(3, 1, 1), - transform=transforms.ToTensor()) + dataset = FakeData(size=1000, num_classes=100, image_size=(3, 1, 1), transform=transforms.ToTensor()) targets = [target.item() for _, target in dataset] sampler = PKSampler(targets, p, k) loader = DataLoader(dataset, batch_size=p * k, sampler=sampler) @@ -38,5 +35,5 @@ def test_pksampler(self): self.assertEqual(bins[b], k) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/references/similarity/train.py b/references/similarity/train.py index 9a166a14b38..c8f041acdad 100644 --- a/references/similarity/train.py +++ b/references/similarity/train.py @@ -1,15 +1,13 @@ import os import torch -from torch.optim import Adam -from torch.utils.data import DataLoader - import torchvision.transforms as transforms -from torchvision.datasets import FashionMNIST - from loss import TripletMarginLoss -from sampler import PKSampler from model import EmbeddingNet +from sampler import PKSampler +from torch.optim import Adam +from torch.utils.data import DataLoader +from torchvision.datasets import FashionMNIST def train_epoch(model, optimizer, criterion, data_loader, device, epoch, print_freq): @@ -33,7 +31,7 @@ def train_epoch(model, optimizer, criterion, data_loader, device, epoch, print_f i += 1 avg_loss = running_loss / print_freq avg_trip = 100.0 * running_frac_pos_triplets / print_freq - print('[{:d}, {:d}] | loss: {:.4f} | % avg hard triplets: {:.2f}%'.format(epoch, i, avg_loss, avg_trip)) + print("[{:d}, {:d}] | loss: {:.4f} | % avg hard triplets: {:.2f}%".format(epoch, i, avg_loss, avg_trip)) running_loss = 0 running_frac_pos_triplets = 0 @@ -79,17 +77,17 @@ def evaluate(model, loader, device): threshold, accuracy = find_best_threshold(dists, targets, device) - print('accuracy: {:.3f}%, threshold: {:.2f}'.format(accuracy, threshold)) + print("accuracy: {:.3f}%, threshold: {:.2f}".format(accuracy, threshold)) def save(model, epoch, save_dir, file_name): - file_name = 'epoch_' + str(epoch) + '__' + file_name + file_name = "epoch_" + str(epoch) + "__" + file_name save_path = os.path.join(save_dir, file_name) torch.save(model.state_dict(), save_path) def main(args): - device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") p = args.labels_per_batch k = args.samples_per_label batch_size = p * k @@ -103,9 +101,9 @@ def main(args): criterion = TripletMarginLoss(margin=args.margin) optimizer = Adam(model.parameters(), lr=args.lr) - transform = transforms.Compose([transforms.Lambda(lambda image: image.convert('RGB')), - transforms.Resize((224, 224)), - transforms.ToTensor()]) + transform = transforms.Compose( + [transforms.Lambda(lambda image: image.convert("RGB")), transforms.Resize((224, 224)), transforms.ToTensor()] + ) # Using FMNIST to demonstrate embedding learning using triplet loss. This dataset can # be replaced with any classification dataset. @@ -118,48 +116,44 @@ def main(args): # targets attribute with the same format. targets = train_dataset.targets.tolist() - train_loader = DataLoader(train_dataset, batch_size=batch_size, - sampler=PKSampler(targets, p, k), - num_workers=args.workers) - test_loader = DataLoader(test_dataset, batch_size=args.eval_batch_size, - shuffle=False, - num_workers=args.workers) + train_loader = DataLoader( + train_dataset, batch_size=batch_size, sampler=PKSampler(targets, p, k), num_workers=args.workers + ) + test_loader = DataLoader(test_dataset, batch_size=args.eval_batch_size, shuffle=False, num_workers=args.workers) for epoch in range(1, args.epochs + 1): - print('Training...') + print("Training...") train_epoch(model, optimizer, criterion, train_loader, device, epoch, args.print_freq) - print('Evaluating...') + print("Evaluating...") evaluate(model, test_loader, device) - print('Saving...') - save(model, epoch, args.save_dir, 'ckpt.pth') + print("Saving...") + save(model, epoch, args.save_dir, "ckpt.pth") def parse_args(): import argparse - parser = argparse.ArgumentParser(description='PyTorch Embedding Learning') - - parser.add_argument('--dataset-dir', default='/tmp/fmnist/', - help='FashionMNIST dataset directory path') - parser.add_argument('-p', '--labels-per-batch', default=8, type=int, - help='Number of unique labels/classes per batch') - parser.add_argument('-k', '--samples-per-label', default=8, type=int, - help='Number of samples per label in a batch') - parser.add_argument('--eval-batch-size', default=512, type=int) - parser.add_argument('--epochs', default=10, type=int, metavar='N', - help='Number of training epochs to run') - parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', - help='Number of data loading workers') - parser.add_argument('--lr', default=0.0001, type=float, help='Learning rate') - parser.add_argument('--margin', default=0.2, type=float, help='Triplet loss margin') - parser.add_argument('--print-freq', default=20, type=int, help='Print frequency') - parser.add_argument('--save-dir', default='.', help='Model save directory') - parser.add_argument('--resume', default='', help='Resume from checkpoint') + + parser = argparse.ArgumentParser(description="PyTorch Embedding Learning") + + parser.add_argument("--dataset-dir", default="/tmp/fmnist/", help="FashionMNIST dataset directory path") + parser.add_argument( + "-p", "--labels-per-batch", default=8, type=int, help="Number of unique labels/classes per batch" + ) + parser.add_argument("-k", "--samples-per-label", default=8, type=int, help="Number of samples per label in a batch") + parser.add_argument("--eval-batch-size", default=512, type=int) + parser.add_argument("--epochs", default=10, type=int, metavar="N", help="Number of training epochs to run") + parser.add_argument("-j", "--workers", default=4, type=int, metavar="N", help="Number of data loading workers") + parser.add_argument("--lr", default=0.0001, type=float, help="Learning rate") + parser.add_argument("--margin", default=0.2, type=float, help="Triplet loss margin") + parser.add_argument("--print-freq", default=20, type=int, help="Print frequency") + parser.add_argument("--save-dir", default=".", help="Model save directory") + parser.add_argument("--resume", default="", help="Resume from checkpoint") return parser.parse_args() -if __name__ == '__main__': +if __name__ == "__main__": args = parse_args() main(args) diff --git a/references/video_classification/presets.py b/references/video_classification/presets.py index 3ee679ad5af..04039c9a4f1 100644 --- a/references/video_classification/presets.py +++ b/references/video_classification/presets.py @@ -1,12 +1,17 @@ import torch - from torchvision.transforms import transforms from transforms import ConvertBHWCtoBCHW, ConvertBCHWtoCBHW class VideoClassificationPresetTrain: - def __init__(self, resize_size, crop_size, mean=(0.43216, 0.394666, 0.37645), std=(0.22803, 0.22145, 0.216989), - hflip_prob=0.5): + def __init__( + self, + resize_size, + crop_size, + mean=(0.43216, 0.394666, 0.37645), + std=(0.22803, 0.22145, 0.216989), + hflip_prob=0.5, + ): trans = [ ConvertBHWCtoBCHW(), transforms.ConvertImageDtype(torch.float32), @@ -14,11 +19,7 @@ def __init__(self, resize_size, crop_size, mean=(0.43216, 0.394666, 0.37645), st ] if hflip_prob > 0: trans.append(transforms.RandomHorizontalFlip(hflip_prob)) - trans.extend([ - transforms.Normalize(mean=mean, std=std), - transforms.RandomCrop(crop_size), - ConvertBCHWtoCBHW() - ]) + trans.extend([transforms.Normalize(mean=mean, std=std), transforms.RandomCrop(crop_size), ConvertBCHWtoCBHW()]) self.transforms = transforms.Compose(trans) def __call__(self, x): @@ -27,14 +28,16 @@ def __call__(self, x): class VideoClassificationPresetEval: def __init__(self, resize_size, crop_size, mean=(0.43216, 0.394666, 0.37645), std=(0.22803, 0.22145, 0.216989)): - self.transforms = transforms.Compose([ - ConvertBHWCtoBCHW(), - transforms.ConvertImageDtype(torch.float32), - transforms.Resize(resize_size), - transforms.Normalize(mean=mean, std=std), - transforms.CenterCrop(crop_size), - ConvertBCHWtoCBHW() - ]) + self.transforms = transforms.Compose( + [ + ConvertBHWCtoBCHW(), + transforms.ConvertImageDtype(torch.float32), + transforms.Resize(resize_size), + transforms.Normalize(mean=mean, std=std), + transforms.CenterCrop(crop_size), + ConvertBCHWtoCBHW(), + ] + ) def __call__(self, x): return self.transforms(x) diff --git a/references/video_classification/train.py b/references/video_classification/train.py index 0eefbc0b282..f944cff7794 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -1,16 +1,16 @@ import datetime import os import time + +import presets import torch import torch.utils.data -from torch.utils.data.dataloader import default_collate -from torch import nn import torchvision import torchvision.datasets.video_utils -from torchvision.datasets.samplers import DistributedSampler, UniformClipSampler, RandomClipSampler - -import presets import utils +from torch import nn +from torch.utils.data.dataloader import default_collate +from torchvision.datasets.samplers import DistributedSampler, UniformClipSampler, RandomClipSampler try: from apex import amp @@ -21,10 +21,10 @@ def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, apex=False): model.train() metric_logger = utils.MetricLogger(delimiter=" ") - metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}')) - metric_logger.add_meter('clips/s', utils.SmoothedValue(window_size=10, fmt='{value:.3f}')) + metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}")) + metric_logger.add_meter("clips/s", utils.SmoothedValue(window_size=10, fmt="{value:.3f}")) - header = 'Epoch: [{}]'.format(epoch) + header = "Epoch: [{}]".format(epoch) for video, target in metric_logger.log_every(data_loader, print_freq, header): start_time = time.time() video, target = video.to(device), target.to(device) @@ -42,16 +42,16 @@ def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, devi acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) batch_size = video.shape[0] metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"]) - metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) - metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) - metric_logger.meters['clips/s'].update(batch_size / (time.time() - start_time)) + metric_logger.meters["acc1"].update(acc1.item(), n=batch_size) + metric_logger.meters["acc5"].update(acc5.item(), n=batch_size) + metric_logger.meters["clips/s"].update(batch_size / (time.time() - start_time)) lr_scheduler.step() def evaluate(model, criterion, data_loader, device): model.eval() metric_logger = utils.MetricLogger(delimiter=" ") - header = 'Test:' + header = "Test:" with torch.no_grad(): for video, target in metric_logger.log_every(data_loader, 100, header): video = video.to(device, non_blocking=True) @@ -64,18 +64,22 @@ def evaluate(model, criterion, data_loader, device): # could have been padded in distributed setup batch_size = video.shape[0] metric_logger.update(loss=loss.item()) - metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) - metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) + metric_logger.meters["acc1"].update(acc1.item(), n=batch_size) + metric_logger.meters["acc5"].update(acc5.item(), n=batch_size) # gather the stats from all processes metric_logger.synchronize_between_processes() - print(' * Clip Acc@1 {top1.global_avg:.3f} Clip Acc@5 {top5.global_avg:.3f}' - .format(top1=metric_logger.acc1, top5=metric_logger.acc5)) + print( + " * Clip Acc@1 {top1.global_avg:.3f} Clip Acc@5 {top5.global_avg:.3f}".format( + top1=metric_logger.acc1, top5=metric_logger.acc5 + ) + ) return metric_logger.acc1.global_avg def _get_cache_path(filepath): import hashlib + h = hashlib.sha1(filepath.encode()).hexdigest() cache_path = os.path.join("~", ".torch", "vision", "datasets", "kinetics", h[:10] + ".pt") cache_path = os.path.expanduser(cache_path) @@ -90,8 +94,10 @@ def collate_fn(batch): def main(args): if args.apex and amp is None: - raise RuntimeError("Failed to import apex. Please install apex from https://www.github.com/nvidia/apex " - "to enable mixed-precision training.") + raise RuntimeError( + "Failed to import apex. Please install apex from https://www.github.com/nvidia/apex " + "to enable mixed-precision training." + ) if args.output_dir: utils.mkdir(args.output_dir) @@ -121,15 +127,17 @@ def main(args): dataset.transform = transform_train else: if args.distributed: - print("It is recommended to pre-compute the dataset cache " - "on a single-gpu first, as it will be faster") + print("It is recommended to pre-compute the dataset cache " "on a single-gpu first, as it will be faster") dataset = torchvision.datasets.Kinetics400( traindir, frames_per_clip=args.clip_len, step_between_clips=1, transform=transform_train, frame_rate=15, - extensions=('avi', 'mp4', ) + extensions=( + "avi", + "mp4", + ), ) if args.cache_dataset: print("Saving dataset_train to {}".format(cache_path)) @@ -149,15 +157,17 @@ def main(args): dataset_test.transform = transform_test else: if args.distributed: - print("It is recommended to pre-compute the dataset cache " - "on a single-gpu first, as it will be faster") + print("It is recommended to pre-compute the dataset cache " "on a single-gpu first, as it will be faster") dataset_test = torchvision.datasets.Kinetics400( valdir, frames_per_clip=args.clip_len, step_between_clips=1, transform=transform_test, frame_rate=15, - extensions=('avi', 'mp4',) + extensions=( + "avi", + "mp4", + ), ) if args.cache_dataset: print("Saving dataset_test to {}".format(cache_path)) @@ -172,14 +182,22 @@ def main(args): test_sampler = DistributedSampler(test_sampler) data_loader = torch.utils.data.DataLoader( - dataset, batch_size=args.batch_size, - sampler=train_sampler, num_workers=args.workers, - pin_memory=True, collate_fn=collate_fn) + dataset, + batch_size=args.batch_size, + sampler=train_sampler, + num_workers=args.workers, + pin_memory=True, + collate_fn=collate_fn, + ) data_loader_test = torch.utils.data.DataLoader( - dataset_test, batch_size=args.batch_size, - sampler=test_sampler, num_workers=args.workers, - pin_memory=True, collate_fn=collate_fn) + dataset_test, + batch_size=args.batch_size, + sampler=test_sampler, + num_workers=args.workers, + pin_memory=True, + collate_fn=collate_fn, + ) print("Creating model") model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained) @@ -190,13 +208,10 @@ def main(args): criterion = nn.CrossEntropyLoss() lr = args.lr * args.world_size - optimizer = torch.optim.SGD( - model.parameters(), lr=lr, momentum=args.momentum, weight_decay=args.weight_decay) + optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.apex: - model, optimizer = amp.initialize(model, optimizer, - opt_level=args.apex_opt_level - ) + model, optimizer = amp.initialize(model, optimizer, opt_level=args.apex_opt_level) # convert scheduler to be per iteration, not per epoch, for warmup that lasts # between different epochs @@ -207,20 +222,22 @@ def main(args): if args.lr_warmup_epochs > 0: warmup_iters = iters_per_epoch * args.lr_warmup_epochs args.lr_warmup_method = args.lr_warmup_method.lower() - if args.lr_warmup_method == 'linear': - warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=args.lr_warmup_decay, - total_iters=warmup_iters) - elif args.lr_warmup_method == 'constant': - warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=args.lr_warmup_decay, - total_iters=warmup_iters) + if args.lr_warmup_method == "linear": + warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR( + optimizer, start_factor=args.lr_warmup_decay, total_iters=warmup_iters + ) + elif args.lr_warmup_method == "constant": + warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR( + optimizer, factor=args.lr_warmup_decay, total_iters=warmup_iters + ) else: - raise RuntimeError("Invalid warmup lr method '{}'. Only linear and constant " - "are supported.".format(args.lr_warmup_method)) + raise RuntimeError( + "Invalid warmup lr method '{}'. Only linear and constant " + "are supported.".format(args.lr_warmup_method) + ) lr_scheduler = torch.optim.lr_scheduler.SequentialLR( - optimizer, - schedulers=[warmup_lr_scheduler, main_lr_scheduler], - milestones=[warmup_iters] + optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[warmup_iters] ) else: lr_scheduler = main_lr_scheduler @@ -231,11 +248,11 @@ def main(args): model_without_ddp = model.module if args.resume: - checkpoint = torch.load(args.resume, map_location='cpu') - model_without_ddp.load_state_dict(checkpoint['model']) - optimizer.load_state_dict(checkpoint['optimizer']) - lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) - args.start_epoch = checkpoint['epoch'] + 1 + checkpoint = torch.load(args.resume, map_location="cpu") + model_without_ddp.load_state_dict(checkpoint["model"]) + optimizer.load_state_dict(checkpoint["optimizer"]) + lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) + args.start_epoch = checkpoint["epoch"] + 1 if args.test_only: evaluate(model, criterion, data_loader_test, device=device) @@ -246,62 +263,65 @@ def main(args): for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) - train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, - device, epoch, args.print_freq, args.apex) + train_one_epoch( + model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, args.print_freq, args.apex + ) evaluate(model, criterion, data_loader_test, device=device) if args.output_dir: checkpoint = { - 'model': model_without_ddp.state_dict(), - 'optimizer': optimizer.state_dict(), - 'lr_scheduler': lr_scheduler.state_dict(), - 'epoch': epoch, - 'args': args} - utils.save_on_master( - checkpoint, - os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) - utils.save_on_master( - checkpoint, - os.path.join(args.output_dir, 'checkpoint.pth')) + "model": model_without_ddp.state_dict(), + "optimizer": optimizer.state_dict(), + "lr_scheduler": lr_scheduler.state_dict(), + "epoch": epoch, + "args": args, + } + utils.save_on_master(checkpoint, os.path.join(args.output_dir, "model_{}.pth".format(epoch))) + utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth")) total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) - print('Training time {}'.format(total_time_str)) + print("Training time {}".format(total_time_str)) def parse_args(): import argparse - parser = argparse.ArgumentParser(description='PyTorch Video Classification Training') - - parser.add_argument('--data-path', default='/datasets01_101/kinetics/070618/', help='dataset') - parser.add_argument('--train-dir', default='train_avi-480p', help='name of train dir') - parser.add_argument('--val-dir', default='val_avi-480p', help='name of val dir') - parser.add_argument('--model', default='r2plus1d_18', help='model') - parser.add_argument('--device', default='cuda', help='device') - parser.add_argument('--clip-len', default=16, type=int, metavar='N', - help='number of frames per clip') - parser.add_argument('--clips-per-video', default=5, type=int, metavar='N', - help='maximum number of clips per video to consider') - parser.add_argument('-b', '--batch-size', default=24, type=int) - parser.add_argument('--epochs', default=45, type=int, metavar='N', - help='number of total epochs to run') - parser.add_argument('-j', '--workers', default=10, type=int, metavar='N', - help='number of data loading workers (default: 10)') - parser.add_argument('--lr', default=0.01, type=float, help='initial learning rate') - parser.add_argument('--momentum', default=0.9, type=float, metavar='M', - help='momentum') - parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, - metavar='W', help='weight decay (default: 1e-4)', - dest='weight_decay') - parser.add_argument('--lr-milestones', nargs='+', default=[20, 30, 40], type=int, help='decrease lr on milestones') - parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma') - parser.add_argument('--lr-warmup-epochs', default=10, type=int, help='the number of epochs to warmup (default: 10)') - parser.add_argument('--lr-warmup-method', default="linear", type=str, help='the warmup method (default: linear)') - parser.add_argument('--lr-warmup-decay', default=0.001, type=float, help='the decay for lr') - parser.add_argument('--print-freq', default=10, type=int, help='print frequency') - parser.add_argument('--output-dir', default='.', help='path where to save') - parser.add_argument('--resume', default='', help='resume from checkpoint') - parser.add_argument('--start-epoch', default=0, type=int, metavar='N', - help='start epoch') + + parser = argparse.ArgumentParser(description="PyTorch Video Classification Training") + + parser.add_argument("--data-path", default="/datasets01_101/kinetics/070618/", help="dataset") + parser.add_argument("--train-dir", default="train_avi-480p", help="name of train dir") + parser.add_argument("--val-dir", default="val_avi-480p", help="name of val dir") + parser.add_argument("--model", default="r2plus1d_18", help="model") + parser.add_argument("--device", default="cuda", help="device") + parser.add_argument("--clip-len", default=16, type=int, metavar="N", help="number of frames per clip") + parser.add_argument( + "--clips-per-video", default=5, type=int, metavar="N", help="maximum number of clips per video to consider" + ) + parser.add_argument("-b", "--batch-size", default=24, type=int) + parser.add_argument("--epochs", default=45, type=int, metavar="N", help="number of total epochs to run") + parser.add_argument( + "-j", "--workers", default=10, type=int, metavar="N", help="number of data loading workers (default: 10)" + ) + parser.add_argument("--lr", default=0.01, type=float, help="initial learning rate") + parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum") + parser.add_argument( + "--wd", + "--weight-decay", + default=1e-4, + type=float, + metavar="W", + help="weight decay (default: 1e-4)", + dest="weight_decay", + ) + parser.add_argument("--lr-milestones", nargs="+", default=[20, 30, 40], type=int, help="decrease lr on milestones") + parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma") + parser.add_argument("--lr-warmup-epochs", default=10, type=int, help="the number of epochs to warmup (default: 10)") + parser.add_argument("--lr-warmup-method", default="linear", type=str, help="the warmup method (default: linear)") + parser.add_argument("--lr-warmup-decay", default=0.001, type=float, help="the decay for lr") + parser.add_argument("--print-freq", default=10, type=int, help="print frequency") + parser.add_argument("--output-dir", default=".", help="path where to save") + parser.add_argument("--resume", default="", help="resume from checkpoint") + parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch") parser.add_argument( "--cache-dataset", dest="cache_dataset", @@ -328,18 +348,19 @@ def parse_args(): ) # Mixed precision training parameters - parser.add_argument('--apex', action='store_true', - help='Use apex for mixed precision training') - parser.add_argument('--apex-opt-level', default='O1', type=str, - help='For apex mixed precision training' - 'O0 for FP32 training, O1 for mixed precision training.' - 'For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet' - ) + parser.add_argument("--apex", action="store_true", help="Use apex for mixed precision training") + parser.add_argument( + "--apex-opt-level", + default="O1", + type=str, + help="For apex mixed precision training" + "O0 for FP32 training, O1 for mixed precision training." + "For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet", + ) # distributed training parameters - parser.add_argument('--world-size', default=1, type=int, - help='number of distributed processes') - parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training') + parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") + parser.add_argument("--dist-url", default="env://", help="url used to set up distributed training") args = parser.parse_args() diff --git a/references/video_classification/transforms.py b/references/video_classification/transforms.py index 27f6c75450a..a0ce691bae7 100644 --- a/references/video_classification/transforms.py +++ b/references/video_classification/transforms.py @@ -3,16 +3,14 @@ class ConvertBHWCtoBCHW(nn.Module): - """Convert tensor from (B, H, W, C) to (B, C, H, W) - """ + """Convert tensor from (B, H, W, C) to (B, C, H, W)""" def forward(self, vid: torch.Tensor) -> torch.Tensor: return vid.permute(0, 3, 1, 2) class ConvertBCHWtoCBHW(nn.Module): - """Convert tensor from (B, C, H, W) to (C, B, H, W) - """ + """Convert tensor from (B, C, H, W) to (C, B, H, W)""" def forward(self, vid: torch.Tensor) -> torch.Tensor: return vid.permute(1, 0, 2, 3) diff --git a/references/video_classification/utils.py b/references/video_classification/utils.py index 3573b84d780..956c4f85239 100644 --- a/references/video_classification/utils.py +++ b/references/video_classification/utils.py @@ -1,12 +1,12 @@ -from collections import defaultdict, deque import datetime +import errno +import os import time +from collections import defaultdict, deque + import torch import torch.distributed as dist -import errno -import os - class SmoothedValue(object): """Track a series of values and provide access to smoothed values over a @@ -32,7 +32,7 @@ def synchronize_between_processes(self): """ if not is_dist_avail_and_initialized(): return - t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") dist.barrier() dist.all_reduce(t) t = t.tolist() @@ -63,11 +63,8 @@ def value(self): def __str__(self): return self.fmt.format( - median=self.median, - avg=self.avg, - global_avg=self.global_avg, - max=self.max, - value=self.value) + median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value + ) class MetricLogger(object): @@ -87,15 +84,12 @@ def __getattr__(self, attr): return self.meters[attr] if attr in self.__dict__: return self.__dict__[attr] - raise AttributeError("'{}' object has no attribute '{}'".format( - type(self).__name__, attr)) + raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr)) def __str__(self): loss_str = [] for name, meter in self.meters.items(): - loss_str.append( - "{}: {}".format(name, str(meter)) - ) + loss_str.append("{}: {}".format(name, str(meter))) return self.delimiter.join(loss_str) def synchronize_between_processes(self): @@ -108,31 +102,28 @@ def add_meter(self, name, meter): def log_every(self, iterable, print_freq, header=None): i = 0 if not header: - header = '' + header = "" start_time = time.time() end = time.time() - iter_time = SmoothedValue(fmt='{avg:.4f}') - data_time = SmoothedValue(fmt='{avg:.4f}') - space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + iter_time = SmoothedValue(fmt="{avg:.4f}") + data_time = SmoothedValue(fmt="{avg:.4f}") + space_fmt = ":" + str(len(str(len(iterable)))) + "d" if torch.cuda.is_available(): - log_msg = self.delimiter.join([ - header, - '[{0' + space_fmt + '}/{1}]', - 'eta: {eta}', - '{meters}', - 'time: {time}', - 'data: {data}', - 'max mem: {memory:.0f}' - ]) + log_msg = self.delimiter.join( + [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + "max mem: {memory:.0f}", + ] + ) else: - log_msg = self.delimiter.join([ - header, - '[{0' + space_fmt + '}/{1}]', - 'eta: {eta}', - '{meters}', - 'time: {time}', - 'data: {data}' - ]) + log_msg = self.delimiter.join( + [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"] + ) MB = 1024.0 * 1024.0 for obj in iterable: data_time.update(time.time() - end) @@ -142,21 +133,28 @@ def log_every(self, iterable, print_freq, header=None): eta_seconds = iter_time.global_avg * (len(iterable) - i) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) if torch.cuda.is_available(): - print(log_msg.format( - i, len(iterable), eta=eta_string, - meters=str(self), - time=str(iter_time), data=str(data_time), - memory=torch.cuda.max_memory_allocated() / MB)) + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB, + ) + ) else: - print(log_msg.format( - i, len(iterable), eta=eta_string, - meters=str(self), - time=str(iter_time), data=str(data_time))) + print( + log_msg.format( + i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time) + ) + ) i += 1 end = time.time() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) - print('{} Total time: {}'.format(header, total_time_str)) + print("{} Total time: {}".format(header, total_time_str)) def accuracy(output, target, topk=(1,)): @@ -189,10 +187,11 @@ def setup_for_distributed(is_master): This function disables printing when not in master process """ import builtins as __builtin__ + builtin_print = __builtin__.print def print(*args, **kwargs): - force = kwargs.pop('force', False) + force = kwargs.pop("force", False) if is_master or force: builtin_print(*args, **kwargs) @@ -229,26 +228,26 @@ def save_on_master(*args, **kwargs): def init_distributed_mode(args): - if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: args.rank = int(os.environ["RANK"]) - args.world_size = int(os.environ['WORLD_SIZE']) - args.gpu = int(os.environ['LOCAL_RANK']) - elif 'SLURM_PROCID' in os.environ: - args.rank = int(os.environ['SLURM_PROCID']) + args.world_size = int(os.environ["WORLD_SIZE"]) + args.gpu = int(os.environ["LOCAL_RANK"]) + elif "SLURM_PROCID" in os.environ: + args.rank = int(os.environ["SLURM_PROCID"]) args.gpu = args.rank % torch.cuda.device_count() elif hasattr(args, "rank"): pass else: - print('Not using distributed mode') + print("Not using distributed mode") args.distributed = False return args.distributed = True torch.cuda.set_device(args.gpu) - args.dist_backend = 'nccl' - print('| distributed init (rank {}): {}'.format( - args.rank, args.dist_url), flush=True) - torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, - world_size=args.world_size, rank=args.rank) + args.dist_backend = "nccl" + print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True) + torch.distributed.init_process_group( + backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank + ) setup_for_distributed(args.rank == 0) diff --git a/setup.cfg b/setup.cfg index fd3b74c47de..8929a5fd37f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -9,7 +9,13 @@ max-line-length = 120 [flake8] max-line-length = 120 -ignore = F401,E402,F403,W503,W504,F821 +ignore = E203, E402, W503, W504, F821 +per-file-ignores = + __init__.py: F401, F403, F405 + ./hubconf.py: F401 + torchvision/models/mobilenet.py: F401, F403 + torchvision/models/quantization/mobilenet.py: F401, F403 + test/smoke_test.py: F401 exclude = venv [pydocstyle] diff --git a/setup.py b/setup.py index 4c9e734f31b..fb7b0bf3786 100644 --- a/setup.py +++ b/setup.py @@ -1,25 +1,22 @@ -import os -import io -import re -import sys -from setuptools import setup, find_packages -from pkg_resources import parse_version, get_distribution, DistributionNotFound -import subprocess import distutils.command.clean import distutils.spawn -from distutils.version import StrictVersion import glob +import io +import os +import re import shutil +import subprocess +import sys +from distutils.version import StrictVersion import torch +from pkg_resources import parse_version, get_distribution, DistributionNotFound +from setuptools import setup, find_packages from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME def read(*names, **kwargs): - with io.open( - os.path.join(os.path.dirname(__file__), *names), - encoding=kwargs.get("encoding", "utf8") - ) as fp: + with io.open(os.path.join(os.path.dirname(__file__), *names), encoding=kwargs.get("encoding", "utf8")) as fp: return fp.read() @@ -32,26 +29,26 @@ def get_dist(pkgname): cwd = os.path.dirname(os.path.abspath(__file__)) -version_txt = os.path.join(cwd, 'version.txt') -with open(version_txt, 'r') as f: +version_txt = os.path.join(cwd, "version.txt") +with open(version_txt, "r") as f: version = f.readline().strip() -sha = 'Unknown' -package_name = 'torchvision' +sha = "Unknown" +package_name = "torchvision" try: - sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=cwd).decode('ascii').strip() + sha = subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=cwd).decode("ascii").strip() except Exception: pass -if os.getenv('BUILD_VERSION'): - version = os.getenv('BUILD_VERSION') -elif sha != 'Unknown': - version += '+' + sha[:7] +if os.getenv("BUILD_VERSION"): + version = os.getenv("BUILD_VERSION") +elif sha != "Unknown": + version += "+" + sha[:7] def write_version_file(): - version_path = os.path.join(cwd, 'torchvision', 'version.py') - with open(version_path, 'w') as f: + version_path = os.path.join(cwd, "torchvision", "version.py") + with open(version_path, "w") as f: f.write("__version__ = '{}'\n".format(version)) f.write("git_version = {}\n".format(repr(sha))) f.write("from torchvision.extension import _check_cuda_version\n") @@ -59,34 +56,34 @@ def write_version_file(): f.write(" cuda = _check_cuda_version()\n") -pytorch_dep = 'torch' -if os.getenv('PYTORCH_VERSION'): - pytorch_dep += "==" + os.getenv('PYTORCH_VERSION') +pytorch_dep = "torch" +if os.getenv("PYTORCH_VERSION"): + pytorch_dep += "==" + os.getenv("PYTORCH_VERSION") requirements = [ - 'numpy', + "numpy", pytorch_dep, ] # Excluding 8.3.0 because of https://github.com/pytorch/vision/issues/4146 -pillow_ver = ' >= 5.3.0, !=8.3.0' -pillow_req = 'pillow-simd' if get_dist('pillow-simd') is not None else 'pillow' +pillow_ver = " >= 5.3.0, !=8.3.0" +pillow_req = "pillow-simd" if get_dist("pillow-simd") is not None else "pillow" requirements.append(pillow_req + pillow_ver) def find_library(name, vision_include): this_dir = os.path.dirname(os.path.abspath(__file__)) - build_prefix = os.environ.get('BUILD_PREFIX', None) + build_prefix = os.environ.get("BUILD_PREFIX", None) is_conda_build = build_prefix is not None library_found = False conda_installed = False lib_folder = None include_folder = None - library_header = '{0}.h'.format(name) + library_header = "{0}.h".format(name) # Lookup in TORCHVISION_INCLUDE or in the package file - package_path = [os.path.join(this_dir, 'torchvision')] + package_path = [os.path.join(this_dir, "torchvision")] for folder in vision_include + package_path: candidate_path = os.path.join(folder, library_header) library_found = os.path.exists(candidate_path) @@ -94,67 +91,66 @@ def find_library(name, vision_include): break if not library_found: - print('Running build on conda-build: {0}'.format(is_conda_build)) + print("Running build on conda-build: {0}".format(is_conda_build)) if is_conda_build: # Add conda headers/libraries - if os.name == 'nt': - build_prefix = os.path.join(build_prefix, 'Library') - include_folder = os.path.join(build_prefix, 'include') - lib_folder = os.path.join(build_prefix, 'lib') - library_header_path = os.path.join( - include_folder, library_header) + if os.name == "nt": + build_prefix = os.path.join(build_prefix, "Library") + include_folder = os.path.join(build_prefix, "include") + lib_folder = os.path.join(build_prefix, "lib") + library_header_path = os.path.join(include_folder, library_header) library_found = os.path.isfile(library_header_path) conda_installed = library_found else: # Check if using Anaconda to produce wheels - conda = distutils.spawn.find_executable('conda') + conda = distutils.spawn.find_executable("conda") is_conda = conda is not None - print('Running build on conda: {0}'.format(is_conda)) + print("Running build on conda: {0}".format(is_conda)) if is_conda: python_executable = sys.executable py_folder = os.path.dirname(python_executable) - if os.name == 'nt': - env_path = os.path.join(py_folder, 'Library') + if os.name == "nt": + env_path = os.path.join(py_folder, "Library") else: env_path = os.path.dirname(py_folder) - lib_folder = os.path.join(env_path, 'lib') - include_folder = os.path.join(env_path, 'include') - library_header_path = os.path.join( - include_folder, library_header) + lib_folder = os.path.join(env_path, "lib") + include_folder = os.path.join(env_path, "include") + library_header_path = os.path.join(include_folder, library_header) library_found = os.path.isfile(library_header_path) conda_installed = library_found if not library_found: - if sys.platform == 'linux': - library_found = os.path.exists('/usr/include/{0}'.format( - library_header)) - library_found = library_found or os.path.exists( - '/usr/local/include/{0}'.format(library_header)) + if sys.platform == "linux": + library_found = os.path.exists("/usr/include/{0}".format(library_header)) + library_found = library_found or os.path.exists("/usr/local/include/{0}".format(library_header)) return library_found, conda_installed, include_folder, lib_folder def get_extensions(): this_dir = os.path.dirname(os.path.abspath(__file__)) - extensions_dir = os.path.join(this_dir, 'torchvision', 'csrc') + extensions_dir = os.path.join(this_dir, "torchvision", "csrc") - main_file = glob.glob(os.path.join(extensions_dir, '*.cpp')) + glob.glob(os.path.join(extensions_dir, 'ops', - '*.cpp')) + main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) + glob.glob( + os.path.join(extensions_dir, "ops", "*.cpp") + ) source_cpu = ( - glob.glob(os.path.join(extensions_dir, 'ops', 'autograd', '*.cpp')) + - glob.glob(os.path.join(extensions_dir, 'ops', 'cpu', '*.cpp')) + - glob.glob(os.path.join(extensions_dir, 'ops', 'quantized', 'cpu', '*.cpp')) + glob.glob(os.path.join(extensions_dir, "ops", "autograd", "*.cpp")) + + glob.glob(os.path.join(extensions_dir, "ops", "cpu", "*.cpp")) + + glob.glob(os.path.join(extensions_dir, "ops", "quantized", "cpu", "*.cpp")) ) is_rocm_pytorch = False - TORCH_MAJOR = int(torch.__version__.split('.')[0]) - TORCH_MINOR = int(torch.__version__.split('.')[1]) + TORCH_MAJOR = int(torch.__version__.split(".")[0]) + TORCH_MINOR = int(torch.__version__.split(".")[1]) if TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 5): from torch.utils.cpp_extension import ROCM_HOME + is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False if is_rocm_pytorch: from torch.utils.hipify import hipify_python + hipify_python.hipify( project_directory=this_dir, output_directory=this_dir, @@ -162,25 +158,25 @@ def get_extensions(): show_detailed=True, is_pytorch_extension=True, ) - source_cuda = glob.glob(os.path.join(extensions_dir, 'ops', 'hip', '*.hip')) + source_cuda = glob.glob(os.path.join(extensions_dir, "ops", "hip", "*.hip")) # Copy over additional files for file in glob.glob(r"torchvision/csrc/ops/cuda/*.h"): shutil.copy(file, "torchvision/csrc/ops/hip") else: - source_cuda = glob.glob(os.path.join(extensions_dir, 'ops', 'cuda', '*.cu')) + source_cuda = glob.glob(os.path.join(extensions_dir, "ops", "cuda", "*.cu")) - source_cuda += glob.glob(os.path.join(extensions_dir, 'ops', 'autocast', '*.cpp')) + source_cuda += glob.glob(os.path.join(extensions_dir, "ops", "autocast", "*.cpp")) sources = main_file + source_cpu extension = CppExtension - compile_cpp_tests = os.getenv('WITH_CPP_MODELS_TEST', '0') == '1' + compile_cpp_tests = os.getenv("WITH_CPP_MODELS_TEST", "0") == "1" if compile_cpp_tests: - test_dir = os.path.join(this_dir, 'test') - models_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'models') - test_file = glob.glob(os.path.join(test_dir, '*.cpp')) - source_models = glob.glob(os.path.join(models_dir, '*.cpp')) + test_dir = os.path.join(this_dir, "test") + models_dir = os.path.join(this_dir, "torchvision", "csrc", "models") + test_file = glob.glob(os.path.join(test_dir, "*.cpp")) + source_models = glob.glob(os.path.join(models_dir, "*.cpp")) test_file = [os.path.join(test_dir, s) for s in test_file] source_models = [os.path.join(models_dir, s) for s in source_models] @@ -189,39 +185,38 @@ def get_extensions(): define_macros = [] - extra_compile_args = {'cxx': []} - if (torch.cuda.is_available() and ((CUDA_HOME is not None) or is_rocm_pytorch)) \ - or os.getenv('FORCE_CUDA', '0') == '1': + extra_compile_args = {"cxx": []} + if (torch.cuda.is_available() and ((CUDA_HOME is not None) or is_rocm_pytorch)) or os.getenv( + "FORCE_CUDA", "0" + ) == "1": extension = CUDAExtension sources += source_cuda if not is_rocm_pytorch: - define_macros += [('WITH_CUDA', None)] - nvcc_flags = os.getenv('NVCC_FLAGS', '') - if nvcc_flags == '': + define_macros += [("WITH_CUDA", None)] + nvcc_flags = os.getenv("NVCC_FLAGS", "") + if nvcc_flags == "": nvcc_flags = [] else: - nvcc_flags = nvcc_flags.split(' ') + nvcc_flags = nvcc_flags.split(" ") else: - define_macros += [('WITH_HIP', None)] + define_macros += [("WITH_HIP", None)] nvcc_flags = [] extra_compile_args["nvcc"] = nvcc_flags - if sys.platform == 'win32': - define_macros += [('torchvision_EXPORTS', None)] + if sys.platform == "win32": + define_macros += [("torchvision_EXPORTS", None)] - extra_compile_args['cxx'].append('/MP') + extra_compile_args["cxx"].append("/MP") - debug_mode = os.getenv('DEBUG', '0') == '1' + debug_mode = os.getenv("DEBUG", "0") == "1" if debug_mode: print("Compile in debug mode") - extra_compile_args['cxx'].append("-g") - extra_compile_args['cxx'].append("-O0") + extra_compile_args["cxx"].append("-g") + extra_compile_args["cxx"].append("-O0") if "nvcc" in extra_compile_args: # we have to remove "-OX" and "-g" flag if exists and append nvcc_flags = extra_compile_args["nvcc"] - extra_compile_args["nvcc"] = [ - f for f in nvcc_flags if not ("-O" in f or "-g" in f) - ] + extra_compile_args["nvcc"] = [f for f in nvcc_flags if not ("-O" in f or "-g" in f)] extra_compile_args["nvcc"].append("-O0") extra_compile_args["nvcc"].append("-g") @@ -231,7 +226,7 @@ def get_extensions(): ext_modules = [ extension( - 'torchvision._C', + "torchvision._C", sorted(sources), include_dirs=include_dirs, define_macros=define_macros, @@ -241,7 +236,7 @@ def get_extensions(): if compile_cpp_tests: ext_modules.append( extension( - 'torchvision._C_tests', + "torchvision._C_tests", tests, include_dirs=tests_include_dirs, define_macros=define_macros, @@ -250,12 +245,10 @@ def get_extensions(): ) # ------------------- Torchvision extra extensions ------------------------ - vision_include = os.environ.get('TORCHVISION_INCLUDE', None) - vision_library = os.environ.get('TORCHVISION_LIBRARY', None) - vision_include = (vision_include.split(os.pathsep) - if vision_include is not None else []) - vision_library = (vision_library.split(os.pathsep) - if vision_library is not None else []) + vision_include = os.environ.get("TORCHVISION_INCLUDE", None) + vision_library = os.environ.get("TORCHVISION_LIBRARY", None) + vision_include = vision_include.split(os.pathsep) if vision_include is not None else [] + vision_library = vision_library.split(os.pathsep) if vision_library is not None else [] include_dirs += vision_include library_dirs = vision_library @@ -266,56 +259,49 @@ def get_extensions(): image_link_flags = [] # Locating libPNG - libpng = distutils.spawn.find_executable('libpng-config') - pngfix = distutils.spawn.find_executable('pngfix') + libpng = distutils.spawn.find_executable("libpng-config") + pngfix = distutils.spawn.find_executable("pngfix") png_found = libpng is not None or pngfix is not None - print('PNG found: {0}'.format(png_found)) + print("PNG found: {0}".format(png_found)) if png_found: if libpng is not None: # Linux / Mac - png_version = subprocess.run([libpng, '--version'], - stdout=subprocess.PIPE) - png_version = png_version.stdout.strip().decode('utf-8') - print('libpng version: {0}'.format(png_version)) + png_version = subprocess.run([libpng, "--version"], stdout=subprocess.PIPE) + png_version = png_version.stdout.strip().decode("utf-8") + print("libpng version: {0}".format(png_version)) png_version = parse_version(png_version) if png_version >= parse_version("1.6.0"): - print('Building torchvision with PNG image support') - png_lib = subprocess.run([libpng, '--libdir'], - stdout=subprocess.PIPE) - png_lib = png_lib.stdout.strip().decode('utf-8') - if 'disabled' not in png_lib: + print("Building torchvision with PNG image support") + png_lib = subprocess.run([libpng, "--libdir"], stdout=subprocess.PIPE) + png_lib = png_lib.stdout.strip().decode("utf-8") + if "disabled" not in png_lib: image_library += [png_lib] - png_include = subprocess.run([libpng, '--I_opts'], - stdout=subprocess.PIPE) - png_include = png_include.stdout.strip().decode('utf-8') - _, png_include = png_include.split('-I') - print('libpng include path: {0}'.format(png_include)) + png_include = subprocess.run([libpng, "--I_opts"], stdout=subprocess.PIPE) + png_include = png_include.stdout.strip().decode("utf-8") + _, png_include = png_include.split("-I") + print("libpng include path: {0}".format(png_include)) image_include += [png_include] - image_link_flags.append('png') + image_link_flags.append("png") else: - print('libpng installed version is less than 1.6.0, ' - 'disabling PNG support') + print("libpng installed version is less than 1.6.0, " "disabling PNG support") png_found = False else: # Windows - png_lib = os.path.join( - os.path.dirname(os.path.dirname(pngfix)), 'lib') - png_include = os.path.join(os.path.dirname( - os.path.dirname(pngfix)), 'include', 'libpng16') + png_lib = os.path.join(os.path.dirname(os.path.dirname(pngfix)), "lib") + png_include = os.path.join(os.path.dirname(os.path.dirname(pngfix)), "include", "libpng16") image_library += [png_lib] image_include += [png_include] - image_link_flags.append('libpng') + image_link_flags.append("libpng") # Locating libjpeg - (jpeg_found, jpeg_conda, - jpeg_include, jpeg_lib) = find_library('jpeglib', vision_include) + (jpeg_found, jpeg_conda, jpeg_include, jpeg_lib) = find_library("jpeglib", vision_include) - print('JPEG found: {0}'.format(jpeg_found)) - image_macros += [('PNG_FOUND', str(int(png_found)))] - image_macros += [('JPEG_FOUND', str(int(jpeg_found)))] + print("JPEG found: {0}".format(jpeg_found)) + image_macros += [("PNG_FOUND", str(int(png_found)))] + image_macros += [("JPEG_FOUND", str(int(jpeg_found)))] if jpeg_found: - print('Building torchvision with JPEG image support') - image_link_flags.append('jpeg') + print("Building torchvision with JPEG image support") + image_link_flags.append("jpeg") if jpeg_conda: image_library += [jpeg_lib] image_include += [jpeg_include] @@ -323,80 +309,71 @@ def get_extensions(): # Locating nvjpeg # Should be included in CUDA_HOME for CUDA >= 10.1, which is the minimum version we have in the CI nvjpeg_found = ( - extension is CUDAExtension and - CUDA_HOME is not None and - os.path.exists(os.path.join(CUDA_HOME, 'include', 'nvjpeg.h')) + extension is CUDAExtension + and CUDA_HOME is not None + and os.path.exists(os.path.join(CUDA_HOME, "include", "nvjpeg.h")) ) - print('NVJPEG found: {0}'.format(nvjpeg_found)) - image_macros += [('NVJPEG_FOUND', str(int(nvjpeg_found)))] + print("NVJPEG found: {0}".format(nvjpeg_found)) + image_macros += [("NVJPEG_FOUND", str(int(nvjpeg_found)))] if nvjpeg_found: - print('Building torchvision with NVJPEG image support') - image_link_flags.append('nvjpeg') - - image_path = os.path.join(extensions_dir, 'io', 'image') - image_src = (glob.glob(os.path.join(image_path, '*.cpp')) + glob.glob(os.path.join(image_path, 'cpu', '*.cpp')) - + glob.glob(os.path.join(image_path, 'cuda', '*.cpp'))) + print("Building torchvision with NVJPEG image support") + image_link_flags.append("nvjpeg") + + image_path = os.path.join(extensions_dir, "io", "image") + image_src = ( + glob.glob(os.path.join(image_path, "*.cpp")) + + glob.glob(os.path.join(image_path, "cpu", "*.cpp")) + + glob.glob(os.path.join(image_path, "cuda", "*.cpp")) + ) if png_found or jpeg_found: - ext_modules.append(extension( - 'torchvision.image', - image_src, - include_dirs=image_include + include_dirs + [image_path], - library_dirs=image_library + library_dirs, - define_macros=image_macros, - libraries=image_link_flags, - extra_compile_args=extra_compile_args - )) - - ffmpeg_exe = distutils.spawn.find_executable('ffmpeg') + ext_modules.append( + extension( + "torchvision.image", + image_src, + include_dirs=image_include + include_dirs + [image_path], + library_dirs=image_library + library_dirs, + define_macros=image_macros, + libraries=image_link_flags, + extra_compile_args=extra_compile_args, + ) + ) + + ffmpeg_exe = distutils.spawn.find_executable("ffmpeg") has_ffmpeg = ffmpeg_exe is not None # FIXME: Building torchvision with ffmpeg on MacOS or with Python 3.9 # FIXME: causes crash. See the following GitHub issues for more details. # FIXME: https://github.com/pytorch/pytorch/issues/65000 # FIXME: https://github.com/pytorch/vision/issues/3367 - if sys.platform != 'linux' or ( - sys.version_info.major == 3 and sys.version_info.minor == 9): + if sys.platform != "linux" or (sys.version_info.major == 3 and sys.version_info.minor == 9): has_ffmpeg = False if has_ffmpeg: try: # This is to check if ffmpeg is installed properly. subprocess.check_output(["ffmpeg", "-version"]) except subprocess.CalledProcessError: - print('Error fetching ffmpeg version, ignoring ffmpeg.') + print("Error fetching ffmpeg version, ignoring ffmpeg.") has_ffmpeg = False print("FFmpeg found: {}".format(has_ffmpeg)) if has_ffmpeg: - ffmpeg_libraries = { - 'libavcodec', - 'libavformat', - 'libavutil', - 'libswresample', - 'libswscale' - } + ffmpeg_libraries = {"libavcodec", "libavformat", "libavutil", "libswresample", "libswscale"} ffmpeg_bin = os.path.dirname(ffmpeg_exe) ffmpeg_root = os.path.dirname(ffmpeg_bin) - ffmpeg_include_dir = os.path.join(ffmpeg_root, 'include') - ffmpeg_library_dir = os.path.join(ffmpeg_root, 'lib') + ffmpeg_include_dir = os.path.join(ffmpeg_root, "include") + ffmpeg_library_dir = os.path.join(ffmpeg_root, "lib") - gcc = distutils.spawn.find_executable('gcc') - platform_tag = subprocess.run( - [gcc, '-print-multiarch'], stdout=subprocess.PIPE) - platform_tag = platform_tag.stdout.strip().decode('utf-8') + gcc = distutils.spawn.find_executable("gcc") + platform_tag = subprocess.run([gcc, "-print-multiarch"], stdout=subprocess.PIPE) + platform_tag = platform_tag.stdout.strip().decode("utf-8") if platform_tag: # Most probably a Debian-based distribution - ffmpeg_include_dir = [ - ffmpeg_include_dir, - os.path.join(ffmpeg_include_dir, platform_tag) - ] - ffmpeg_library_dir = [ - ffmpeg_library_dir, - os.path.join(ffmpeg_library_dir, platform_tag) - ] + ffmpeg_include_dir = [ffmpeg_include_dir, os.path.join(ffmpeg_include_dir, platform_tag)] + ffmpeg_library_dir = [ffmpeg_library_dir, os.path.join(ffmpeg_library_dir, platform_tag)] else: ffmpeg_include_dir = [ffmpeg_include_dir] ffmpeg_library_dir = [ffmpeg_library_dir] @@ -405,11 +382,11 @@ def get_extensions(): for library in ffmpeg_libraries: library_found = False for search_path in ffmpeg_include_dir + include_dirs: - full_path = os.path.join(search_path, library, '*.h') + full_path = os.path.join(search_path, library, "*.h") library_found |= len(glob.glob(full_path)) > 0 if not library_found: - print(f'{library} header files were not found, disabling ffmpeg support') + print(f"{library} header files were not found, disabling ffmpeg support") has_ffmpeg = False if has_ffmpeg: @@ -417,22 +394,21 @@ def get_extensions(): print("ffmpeg library_dir: {}".format(ffmpeg_library_dir)) # TorchVision base decoder + video reader - video_reader_src_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'io', 'video_reader') + video_reader_src_dir = os.path.join(this_dir, "torchvision", "csrc", "io", "video_reader") video_reader_src = glob.glob(os.path.join(video_reader_src_dir, "*.cpp")) - base_decoder_src_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'io', 'decoder') - base_decoder_src = glob.glob( - os.path.join(base_decoder_src_dir, "*.cpp")) + base_decoder_src_dir = os.path.join(this_dir, "torchvision", "csrc", "io", "decoder") + base_decoder_src = glob.glob(os.path.join(base_decoder_src_dir, "*.cpp")) # Torchvision video API - videoapi_src_dir = os.path.join(this_dir, 'torchvision', 'csrc', 'io', 'video') + videoapi_src_dir = os.path.join(this_dir, "torchvision", "csrc", "io", "video") videoapi_src = glob.glob(os.path.join(videoapi_src_dir, "*.cpp")) # exclude tests - base_decoder_src = [x for x in base_decoder_src if '_test.cpp' not in x] + base_decoder_src = [x for x in base_decoder_src if "_test.cpp" not in x] combined_src = video_reader_src + base_decoder_src + videoapi_src ext_modules.append( CppExtension( - 'torchvision.video_reader', + "torchvision.video_reader", combined_src, include_dirs=[ base_decoder_src_dir, @@ -440,18 +416,18 @@ def get_extensions(): videoapi_src_dir, extensions_dir, *ffmpeg_include_dir, - *include_dirs + *include_dirs, ], library_dirs=ffmpeg_library_dir + library_dirs, libraries=[ - 'avcodec', - 'avformat', - 'avutil', - 'swresample', - 'swscale', + "avcodec", + "avformat", + "avutil", + "swresample", + "swscale", ], - extra_compile_args=["-std=c++14"] if os.name != 'nt' else ['/std:c++14', '/MP'], - extra_link_args=["-std=c++14" if os.name != 'nt' else '/std:c++14'], + extra_compile_args=["-std=c++14"] if os.name != "nt" else ["/std:c++14", "/MP"], + extra_link_args=["-std=c++14" if os.name != "nt" else "/std:c++14"], ) ) @@ -460,9 +436,9 @@ def get_extensions(): class clean(distutils.command.clean.clean): def run(self): - with open('.gitignore', 'r') as f: + with open(".gitignore", "r") as f: ignores = f.read() - for wildcard in filter(None, ignores.split('\n')): + for wildcard in filter(None, ignores.split("\n")): for filename in glob.glob(wildcard): try: os.remove(filename) @@ -478,25 +454,22 @@ def run(self): write_version_file() - with open('README.rst') as f: + with open("README.rst") as f: readme = f.read() setup( # Metadata name=package_name, version=version, - author='PyTorch Core Team', - author_email='soumith@pytorch.org', - url='https://github.com/pytorch/vision', - description='image and video datasets and models for torch deep learning', + author="PyTorch Core Team", + author_email="soumith@pytorch.org", + url="https://github.com/pytorch/vision", + description="image and video datasets and models for torch deep learning", long_description=readme, - license='BSD', - + license="BSD", # Package info - packages=find_packages(exclude=('test',)), - package_data={ - package_name: ['*.dll', '*.dylib', '*.so', '*.categories'] - }, + packages=find_packages(exclude=("test",)), + package_data={package_name: ["*.dll", "*.dylib", "*.so", "*.categories"]}, zip_safe=False, install_requires=requirements, extras_require={ @@ -504,7 +477,7 @@ def run(self): }, ext_modules=get_extensions(), cmdclass={ - 'build_ext': BuildExtension.with_options(no_python_abi_suffix=True), - 'clean': clean, - } + "build_ext": BuildExtension.with_options(no_python_abi_suffix=True), + "clean": clean, + }, ) diff --git a/test/common_utils.py b/test/common_utils.py index a9f6703fcd0..7469478670a 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -1,30 +1,30 @@ +import argparse +import contextlib +import functools +import inspect import os +import random import shutil +import sys import tempfile -import contextlib import unittest +from collections import OrderedDict +from numbers import Number + +import numpy as np import pytest -import argparse -import sys import torch -import __main__ -import random -import inspect -import functools - -from numbers import Number +from PIL import Image from torch._six import string_classes -from collections import OrderedDict from torchvision import io -import numpy as np -from PIL import Image +import __main__ -IN_CIRCLE_CI = os.getenv("CIRCLECI", False) == 'true' +IN_CIRCLE_CI = os.getenv("CIRCLECI", False) == "true" IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1" -CUDA_NOT_AVAILABLE_MSG = 'CUDA device not available' +CUDA_NOT_AVAILABLE_MSG = "CUDA device not available" CIRCLECI_GPU_NO_CUDA_MSG = "We're in a CircleCI GPU machine, and this test doesn't need cuda." @@ -95,7 +95,7 @@ def freeze_rng_state(): def cycle_over(objs): for idx, obj1 in enumerate(objs): - for obj2 in objs[:idx] + objs[idx + 1:]: + for obj2 in objs[:idx] + objs[idx + 1 :]: yield obj1, obj2 @@ -117,11 +117,13 @@ def disable_console_output(): def cpu_and_gpu(): import pytest # noqa - return ('cpu', pytest.param('cuda', marks=pytest.mark.needs_cuda)) + + return ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda)) def needs_cuda(test_func): import pytest # noqa + return pytest.mark.needs_cuda(test_func) @@ -139,12 +141,7 @@ def _create_data(height=3, width=3, channels=3, device="cpu"): def _create_data_batch(height=3, width=3, channels=3, num_samples=4, device="cpu"): # TODO: When all relevant tests are ported to pytest, turn this into a module-level fixture - batch_tensor = torch.randint( - 0, 256, - (num_samples, channels, height, width), - dtype=torch.uint8, - device=device - ) + batch_tensor = torch.randint(0, 256, (num_samples, channels, height, width), dtype=torch.uint8, device=device) return batch_tensor @@ -180,8 +177,9 @@ def _assert_equal_tensor_to_pil(tensor, pil_image, msg=None): assert_equal(tensor.cpu(), pil_tensor, msg=msg) -def _assert_approx_equal_tensor_to_pil(tensor, pil_image, tol=1e-5, msg=None, agg_method="mean", - allowed_percentage_diff=None): +def _assert_approx_equal_tensor_to_pil( + tensor, pil_image, tol=1e-5, msg=None, agg_method="mean", allowed_percentage_diff=None +): # TODO: we could just merge this into _assert_equal_tensor_to_pil np_pil_image = np.array(pil_image) if np_pil_image.ndim == 2: diff --git a/test/conftest.py b/test/conftest.py index a84b9f8dd52..a8b9054a4e5 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,18 +1,15 @@ -from common_utils import IN_CIRCLE_CI, CIRCLECI_GPU_NO_CUDA_MSG, IN_FBCODE, IN_RE_WORKER, CUDA_NOT_AVAILABLE_MSG -import torch -import numpy as np import random + +import numpy as np import pytest +import torch +from common_utils import IN_CIRCLE_CI, CIRCLECI_GPU_NO_CUDA_MSG, IN_FBCODE, IN_RE_WORKER, CUDA_NOT_AVAILABLE_MSG def pytest_configure(config): # register an additional marker (see pytest_collection_modifyitems) - config.addinivalue_line( - "markers", "needs_cuda: mark for tests that rely on a CUDA device" - ) - config.addinivalue_line( - "markers", "dont_collect: mark for tests that should not be collected" - ) + config.addinivalue_line("markers", "needs_cuda: mark for tests that rely on a CUDA device") + config.addinivalue_line("markers", "dont_collect: mark for tests that should not be collected") def pytest_collection_modifyitems(items): @@ -34,7 +31,7 @@ def pytest_collection_modifyitems(items): # @pytest.mark.parametrize('device', cpu_and_gpu()) # the "instances" of the tests where device == 'cuda' will have the 'needs_cuda' mark, # and the ones with device == 'cpu' won't have the mark. - needs_cuda = item.get_closest_marker('needs_cuda') is not None + needs_cuda = item.get_closest_marker("needs_cuda") is not None if needs_cuda and not torch.cuda.is_available(): # In general, we skip cuda tests on machines without a GPU @@ -59,7 +56,7 @@ def pytest_collection_modifyitems(items): # to run the CPU-only tests. item.add_marker(pytest.mark.skip(reason=CIRCLECI_GPU_NO_CUDA_MSG)) - if item.get_closest_marker('dont_collect') is not None: + if item.get_closest_marker("dont_collect") is not None: # currently, this is only used for some tests we're sure we dont want to run on fbcode continue diff --git a/test/datasets_utils.py b/test/datasets_utils.py index ca182b8d25f..646babdda1e 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -18,7 +18,6 @@ import torch import torchvision.datasets import torchvision.io - from common_utils import get_tmp_dir, disable_console_output @@ -419,7 +418,7 @@ def _populate_private_class_attributes(cls): defaults.append( { kwarg: default - for kwarg, default in zip(argspec.args[-len(argspec.defaults):], argspec.defaults) + for kwarg, default in zip(argspec.args[-len(argspec.defaults) :], argspec.defaults) if not kwarg.startswith("_") } ) @@ -640,7 +639,7 @@ def __init__(self, *args, **kwargs): def _set_default_frames_per_clip(self, inject_fake_data): argspec = inspect.getfullargspec(self.DATASET_CLASS.__init__) - args_without_default = argspec.args[1:(-len(argspec.defaults) if argspec.defaults else None)] + args_without_default = argspec.args[1 : (-len(argspec.defaults) if argspec.defaults else None)] frames_per_clip_last = args_without_default[-1] == "frames_per_clip" @functools.wraps(inject_fake_data) diff --git a/test/preprocess-bench.py b/test/preprocess-bench.py index 4ba3ca46dbc..df557b29197 100644 --- a/test/preprocess-bench.py +++ b/test/preprocess-bench.py @@ -1,47 +1,50 @@ import argparse import os from timeit import default_timer as timer -from torch.utils.model_zoo import tqdm + import torch import torch.utils.data import torchvision -import torchvision.transforms as transforms import torchvision.datasets as datasets +import torchvision.transforms as transforms +from torch.utils.model_zoo import tqdm -parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') -parser.add_argument('--data', metavar='PATH', required=True, - help='path to dataset') -parser.add_argument('--nThreads', '-j', default=2, type=int, metavar='N', - help='number of data loading threads (default: 2)') -parser.add_argument('--batchSize', '-b', default=256, type=int, metavar='N', - help='mini-batch size (1 = pure stochastic) Default: 256') -parser.add_argument('--accimage', action='store_true', - help='use accimage') +parser = argparse.ArgumentParser(description="PyTorch ImageNet Training") +parser.add_argument("--data", metavar="PATH", required=True, help="path to dataset") +parser.add_argument( + "--nThreads", "-j", default=2, type=int, metavar="N", help="number of data loading threads (default: 2)" +) +parser.add_argument( + "--batchSize", "-b", default=256, type=int, metavar="N", help="mini-batch size (1 = pure stochastic) Default: 256" +) +parser.add_argument("--accimage", action="store_true", help="use accimage") if __name__ == "__main__": args = parser.parse_args() if args.accimage: - torchvision.set_image_backend('accimage') - print('Using {}'.format(torchvision.get_image_backend())) + torchvision.set_image_backend("accimage") + print("Using {}".format(torchvision.get_image_backend())) # Data loading code - transform = transforms.Compose([ - transforms.RandomSizedCrop(224), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]), - ]) - - traindir = os.path.join(args.data, 'train') - valdir = os.path.join(args.data, 'val') + transform = transforms.Compose( + [ + transforms.RandomSizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + + traindir = os.path.join(args.data, "train") + valdir = os.path.join(args.data, "val") train = datasets.ImageFolder(traindir, transform) val = datasets.ImageFolder(valdir, transform) train_loader = torch.utils.data.DataLoader( - train, batch_size=args.batchSize, shuffle=True, num_workers=args.nThreads) + train, batch_size=args.batchSize, shuffle=True, num_workers=args.nThreads + ) train_iter = iter(train_loader) start_time = timer() @@ -51,9 +54,12 @@ pbar.update(1) batch = next(train_iter) end_time = timer() - print("Performance: {dataset:.0f} minutes/dataset, {batch:.1f} ms/batch," - " {image:.2f} ms/image {rate:.0f} images/sec" - .format(dataset=(end_time - start_time) * (float(len(train_loader)) / batch_count / 60.0), - batch=(end_time - start_time) / float(batch_count) * 1.0e+3, - image=(end_time - start_time) / (batch_count * args.batchSize) * 1.0e+3, - rate=(batch_count * args.batchSize) / (end_time - start_time))) + print( + "Performance: {dataset:.0f} minutes/dataset, {batch:.1f} ms/batch," + " {image:.2f} ms/image {rate:.0f} images/sec".format( + dataset=(end_time - start_time) * (float(len(train_loader)) / batch_count / 60.0), + batch=(end_time - start_time) / float(batch_count) * 1.0e3, + image=(end_time - start_time) / (batch_count * args.batchSize) * 1.0e3, + rate=(batch_count * args.batchSize) / (end_time - start_time), + ) + ) diff --git a/test/test_backbone_utils.py b/test/test_backbone_utils.py index a29509bfe87..1b0c0251169 100644 --- a/test/test_backbone_utils.py +++ b/test/test_backbone_utils.py @@ -1,31 +1,28 @@ +import random from functools import partial from itertools import chain -import random +import pytest import torch -from torchvision import models import torchvision +from common_utils import set_rng_seed +from torchvision import models +from torchvision.models._utils import IntermediateLayerGetter from torchvision.models.detection.backbone_utils import resnet_fpn_backbone from torchvision.models.feature_extraction import create_feature_extractor from torchvision.models.feature_extraction import get_graph_node_names -from torchvision.models._utils import IntermediateLayerGetter - -import pytest - -from common_utils import set_rng_seed def get_available_models(): # TODO add a registration mechanism to torchvision.models - return [k for k, v in models.__dict__.items() - if callable(v) and k[0].lower() == k[0] and k[0] != "_"] + return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"] -@pytest.mark.parametrize('backbone_name', ('resnet18', 'resnet50')) +@pytest.mark.parametrize("backbone_name", ("resnet18", "resnet50")) def test_resnet_fpn_backbone(backbone_name): - x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device='cpu') + x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device="cpu") y = resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False)(x) - assert list(y.keys()) == ['0', '1', '2', '3', 'pool'] + assert list(y.keys()) == ["0", "1", "2", "3", "pool"] # Needed by TestFxFeatureExtraction.test_leaf_module_and_function @@ -64,16 +61,21 @@ def forward(self, x): test_module_nodes = [ - 'x', 'submodule.add', 'submodule.add_1', 'submodule.relu', - 'submodule.relu_1', 'add', 'add_1', 'relu', 'relu_1'] + "x", + "submodule.add", + "submodule.add_1", + "submodule.relu", + "submodule.relu_1", + "add", + "add_1", + "relu", + "relu_1", +] class TestFxFeatureExtraction: - inp = torch.rand(1, 3, 224, 224, dtype=torch.float32, device='cpu') - model_defaults = { - 'num_classes': 1, - 'pretrained': False - } + inp = torch.rand(1, 3, 224, 224, dtype=torch.float32, device="cpu") + model_defaults = {"num_classes": 1, "pretrained": False} leaf_modules = [] def _create_feature_extractor(self, *args, **kwargs): @@ -81,41 +83,36 @@ def _create_feature_extractor(self, *args, **kwargs): Apply leaf modules """ tracer_kwargs = {} - if 'tracer_kwargs' not in kwargs: - tracer_kwargs = {'leaf_modules': self.leaf_modules} + if "tracer_kwargs" not in kwargs: + tracer_kwargs = {"leaf_modules": self.leaf_modules} else: - tracer_kwargs = kwargs.pop('tracer_kwargs') - return create_feature_extractor( - *args, **kwargs, - tracer_kwargs=tracer_kwargs, - suppress_diff_warning=True) + tracer_kwargs = kwargs.pop("tracer_kwargs") + return create_feature_extractor(*args, **kwargs, tracer_kwargs=tracer_kwargs, suppress_diff_warning=True) def _get_return_nodes(self, model): set_rng_seed(0) - exclude_nodes_filter = ['getitem', 'floordiv', 'size', 'chunk'] + exclude_nodes_filter = ["getitem", "floordiv", "size", "chunk"] train_nodes, eval_nodes = get_graph_node_names( - model, tracer_kwargs={'leaf_modules': self.leaf_modules}, - suppress_diff_warning=True) + model, tracer_kwargs={"leaf_modules": self.leaf_modules}, suppress_diff_warning=True + ) # Get rid of any nodes that don't return tensors as they cause issues # when testing backward pass. - train_nodes = [n for n in train_nodes - if not any(x in n for x in exclude_nodes_filter)] - eval_nodes = [n for n in eval_nodes - if not any(x in n for x in exclude_nodes_filter)] + train_nodes = [n for n in train_nodes if not any(x in n for x in exclude_nodes_filter)] + eval_nodes = [n for n in eval_nodes if not any(x in n for x in exclude_nodes_filter)] return random.sample(train_nodes, 10), random.sample(eval_nodes, 10) - @pytest.mark.parametrize('model_name', get_available_models()) + @pytest.mark.parametrize("model_name", get_available_models()) def test_build_fx_feature_extractor(self, model_name): set_rng_seed(0) model = models.__dict__[model_name](**self.model_defaults).eval() train_return_nodes, eval_return_nodes = self._get_return_nodes(model) # Check that it works with both a list and dict for return nodes self._create_feature_extractor( - model, train_return_nodes={v: v for v in train_return_nodes}, - eval_return_nodes=eval_return_nodes) + model, train_return_nodes={v: v for v in train_return_nodes}, eval_return_nodes=eval_return_nodes + ) self._create_feature_extractor( - model, train_return_nodes=train_return_nodes, - eval_return_nodes=eval_return_nodes) + model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes + ) # Check must specify return nodes with pytest.raises(AssertionError): self._create_feature_extractor(model) @@ -123,19 +120,16 @@ def test_build_fx_feature_extractor(self, model_name): # mutual exclusivity with pytest.raises(AssertionError): self._create_feature_extractor( - model, return_nodes=train_return_nodes, - train_return_nodes=train_return_nodes) + model, return_nodes=train_return_nodes, train_return_nodes=train_return_nodes + ) # Check train_return_nodes / eval_return nodes must both be specified with pytest.raises(AssertionError): - self._create_feature_extractor( - model, train_return_nodes=train_return_nodes) + self._create_feature_extractor(model, train_return_nodes=train_return_nodes) # Check invalid node name raises ValueError with pytest.raises(ValueError): # First just double check that this node really doesn't exist - if not any(n.startswith('l') or n.startswith('l.') for n - in chain(train_return_nodes, eval_return_nodes)): - self._create_feature_extractor( - model, train_return_nodes=['l'], eval_return_nodes=['l']) + if not any(n.startswith("l") or n.startswith("l.") for n in chain(train_return_nodes, eval_return_nodes)): + self._create_feature_extractor(model, train_return_nodes=["l"], eval_return_nodes=["l"]) else: # otherwise skip this check raise ValueError @@ -144,32 +138,25 @@ def test_node_name_conventions(self): train_nodes, _ = get_graph_node_names(model) assert all(a == b for a, b in zip(train_nodes, test_module_nodes)) - @pytest.mark.parametrize('model_name', get_available_models()) + @pytest.mark.parametrize("model_name", get_available_models()) def test_forward_backward(self, model_name): model = models.__dict__[model_name](**self.model_defaults).train() train_return_nodes, eval_return_nodes = self._get_return_nodes(model) model = self._create_feature_extractor( - model, train_return_nodes=train_return_nodes, - eval_return_nodes=eval_return_nodes) + model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes + ) out = model(self.inp) sum([o.mean() for o in out.values()]).backward() def test_feature_extraction_methods_equivalence(self): model = models.resnet18(**self.model_defaults).eval() - return_layers = { - 'layer1': 'layer1', - 'layer2': 'layer2', - 'layer3': 'layer3', - 'layer4': 'layer4' - } - - ilg_model = IntermediateLayerGetter( - model, return_layers).eval() + return_layers = {"layer1": "layer1", "layer2": "layer2", "layer3": "layer3", "layer4": "layer4"} + + ilg_model = IntermediateLayerGetter(model, return_layers).eval() fx_model = self._create_feature_extractor(model, return_layers) # Check that we have same parameters - for (n1, p1), (n2, p2) in zip(ilg_model.named_parameters(), - fx_model.named_parameters()): + for (n1, p1), (n2, p2) in zip(ilg_model.named_parameters(), fx_model.named_parameters()): assert n1 == n2 assert p1.equal(p2) @@ -181,14 +168,14 @@ def test_feature_extraction_methods_equivalence(self): for k in ilg_out.keys(): assert ilg_out[k].equal(fgn_out[k]) - @pytest.mark.parametrize('model_name', get_available_models()) + @pytest.mark.parametrize("model_name", get_available_models()) def test_jit_forward_backward(self, model_name): set_rng_seed(0) model = models.__dict__[model_name](**self.model_defaults).train() train_return_nodes, eval_return_nodes = self._get_return_nodes(model) model = self._create_feature_extractor( - model, train_return_nodes=train_return_nodes, - eval_return_nodes=eval_return_nodes) + model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes + ) model = torch.jit.script(model) fgn_out = model(self.inp) sum([o.mean() for o in fgn_out.values()]).backward() @@ -197,7 +184,7 @@ def test_train_eval(self): class TestModel(torch.nn.Module): def __init__(self): super().__init__() - self.dropout = torch.nn.Dropout(p=1.) + self.dropout = torch.nn.Dropout(p=1.0) def forward(self, x): x = x.mean() @@ -211,54 +198,54 @@ def forward(self, x): model = TestModel() - train_return_nodes = ['dropout', 'add', 'sub'] - eval_return_nodes = ['dropout', 'mul', 'sub'] + train_return_nodes = ["dropout", "add", "sub"] + eval_return_nodes = ["dropout", "mul", "sub"] def checks(model, mode): with torch.no_grad(): out = model(torch.ones(10, 10)) - if mode == 'train': + if mode == "train": # Check that dropout is respected - assert out['dropout'].item() == 0 + assert out["dropout"].item() == 0 # Check that control flow dependent on training_mode is respected - assert out['sub'].item() == 100 - assert 'add' in out - assert 'mul' not in out - elif mode == 'eval': + assert out["sub"].item() == 100 + assert "add" in out + assert "mul" not in out + elif mode == "eval": # Check that dropout is respected - assert out['dropout'].item() == 1 + assert out["dropout"].item() == 1 # Check that control flow dependent on training_mode is respected - assert out['sub'].item() == 0 - assert 'mul' in out - assert 'add' not in out + assert out["sub"].item() == 0 + assert "mul" in out + assert "add" not in out # Starting from train mode model.train() fx_model = self._create_feature_extractor( - model, train_return_nodes=train_return_nodes, - eval_return_nodes=eval_return_nodes) + model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes + ) # Check that the models stay in their original training state assert model.training assert fx_model.training # Check outputs - checks(fx_model, 'train') + checks(fx_model, "train") # Check outputs after switching to eval mode fx_model.eval() - checks(fx_model, 'eval') + checks(fx_model, "eval") # Starting from eval mode model.eval() fx_model = self._create_feature_extractor( - model, train_return_nodes=train_return_nodes, - eval_return_nodes=eval_return_nodes) + model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes + ) # Check that the models stay in their original training state assert not model.training assert not fx_model.training # Check outputs - checks(fx_model, 'eval') + checks(fx_model, "eval") # Check outputs after switching to train mode fx_model.train() - checks(fx_model, 'train') + checks(fx_model, "train") def test_leaf_module_and_function(self): class LeafModule(torch.nn.Module): @@ -279,15 +266,16 @@ def forward(self, x): return self.leaf_module(x) model = self._create_feature_extractor( - TestModule(), return_nodes=['leaf_module'], - tracer_kwargs={'leaf_modules': [LeafModule], - 'autowrap_functions': [leaf_function]}).train() + TestModule(), + return_nodes=["leaf_module"], + tracer_kwargs={"leaf_modules": [LeafModule], "autowrap_functions": [leaf_function]}, + ).train() # Check that LeafModule is not in the list of nodes - assert 'relu' not in [str(n) for n in model.graph.nodes] - assert 'leaf_module' in [str(n) for n in model.graph.nodes] + assert "relu" not in [str(n) for n in model.graph.nodes] + assert "leaf_module" in [str(n) for n in model.graph.nodes] # Check forward out = model(self.inp) # And backward - out['leaf_module'].mean().backward() + out["leaf_module"].mean().backward() diff --git a/test/test_cpp_models.py b/test/test_cpp_models.py index 2307051ff60..25b5c79c984 100644 --- a/test/test_cpp_models.py +++ b/test/test_cpp_models.py @@ -1,11 +1,11 @@ -import torch import os -import unittest -from torchvision import models, transforms import sys +import unittest -from PIL import Image +import torch import torchvision.transforms.functional as F +from PIL import Image +from torchvision import models, transforms try: from torchvision import _C_tests @@ -21,12 +21,13 @@ def process_model(model, tensor, func, name): py_output = model.forward(tensor) cpp_output = func("model.pt", tensor) - assert torch.allclose(py_output, cpp_output), 'Output mismatch of ' + name + ' models' + assert torch.allclose(py_output, cpp_output), "Output mismatch of " + name + " models" def read_image1(): - image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg', - 'grace_hopper_517x606.jpg') + image_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "assets", "encode_jpeg", "grace_hopper_517x606.jpg" + ) image = Image.open(image_path) image = image.resize((224, 224)) x = F.to_tensor(image) @@ -34,8 +35,9 @@ def read_image1(): def read_image2(): - image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg', - 'grace_hopper_517x606.jpg') + image_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "assets", "encode_jpeg", "grace_hopper_517x606.jpg" + ) image = Image.open(image_path) image = image.resize((299, 299)) x = F.to_tensor(image) @@ -46,107 +48,110 @@ def read_image2(): @unittest.skipIf( sys.platform == "darwin" or True, "C++ models are broken on OS X at the moment, and there's a BC breakage on main; " - "see https://github.com/pytorch/vision/issues/1191") + "see https://github.com/pytorch/vision/issues/1191", +) class Tester(unittest.TestCase): pretrained = False image = read_image1() def test_alexnet(self): - process_model(models.alexnet(self.pretrained), self.image, _C_tests.forward_alexnet, 'Alexnet') + process_model(models.alexnet(self.pretrained), self.image, _C_tests.forward_alexnet, "Alexnet") def test_vgg11(self): - process_model(models.vgg11(self.pretrained), self.image, _C_tests.forward_vgg11, 'VGG11') + process_model(models.vgg11(self.pretrained), self.image, _C_tests.forward_vgg11, "VGG11") def test_vgg13(self): - process_model(models.vgg13(self.pretrained), self.image, _C_tests.forward_vgg13, 'VGG13') + process_model(models.vgg13(self.pretrained), self.image, _C_tests.forward_vgg13, "VGG13") def test_vgg16(self): - process_model(models.vgg16(self.pretrained), self.image, _C_tests.forward_vgg16, 'VGG16') + process_model(models.vgg16(self.pretrained), self.image, _C_tests.forward_vgg16, "VGG16") def test_vgg19(self): - process_model(models.vgg19(self.pretrained), self.image, _C_tests.forward_vgg19, 'VGG19') + process_model(models.vgg19(self.pretrained), self.image, _C_tests.forward_vgg19, "VGG19") def test_vgg11_bn(self): - process_model(models.vgg11_bn(self.pretrained), self.image, _C_tests.forward_vgg11bn, 'VGG11BN') + process_model(models.vgg11_bn(self.pretrained), self.image, _C_tests.forward_vgg11bn, "VGG11BN") def test_vgg13_bn(self): - process_model(models.vgg13_bn(self.pretrained), self.image, _C_tests.forward_vgg13bn, 'VGG13BN') + process_model(models.vgg13_bn(self.pretrained), self.image, _C_tests.forward_vgg13bn, "VGG13BN") def test_vgg16_bn(self): - process_model(models.vgg16_bn(self.pretrained), self.image, _C_tests.forward_vgg16bn, 'VGG16BN') + process_model(models.vgg16_bn(self.pretrained), self.image, _C_tests.forward_vgg16bn, "VGG16BN") def test_vgg19_bn(self): - process_model(models.vgg19_bn(self.pretrained), self.image, _C_tests.forward_vgg19bn, 'VGG19BN') + process_model(models.vgg19_bn(self.pretrained), self.image, _C_tests.forward_vgg19bn, "VGG19BN") def test_resnet18(self): - process_model(models.resnet18(self.pretrained), self.image, _C_tests.forward_resnet18, 'Resnet18') + process_model(models.resnet18(self.pretrained), self.image, _C_tests.forward_resnet18, "Resnet18") def test_resnet34(self): - process_model(models.resnet34(self.pretrained), self.image, _C_tests.forward_resnet34, 'Resnet34') + process_model(models.resnet34(self.pretrained), self.image, _C_tests.forward_resnet34, "Resnet34") def test_resnet50(self): - process_model(models.resnet50(self.pretrained), self.image, _C_tests.forward_resnet50, 'Resnet50') + process_model(models.resnet50(self.pretrained), self.image, _C_tests.forward_resnet50, "Resnet50") def test_resnet101(self): - process_model(models.resnet101(self.pretrained), self.image, _C_tests.forward_resnet101, 'Resnet101') + process_model(models.resnet101(self.pretrained), self.image, _C_tests.forward_resnet101, "Resnet101") def test_resnet152(self): - process_model(models.resnet152(self.pretrained), self.image, _C_tests.forward_resnet152, 'Resnet152') + process_model(models.resnet152(self.pretrained), self.image, _C_tests.forward_resnet152, "Resnet152") def test_resnext50_32x4d(self): - process_model(models.resnext50_32x4d(), self.image, _C_tests.forward_resnext50_32x4d, 'ResNext50_32x4d') + process_model(models.resnext50_32x4d(), self.image, _C_tests.forward_resnext50_32x4d, "ResNext50_32x4d") def test_resnext101_32x8d(self): - process_model(models.resnext101_32x8d(), self.image, _C_tests.forward_resnext101_32x8d, 'ResNext101_32x8d') + process_model(models.resnext101_32x8d(), self.image, _C_tests.forward_resnext101_32x8d, "ResNext101_32x8d") def test_wide_resnet50_2(self): - process_model(models.wide_resnet50_2(), self.image, _C_tests.forward_wide_resnet50_2, 'WideResNet50_2') + process_model(models.wide_resnet50_2(), self.image, _C_tests.forward_wide_resnet50_2, "WideResNet50_2") def test_wide_resnet101_2(self): - process_model(models.wide_resnet101_2(), self.image, _C_tests.forward_wide_resnet101_2, 'WideResNet101_2') + process_model(models.wide_resnet101_2(), self.image, _C_tests.forward_wide_resnet101_2, "WideResNet101_2") def test_squeezenet1_0(self): - process_model(models.squeezenet1_0(self.pretrained), self.image, - _C_tests.forward_squeezenet1_0, 'Squeezenet1.0') + process_model( + models.squeezenet1_0(self.pretrained), self.image, _C_tests.forward_squeezenet1_0, "Squeezenet1.0" + ) def test_squeezenet1_1(self): - process_model(models.squeezenet1_1(self.pretrained), self.image, - _C_tests.forward_squeezenet1_1, 'Squeezenet1.1') + process_model( + models.squeezenet1_1(self.pretrained), self.image, _C_tests.forward_squeezenet1_1, "Squeezenet1.1" + ) def test_densenet121(self): - process_model(models.densenet121(self.pretrained), self.image, _C_tests.forward_densenet121, 'Densenet121') + process_model(models.densenet121(self.pretrained), self.image, _C_tests.forward_densenet121, "Densenet121") def test_densenet169(self): - process_model(models.densenet169(self.pretrained), self.image, _C_tests.forward_densenet169, 'Densenet169') + process_model(models.densenet169(self.pretrained), self.image, _C_tests.forward_densenet169, "Densenet169") def test_densenet201(self): - process_model(models.densenet201(self.pretrained), self.image, _C_tests.forward_densenet201, 'Densenet201') + process_model(models.densenet201(self.pretrained), self.image, _C_tests.forward_densenet201, "Densenet201") def test_densenet161(self): - process_model(models.densenet161(self.pretrained), self.image, _C_tests.forward_densenet161, 'Densenet161') + process_model(models.densenet161(self.pretrained), self.image, _C_tests.forward_densenet161, "Densenet161") def test_mobilenet_v2(self): - process_model(models.mobilenet_v2(self.pretrained), self.image, _C_tests.forward_mobilenetv2, 'MobileNet') + process_model(models.mobilenet_v2(self.pretrained), self.image, _C_tests.forward_mobilenetv2, "MobileNet") def test_googlenet(self): - process_model(models.googlenet(self.pretrained), self.image, _C_tests.forward_googlenet, 'GoogLeNet') + process_model(models.googlenet(self.pretrained), self.image, _C_tests.forward_googlenet, "GoogLeNet") def test_mnasnet0_5(self): - process_model(models.mnasnet0_5(self.pretrained), self.image, _C_tests.forward_mnasnet0_5, 'MNASNet0_5') + process_model(models.mnasnet0_5(self.pretrained), self.image, _C_tests.forward_mnasnet0_5, "MNASNet0_5") def test_mnasnet0_75(self): - process_model(models.mnasnet0_75(self.pretrained), self.image, _C_tests.forward_mnasnet0_75, 'MNASNet0_75') + process_model(models.mnasnet0_75(self.pretrained), self.image, _C_tests.forward_mnasnet0_75, "MNASNet0_75") def test_mnasnet1_0(self): - process_model(models.mnasnet1_0(self.pretrained), self.image, _C_tests.forward_mnasnet1_0, 'MNASNet1_0') + process_model(models.mnasnet1_0(self.pretrained), self.image, _C_tests.forward_mnasnet1_0, "MNASNet1_0") def test_mnasnet1_3(self): - process_model(models.mnasnet1_3(self.pretrained), self.image, _C_tests.forward_mnasnet1_3, 'MNASNet1_3') + process_model(models.mnasnet1_3(self.pretrained), self.image, _C_tests.forward_mnasnet1_3, "MNASNet1_3") def test_inception_v3(self): self.image = read_image2() - process_model(models.inception_v3(self.pretrained), self.image, _C_tests.forward_inceptionv3, 'Inceptionv3') + process_model(models.inception_v3(self.pretrained), self.image, _C_tests.forward_inceptionv3, "Inceptionv3") -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/test/test_datasets.py b/test/test_datasets.py index 2f4662fbac9..d2dc4ea6958 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -2,10 +2,10 @@ import contextlib import io import itertools +import json import os import pathlib import pickle -import json import random import shutil import string @@ -13,9 +13,9 @@ import xml.etree.ElementTree as ET import zipfile -import PIL import datasets_utils import numpy as np +import PIL import pytest import torch import torch.nn.functional as F @@ -24,8 +24,7 @@ class STL10TestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.STL10 - ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( - split=("train", "test", "unlabeled", "train+unlabeled")) + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test", "unlabeled", "train+unlabeled")) @staticmethod def _make_binary_file(num_elements, root, name): @@ -206,11 +205,11 @@ def inject_fake_data(self, tmpdir, config): class WIDERFaceTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.WIDERFace FEATURE_TYPES = (PIL.Image.Image, (dict, type(None))) # test split returns None as target - ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=('train', 'val', 'test')) + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val", "test")) def inject_fake_data(self, tmpdir, config): - widerface_dir = pathlib.Path(tmpdir) / 'widerface' - annotations_dir = widerface_dir / 'wider_face_split' + widerface_dir = pathlib.Path(tmpdir) / "widerface" + annotations_dir = widerface_dir / "wider_face_split" os.makedirs(annotations_dir) split_to_idx = split_to_num_examples = { @@ -220,21 +219,21 @@ def inject_fake_data(self, tmpdir, config): } # We need to create all folders regardless of the split in config - for split in ('train', 'val', 'test'): + for split in ("train", "val", "test"): split_idx = split_to_idx[split] num_examples = split_to_num_examples[split] datasets_utils.create_image_folder( root=tmpdir, - name=widerface_dir / f'WIDER_{split}' / 'images' / '0--Parade', + name=widerface_dir / f"WIDER_{split}" / "images" / "0--Parade", file_name_fn=lambda image_idx: f"0_Parade_marchingband_1_{split_idx + image_idx}.jpg", num_examples=num_examples, ) annotation_file_name = { - 'train': annotations_dir / 'wider_face_train_bbx_gt.txt', - 'val': annotations_dir / 'wider_face_val_bbx_gt.txt', - 'test': annotations_dir / 'wider_face_test_filelist.txt', + "train": annotations_dir / "wider_face_train_bbx_gt.txt", + "val": annotations_dir / "wider_face_val_bbx_gt.txt", + "test": annotations_dir / "wider_face_test_filelist.txt", }[split] annotation_content = { @@ -267,9 +266,7 @@ class CityScapesTestCase(datasets_utils.ImageDatasetTestCase): "color", ) ADDITIONAL_CONFIGS = ( - *datasets_utils.combinations_grid( - mode=("fine",), split=("train", "test", "val"), target_type=TARGET_TYPES - ), + *datasets_utils.combinations_grid(mode=("fine",), split=("train", "test", "val"), target_type=TARGET_TYPES), *datasets_utils.combinations_grid( mode=("coarse",), split=("train", "train_extra", "val"), @@ -324,6 +321,7 @@ def inject_fake_data(self, tmpdir, config): gt_dir = tmpdir / f"gt{mode}" for split in mode_to_splits[mode]: for city in cities: + def make_image(name, size=10): datasets_utils.create_image_folder( root=gt_dir / split, @@ -332,6 +330,7 @@ def make_image(name, size=10): size=size, num_examples=1, ) + make_image(f"{city}_000000_000000_gt{mode}_instanceIds.png") make_image(f"{city}_000000_000000_gt{mode}_labelIds.png") make_image(f"{city}_000000_000000_gt{mode}_color.png", size=(4, 10, 10)) @@ -341,7 +340,7 @@ def make_image(name, size=10): json.dump(polygon_target, outfile) # Create leftImg8bit folder - for split in ['test', 'train_extra', 'train', 'val']: + for split in ["test", "train_extra", "train", "val"]: for city in cities: datasets_utils.create_image_folder( root=tmpdir / "leftImg8bit" / split, @@ -350,13 +349,13 @@ def make_image(name, size=10): num_examples=1, ) - info = {'num_examples': len(cities)} - if config['target_type'] == 'polygon': - info['expected_polygon_target'] = polygon_target + info = {"num_examples": len(cities)} + if config["target_type"] == "polygon": + info["expected_polygon_target"] = polygon_target return info def test_combined_targets(self): - target_types = ['semantic', 'polygon', 'color'] + target_types = ["semantic", "polygon", "color"] with self.create_dataset(target_type=target_types) as (dataset, _): output = dataset[0] @@ -370,32 +369,32 @@ def test_combined_targets(self): assert isinstance(output[1][2], PIL.Image.Image) # color def test_feature_types_target_color(self): - with self.create_dataset(target_type='color') as (dataset, _): + with self.create_dataset(target_type="color") as (dataset, _): color_img, color_target = dataset[0] assert isinstance(color_img, PIL.Image.Image) assert np.array(color_target).shape[2] == 4 def test_feature_types_target_polygon(self): - with self.create_dataset(target_type='polygon') as (dataset, info): + with self.create_dataset(target_type="polygon") as (dataset, info): polygon_img, polygon_target = dataset[0] assert isinstance(polygon_img, PIL.Image.Image) - (polygon_target, info['expected_polygon_target']) + (polygon_target, info["expected_polygon_target"]) class ImageNetTestCase(datasets_utils.ImageDatasetTestCase): DATASET_CLASS = datasets.ImageNet - REQUIRED_PACKAGES = ('scipy',) - ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=('train', 'val')) + REQUIRED_PACKAGES = ("scipy",) + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val")) def inject_fake_data(self, tmpdir, config): tmpdir = pathlib.Path(tmpdir) - wnid = 'n01234567' - if config['split'] == 'train': + wnid = "n01234567" + if config["split"] == "train": num_examples = 3 datasets_utils.create_image_folder( root=tmpdir, - name=tmpdir / 'train' / wnid / wnid, + name=tmpdir / "train" / wnid / wnid, file_name_fn=lambda image_idx: f"{wnid}_{image_idx}.JPEG", num_examples=num_examples, ) @@ -403,13 +402,13 @@ def inject_fake_data(self, tmpdir, config): num_examples = 1 datasets_utils.create_image_folder( root=tmpdir, - name=tmpdir / 'val' / wnid, + name=tmpdir / "val" / wnid, file_name_fn=lambda image_ifx: "ILSVRC2012_val_0000000{image_idx}.JPEG", num_examples=num_examples, ) wnid_to_classes = {wnid: [1]} - torch.save((wnid_to_classes, None), tmpdir / 'meta.bin') + torch.save((wnid_to_classes, None), tmpdir / "meta.bin") return num_examples @@ -596,7 +595,7 @@ def test_attr_names(self): assert tuple(dataset.attr_names) == info["attr_names"] def test_images_names_split(self): - with self.create_dataset(split='all') as (dataset, _): + with self.create_dataset(split="all") as (dataset, _): all_imgs_names = set(dataset.filename) merged_imgs_names = set() @@ -888,10 +887,7 @@ def inject_fake_data(self, tmpdir, config): return num_images @contextlib.contextmanager - def create_dataset( - self, - *args, **kwargs - ): + def create_dataset(self, *args, **kwargs): with super().create_dataset(*args, **kwargs) as output: yield output # Currently datasets.LSUN caches the keys in the current directory rather than in the root directory. Thus, @@ -951,14 +947,12 @@ def test_not_found_or_corrupted(self): class KineticsTestCase(datasets_utils.VideoDatasetTestCase): DATASET_CLASS = datasets.Kinetics - ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( - split=("train", "val"), num_classes=("400", "600", "700") - ) + ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val"), num_classes=("400", "600", "700")) def inject_fake_data(self, tmpdir, config): classes = ("Abseiling", "Zumba") num_videos_per_class = 2 - tmpdir = pathlib.Path(tmpdir) / config['split'] + tmpdir = pathlib.Path(tmpdir) / config["split"] digits = string.ascii_letters + string.digits + "-_" for cls in classes: datasets_utils.create_video_folder( @@ -1582,7 +1576,7 @@ def test_is_valid_file(self, config): # We need to explicitly pass extensions=None here or otherwise it would be filled by the value from the # DEFAULT_CONFIG. with self.create_dataset( - config, extensions=None, is_valid_file=lambda file: pathlib.Path(file).suffix[1:] in extensions + config, extensions=None, is_valid_file=lambda file: pathlib.Path(file).suffix[1:] in extensions ) as (dataset, info): assert len(dataset) == info["num_examples"] @@ -1668,7 +1662,7 @@ def inject_fake_data(self, tmpdir, config): file = f"{split}_32x32.mat" images = np.zeros((32, 32, 3, num_examples), dtype=np.uint8) targets = np.zeros((num_examples,), dtype=np.uint8) - sio.savemat(os.path.join(tmpdir, file), {'X': images, 'y': targets}) + sio.savemat(os.path.join(tmpdir, file), {"X": images, "y": targets}) return num_examples @@ -1703,8 +1697,7 @@ class Places365TestCase(datasets_utils.ImageDatasetTestCase): # (file, idx) _FILE_LIST_CONTENT = ( ("Places365_val_00000001.png", 0), - *((f"{category}/Places365_train_00000001.png", idx) - for category, idx in _CATEGORIES_CONTENT), + *((f"{category}/Places365_train_00000001.png", idx) for category, idx in _CATEGORIES_CONTENT), ) @staticmethod @@ -1744,8 +1737,8 @@ def _make_images_archive(root, split, small): return [(os.path.join(root, folder_name, image), idx) for image, idx in zip(images, idcs)] def inject_fake_data(self, tmpdir, config): - self._make_devkit_archive(tmpdir, config['split']) - return len(self._make_images_archive(tmpdir, config['split'], config['small'])) + self._make_devkit_archive(tmpdir, config["split"]) + return len(self._make_images_archive(tmpdir, config["split"], config["small"])) def test_classes(self): classes = list(map(lambda x: x[0], self._CATEGORIES_CONTENT)) @@ -1759,7 +1752,7 @@ def test_class_to_idx(self): def test_images_download_preexisting(self): with pytest.raises(RuntimeError): - with self.create_dataset({'download': True}): + with self.create_dataset({"download": True}): pass @@ -1805,22 +1798,17 @@ class LFWPeopleTestCase(datasets_utils.DatasetTestCase): DATASET_CLASS = datasets.LFWPeople FEATURE_TYPES = (PIL.Image.Image, int) ADDITIONAL_CONFIGS = datasets_utils.combinations_grid( - split=('10fold', 'train', 'test'), - image_set=('original', 'funneled', 'deepfunneled') + split=("10fold", "train", "test"), image_set=("original", "funneled", "deepfunneled") ) - _IMAGES_DIR = { - "original": "lfw", - "funneled": "lfw_funneled", - "deepfunneled": "lfw-deepfunneled" - } - _file_id = {'10fold': '', 'train': 'DevTrain', 'test': 'DevTest'} + _IMAGES_DIR = {"original": "lfw", "funneled": "lfw_funneled", "deepfunneled": "lfw-deepfunneled"} + _file_id = {"10fold": "", "train": "DevTrain", "test": "DevTest"} def inject_fake_data(self, tmpdir, config): tmpdir = pathlib.Path(tmpdir) / "lfw-py" os.makedirs(tmpdir, exist_ok=True) return dict( num_examples=self._create_images_dir(tmpdir, self._IMAGES_DIR[config["image_set"]], config["split"]), - split=config["split"] + split=config["split"], ) def _create_images_dir(self, root, idir, split): diff --git a/test/test_datasets_download.py b/test/test_datasets_download.py index 0cf86918575..4bf31eba92b 100644 --- a/test/test_datasets_download.py +++ b/test/test_datasets_download.py @@ -1,18 +1,17 @@ import contextlib import itertools +import tempfile import time import unittest.mock +import warnings from datetime import datetime from distutils import dir_util from os import path from urllib.error import HTTPError, URLError from urllib.parse import urlparse from urllib.request import urlopen, Request -import tempfile -import warnings import pytest - from torchvision import datasets from torchvision.datasets.utils import ( download_url, diff --git a/test/test_datasets_samplers.py b/test/test_datasets_samplers.py index c76fd1849fc..56c2a930dc8 100644 --- a/test/test_datasets_samplers.py +++ b/test/test_datasets_samplers.py @@ -1,9 +1,11 @@ import contextlib -import sys import os -import torch -import pytest +import sys +import pytest +import torch +from common_utils import get_list_of_videos, assert_equal +from torchvision import get_video_backend from torchvision import io from torchvision.datasets.samplers import ( DistributedSampler, @@ -11,9 +13,6 @@ UniformClipSampler, ) from torchvision.datasets.video_utils import VideoClips, unfold -from torchvision import get_video_backend - -from common_utils import get_list_of_videos, assert_equal @pytest.mark.skipif(not io.video._av_available(), reason="this test requires av") @@ -24,7 +23,7 @@ def test_random_clip_sampler(self, tmpdir): sampler = RandomClipSampler(video_clips, 3) assert len(sampler) == 3 * 3 indices = torch.tensor(list(iter(sampler))) - videos = torch.div(indices, 5, rounding_mode='floor') + videos = torch.div(indices, 5, rounding_mode="floor") v_idxs, count = torch.unique(videos, return_counts=True) assert_equal(v_idxs, torch.tensor([0, 1, 2])) assert_equal(count, torch.tensor([3, 3, 3])) @@ -41,7 +40,7 @@ def test_random_clip_sampler_unequal(self, tmpdir): indices.remove(0) indices.remove(1) indices = torch.tensor(indices) - 2 - videos = torch.div(indices, 5, rounding_mode='floor') + videos = torch.div(indices, 5, rounding_mode="floor") v_idxs, count = torch.unique(videos, return_counts=True) assert_equal(v_idxs, torch.tensor([0, 1])) assert_equal(count, torch.tensor([3, 3])) @@ -52,7 +51,7 @@ def test_uniform_clip_sampler(self, tmpdir): sampler = UniformClipSampler(video_clips, 3) assert len(sampler) == 3 * 3 indices = torch.tensor(list(iter(sampler))) - videos = torch.div(indices, 5, rounding_mode='floor') + videos = torch.div(indices, 5, rounding_mode="floor") v_idxs, count = torch.unique(videos, return_counts=True) assert_equal(v_idxs, torch.tensor([0, 1, 2])) assert_equal(count, torch.tensor([3, 3, 3])) @@ -92,5 +91,5 @@ def test_distributed_sampler_and_uniform_clip_sampler(self, tmpdir): assert_equal(indices, torch.tensor([5, 7, 9, 0, 2, 4])) -if __name__ == '__main__': +if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/test_datasets_utils.py b/test/test_datasets_utils.py index 8753314e789..00cdf8ad2c7 100644 --- a/test/test_datasets_utils.py +++ b/test/test_datasets_utils.py @@ -1,22 +1,23 @@ import bz2 +import contextlib +import gzip +import itertools +import lzma import os -import torchvision.datasets.utils as utils -import pytest -import zipfile import tarfile -import gzip import warnings -from torch._utils_internal import get_file_path_2 +import zipfile from urllib.error import URLError -import itertools -import lzma -import contextlib +import pytest +import torchvision.datasets.utils as utils +from torch._utils_internal import get_file_path_2 from torchvision.datasets.utils import _COMPRESSED_FILE_OPENERS TEST_FILE = get_file_path_2( - os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg', 'grace_hopper_517x606.jpg') + os.path.dirname(os.path.abspath(__file__)), "assets", "encode_jpeg", "grace_hopper_517x606.jpg" +) def patch_url_redirection(mocker, redirect_url): @@ -60,16 +61,16 @@ def test_get_redirect_url_max_hops_exceeded(self, mocker): def test_check_md5(self): fpath = TEST_FILE - correct_md5 = '9c0bb82894bb3af7f7675ef2b3b6dcdc' - false_md5 = '' + correct_md5 = "9c0bb82894bb3af7f7675ef2b3b6dcdc" + false_md5 = "" assert utils.check_md5(fpath, correct_md5) assert not utils.check_md5(fpath, false_md5) def test_check_integrity(self): existing_fpath = TEST_FILE - nonexisting_fpath = '' - correct_md5 = '9c0bb82894bb3af7f7675ef2b3b6dcdc' - false_md5 = '' + nonexisting_fpath = "" + correct_md5 = "9c0bb82894bb3af7f7675ef2b3b6dcdc" + false_md5 = "" assert utils.check_integrity(existing_fpath, correct_md5) assert not utils.check_integrity(existing_fpath, false_md5) assert utils.check_integrity(existing_fpath) @@ -87,31 +88,35 @@ def test_get_google_drive_file_id_invalid_url(self): assert utils._get_google_drive_file_id(url) is None - @pytest.mark.parametrize('file, expected', [ - ("foo.tar.bz2", (".tar.bz2", ".tar", ".bz2")), - ("foo.tar.xz", (".tar.xz", ".tar", ".xz")), - ("foo.tar", (".tar", ".tar", None)), - ("foo.tar.gz", (".tar.gz", ".tar", ".gz")), - ("foo.tbz", (".tbz", ".tar", ".bz2")), - ("foo.tbz2", (".tbz2", ".tar", ".bz2")), - ("foo.tgz", (".tgz", ".tar", ".gz")), - ("foo.bz2", (".bz2", None, ".bz2")), - ("foo.gz", (".gz", None, ".gz")), - ("foo.zip", (".zip", ".zip", None)), - ("foo.xz", (".xz", None, ".xz")), - ("foo.bar.tar.gz", (".tar.gz", ".tar", ".gz")), - ("foo.bar.gz", (".gz", None, ".gz")), - ("foo.bar.zip", (".zip", ".zip", None))]) + @pytest.mark.parametrize( + "file, expected", + [ + ("foo.tar.bz2", (".tar.bz2", ".tar", ".bz2")), + ("foo.tar.xz", (".tar.xz", ".tar", ".xz")), + ("foo.tar", (".tar", ".tar", None)), + ("foo.tar.gz", (".tar.gz", ".tar", ".gz")), + ("foo.tbz", (".tbz", ".tar", ".bz2")), + ("foo.tbz2", (".tbz2", ".tar", ".bz2")), + ("foo.tgz", (".tgz", ".tar", ".gz")), + ("foo.bz2", (".bz2", None, ".bz2")), + ("foo.gz", (".gz", None, ".gz")), + ("foo.zip", (".zip", ".zip", None)), + ("foo.xz", (".xz", None, ".xz")), + ("foo.bar.tar.gz", (".tar.gz", ".tar", ".gz")), + ("foo.bar.gz", (".gz", None, ".gz")), + ("foo.bar.zip", (".zip", ".zip", None)), + ], + ) def test_detect_file_type(self, file, expected): assert utils._detect_file_type(file) == expected - @pytest.mark.parametrize('file', ["foo", "foo.tar.baz", "foo.bar"]) + @pytest.mark.parametrize("file", ["foo", "foo.tar.baz", "foo.bar"]) def test_detect_file_type_incompatible(self, file): # tests detect file type for no extension, unknown compression and unknown partial extension with pytest.raises(RuntimeError): utils._detect_file_type(file) - @pytest.mark.parametrize('extension', [".bz2", ".gz", ".xz"]) + @pytest.mark.parametrize("extension", [".bz2", ".gz", ".xz"]) def test_decompress(self, extension, tmpdir): def create_compressed(root, content="this is the content"): file = os.path.join(root, "file") @@ -152,8 +157,8 @@ def create_compressed(root, content="this is the content"): assert not os.path.exists(compressed) - @pytest.mark.parametrize('extension', [".gz", ".xz"]) - @pytest.mark.parametrize('remove_finished', [True, False]) + @pytest.mark.parametrize("extension", [".gz", ".xz"]) + @pytest.mark.parametrize("remove_finished", [True, False]) def test_extract_archive_defer_to_decompress(self, extension, remove_finished, mocker): filename = "foo" file = f"{filename}{extension}" @@ -182,8 +187,9 @@ def create_archive(root, content="this is the content"): with open(file, "r") as fh: assert fh.read() == content - @pytest.mark.parametrize('extension, mode', [ - ('.tar', 'w'), ('.tar.gz', 'w:gz'), ('.tgz', 'w:gz'), ('.tar.xz', 'w:xz')]) + @pytest.mark.parametrize( + "extension, mode", [(".tar", "w"), (".tar.gz", "w:gz"), (".tgz", "w:gz"), (".tar.xz", "w:xz")] + ) def test_extract_tar(self, extension, mode, tmpdir): def create_archive(root, extension, mode, content="this is the content"): src = os.path.join(root, "src.txt") @@ -213,5 +219,5 @@ def test_verify_str_arg(self): pytest.raises(ValueError, utils.verify_str_arg, "b", ("a",), "arg") -if __name__ == '__main__': +if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/test_datasets_video_utils.py b/test/test_datasets_video_utils.py index 9671d1d8f4c..3377577a047 100644 --- a/test/test_datasets_video_utils.py +++ b/test/test_datasets_video_utils.py @@ -1,39 +1,37 @@ import contextlib import os -import torch -import pytest +import pytest +import torch +from common_utils import get_list_of_videos, assert_equal from torchvision import io from torchvision.datasets.video_utils import VideoClips, unfold -from common_utils import get_list_of_videos, assert_equal - class TestVideo: - def test_unfold(self): a = torch.arange(7) r = unfold(a, 3, 3, 1) - expected = torch.tensor([ - [0, 1, 2], - [3, 4, 5], - ]) + expected = torch.tensor( + [ + [0, 1, 2], + [3, 4, 5], + ] + ) assert_equal(r, expected) r = unfold(a, 3, 2, 1) - expected = torch.tensor([ - [0, 1, 2], - [2, 3, 4], - [4, 5, 6] - ]) + expected = torch.tensor([[0, 1, 2], [2, 3, 4], [4, 5, 6]]) assert_equal(r, expected) r = unfold(a, 3, 2, 2) - expected = torch.tensor([ - [0, 2, 4], - [2, 4, 6], - ]) + expected = torch.tensor( + [ + [0, 2, 4], + [2, 4, 6], + ] + ) assert_equal(r, expected) @pytest.mark.skipif(not io.video._av_available(), reason="this test requires av") @@ -79,8 +77,7 @@ def test_compute_clips_for_video(self): orig_fps = 30 duration = float(len(video_pts)) / orig_fps new_fps = 13 - clips, idxs = VideoClips.compute_clips_for_video(video_pts, num_frames, num_frames, - orig_fps, new_fps) + clips, idxs = VideoClips.compute_clips_for_video(video_pts, num_frames, num_frames, orig_fps, new_fps) resampled_idxs = VideoClips._resample_video_idx(int(duration * new_fps), orig_fps, new_fps) assert len(clips) == 1 assert_equal(clips, idxs) @@ -91,8 +88,7 @@ def test_compute_clips_for_video(self): orig_fps = 30 duration = float(len(video_pts)) / orig_fps new_fps = 12 - clips, idxs = VideoClips.compute_clips_for_video(video_pts, num_frames, num_frames, - orig_fps, new_fps) + clips, idxs = VideoClips.compute_clips_for_video(video_pts, num_frames, num_frames, orig_fps, new_fps) resampled_idxs = VideoClips._resample_video_idx(int(duration * new_fps), orig_fps, new_fps) assert len(clips) == 3 assert_equal(clips, idxs) @@ -103,11 +99,10 @@ def test_compute_clips_for_video(self): orig_fps = 30 new_fps = 13 with pytest.warns(UserWarning): - clips, idxs = VideoClips.compute_clips_for_video(video_pts, num_frames, num_frames, - orig_fps, new_fps) + clips, idxs = VideoClips.compute_clips_for_video(video_pts, num_frames, num_frames, orig_fps, new_fps) assert len(clips) == 0 assert len(idxs) == 0 -if __name__ == '__main__': +if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/test_datasets_video_utils_opt.py b/test/test_datasets_video_utils_opt.py index 8075c701ed9..993e77cdf4c 100644 --- a/test/test_datasets_video_utils_opt.py +++ b/test/test_datasets_video_utils_opt.py @@ -1,11 +1,12 @@ import unittest -from torchvision import set_video_backend + import test_datasets_video_utils +from torchvision import set_video_backend # Disabling the video backend switching temporarily # set_video_backend('video_reader') -if __name__ == '__main__': +if __name__ == "__main__": suite = unittest.TestLoader().loadTestsFromModule(test_datasets_video_utils) unittest.TextTestRunner(verbosity=1).run(suite) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 1bdc752839e..de05d0b81fc 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -1,19 +1,17 @@ -from functools import partial -import itertools -import os import colorsys +import itertools import math +import os +from functools import partial +from typing import Dict, List, Sequence, Tuple import numpy as np import pytest - import torch -import torchvision.transforms.functional_tensor as F_t -import torchvision.transforms.functional_pil as F_pil -import torchvision.transforms.functional as F import torchvision.transforms as T -from torchvision.transforms import InterpolationMode - +import torchvision.transforms.functional as F +import torchvision.transforms.functional_pil as F_pil +import torchvision.transforms.functional_tensor as F_t from common_utils import ( cpu_and_gpu, needs_cuda, @@ -24,15 +22,14 @@ _test_fn_on_batch, assert_equal, ) - -from typing import Dict, List, Sequence, Tuple +from torchvision.transforms import InterpolationMode NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('fn', [F.get_image_size, F.get_image_num_channels]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("fn", [F.get_image_size, F.get_image_num_channels]) def test_image_sizes(device, fn): script_F = torch.jit.script(fn) @@ -57,10 +54,10 @@ def test_scale_channel(): # TODO: when # https://github.com/pytorch/pytorch/issues/53194 is fixed, # only use bincount and remove that test. size = (1_000,) - img_chan = torch.randint(0, 256, size=size).to('cpu') + img_chan = torch.randint(0, 256, size=size).to("cpu") scaled_cpu = F_t._scale_channel(img_chan) - scaled_cuda = F_t._scale_channel(img_chan.to('cuda')) - assert_equal(scaled_cpu, scaled_cuda.to('cpu')) + scaled_cuda = F_t._scale_channel(img_chan.to("cuda")) + assert_equal(scaled_cpu, scaled_cuda.to("cpu")) class TestRotate: @@ -69,18 +66,33 @@ class TestRotate: scripted_rotate = torch.jit.script(F.rotate) IMG_W = 26 - @pytest.mark.parametrize('device', cpu_and_gpu()) - @pytest.mark.parametrize('height, width', [(26, IMG_W), (32, IMG_W)]) - @pytest.mark.parametrize('center', [ - None, - (int(IMG_W * 0.3), int(IMG_W * 0.4)), - [int(IMG_W * 0.5), int(IMG_W * 0.6)], - ]) - @pytest.mark.parametrize('dt', ALL_DTYPES) - @pytest.mark.parametrize('angle', range(-180, 180, 17)) - @pytest.mark.parametrize('expand', [True, False]) - @pytest.mark.parametrize('fill', [None, [0, 0, 0], (1, 2, 3), [255, 255, 255], [1, ], (2.0, )]) - @pytest.mark.parametrize('fn', [F.rotate, scripted_rotate]) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("height, width", [(26, IMG_W), (32, IMG_W)]) + @pytest.mark.parametrize( + "center", + [ + None, + (int(IMG_W * 0.3), int(IMG_W * 0.4)), + [int(IMG_W * 0.5), int(IMG_W * 0.6)], + ], + ) + @pytest.mark.parametrize("dt", ALL_DTYPES) + @pytest.mark.parametrize("angle", range(-180, 180, 17)) + @pytest.mark.parametrize("expand", [True, False]) + @pytest.mark.parametrize( + "fill", + [ + None, + [0, 0, 0], + (1, 2, 3), + [255, 255, 255], + [ + 1, + ], + (2.0,), + ], + ) + @pytest.mark.parametrize("fn", [F.rotate, scripted_rotate]) def test_rotate(self, device, height, width, center, dt, angle, expand, fill, fn): tensor, pil_img = _create_data(height, width, device=device) @@ -101,8 +113,8 @@ def test_rotate(self, device, height, width, center, dt, angle, expand, fill, fn out_tensor = out_tensor.to(torch.uint8) assert out_tensor.shape == out_pil_tensor.shape, ( - f"{(height, width, NEAREST, dt, angle, expand, center)}: " - f"{out_tensor.shape} vs {out_pil_tensor.shape}") + f"{(height, width, NEAREST, dt, angle, expand, center)}: " f"{out_tensor.shape} vs {out_pil_tensor.shape}" + ) num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] @@ -110,10 +122,11 @@ def test_rotate(self, device, height, width, center, dt, angle, expand, fill, fn assert ratio_diff_pixels < 0.03, ( f"{(height, width, NEAREST, dt, angle, expand, center, fill)}: " f"{ratio_diff_pixels}\n{out_tensor[0, :7, :7]} vs \n" - f"{out_pil_tensor[0, :7, :7]}") + f"{out_pil_tensor[0, :7, :7]}" + ) - @pytest.mark.parametrize('device', cpu_and_gpu()) - @pytest.mark.parametrize('dt', ALL_DTYPES) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("dt", ALL_DTYPES) def test_rotate_batch(self, device, dt): if dt == torch.float16 and device == "cpu": # skip float16 on CPU case @@ -124,9 +137,7 @@ def test_rotate_batch(self, device, dt): batch_tensors = batch_tensors.to(dtype=dt) center = (20, 22) - _test_fn_on_batch( - batch_tensors, F.rotate, angle=32, interpolation=NEAREST, expand=True, center=center - ) + _test_fn_on_batch(batch_tensors, F.rotate, angle=32, interpolation=NEAREST, expand=True, center=center) def test_rotate_deprecation_resample(self): tensor, _ = _create_data(26, 26) @@ -150,9 +161,9 @@ class TestAffine: ALL_DTYPES = [None, torch.float32, torch.float64, torch.float16] scripted_affine = torch.jit.script(F.affine) - @pytest.mark.parametrize('device', cpu_and_gpu()) - @pytest.mark.parametrize('height, width', [(26, 26), (32, 26)]) - @pytest.mark.parametrize('dt', ALL_DTYPES) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("height, width", [(26, 26), (32, 26)]) + @pytest.mark.parametrize("dt", ALL_DTYPES) def test_identity_map(self, device, height, width, dt): # Tests on square and rectangular images tensor, pil_img = _create_data(height, width, device=device) @@ -173,19 +184,22 @@ def test_identity_map(self, device, height, width, dt): ) assert_equal(tensor, out_tensor, msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])) - @pytest.mark.parametrize('device', cpu_and_gpu()) - @pytest.mark.parametrize('height, width', [(26, 26)]) - @pytest.mark.parametrize('dt', ALL_DTYPES) - @pytest.mark.parametrize('angle, config', [ - (90, {'k': 1, 'dims': (-1, -2)}), - (45, None), - (30, None), - (-30, None), - (-45, None), - (-90, {'k': -1, 'dims': (-1, -2)}), - (180, {'k': 2, 'dims': (-1, -2)}), - ]) - @pytest.mark.parametrize('fn', [F.affine, scripted_affine]) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("height, width", [(26, 26)]) + @pytest.mark.parametrize("dt", ALL_DTYPES) + @pytest.mark.parametrize( + "angle, config", + [ + (90, {"k": 1, "dims": (-1, -2)}), + (45, None), + (30, None), + (-30, None), + (-45, None), + (-90, {"k": -1, "dims": (-1, -2)}), + (180, {"k": 2, "dims": (-1, -2)}), + ], + ) + @pytest.mark.parametrize("fn", [F.affine, scripted_affine]) def test_square_rotations(self, device, height, width, dt, angle, config, fn): # 2) Test rotation tensor, pil_img = _create_data(height, width, device=device) @@ -202,9 +216,7 @@ def test_square_rotations(self, device, height, width, dt, angle, config, fn): ) out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))).to(device) - out_tensor = fn( - tensor, angle=angle, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST - ) + out_tensor = fn(tensor, angle=angle, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST) if config is not None: assert_equal(torch.rot90(tensor, **config), out_tensor) @@ -218,11 +230,11 @@ def test_square_rotations(self, device, height, width, dt, angle, config, fn): ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7] ) - @pytest.mark.parametrize('device', cpu_and_gpu()) - @pytest.mark.parametrize('height, width', [(32, 26)]) - @pytest.mark.parametrize('dt', ALL_DTYPES) - @pytest.mark.parametrize('angle', [90, 45, 15, -30, -60, -120]) - @pytest.mark.parametrize('fn', [F.affine, scripted_affine]) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("height, width", [(32, 26)]) + @pytest.mark.parametrize("dt", ALL_DTYPES) + @pytest.mark.parametrize("angle", [90, 45, 15, -30, -60, -120]) + @pytest.mark.parametrize("fn", [F.affine, scripted_affine]) def test_rect_rotations(self, device, height, width, dt, angle, fn): # Tests on rectangular images tensor, pil_img = _create_data(height, width, device=device) @@ -239,9 +251,7 @@ def test_rect_rotations(self, device, height, width, dt, angle, fn): ) out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) - out_tensor = fn( - tensor, angle=angle, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST - ).cpu() + out_tensor = fn(tensor, angle=angle, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST).cpu() if out_tensor.dtype != torch.uint8: out_tensor = out_tensor.to(torch.uint8) @@ -253,11 +263,11 @@ def test_rect_rotations(self, device, height, width, dt, angle, fn): angle, ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7] ) - @pytest.mark.parametrize('device', cpu_and_gpu()) - @pytest.mark.parametrize('height, width', [(26, 26), (32, 26)]) - @pytest.mark.parametrize('dt', ALL_DTYPES) - @pytest.mark.parametrize('t', [[10, 12], (-12, -13)]) - @pytest.mark.parametrize('fn', [F.affine, scripted_affine]) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("height, width", [(26, 26), (32, 26)]) + @pytest.mark.parametrize("dt", ALL_DTYPES) + @pytest.mark.parametrize("t", [[10, 12], (-12, -13)]) + @pytest.mark.parametrize("fn", [F.affine, scripted_affine]) def test_translations(self, device, height, width, dt, t, fn): # 3) Test translation tensor, pil_img = _create_data(height, width, device=device) @@ -278,22 +288,41 @@ def test_translations(self, device, height, width, dt, t, fn): _assert_equal_tensor_to_pil(out_tensor, out_pil_img) - @pytest.mark.parametrize('device', cpu_and_gpu()) - @pytest.mark.parametrize('height, width', [(26, 26), (32, 26)]) - @pytest.mark.parametrize('dt', ALL_DTYPES) - @pytest.mark.parametrize('a, t, s, sh, f', [ - (45.5, [5, 6], 1.0, [0.0, 0.0], None), - (33, (5, -4), 1.0, [0.0, 0.0], [0, 0, 0]), - (45, [-5, 4], 1.2, [0.0, 0.0], (1, 2, 3)), - (33, (-4, -8), 2.0, [0.0, 0.0], [255, 255, 255]), - (85, (10, -10), 0.7, [0.0, 0.0], [1, ]), - (0, [0, 0], 1.0, [35.0, ], (2.0, )), - (-25, [0, 0], 1.2, [0.0, 15.0], None), - (-45, [-10, 0], 0.7, [2.0, 5.0], None), - (-45, [-10, -10], 1.2, [4.0, 5.0], None), - (-90, [0, 0], 1.0, [0.0, 0.0], None), - ]) - @pytest.mark.parametrize('fn', [F.affine, scripted_affine]) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("height, width", [(26, 26), (32, 26)]) + @pytest.mark.parametrize("dt", ALL_DTYPES) + @pytest.mark.parametrize( + "a, t, s, sh, f", + [ + (45.5, [5, 6], 1.0, [0.0, 0.0], None), + (33, (5, -4), 1.0, [0.0, 0.0], [0, 0, 0]), + (45, [-5, 4], 1.2, [0.0, 0.0], (1, 2, 3)), + (33, (-4, -8), 2.0, [0.0, 0.0], [255, 255, 255]), + ( + 85, + (10, -10), + 0.7, + [0.0, 0.0], + [ + 1, + ], + ), + ( + 0, + [0, 0], + 1.0, + [ + 35.0, + ], + (2.0,), + ), + (-25, [0, 0], 1.2, [0.0, 15.0], None), + (-45, [-10, 0], 0.7, [2.0, 5.0], None), + (-45, [-10, -10], 1.2, [4.0, 5.0], None), + (-90, [0, 0], 1.0, [0.0, 0.0], None), + ], + ) + @pytest.mark.parametrize("fn", [F.affine, scripted_affine]) def test_all_ops(self, device, height, width, dt, a, t, s, sh, f, fn): # 4) Test rotation + translation + scale + shear tensor, pil_img = _create_data(height, width, device=device) @@ -322,8 +351,8 @@ def test_all_ops(self, device, height, width, dt, a, t, s, sh, f, fn): (NEAREST, a, t, s, sh, f), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7] ) - @pytest.mark.parametrize('device', cpu_and_gpu()) - @pytest.mark.parametrize('dt', ALL_DTYPES) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("dt", ALL_DTYPES) def test_batches(self, device, dt): if dt == torch.float16 and device == "cpu": # skip float16 on CPU case @@ -333,11 +362,9 @@ def test_batches(self, device, dt): if dt is not None: batch_tensors = batch_tensors.to(dtype=dt) - _test_fn_on_batch( - batch_tensors, F.affine, angle=-43, translate=[-3, 4], scale=1.2, shear=[4.0, 5.0] - ) + _test_fn_on_batch(batch_tensors, F.affine, angle=-43, translate=[-3, 4], scale=1.2, shear=[4.0, 5.0]) - @pytest.mark.parametrize('device', cpu_and_gpu()) + @pytest.mark.parametrize("device", cpu_and_gpu()) def test_warnings(self, device): tensor, pil_img = _create_data(26, 26, device=device) @@ -379,18 +406,27 @@ def _get_data_dims_and_points_for_perspective(): n = 10 for dim in data_dims: - points += [ - (dim, T.RandomPerspective.get_params(dim[1], dim[0], i / n)) - for i in range(n) - ] + points += [(dim, T.RandomPerspective.get_params(dim[1], dim[0], i / n)) for i in range(n)] return dims_and_points -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('dims_and_points', _get_data_dims_and_points_for_perspective()) -@pytest.mark.parametrize('dt', [None, torch.float32, torch.float64, torch.float16]) -@pytest.mark.parametrize('fill', (None, [0, 0, 0], [1, 2, 3], [255, 255, 255], [1, ], (2.0, ))) -@pytest.mark.parametrize('fn', [F.perspective, torch.jit.script(F.perspective)]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("dims_and_points", _get_data_dims_and_points_for_perspective()) +@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16]) +@pytest.mark.parametrize( + "fill", + ( + None, + [0, 0, 0], + [1, 2, 3], + [255, 255, 255], + [ + 1, + ], + (2.0,), + ), +) +@pytest.mark.parametrize("fn", [F.perspective, torch.jit.script(F.perspective)]) def test_perspective_pil_vs_tensor(device, dims_and_points, dt, fill, fn): if dt == torch.float16 and device == "cpu": @@ -405,8 +441,9 @@ def test_perspective_pil_vs_tensor(device, dims_and_points, dt, fill, fn): interpolation = NEAREST fill_pil = int(fill[0]) if fill is not None and len(fill) == 1 else fill - out_pil_img = F.perspective(pil_img, startpoints=spoints, endpoints=epoints, interpolation=interpolation, - fill=fill_pil) + out_pil_img = F.perspective( + pil_img, startpoints=spoints, endpoints=epoints, interpolation=interpolation, fill=fill_pil + ) out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) out_tensor = fn(tensor, startpoints=spoints, endpoints=epoints, interpolation=interpolation, fill=fill).cpu() @@ -419,9 +456,9 @@ def test_perspective_pil_vs_tensor(device, dims_and_points, dt, fill, fn): assert ratio_diff_pixels < 0.05 -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('dims_and_points', _get_data_dims_and_points_for_perspective()) -@pytest.mark.parametrize('dt', [None, torch.float32, torch.float64, torch.float16]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("dims_and_points", _get_data_dims_and_points_for_perspective()) +@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16]) def test_perspective_batch(device, dims_and_points, dt): if dt == torch.float16 and device == "cpu": @@ -438,8 +475,12 @@ def test_perspective_batch(device, dims_and_points, dt): # the border may be entirely different due to small rounding errors. scripted_fn_atol = -1 if (dt == torch.float16 and device == "cuda") else 1e-8 _test_fn_on_batch( - batch_tensors, F.perspective, scripted_fn_atol=scripted_fn_atol, - startpoints=spoints, endpoints=epoints, interpolation=NEAREST + batch_tensors, + F.perspective, + scripted_fn_atol=scripted_fn_atol, + startpoints=spoints, + endpoints=epoints, + interpolation=NEAREST, ) @@ -454,11 +495,23 @@ def test_perspective_interpolation_warning(): assert_equal(res1, res2) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('dt', [None, torch.float32, torch.float64, torch.float16]) -@pytest.mark.parametrize('size', [32, 26, [32, ], [32, 32], (32, 32), [26, 35]]) -@pytest.mark.parametrize('max_size', [None, 34, 40, 1000]) -@pytest.mark.parametrize('interpolation', [BILINEAR, BICUBIC, NEAREST]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16]) +@pytest.mark.parametrize( + "size", + [ + 32, + 26, + [ + 32, + ], + [32, 32], + (32, 32), + [26, 35], + ], +) +@pytest.mark.parametrize("max_size", [None, 34, 40, 1000]) +@pytest.mark.parametrize("interpolation", [BILINEAR, BICUBIC, NEAREST]) def test_resize(device, dt, size, max_size, interpolation): if dt == torch.float16 and device == "cpu": @@ -483,7 +536,9 @@ def test_resize(device, dt, size, max_size, interpolation): assert resized_tensor.size()[1:] == resized_pil_img.size[::-1] - if interpolation not in [NEAREST, ]: + if interpolation not in [ + NEAREST, + ]: # We can not check values if mode = NEAREST, as results are different # E.g. resized_tensor = [[a, a, b, c, d, d, e, ...]] # E.g. resized_pil_img = [[a, b, c, c, d, e, f, ...]] @@ -496,21 +551,19 @@ def test_resize(device, dt, size, max_size, interpolation): _assert_approx_equal_tensor_to_pil(resized_tensor_f, resized_pil_img, tol=8.0) if isinstance(size, int): - script_size = [size, ] + script_size = [ + size, + ] else: script_size = size - resize_result = script_fn( - tensor, size=script_size, interpolation=interpolation, max_size=max_size - ) + resize_result = script_fn(tensor, size=script_size, interpolation=interpolation, max_size=max_size) assert_equal(resized_tensor, resize_result) - _test_fn_on_batch( - batch_tensors, F.resize, size=script_size, interpolation=interpolation, max_size=max_size - ) + _test_fn_on_batch(batch_tensors, F.resize, size=script_size, interpolation=interpolation, max_size=max_size) -@pytest.mark.parametrize('device', cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_gpu()) def test_resize_asserts(device): tensor, pil_img = _create_data(26, 36, device=device) @@ -530,10 +583,10 @@ def test_resize_asserts(device): F.resize(img, size=32, max_size=32) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('dt', [None, torch.float32, torch.float64, torch.float16]) -@pytest.mark.parametrize('size', [[96, 72], [96, 420], [420, 72]]) -@pytest.mark.parametrize('interpolation', [BILINEAR, BICUBIC]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16]) +@pytest.mark.parametrize("size", [[96, 72], [96, 420], [420, 72]]) +@pytest.mark.parametrize("interpolation", [BILINEAR, BICUBIC]) def test_resize_antialias(device, dt, size, interpolation): if dt == torch.float16 and device == "cpu": @@ -558,9 +611,7 @@ def test_resize_antialias(device, dt, size, interpolation): if resized_tensor_f.dtype == torch.uint8: resized_tensor_f = resized_tensor_f.to(torch.float) - _assert_approx_equal_tensor_to_pil( - resized_tensor_f, resized_pil_img, tol=0.5, msg=f"{size}, {interpolation}, {dt}" - ) + _assert_approx_equal_tensor_to_pil(resized_tensor_f, resized_pil_img, tol=0.5, msg=f"{size}, {interpolation}, {dt}") accepted_tol = 1.0 + 1e-5 if interpolation == BICUBIC: @@ -571,12 +622,13 @@ def test_resize_antialias(device, dt, size, interpolation): accepted_tol = 15.0 _assert_approx_equal_tensor_to_pil( - resized_tensor_f, resized_pil_img, tol=accepted_tol, agg_method="max", - msg=f"{size}, {interpolation}, {dt}" + resized_tensor_f, resized_pil_img, tol=accepted_tol, agg_method="max", msg=f"{size}, {interpolation}, {dt}" ) if isinstance(size, int): - script_size = [size, ] + script_size = [ + size, + ] else: script_size = size @@ -585,7 +637,7 @@ def test_resize_antialias(device, dt, size, interpolation): @needs_cuda -@pytest.mark.parametrize('interpolation', [BILINEAR, BICUBIC]) +@pytest.mark.parametrize("interpolation", [BILINEAR, BICUBIC]) def test_assert_resize_antialias(interpolation): # Checks implementation on very large scales @@ -597,10 +649,10 @@ def test_assert_resize_antialias(interpolation): F.resize(tensor, size=(5, 5), interpolation=interpolation, antialias=True) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('dt', [torch.float32, torch.float64, torch.float16]) -@pytest.mark.parametrize('size', [[10, 7], [10, 42], [42, 7]]) -@pytest.mark.parametrize('interpolation', [BILINEAR, BICUBIC]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("dt", [torch.float32, torch.float64, torch.float16]) +@pytest.mark.parametrize("size", [[10, 7], [10, 42], [42, 7]]) +@pytest.mark.parametrize("interpolation", [BILINEAR, BICUBIC]) def test_interpolate_antialias_backward(device, dt, size, interpolation): if dt == torch.float16 and device == "cpu": @@ -616,7 +668,6 @@ def test_interpolate_antialias_backward(device, dt, size, interpolation): backward_op = torch.ops.torchvision._interpolate_bicubic2d_aa_backward class F(torch.autograd.Function): - @staticmethod def forward(ctx, i): result = forward_op(i, size, False) @@ -630,14 +681,10 @@ def backward(ctx, grad_output): oshape = result.shape[2:] return backward_op(grad_output, oshape, ishape, False) - x = ( - torch.rand(1, 32, 29, 3, dtype=torch.double, device=device).permute(0, 3, 1, 2).requires_grad_(True), - ) + x = (torch.rand(1, 32, 29, 3, dtype=torch.double, device=device).permute(0, 3, 1, 2).requires_grad_(True),) assert torch.autograd.gradcheck(F.apply, x, eps=1e-8, atol=1e-6, rtol=1e-6, fast_mode=False) - x = ( - torch.rand(1, 3, 32, 29, dtype=torch.double, device=device, requires_grad=True), - ) + x = (torch.rand(1, 3, 32, 29, dtype=torch.double, device=device, requires_grad=True),) assert torch.autograd.gradcheck(F.apply, x, eps=1e-8, atol=1e-6, rtol=1e-6, fast_mode=False) @@ -678,10 +725,10 @@ def check_functional_vs_PIL_vs_scripted( _test_fn_on_batch(batch_tensors, fn, scripted_fn_atol=atol, **config) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64)) -@pytest.mark.parametrize('config', [{"brightness_factor": f} for f in (0.1, 0.5, 1.0, 1.34, 2.5)]) -@pytest.mark.parametrize('channels', [1, 3]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64)) +@pytest.mark.parametrize("config", [{"brightness_factor": f} for f in (0.1, 0.5, 1.0, 1.34, 2.5)]) +@pytest.mark.parametrize("channels", [1, 3]) def test_adjust_brightness(device, dtype, config, channels): check_functional_vs_PIL_vs_scripted( F.adjust_brightness, @@ -694,26 +741,18 @@ def test_adjust_brightness(device, dtype, config, channels): ) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64)) -@pytest.mark.parametrize('channels', [1, 3]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64)) +@pytest.mark.parametrize("channels", [1, 3]) def test_invert(device, dtype, channels): check_functional_vs_PIL_vs_scripted( - F.invert, - F_pil.invert, - F_t.invert, - {}, - device, - dtype, - channels, - tol=1.0, - agg_method="max" + F.invert, F_pil.invert, F_t.invert, {}, device, dtype, channels, tol=1.0, agg_method="max" ) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('config', [{"bits": bits} for bits in range(0, 8)]) -@pytest.mark.parametrize('channels', [1, 3]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("config", [{"bits": bits} for bits in range(0, 8)]) +@pytest.mark.parametrize("channels", [1, 3]) def test_posterize(device, config, channels): check_functional_vs_PIL_vs_scripted( F.posterize, @@ -728,9 +767,9 @@ def test_posterize(device, config, channels): ) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('config', [{"threshold": threshold} for threshold in [0, 64, 128, 192, 255]]) -@pytest.mark.parametrize('channels', [1, 3]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("config", [{"threshold": threshold} for threshold in [0, 64, 128, 192, 255]]) +@pytest.mark.parametrize("channels", [1, 3]) def test_solarize1(device, config, channels): check_functional_vs_PIL_vs_scripted( F.solarize, @@ -745,10 +784,10 @@ def test_solarize1(device, config, channels): ) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('dtype', (torch.float32, torch.float64)) -@pytest.mark.parametrize('config', [{"threshold": threshold} for threshold in [0.0, 0.25, 0.5, 0.75, 1.0]]) -@pytest.mark.parametrize('channels', [1, 3]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("dtype", (torch.float32, torch.float64)) +@pytest.mark.parametrize("config", [{"threshold": threshold} for threshold in [0.0, 0.25, 0.5, 0.75, 1.0]]) +@pytest.mark.parametrize("channels", [1, 3]) def test_solarize2(device, dtype, config, channels): check_functional_vs_PIL_vs_scripted( F.solarize, @@ -763,10 +802,10 @@ def test_solarize2(device, dtype, config, channels): ) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64)) -@pytest.mark.parametrize('config', [{"sharpness_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]]) -@pytest.mark.parametrize('channels', [1, 3]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64)) +@pytest.mark.parametrize("config", [{"sharpness_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]]) +@pytest.mark.parametrize("channels", [1, 3]) def test_adjust_sharpness(device, dtype, config, channels): check_functional_vs_PIL_vs_scripted( F.adjust_sharpness, @@ -779,25 +818,17 @@ def test_adjust_sharpness(device, dtype, config, channels): ) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64)) -@pytest.mark.parametrize('channels', [1, 3]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64)) +@pytest.mark.parametrize("channels", [1, 3]) def test_autocontrast(device, dtype, channels): check_functional_vs_PIL_vs_scripted( - F.autocontrast, - F_pil.autocontrast, - F_t.autocontrast, - {}, - device, - dtype, - channels, - tol=1.0, - agg_method="max" + F.autocontrast, F_pil.autocontrast, F_t.autocontrast, {}, device, dtype, channels, tol=1.0, agg_method="max" ) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('channels', [1, 3]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("channels", [1, 3]) def test_equalize(device, channels): torch.use_deterministic_algorithms(False) check_functional_vs_PIL_vs_scripted( @@ -813,60 +844,40 @@ def test_equalize(device, channels): ) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64)) -@pytest.mark.parametrize('config', [{"contrast_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]]) -@pytest.mark.parametrize('channels', [1, 3]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64)) +@pytest.mark.parametrize("config", [{"contrast_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]]) +@pytest.mark.parametrize("channels", [1, 3]) def test_adjust_contrast(device, dtype, config, channels): check_functional_vs_PIL_vs_scripted( - F.adjust_contrast, - F_pil.adjust_contrast, - F_t.adjust_contrast, - config, - device, - dtype, - channels + F.adjust_contrast, F_pil.adjust_contrast, F_t.adjust_contrast, config, device, dtype, channels ) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64)) -@pytest.mark.parametrize('config', [{"saturation_factor": f} for f in [0.5, 0.75, 1.0, 1.5, 2.0]]) -@pytest.mark.parametrize('channels', [1, 3]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64)) +@pytest.mark.parametrize("config", [{"saturation_factor": f} for f in [0.5, 0.75, 1.0, 1.5, 2.0]]) +@pytest.mark.parametrize("channels", [1, 3]) def test_adjust_saturation(device, dtype, config, channels): check_functional_vs_PIL_vs_scripted( - F.adjust_saturation, - F_pil.adjust_saturation, - F_t.adjust_saturation, - config, - device, - dtype, - channels + F.adjust_saturation, F_pil.adjust_saturation, F_t.adjust_saturation, config, device, dtype, channels ) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64)) -@pytest.mark.parametrize('config', [{"hue_factor": f} for f in [-0.45, -0.25, 0.0, 0.25, 0.45]]) -@pytest.mark.parametrize('channels', [1, 3]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64)) +@pytest.mark.parametrize("config", [{"hue_factor": f} for f in [-0.45, -0.25, 0.0, 0.25, 0.45]]) +@pytest.mark.parametrize("channels", [1, 3]) def test_adjust_hue(device, dtype, config, channels): check_functional_vs_PIL_vs_scripted( - F.adjust_hue, - F_pil.adjust_hue, - F_t.adjust_hue, - config, - device, - dtype, - channels, - tol=16.1, - agg_method="max" + F.adjust_hue, F_pil.adjust_hue, F_t.adjust_hue, config, device, dtype, channels, tol=16.1, agg_method="max" ) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64)) -@pytest.mark.parametrize('config', [{"gamma": g1, "gain": g2} for g1, g2 in zip([0.8, 1.0, 1.2], [0.7, 1.0, 1.3])]) -@pytest.mark.parametrize('channels', [1, 3]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("dtype", (None, torch.float32, torch.float64)) +@pytest.mark.parametrize("config", [{"gamma": g1, "gain": g2} for g1, g2 in zip([0.8, 1.0, 1.2], [0.7, 1.0, 1.3])]) +@pytest.mark.parametrize("channels", [1, 3]) def test_adjust_gamma(device, dtype, config, channels): check_functional_vs_PIL_vs_scripted( F.adjust_gamma, @@ -879,17 +890,31 @@ def test_adjust_gamma(device, dtype, config, channels): ) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('dt', [None, torch.float32, torch.float64, torch.float16]) -@pytest.mark.parametrize('pad', [2, [3, ], [0, 3], (3, 3), [4, 2, 4, 3]]) -@pytest.mark.parametrize('config', [ - {"padding_mode": "constant", "fill": 0}, - {"padding_mode": "constant", "fill": 10}, - {"padding_mode": "constant", "fill": 20}, - {"padding_mode": "edge"}, - {"padding_mode": "reflect"}, - {"padding_mode": "symmetric"}, -]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16]) +@pytest.mark.parametrize( + "pad", + [ + 2, + [ + 3, + ], + [0, 3], + (3, 3), + [4, 2, 4, 3], + ], +) +@pytest.mark.parametrize( + "config", + [ + {"padding_mode": "constant", "fill": 0}, + {"padding_mode": "constant", "fill": 10}, + {"padding_mode": "constant", "fill": 20}, + {"padding_mode": "edge"}, + {"padding_mode": "reflect"}, + {"padding_mode": "symmetric"}, + ], +) def test_pad(device, dt, pad, config): script_fn = torch.jit.script(F.pad) tensor, pil_img = _create_data(7, 8, device=device) @@ -915,7 +940,9 @@ def test_pad(device, dt, pad, config): _assert_equal_tensor_to_pil(pad_tensor_8b, pad_pil_img, msg="{}, {}".format(pad, config)) if isinstance(pad, int): - script_pad = [pad, ] + script_pad = [ + pad, + ] else: script_pad = pad pad_tensor_script = script_fn(tensor, script_pad, **config) @@ -924,8 +951,8 @@ def test_pad(device, dt, pad, config): _test_fn_on_batch(batch_tensors, F.pad, padding=script_pad, **config) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('mode', [NEAREST, BILINEAR, BICUBIC]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("mode", [NEAREST, BILINEAR, BICUBIC]) def test_resized_crop(device, mode): # test values of F.resized_crop in several cases: # 1) resize to the same size, crop to the same size => should be identity @@ -950,19 +977,46 @@ def test_resized_crop(device, mode): ) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('func, args', [ - (F_t.get_image_size, ()), (F_t.vflip, ()), - (F_t.hflip, ()), (F_t.crop, (1, 2, 4, 5)), - (F_t.adjust_brightness, (0., )), (F_t.adjust_contrast, (1., )), - (F_t.adjust_hue, (-0.5, )), (F_t.adjust_saturation, (2., )), - (F_t.pad, ([2, ], 2, "constant")), - (F_t.resize, ([10, 11], )), (F_t.perspective, ([0.2, ])), - (F_t.gaussian_blur, ((2, 2), (0.7, 0.5))), - (F_t.invert, ()), (F_t.posterize, (0, )), - (F_t.solarize, (0.3, )), (F_t.adjust_sharpness, (0.3, )), - (F_t.autocontrast, ()), (F_t.equalize, ()) -]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize( + "func, args", + [ + (F_t.get_image_size, ()), + (F_t.vflip, ()), + (F_t.hflip, ()), + (F_t.crop, (1, 2, 4, 5)), + (F_t.adjust_brightness, (0.0,)), + (F_t.adjust_contrast, (1.0,)), + (F_t.adjust_hue, (-0.5,)), + (F_t.adjust_saturation, (2.0,)), + ( + F_t.pad, + ( + [ + 2, + ], + 2, + "constant", + ), + ), + (F_t.resize, ([10, 11],)), + ( + F_t.perspective, + ( + [ + 0.2, + ] + ), + ), + (F_t.gaussian_blur, ((2, 2), (0.7, 0.5))), + (F_t.invert, ()), + (F_t.posterize, (0,)), + (F_t.solarize, (0.3,)), + (F_t.adjust_sharpness, (0.3,)), + (F_t.autocontrast, ()), + (F_t.equalize, ()), + ], +) def test_assert_image_tensor(device, func, args): shape = (100,) tensor = torch.rand(*shape, dtype=torch.float, device=device) @@ -970,7 +1024,7 @@ def test_assert_image_tensor(device, func, args): func(tensor, *args) -@pytest.mark.parametrize('device', cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_gpu()) def test_vflip(device): script_vflip = torch.jit.script(F.vflip) @@ -987,7 +1041,7 @@ def test_vflip(device): _test_fn_on_batch(batch_tensors, F.vflip) -@pytest.mark.parametrize('device', cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_gpu()) def test_hflip(device): script_hflip = torch.jit.script(F.hflip) @@ -1004,13 +1058,16 @@ def test_hflip(device): _test_fn_on_batch(batch_tensors, F.hflip) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('top, left, height, width', [ - (1, 2, 4, 5), # crop inside top-left corner - (2, 12, 3, 4), # crop inside top-right corner - (8, 3, 5, 6), # crop inside bottom-left corner - (8, 11, 4, 3), # crop inside bottom-right corner -]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize( + "top, left, height, width", + [ + (1, 2, 4, 5), # crop inside top-left corner + (2, 12, 3, 4), # crop inside top-right corner + (8, 3, 5, 6), # crop inside bottom-left corner + (8, 11, 4, 3), # crop inside bottom-right corner + ], +) def test_crop(device, top, left, height, width): script_crop = torch.jit.script(F.crop) @@ -1028,12 +1085,12 @@ def test_crop(device, top, left, height, width): _test_fn_on_batch(batch_tensors, F.crop, top=top, left=left, height=height, width=width) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('image_size', ('small', 'large')) -@pytest.mark.parametrize('dt', [None, torch.float32, torch.float64, torch.float16]) -@pytest.mark.parametrize('ksize', [(3, 3), [3, 5], (23, 23)]) -@pytest.mark.parametrize('sigma', [[0.5, 0.5], (0.5, 0.5), (0.8, 0.8), (1.7, 1.7)]) -@pytest.mark.parametrize('fn', [F.gaussian_blur, torch.jit.script(F.gaussian_blur)]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("image_size", ("small", "large")) +@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16]) +@pytest.mark.parametrize("ksize", [(3, 3), [3, 5], (23, 23)]) +@pytest.mark.parametrize("sigma", [[0.5, 0.5], (0.5, 0.5), (0.8, 0.8), (1.7, 1.7)]) +@pytest.mark.parametrize("fn", [F.gaussian_blur, torch.jit.script(F.gaussian_blur)]) def test_gaussian_blur(device, image_size, dt, ksize, sigma, fn): # true_cv2_results = { @@ -1050,17 +1107,15 @@ def test_gaussian_blur(device, image_size, dt, ksize, sigma, fn): # # cv2.GaussianBlur(np_img2, ksize=(23, 23), sigmaX=1.7) # "23_23_1.7": ... # } - p = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'gaussian_blur_opencv_results.pt') + p = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "gaussian_blur_opencv_results.pt") true_cv2_results = torch.load(p) - if image_size == 'small': - tensor = torch.from_numpy( - np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3)) - ).permute(2, 0, 1).to(device) + if image_size == "small": + tensor = ( + torch.from_numpy(np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3))).permute(2, 0, 1).to(device) + ) else: - tensor = torch.from_numpy( - np.arange(26 * 28, dtype="uint8").reshape((1, 26, 28)) - ).to(device) + tensor = torch.from_numpy(np.arange(26 * 28, dtype="uint8").reshape((1, 26, 28))).to(device) if dt == torch.float16 and device == "cpu": # skip float16 on CPU case @@ -1072,22 +1127,19 @@ def test_gaussian_blur(device, image_size, dt, ksize, sigma, fn): _ksize = (ksize, ksize) if isinstance(ksize, int) else ksize _sigma = sigma[0] if sigma is not None else None shape = tensor.shape - gt_key = "{}_{}_{}__{}_{}_{}".format( - shape[-2], shape[-1], shape[-3], - _ksize[0], _ksize[1], _sigma - ) + gt_key = "{}_{}_{}__{}_{}_{}".format(shape[-2], shape[-1], shape[-3], _ksize[0], _ksize[1], _sigma) if gt_key not in true_cv2_results: return - true_out = torch.tensor( - true_cv2_results[gt_key] - ).reshape(shape[-2], shape[-1], shape[-3]).permute(2, 0, 1).to(tensor) + true_out = ( + torch.tensor(true_cv2_results[gt_key]).reshape(shape[-2], shape[-1], shape[-3]).permute(2, 0, 1).to(tensor) + ) out = fn(tensor, kernel_size=ksize, sigma=sigma) torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg="{}, {}".format(ksize, sigma)) -@pytest.mark.parametrize('device', cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_gpu()) def test_hsv2rgb(device): scripted_fn = torch.jit.script(F_t._hsv2rgb) shape = (3, 100, 150) @@ -1096,7 +1148,11 @@ def test_hsv2rgb(device): rgb_img = F_t._hsv2rgb(hsv_img) ft_img = rgb_img.permute(1, 2, 0).flatten(0, 1) - h, s, v, = hsv_img.unbind(0) + ( + h, + s, + v, + ) = hsv_img.unbind(0) h = h.flatten().cpu().numpy() s = s.flatten().cpu().numpy() v = v.flatten().cpu().numpy() @@ -1114,7 +1170,7 @@ def test_hsv2rgb(device): _test_fn_on_batch(batch_tensors, F_t._hsv2rgb) -@pytest.mark.parametrize('device', cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_gpu()) def test_rgb2hsv(device): scripted_fn = torch.jit.script(F_t._rgb2hsv) shape = (3, 150, 100) @@ -1123,7 +1179,11 @@ def test_rgb2hsv(device): hsv_img = F_t._rgb2hsv(rgb_img) ft_hsv_img = hsv_img.permute(1, 2, 0).flatten(0, 1) - r, g, b, = rgb_img.unbind(dim=-3) + ( + r, + g, + b, + ) = rgb_img.unbind(dim=-3) r = r.flatten().cpu().numpy() g = g.flatten().cpu().numpy() b = b.flatten().cpu().numpy() @@ -1149,8 +1209,8 @@ def test_rgb2hsv(device): _test_fn_on_batch(batch_tensors, F_t._rgb2hsv) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('num_output_channels', (3, 1)) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("num_output_channels", (3, 1)) def test_rgb_to_grayscale(device, num_output_channels): script_rgb_to_grayscale = torch.jit.script(F.rgb_to_grayscale) @@ -1168,7 +1228,7 @@ def test_rgb_to_grayscale(device, num_output_channels): _test_fn_on_batch(batch_tensors, F.rgb_to_grayscale, num_output_channels=num_output_channels) -@pytest.mark.parametrize('device', cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_gpu()) def test_center_crop(device): script_center_crop = torch.jit.script(F.center_crop) @@ -1186,7 +1246,7 @@ def test_center_crop(device): _test_fn_on_batch(batch_tensors, F.center_crop, output_size=[10, 11]) -@pytest.mark.parametrize('device', cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_gpu()) def test_five_crop(device): script_five_crop = torch.jit.script(F.five_crop) @@ -1220,7 +1280,7 @@ def test_five_crop(device): assert_equal(transformed_batch, s_transformed_batch) -@pytest.mark.parametrize('device', cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_gpu()) def test_ten_crop(device): script_ten_crop = torch.jit.script(F.ten_crop) @@ -1254,5 +1314,5 @@ def test_ten_crop(device): assert_equal(transformed_batch, s_transformed_batch) -if __name__ == '__main__': +if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/test_hub.py b/test/test_hub.py index 9c9e417933e..5c791bf9d7a 100644 --- a/test/test_hub.py +++ b/test/test_hub.py @@ -1,9 +1,10 @@ -import torch.hub as hub -import tempfile -import shutil import os +import shutil import sys +import tempfile + import pytest +import torch.hub as hub def sum_of_model_parameters(model): @@ -16,8 +17,7 @@ def sum_of_model_parameters(model): SUM_OF_PRETRAINED_RESNET18_PARAMS = -12703.9931640625 -@pytest.mark.skipif('torchvision' in sys.modules, - reason='TestHub must start without torchvision imported') +@pytest.mark.skipif("torchvision" in sys.modules, reason="TestHub must start without torchvision imported") class TestHub: # Only run this check ONCE before all tests start. # - If torchvision is imported before all tests start, e.g. we might find _C.so @@ -26,28 +26,20 @@ class TestHub: # Python cache as we run all hub tests in the same python process. def test_load_from_github(self): - hub_model = hub.load( - 'pytorch/vision', - 'resnet18', - pretrained=True, - progress=False) + hub_model = hub.load("pytorch/vision", "resnet18", pretrained=True, progress=False) assert sum_of_model_parameters(hub_model).item() == pytest.approx(SUM_OF_PRETRAINED_RESNET18_PARAMS) def test_set_dir(self): temp_dir = tempfile.gettempdir() hub.set_dir(temp_dir) - hub_model = hub.load( - 'pytorch/vision', - 'resnet18', - pretrained=True, - progress=False) + hub_model = hub.load("pytorch/vision", "resnet18", pretrained=True, progress=False) assert sum_of_model_parameters(hub_model).item() == pytest.approx(SUM_OF_PRETRAINED_RESNET18_PARAMS) - assert os.path.exists(temp_dir + '/pytorch_vision_master') - shutil.rmtree(temp_dir + '/pytorch_vision_master') + assert os.path.exists(temp_dir + "/pytorch_vision_master") + shutil.rmtree(temp_dir + "/pytorch_vision_master") def test_list_entrypoints(self): - entry_lists = hub.list('pytorch/vision', force_reload=True) - assert 'resnet18' in entry_lists + entry_lists = hub.list("pytorch/vision", force_reload=True) + assert "resnet18" in entry_lists if __name__ == "__main__": diff --git a/test/test_image.py b/test/test_image.py index 5630d5d8226..35ec677ba5c 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -4,25 +4,34 @@ import sys from pathlib import Path -import pytest import numpy as np +import pytest import torch -from PIL import Image, __version__ as PILLOW_VERSION import torchvision.transforms.functional as F from common_utils import needs_cuda, assert_equal - +from PIL import Image, __version__ as PILLOW_VERSION from torchvision.io.image import ( - decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file, - encode_png, write_png, write_file, ImageReadMode, read_image) + decode_png, + decode_jpeg, + encode_jpeg, + write_jpeg, + decode_image, + read_file, + encode_png, + write_png, + write_file, + ImageReadMode, + read_image, +) IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata") IMAGE_DIR = os.path.join(FAKEDATA_DIR, "imagefolder") -DAMAGED_JPEG = os.path.join(IMAGE_ROOT, 'damaged_jpeg') +DAMAGED_JPEG = os.path.join(IMAGE_ROOT, "damaged_jpeg") ENCODE_JPEG = os.path.join(IMAGE_ROOT, "encode_jpeg") INTERLACED_PNG = os.path.join(IMAGE_ROOT, "interlaced_png") -IS_WINDOWS = sys.platform in ('win32', 'cygwin') -PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split('.')) +IS_WINDOWS = sys.platform in ("win32", "cygwin") +PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split(".")) def _get_safe_image_name(name): @@ -35,9 +44,9 @@ def _get_safe_image_name(name): def get_images(directory, img_ext): assert os.path.isdir(directory) - image_paths = glob.glob(directory + f'/**/*{img_ext}', recursive=True) + image_paths = glob.glob(directory + f"/**/*{img_ext}", recursive=True) for path in image_paths: - if path.split(os.sep)[-2] not in ['damaged_jpeg', 'jpeg_write']: + if path.split(os.sep)[-2] not in ["damaged_jpeg", "jpeg_write"]: yield path @@ -54,15 +63,18 @@ def normalize_dimensions(img_pil): return img_pil -@pytest.mark.parametrize('img_path', [ - pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) - for jpeg_path in get_images(IMAGE_ROOT, ".jpg") -]) -@pytest.mark.parametrize('pil_mode, mode', [ - (None, ImageReadMode.UNCHANGED), - ("L", ImageReadMode.GRAY), - ("RGB", ImageReadMode.RGB), -]) +@pytest.mark.parametrize( + "img_path", + [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg")], +) +@pytest.mark.parametrize( + "pil_mode, mode", + [ + (None, ImageReadMode.UNCHANGED), + ("L", ImageReadMode.GRAY), + ("RGB", ImageReadMode.RGB), + ], +) def test_decode_jpeg(img_path, pil_mode, mode): with Image.open(img_path) as img: @@ -100,18 +112,21 @@ def test_decode_jpeg_errors(): def test_decode_bad_huffman_images(): # sanity check: make sure we can decode the bad Huffman encoding - bad_huff = read_file(os.path.join(DAMAGED_JPEG, 'bad_huffman.jpg')) + bad_huff = read_file(os.path.join(DAMAGED_JPEG, "bad_huffman.jpg")) decode_jpeg(bad_huff) -@pytest.mark.parametrize('img_path', [ - pytest.param(truncated_image, id=_get_safe_image_name(truncated_image)) - for truncated_image in glob.glob(os.path.join(DAMAGED_JPEG, 'corrupt*.jpg')) -]) +@pytest.mark.parametrize( + "img_path", + [ + pytest.param(truncated_image, id=_get_safe_image_name(truncated_image)) + for truncated_image in glob.glob(os.path.join(DAMAGED_JPEG, "corrupt*.jpg")) + ], +) def test_damaged_corrupt_images(img_path): # Truncated images should raise an exception data = read_file(img_path) - if 'corrupt34' in img_path: + if "corrupt34" in img_path: match_message = "Image is incomplete or truncated" else: match_message = "Unsupported marker type" @@ -119,17 +134,20 @@ def test_damaged_corrupt_images(img_path): decode_jpeg(data) -@pytest.mark.parametrize('img_path', [ - pytest.param(png_path, id=_get_safe_image_name(png_path)) - for png_path in get_images(FAKEDATA_DIR, ".png") -]) -@pytest.mark.parametrize('pil_mode, mode', [ - (None, ImageReadMode.UNCHANGED), - ("L", ImageReadMode.GRAY), - ("LA", ImageReadMode.GRAY_ALPHA), - ("RGB", ImageReadMode.RGB), - ("RGBA", ImageReadMode.RGB_ALPHA), -]) +@pytest.mark.parametrize( + "img_path", + [pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(FAKEDATA_DIR, ".png")], +) +@pytest.mark.parametrize( + "pil_mode, mode", + [ + (None, ImageReadMode.UNCHANGED), + ("L", ImageReadMode.GRAY), + ("LA", ImageReadMode.GRAY_ALPHA), + ("RGB", ImageReadMode.RGB), + ("RGBA", ImageReadMode.RGB_ALPHA), + ], +) def test_decode_png(img_path, pil_mode, mode): with Image.open(img_path) as img: @@ -160,10 +178,10 @@ def test_decode_png_errors(): decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8)) -@pytest.mark.parametrize('img_path', [ - pytest.param(png_path, id=_get_safe_image_name(png_path)) - for png_path in get_images(IMAGE_DIR, ".png") -]) +@pytest.mark.parametrize( + "img_path", + [pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(IMAGE_DIR, ".png")], +) def test_encode_png(img_path): pil_image = Image.open(img_path) img_pil = torch.from_numpy(np.array(pil_image)) @@ -182,28 +200,26 @@ def test_encode_png_errors(): encode_png(torch.empty((3, 100, 100), dtype=torch.float32)) with pytest.raises(RuntimeError, match="Compression level should be between 0 and 9"): - encode_png(torch.empty((3, 100, 100), dtype=torch.uint8), - compression_level=-1) + encode_png(torch.empty((3, 100, 100), dtype=torch.uint8), compression_level=-1) with pytest.raises(RuntimeError, match="Compression level should be between 0 and 9"): - encode_png(torch.empty((3, 100, 100), dtype=torch.uint8), - compression_level=10) + encode_png(torch.empty((3, 100, 100), dtype=torch.uint8), compression_level=10) with pytest.raises(RuntimeError, match="The number of channels should be 1 or 3, got: 5"): encode_png(torch.empty((5, 100, 100), dtype=torch.uint8)) -@pytest.mark.parametrize('img_path', [ - pytest.param(png_path, id=_get_safe_image_name(png_path)) - for png_path in get_images(IMAGE_DIR, ".png") -]) +@pytest.mark.parametrize( + "img_path", + [pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(IMAGE_DIR, ".png")], +) def test_write_png(img_path, tmpdir): pil_image = Image.open(img_path) img_pil = torch.from_numpy(np.array(pil_image)) img_pil = img_pil.permute(2, 0, 1) filename, _ = os.path.splitext(os.path.basename(img_path)) - torch_png = os.path.join(tmpdir, '{0}_torch.png'.format(filename)) + torch_png = os.path.join(tmpdir, "{0}_torch.png".format(filename)) write_png(img_pil, torch_png, compression_level=6) saved_image = torch.from_numpy(np.array(Image.open(torch_png))) saved_image = saved_image.permute(2, 0, 1) @@ -212,9 +228,9 @@ def test_write_png(img_path, tmpdir): def test_read_file(tmpdir): - fname, content = 'test1.bin', b'TorchVision\211\n' + fname, content = "test1.bin", b"TorchVision\211\n" fpath = os.path.join(tmpdir, fname) - with open(fpath, 'wb') as f: + with open(fpath, "wb") as f: f.write(content) data = read_file(fpath) @@ -223,13 +239,13 @@ def test_read_file(tmpdir): assert_equal(data, expected) with pytest.raises(RuntimeError, match="No such file or directory: 'tst'"): - read_file('tst') + read_file("tst") def test_read_file_non_ascii(tmpdir): - fname, content = '日本語(Japanese).bin', b'TorchVision\211\n' + fname, content = "日本語(Japanese).bin", b"TorchVision\211\n" fpath = os.path.join(tmpdir, fname) - with open(fpath, 'wb') as f: + with open(fpath, "wb") as f: f.write(content) data = read_file(fpath) @@ -239,37 +255,40 @@ def test_read_file_non_ascii(tmpdir): def test_write_file(tmpdir): - fname, content = 'test1.bin', b'TorchVision\211\n' + fname, content = "test1.bin", b"TorchVision\211\n" fpath = os.path.join(tmpdir, fname) content_tensor = torch.tensor(list(content), dtype=torch.uint8) write_file(fpath, content_tensor) - with open(fpath, 'rb') as f: + with open(fpath, "rb") as f: saved_content = f.read() os.unlink(fpath) assert content == saved_content def test_write_file_non_ascii(tmpdir): - fname, content = '日本語(Japanese).bin', b'TorchVision\211\n' + fname, content = "日本語(Japanese).bin", b"TorchVision\211\n" fpath = os.path.join(tmpdir, fname) content_tensor = torch.tensor(list(content), dtype=torch.uint8) write_file(fpath, content_tensor) - with open(fpath, 'rb') as f: + with open(fpath, "rb") as f: saved_content = f.read() os.unlink(fpath) assert content == saved_content -@pytest.mark.parametrize('shape', [ - (27, 27), - (60, 60), - (105, 105), -]) +@pytest.mark.parametrize( + "shape", + [ + (27, 27), + (60, 60), + (105, 105), + ], +) def test_read_1_bit_png(shape, tmpdir): np_rng = np.random.RandomState(0) - image_path = os.path.join(tmpdir, f'test_{shape}.png') + image_path = os.path.join(tmpdir, f"test_{shape}.png") pixels = np_rng.rand(*shape) > 0.5 img = Image.fromarray(pixels) img.save(image_path) @@ -278,18 +297,24 @@ def test_read_1_bit_png(shape, tmpdir): assert_equal(img1, img2) -@pytest.mark.parametrize('shape', [ - (27, 27), - (60, 60), - (105, 105), -]) -@pytest.mark.parametrize('mode', [ - ImageReadMode.UNCHANGED, - ImageReadMode.GRAY, -]) +@pytest.mark.parametrize( + "shape", + [ + (27, 27), + (60, 60), + (105, 105), + ], +) +@pytest.mark.parametrize( + "mode", + [ + ImageReadMode.UNCHANGED, + ImageReadMode.GRAY, + ], +) def test_read_1_bit_png_consistency(shape, mode, tmpdir): np_rng = np.random.RandomState(0) - image_path = os.path.join(tmpdir, f'test_{shape}.png') + image_path = os.path.join(tmpdir, f"test_{shape}.png") pixels = np_rng.rand(*shape) > 0.5 img = Image.fromarray(pixels) img.save(image_path) @@ -308,30 +333,30 @@ def test_read_interlaced_png(): @needs_cuda -@pytest.mark.parametrize('img_path', [ - pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) - for jpeg_path in get_images(IMAGE_ROOT, ".jpg") -]) -@pytest.mark.parametrize('mode', [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB]) -@pytest.mark.parametrize('scripted', (False, True)) +@pytest.mark.parametrize( + "img_path", + [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg")], +) +@pytest.mark.parametrize("mode", [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB]) +@pytest.mark.parametrize("scripted", (False, True)) def test_decode_jpeg_cuda(mode, img_path, scripted): - if 'cmyk' in img_path: + if "cmyk" in img_path: pytest.xfail("Decoding a CMYK jpeg isn't supported") data = read_file(img_path) img = decode_image(data, mode=mode) f = torch.jit.script(decode_jpeg) if scripted else decode_jpeg - img_nvjpeg = f(data, mode=mode, device='cuda') + img_nvjpeg = f(data, mode=mode, device="cuda") # Some difference expected between jpeg implementations assert (img.float() - img_nvjpeg.cpu().float()).abs().mean() < 2 @needs_cuda -@pytest.mark.parametrize('cuda_device', ('cuda', 'cuda:0', torch.device('cuda'))) +@pytest.mark.parametrize("cuda_device", ("cuda", "cuda:0", torch.device("cuda"))) def test_decode_jpeg_cuda_device_param(cuda_device): """Make sure we can pass a string or a torch.device as device param""" - path = next(path for path in get_images(IMAGE_ROOT, ".jpg") if 'cmyk' not in path) + path = next(path for path in get_images(IMAGE_ROOT, ".jpg") if "cmyk" not in path) data = read_file(path) decode_jpeg(data, device=cuda_device) @@ -340,13 +365,13 @@ def test_decode_jpeg_cuda_device_param(cuda_device): def test_decode_jpeg_cuda_errors(): data = read_file(next(get_images(IMAGE_ROOT, ".jpg"))) with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"): - decode_jpeg(data.reshape(-1, 1), device='cuda') + decode_jpeg(data.reshape(-1, 1), device="cuda") with pytest.raises(RuntimeError, match="input tensor must be on CPU"): - decode_jpeg(data.to('cuda'), device='cuda') + decode_jpeg(data.to("cuda"), device="cuda") with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"): - decode_jpeg(data.to(torch.float), device='cuda') + decode_jpeg(data.to(torch.float), device="cuda") with pytest.raises(RuntimeError, match="Expected a cuda device"): - torch.ops.image.decode_jpeg_cuda(data, ImageReadMode.UNCHANGED.value, 'cpu') + torch.ops.image.decode_jpeg_cuda(data, ImageReadMode.UNCHANGED.value, "cpu") def test_encode_jpeg_errors(): @@ -354,12 +379,10 @@ def test_encode_jpeg_errors(): with pytest.raises(RuntimeError, match="Input tensor dtype should be uint8"): encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32)) - with pytest.raises(ValueError, match="Image quality should be a positive number " - "between 1 and 100"): + with pytest.raises(ValueError, match="Image quality should be a positive number " "between 1 and 100"): encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1) - with pytest.raises(ValueError, match="Image quality should be a positive number " - "between 1 and 100"): + with pytest.raises(ValueError, match="Image quality should be a positive number " "between 1 and 100"): encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101) with pytest.raises(RuntimeError, match="The number of channels should be 1 or 3, got: 5"): @@ -380,14 +403,15 @@ def _inner(test_func): return test_func else: return pytest.mark.dont_collect(test_func) + return _inner @_collect_if(cond=IS_WINDOWS) -@pytest.mark.parametrize('img_path', [ - pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) - for jpeg_path in get_images(ENCODE_JPEG, ".jpg") -]) +@pytest.mark.parametrize( + "img_path", + [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")], +) def test_encode_jpeg_reference(img_path): # This test is *wrong*. # It compares a torchvision-encoded jpeg with a PIL-encoded jpeg (the reference), but it @@ -401,12 +425,11 @@ def test_encode_jpeg_reference(img_path): # FIXME: make the correct tests pass on windows and remove this. dirname = os.path.dirname(img_path) filename, _ = os.path.splitext(os.path.basename(img_path)) - write_folder = os.path.join(dirname, 'jpeg_write') - expected_file = os.path.join( - write_folder, '{0}_pil.jpg'.format(filename)) + write_folder = os.path.join(dirname, "jpeg_write") + expected_file = os.path.join(write_folder, "{0}_pil.jpg".format(filename)) img = decode_jpeg(read_file(img_path)) - with open(expected_file, 'rb') as f: + with open(expected_file, "rb") as f: pil_bytes = f.read() pil_bytes = torch.as_tensor(list(pil_bytes), dtype=torch.uint8) for src_img in [img, img.contiguous()]: @@ -416,10 +439,10 @@ def test_encode_jpeg_reference(img_path): @_collect_if(cond=IS_WINDOWS) -@pytest.mark.parametrize('img_path', [ - pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) - for jpeg_path in get_images(ENCODE_JPEG, ".jpg") -]) +@pytest.mark.parametrize( + "img_path", + [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")], +) def test_write_jpeg_reference(img_path, tmpdir): # FIXME: Remove this eventually, see test_encode_jpeg_reference data = read_file(img_path) @@ -427,35 +450,31 @@ def test_write_jpeg_reference(img_path, tmpdir): basedir = os.path.dirname(img_path) filename, _ = os.path.splitext(os.path.basename(img_path)) - torch_jpeg = os.path.join( - tmpdir, '{0}_torch.jpg'.format(filename)) - pil_jpeg = os.path.join( - basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename)) + torch_jpeg = os.path.join(tmpdir, "{0}_torch.jpg".format(filename)) + pil_jpeg = os.path.join(basedir, "jpeg_write", "{0}_pil.jpg".format(filename)) write_jpeg(img, torch_jpeg, quality=75) - with open(torch_jpeg, 'rb') as f: + with open(torch_jpeg, "rb") as f: torch_bytes = f.read() - with open(pil_jpeg, 'rb') as f: + with open(pil_jpeg, "rb") as f: pil_bytes = f.read() assert_equal(torch_bytes, pil_bytes) -@pytest.mark.skipif(IS_WINDOWS, reason=( - 'this test fails on windows because PIL uses libjpeg-turbo on windows' -)) -@pytest.mark.parametrize('img_path', [ - pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) - for jpeg_path in get_images(ENCODE_JPEG, ".jpg") -]) +@pytest.mark.skipif(IS_WINDOWS, reason=("this test fails on windows because PIL uses libjpeg-turbo on windows")) +@pytest.mark.parametrize( + "img_path", + [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")], +) def test_encode_jpeg(img_path): img = read_image(img_path) pil_img = F.to_pil_image(img) buf = io.BytesIO() - pil_img.save(buf, format='JPEG', quality=75) + pil_img.save(buf, format="JPEG", quality=75) # pytorch can't read from raw bytes so we go through numpy pil_bytes = np.frombuffer(buf.getvalue(), dtype=np.uint8) @@ -466,28 +485,26 @@ def test_encode_jpeg(img_path): assert_equal(encoded_jpeg_torch, encoded_jpeg_pil) -@pytest.mark.skipif(IS_WINDOWS, reason=( - 'this test fails on windows because PIL uses libjpeg-turbo on windows' -)) -@pytest.mark.parametrize('img_path', [ - pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) - for jpeg_path in get_images(ENCODE_JPEG, ".jpg") -]) +@pytest.mark.skipif(IS_WINDOWS, reason=("this test fails on windows because PIL uses libjpeg-turbo on windows")) +@pytest.mark.parametrize( + "img_path", + [pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")], +) def test_write_jpeg(img_path, tmpdir): tmpdir = Path(tmpdir) img = read_image(img_path) pil_img = F.to_pil_image(img) - torch_jpeg = str(tmpdir / 'torch.jpg') - pil_jpeg = str(tmpdir / 'pil.jpg') + torch_jpeg = str(tmpdir / "torch.jpg") + pil_jpeg = str(tmpdir / "pil.jpg") write_jpeg(img, torch_jpeg, quality=75) pil_img.save(pil_jpeg, quality=75) - with open(torch_jpeg, 'rb') as f: + with open(torch_jpeg, "rb") as f: torch_bytes = f.read() - with open(pil_jpeg, 'rb') as f: + with open(pil_jpeg, "rb") as f: pil_bytes = f.read() assert_equal(torch_bytes, pil_bytes) diff --git a/test/test_internet.py b/test/test_internet.py index ec749daf9b4..1d63f38d5af 100644 --- a/test/test_internet.py +++ b/test/test_internet.py @@ -6,10 +6,10 @@ """ import os -import pytest import warnings from urllib.error import URLError +import pytest import torchvision.datasets.utils as utils @@ -42,11 +42,11 @@ def test_download_url_dispatch_download_from_google_drive(self, mocker, tmpdir): filename = "filename" md5 = "md5" - mocked = mocker.patch('torchvision.datasets.utils.download_file_from_google_drive') + mocked = mocker.patch("torchvision.datasets.utils.download_file_from_google_drive") utils.download_url(url, tmpdir, filename, md5) mocked.assert_called_once_with(id, tmpdir, filename, md5) -if __name__ == '__main__': +if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/test_io.py b/test/test_io.py index 150d66f0814..73356bb8092 100644 --- a/test/test_io.py +++ b/test/test_io.py @@ -1,19 +1,20 @@ -import pytest -import os import contextlib +import os import sys import tempfile -import torch -import torchvision.io as io -from torchvision import get_video_backend import warnings from urllib.error import URLError +import pytest +import torch +import torchvision.io as io from common_utils import assert_equal +from torchvision import get_video_backend try: import av + # Do a version test too io.video._check_av_available() except ImportError: @@ -42,29 +43,30 @@ def temp_video(num_frames, height, width, fps, lossless=False, video_codec=None, raise ValueError("video_codec can't be specified together with lossless") if options is not None: raise ValueError("options can't be specified together with lossless") - video_codec = 'libx264rgb' - options = {'crf': '0'} + video_codec = "libx264rgb" + options = {"crf": "0"} if video_codec is None: if get_video_backend() == "pyav": - video_codec = 'libx264' + video_codec = "libx264" else: # when video_codec is not set, we assume it is libx264rgb which accepts # RGB pixel formats as input instead of YUV - video_codec = 'libx264rgb' + video_codec = "libx264rgb" if options is None: options = {} data = _create_video_frames(num_frames, height, width) - with tempfile.NamedTemporaryFile(suffix='.mp4') as f: + with tempfile.NamedTemporaryFile(suffix=".mp4") as f: f.close() io.write_video(f.name, data, fps=fps, video_codec=video_codec, options=options) yield f.name, data os.unlink(f.name) -@pytest.mark.skipif(get_video_backend() != "pyav" and not io._HAS_VIDEO_OPT, - reason="video_reader backend not available") +@pytest.mark.skipif( + get_video_backend() != "pyav" and not io._HAS_VIDEO_OPT, reason="video_reader backend not available" +) @pytest.mark.skipif(av is None, reason="PyAV unavailable") class TestVideo: # compression adds artifacts, thus we add a tolerance of @@ -107,14 +109,14 @@ def test_read_timestamps(self): assert pts == expected_pts - @pytest.mark.parametrize('start', range(5)) - @pytest.mark.parametrize('offset', range(1, 4)) + @pytest.mark.parametrize("start", range(5)) + @pytest.mark.parametrize("offset", range(1, 4)) def test_read_partial_video(self, start, offset): with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data): pts, _ = io.read_video_timestamps(f_name) lv, _, _ = io.read_video(f_name, pts[start], pts[start + offset - 1]) - s_data = data[start:(start + offset)] + s_data = data[start : (start + offset)] assert len(lv) == offset assert_equal(s_data, lv) @@ -125,22 +127,22 @@ def test_read_partial_video(self, start, offset): assert len(lv) == 4 assert_equal(data[4:8], lv) - @pytest.mark.parametrize('start', range(0, 80, 20)) - @pytest.mark.parametrize('offset', range(1, 4)) + @pytest.mark.parametrize("start", range(0, 80, 20)) + @pytest.mark.parametrize("offset", range(1, 4)) def test_read_partial_video_bframes(self, start, offset): # do not use lossless encoding, to test the presence of B-frames - options = {'bframes': '16', 'keyint': '10', 'min-keyint': '4'} + options = {"bframes": "16", "keyint": "10", "min-keyint": "4"} with temp_video(100, 300, 300, 5, options=options) as (f_name, data): pts, _ = io.read_video_timestamps(f_name) lv, _, _ = io.read_video(f_name, pts[start], pts[start + offset - 1]) - s_data = data[start:(start + offset)] + s_data = data[start : (start + offset)] assert len(lv) == offset assert_equal(s_data, lv, rtol=0.0, atol=self.TOLERANCE) lv, _, _ = io.read_video(f_name, pts[4] + 1, pts[7]) # TODO fix this - if get_video_backend() == 'pyav': + if get_video_backend() == "pyav": assert len(lv) == 4 assert_equal(data[4:8], lv, rtol=0.0, atol=self.TOLERANCE) else: @@ -156,7 +158,7 @@ def test_read_packed_b_frames_divx_file(self): assert fps == 30 def test_read_timestamps_from_packet(self): - with temp_video(10, 300, 300, 5, video_codec='mpeg4') as (f_name, data): + with temp_video(10, 300, 300, 5, video_codec="mpeg4") as (f_name, data): pts, _ = io.read_video_timestamps(f_name) # note: not all formats/codecs provide accurate information for computing the # timestamps. For the format that we use here, this information is available, @@ -164,7 +166,7 @@ def test_read_timestamps_from_packet(self): with av.open(f_name) as container: stream = container.streams[0] # make sure we went through the optimized codepath - assert b'Lavc' in stream.codec_context.extradata + assert b"Lavc" in stream.codec_context.extradata pts_step = int(round(float(1 / (stream.average_rate * stream.time_base)))) num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration))) expected_pts = [i * pts_step for i in range(num_frames)] @@ -173,7 +175,7 @@ def test_read_timestamps_from_packet(self): def test_read_video_pts_unit_sec(self): with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data): - lv, _, info = io.read_video(f_name, pts_unit='sec') + lv, _, info = io.read_video(f_name, pts_unit="sec") assert_equal(data, lv) assert info["video_fps"] == 5 @@ -181,7 +183,7 @@ def test_read_video_pts_unit_sec(self): def test_read_timestamps_pts_unit_sec(self): with temp_video(10, 300, 300, 5) as (f_name, data): - pts, _ = io.read_video_timestamps(f_name, pts_unit='sec') + pts, _ = io.read_video_timestamps(f_name, pts_unit="sec") with av.open(f_name) as container: stream = container.streams[0] @@ -191,22 +193,22 @@ def test_read_timestamps_pts_unit_sec(self): assert pts == expected_pts - @pytest.mark.parametrize('start', range(5)) - @pytest.mark.parametrize('offset', range(1, 4)) + @pytest.mark.parametrize("start", range(5)) + @pytest.mark.parametrize("offset", range(1, 4)) def test_read_partial_video_pts_unit_sec(self, start, offset): with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data): - pts, _ = io.read_video_timestamps(f_name, pts_unit='sec') + pts, _ = io.read_video_timestamps(f_name, pts_unit="sec") - lv, _, _ = io.read_video(f_name, pts[start], pts[start + offset - 1], pts_unit='sec') - s_data = data[start:(start + offset)] + lv, _, _ = io.read_video(f_name, pts[start], pts[start + offset - 1], pts_unit="sec") + s_data = data[start : (start + offset)] assert len(lv) == offset assert_equal(s_data, lv) with av.open(f_name) as container: stream = container.streams[0] - lv, _, _ = io.read_video(f_name, - int(pts[4] * (1.0 / stream.time_base) + 1) * stream.time_base, pts[7], - pts_unit='sec') + lv, _, _ = io.read_video( + f_name, int(pts[4] * (1.0 / stream.time_base) + 1) * stream.time_base, pts[7], pts_unit="sec" + ) if get_video_backend() == "pyav": # for "video_reader" backend, we don't decode the closest early frame # when the given start pts is not matching any frame pts @@ -214,8 +216,8 @@ def test_read_partial_video_pts_unit_sec(self, start, offset): assert_equal(data[4:8], lv) def test_read_video_corrupted_file(self): - with tempfile.NamedTemporaryFile(suffix='.mp4') as f: - f.write(b'This is not an mpg4 file') + with tempfile.NamedTemporaryFile(suffix=".mp4") as f: + f.write(b"This is not an mpg4 file") video, audio, info = io.read_video(f.name) assert isinstance(video, torch.Tensor) assert isinstance(audio, torch.Tensor) @@ -224,8 +226,8 @@ def test_read_video_corrupted_file(self): assert info == {} def test_read_video_timestamps_corrupted_file(self): - with tempfile.NamedTemporaryFile(suffix='.mp4') as f: - f.write(b'This is not an mpg4 file') + with tempfile.NamedTemporaryFile(suffix=".mp4") as f: + f.write(b"This is not an mpg4 file") video_pts, video_fps = io.read_video_timestamps(f.name) assert video_pts == [] assert video_fps is None @@ -233,18 +235,18 @@ def test_read_video_timestamps_corrupted_file(self): @pytest.mark.skip(reason="Temporarily disabled due to new pyav") def test_read_video_partially_corrupted_file(self): with temp_video(5, 4, 4, 5, lossless=True) as (f_name, data): - with open(f_name, 'r+b') as f: + with open(f_name, "r+b") as f: size = os.path.getsize(f_name) bytes_to_overwrite = size // 10 # seek to the middle of the file f.seek(5 * bytes_to_overwrite) # corrupt 10% of the file from the middle - f.write(b'\xff' * bytes_to_overwrite) + f.write(b"\xff" * bytes_to_overwrite) # this exercises the container.decode assertion check - video, audio, info = io.read_video(f.name, pts_unit='sec') + video, audio, info = io.read_video(f.name, pts_unit="sec") # check that size is not equal to 5, but 3 # TODO fix this - if get_video_backend() == 'pyav': + if get_video_backend() == "pyav": assert len(video) == 3 else: assert len(video) == 4 @@ -254,7 +256,7 @@ def test_read_video_partially_corrupted_file(self): with pytest.raises(AssertionError): assert_equal(video, data) - @pytest.mark.skipif(sys.platform == 'win32', reason='temporarily disabled on Windows') + @pytest.mark.skipif(sys.platform == "win32", reason="temporarily disabled on Windows") def test_write_video_with_audio(self, tmpdir): f_name = os.path.join(VIDEO_DIR, "R6llTwEh07w.mp4") video_tensor, audio_tensor, info = io.read_video(f_name, pts_unit="sec") @@ -265,15 +267,13 @@ def test_write_video_with_audio(self, tmpdir): video_tensor, round(info["video_fps"]), video_codec="libx264rgb", - options={'crf': '0'}, + options={"crf": "0"}, audio_array=audio_tensor, audio_fps=info["audio_fps"], audio_codec="aac", ) - out_video_tensor, out_audio_tensor, out_info = io.read_video( - out_f_name, pts_unit="sec" - ) + out_video_tensor, out_audio_tensor, out_info = io.read_video(out_f_name, pts_unit="sec") assert info["video_fps"] == out_info["video_fps"] assert_equal(video_tensor, out_video_tensor) @@ -289,5 +289,5 @@ def test_write_video_with_audio(self, tmpdir): # TODO add tests for audio -if __name__ == '__main__': +if __name__ == "__main__": pytest.main(__file__) diff --git a/test/test_io_opt.py b/test/test_io_opt.py index 87698b34624..be0456665ec 100644 --- a/test/test_io_opt.py +++ b/test/test_io_opt.py @@ -1,12 +1,13 @@ import unittest -from torchvision import set_video_backend + import test_io +from torchvision import set_video_backend # Disabling the video backend switching temporarily # set_video_backend('video_reader') -if __name__ == '__main__': +if __name__ == "__main__": suite = unittest.TestLoader().loadTestsFromModule(test_io) unittest.TextTestRunner(verbosity=1).run(suite) diff --git a/test/test_models.py b/test/test_models.py index 9e376bedce5..a108c2633ec 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -1,22 +1,23 @@ -import os +import functools import io +import operator +import os import sys -from common_utils import map_nested_tensor_object, freeze_rng_state, set_rng_seed, cpu_and_gpu, needs_cuda -from _utils_internal import get_relative_path +import traceback +import warnings from collections import OrderedDict -import functools -import operator + +import pytest import torch import torch.fx import torch.nn as nn import torchvision +from _utils_internal import get_relative_path +from common_utils import map_nested_tensor_object, freeze_rng_state, set_rng_seed, cpu_and_gpu, needs_cuda from torchvision import models -import pytest -import warnings -import traceback -ACCEPT = os.getenv('EXPECTTEST_ACCEPT', '0') == '1' +ACCEPT = os.getenv("EXPECTTEST_ACCEPT", "0") == "1" def get_available_classification_models(): @@ -50,7 +51,7 @@ def _get_expected_file(name=None): # Note: for legacy reasons, the reference file names all had "ModelTest.test_" in their names # We hardcode it here to avoid having to re-generate the reference files - expected_file = expected_file = os.path.join(expected_file_base, 'ModelTester.test_' + name) + expected_file = expected_file = os.path.join(expected_file_base, "ModelTester.test_" + name) expected_file += "_expect.pkl" if not ACCEPT and not os.path.exists(expected_file): @@ -92,6 +93,7 @@ def _check_jit_scriptable(nn_module, args, unwrapper=None, skip=False): def assert_export_import_module(m, args): """Check that the results of a model are the same after saving and loading""" + def get_export_import_copy(m): """Save and load a TorchScript model""" buffer = io.BytesIO() @@ -115,15 +117,17 @@ def get_export_import_copy(m): if a is not None: torch.testing.assert_close(a, b, atol=tol, rtol=tol) - TEST_WITH_SLOW = os.getenv('PYTORCH_TEST_WITH_SLOW', '0') == '1' + TEST_WITH_SLOW = os.getenv("PYTORCH_TEST_WITH_SLOW", "0") == "1" if not TEST_WITH_SLOW or skip: # TorchScript is not enabled, skip these tests - msg = "The check_jit_scriptable test for {} was skipped. " \ - "This test checks if the module's results in TorchScript " \ - "match eager and that it can be exported. To run these " \ - "tests make sure you set the environment variable " \ - "PYTORCH_TEST_WITH_SLOW=1 and that the test is not " \ - "manually skipped.".format(nn_module.__class__.__name__) + msg = ( + "The check_jit_scriptable test for {} was skipped. " + "This test checks if the module's results in TorchScript " + "match eager and that it can be exported. To run these " + "tests make sure you set the environment variable " + "PYTORCH_TEST_WITH_SLOW=1 and that the test is not " + "manually skipped.".format(nn_module.__class__.__name__) + ) warnings.warn(msg, RuntimeWarning) return None @@ -181,8 +185,8 @@ def _check_input_backprop(model, inputs): # before they are compared to the eager model outputs. This is useful if the # model outputs are different between TorchScript / Eager mode script_model_unwrapper = { - 'googlenet': lambda x: x.logits, - 'inception_v3': lambda x: x.logits, + "googlenet": lambda x: x.logits, + "inception_v3": lambda x: x.logits, "fasterrcnn_resnet50_fpn": lambda x: x[1], "fasterrcnn_mobilenet_v3_large_fpn": lambda x: x[1], "fasterrcnn_mobilenet_v3_large_320_fpn": lambda x: x[1], @@ -221,43 +225,41 @@ def _check_input_backprop(model, inputs): # The following contains configuration parameters for all models which are used by # the _test_*_model methods. _model_params = { - 'inception_v3': { - 'input_shape': (1, 3, 299, 299) + "inception_v3": {"input_shape": (1, 3, 299, 299)}, + "retinanet_resnet50_fpn": { + "num_classes": 20, + "score_thresh": 0.01, + "min_size": 224, + "max_size": 224, + "input_shape": (3, 224, 224), }, - 'retinanet_resnet50_fpn': { - 'num_classes': 20, - 'score_thresh': 0.01, - 'min_size': 224, - 'max_size': 224, - 'input_shape': (3, 224, 224), + "keypointrcnn_resnet50_fpn": { + "num_classes": 2, + "min_size": 224, + "max_size": 224, + "box_score_thresh": 0.15, + "input_shape": (3, 224, 224), }, - 'keypointrcnn_resnet50_fpn': { - 'num_classes': 2, - 'min_size': 224, - 'max_size': 224, - 'box_score_thresh': 0.15, - 'input_shape': (3, 224, 224), + "fasterrcnn_resnet50_fpn": { + "num_classes": 20, + "min_size": 224, + "max_size": 224, + "input_shape": (3, 224, 224), }, - 'fasterrcnn_resnet50_fpn': { - 'num_classes': 20, - 'min_size': 224, - 'max_size': 224, - 'input_shape': (3, 224, 224), + "maskrcnn_resnet50_fpn": { + "num_classes": 10, + "min_size": 224, + "max_size": 224, + "input_shape": (3, 224, 224), }, - 'maskrcnn_resnet50_fpn': { - 'num_classes': 10, - 'min_size': 224, - 'max_size': 224, - 'input_shape': (3, 224, 224), + "fasterrcnn_mobilenet_v3_large_fpn": { + "box_score_thresh": 0.02076, }, - 'fasterrcnn_mobilenet_v3_large_fpn': { - 'box_score_thresh': 0.02076, + "fasterrcnn_mobilenet_v3_large_320_fpn": { + "box_score_thresh": 0.02076, + "rpn_pre_nms_top_n_test": 1000, + "rpn_post_nms_top_n_test": 1000, }, - 'fasterrcnn_mobilenet_v3_large_320_fpn': { - 'box_score_thresh': 0.02076, - 'rpn_pre_nms_top_n_test': 1000, - 'rpn_post_nms_top_n_test': 1000, - } } @@ -271,7 +273,7 @@ def _make_sliced_model(model, stop_layer): return new_model -@pytest.mark.parametrize('model_name', ['densenet121', 'densenet169', 'densenet201', 'densenet161']) +@pytest.mark.parametrize("model_name", ["densenet121", "densenet169", "densenet201", "densenet161"]) def test_memory_efficient_densenet(model_name): input_shape = (1, 3, 300, 300) x = torch.rand(input_shape) @@ -296,9 +298,9 @@ def test_memory_efficient_densenet(model_name): _check_input_backprop(model2, x) -@pytest.mark.parametrize('dilate_layer_2', (True, False)) -@pytest.mark.parametrize('dilate_layer_3', (True, False)) -@pytest.mark.parametrize('dilate_layer_4', (True, False)) +@pytest.mark.parametrize("dilate_layer_2", (True, False)) +@pytest.mark.parametrize("dilate_layer_3", (True, False)) +@pytest.mark.parametrize("dilate_layer_4", (True, False)) def test_resnet_dilation(dilate_layer_2, dilate_layer_3, dilate_layer_4): # TODO improve tests to also check that each layer has the right dimensionality model = models.__dict__["resnet50"](replace_stride_with_dilation=(dilate_layer_2, dilate_layer_3, dilate_layer_4)) @@ -318,7 +320,7 @@ def test_mobilenet_v2_residual_setting(): assert out.shape[-1] == 1000 -@pytest.mark.parametrize('model_name', ["mobilenet_v2", "mobilenet_v3_large", "mobilenet_v3_small"]) +@pytest.mark.parametrize("model_name", ["mobilenet_v2", "mobilenet_v3_large", "mobilenet_v3_small"]) def test_mobilenet_norm_layer(model_name): model = models.__dict__[model_name]() assert any(isinstance(x, nn.BatchNorm2d) for x in model.modules()) @@ -327,16 +329,16 @@ def get_gn(num_channels): return nn.GroupNorm(32, num_channels) model = models.__dict__[model_name](norm_layer=get_gn) - assert not(any(isinstance(x, nn.BatchNorm2d) for x in model.modules())) + assert not (any(isinstance(x, nn.BatchNorm2d) for x in model.modules())) assert any(isinstance(x, nn.GroupNorm) for x in model.modules()) def test_inception_v3_eval(): # replacement for models.inception_v3(pretrained=True) that does not download weights kwargs = {} - kwargs['transform_input'] = True - kwargs['aux_logits'] = True - kwargs['init_weights'] = False + kwargs["transform_input"] = True + kwargs["aux_logits"] = True + kwargs["init_weights"] = False name = "inception_v3" model = models.Inception3(**kwargs) model.aux_logits = False @@ -366,9 +368,9 @@ def test_fasterrcnn_double(): def test_googlenet_eval(): # replacement for models.googlenet(pretrained=True) that does not download weights kwargs = {} - kwargs['transform_input'] = True - kwargs['aux_logits'] = True - kwargs['init_weights'] = False + kwargs["transform_input"] = True + kwargs["aux_logits"] = True + kwargs["init_weights"] = False name = "googlenet" model = models.GoogLeNet(**kwargs) model.aux_logits = False @@ -392,7 +394,7 @@ def checkOut(out): model.cuda() model.eval() input_shape = (3, 300, 300) - x = torch.rand(input_shape, device='cuda') + x = torch.rand(input_shape, device="cuda") model_input = [x] out = model(model_input) assert model_input[0] is x @@ -422,30 +424,29 @@ def test_generalizedrcnn_transform_repr(): image_mean = [0.485, 0.456, 0.406] image_std = [0.229, 0.224, 0.225] - t = models.detection.transform.GeneralizedRCNNTransform(min_size=min_size, - max_size=max_size, - image_mean=image_mean, - image_std=image_std) + t = models.detection.transform.GeneralizedRCNNTransform( + min_size=min_size, max_size=max_size, image_mean=image_mean, image_std=image_std + ) # Check integrity of object __repr__ attribute - expected_string = 'GeneralizedRCNNTransform(' - _indent = '\n ' - expected_string += '{0}Normalize(mean={1}, std={2})'.format(_indent, image_mean, image_std) - expected_string += '{0}Resize(min_size=({1},), max_size={2}, '.format(_indent, min_size, max_size) + expected_string = "GeneralizedRCNNTransform(" + _indent = "\n " + expected_string += "{0}Normalize(mean={1}, std={2})".format(_indent, image_mean, image_std) + expected_string += "{0}Resize(min_size=({1},), max_size={2}, ".format(_indent, min_size, max_size) expected_string += "mode='bilinear')\n)" assert t.__repr__() == expected_string -@pytest.mark.parametrize('model_name', get_available_classification_models()) -@pytest.mark.parametrize('dev', cpu_and_gpu()) +@pytest.mark.parametrize("model_name", get_available_classification_models()) +@pytest.mark.parametrize("dev", cpu_and_gpu()) def test_classification_model(model_name, dev): set_rng_seed(0) defaults = { - 'num_classes': 50, - 'input_shape': (1, 3, 224, 224), + "num_classes": 50, + "input_shape": (1, 3, 224, 224), } kwargs = {**defaults, **_model_params.get(model_name, {})} - input_shape = kwargs.pop('input_shape') + input_shape = kwargs.pop("input_shape") model = models.__dict__[model_name](**kwargs) model.eval().to(device=dev) @@ -468,17 +469,17 @@ def test_classification_model(model_name, dev): _check_input_backprop(model, x) -@pytest.mark.parametrize('model_name', get_available_segmentation_models()) -@pytest.mark.parametrize('dev', cpu_and_gpu()) +@pytest.mark.parametrize("model_name", get_available_segmentation_models()) +@pytest.mark.parametrize("dev", cpu_and_gpu()) def test_segmentation_model(model_name, dev): set_rng_seed(0) defaults = { - 'num_classes': 10, - 'pretrained_backbone': False, - 'input_shape': (1, 3, 32, 32), + "num_classes": 10, + "pretrained_backbone": False, + "input_shape": (1, 3, 32, 32), } kwargs = {**defaults, **_model_params.get(model_name, {})} - input_shape = kwargs.pop('input_shape') + input_shape = kwargs.pop("input_shape") model = models.segmentation.__dict__[model_name](**kwargs) model.eval().to(device=dev) @@ -517,27 +518,29 @@ def check_out(out): full_validation &= check_out(out) if not full_validation: - msg = "The output of {} could only be partially validated. " \ - "This is likely due to unit-test flakiness, but you may " \ - "want to do additional manual checks if you made " \ - "significant changes to the codebase.".format(test_segmentation_model.__name__) + msg = ( + "The output of {} could only be partially validated. " + "This is likely due to unit-test flakiness, but you may " + "want to do additional manual checks if you made " + "significant changes to the codebase.".format(test_segmentation_model.__name__) + ) warnings.warn(msg, RuntimeWarning) pytest.skip(msg) _check_input_backprop(model, x) -@pytest.mark.parametrize('model_name', get_available_detection_models()) -@pytest.mark.parametrize('dev', cpu_and_gpu()) +@pytest.mark.parametrize("model_name", get_available_detection_models()) +@pytest.mark.parametrize("dev", cpu_and_gpu()) def test_detection_model(model_name, dev): set_rng_seed(0) defaults = { - 'num_classes': 50, - 'pretrained_backbone': False, - 'input_shape': (3, 300, 300), + "num_classes": 50, + "pretrained_backbone": False, + "input_shape": (3, 300, 300), } kwargs = {**defaults, **_model_params.get(model_name, {})} - input_shape = kwargs.pop('input_shape') + input_shape = kwargs.pop("input_shape") model = models.detection.__dict__[model_name](**kwargs) model.eval().to(device=dev) @@ -565,7 +568,7 @@ def subsample_tensor(tensor): return tensor ith_index = num_elems // num_samples - return tensor[ith_index - 1::ith_index] + return tensor[ith_index - 1 :: ith_index] def compute_mean_std(tensor): # can't compute mean of integral tensor @@ -588,8 +591,9 @@ def compute_mean_std(tensor): # scores. expected_file = _get_expected_file(model_name) expected = torch.load(expected_file) - torch.testing.assert_close(output[0]["scores"], expected[0]["scores"], rtol=prec, atol=prec, - check_device=False, check_dtype=False) + torch.testing.assert_close( + output[0]["scores"], expected[0]["scores"], rtol=prec, atol=prec, check_device=False, check_dtype=False + ) # Note: Fmassa proposed turning off NMS by adapting the threshold # and then using the Hungarian algorithm as in DETR to find the @@ -610,17 +614,19 @@ def compute_mean_std(tensor): full_validation &= check_out(out) if not full_validation: - msg = "The output of {} could only be partially validated. " \ - "This is likely due to unit-test flakiness, but you may " \ - "want to do additional manual checks if you made " \ - "significant changes to the codebase.".format(test_detection_model.__name__) + msg = ( + "The output of {} could only be partially validated. " + "This is likely due to unit-test flakiness, but you may " + "want to do additional manual checks if you made " + "significant changes to the codebase.".format(test_detection_model.__name__) + ) warnings.warn(msg, RuntimeWarning) pytest.skip(msg) _check_input_backprop(model, model_input) -@pytest.mark.parametrize('model_name', get_available_detection_models()) +@pytest.mark.parametrize("model_name", get_available_detection_models()) def test_detection_model_validation(model_name): set_rng_seed(0) model = models.detection.__dict__[model_name](num_classes=50, pretrained_backbone=False) @@ -632,25 +638,25 @@ def test_detection_model_validation(model_name): model(x) # validate type - targets = [{'boxes': 0.}] + targets = [{"boxes": 0.0}] with pytest.raises(ValueError): model(x, targets=targets) # validate boxes shape for boxes in (torch.rand((4,)), torch.rand((1, 5))): - targets = [{'boxes': boxes}] + targets = [{"boxes": boxes}] with pytest.raises(ValueError): model(x, targets=targets) # validate that no degenerate boxes are present boxes = torch.tensor([[1, 3, 1, 4], [2, 4, 3, 4]]) - targets = [{'boxes': boxes}] + targets = [{"boxes": boxes}] with pytest.raises(ValueError): model(x, targets=targets) -@pytest.mark.parametrize('model_name', get_available_video_models()) -@pytest.mark.parametrize('dev', cpu_and_gpu()) +@pytest.mark.parametrize("model_name", get_available_video_models()) +@pytest.mark.parametrize("dev", cpu_and_gpu()) def test_video_model(model_name, dev): # the default input shape is # bs * num_channels * clip_len * h *w @@ -673,25 +679,29 @@ def test_video_model(model_name, dev): _check_input_backprop(model, x) -@pytest.mark.skipif(not ('fbgemm' in torch.backends.quantized.supported_engines and - 'qnnpack' in torch.backends.quantized.supported_engines), - reason="This Pytorch Build has not been built with fbgemm and qnnpack") -@pytest.mark.parametrize('model_name', get_available_quantizable_models()) +@pytest.mark.skipif( + not ( + "fbgemm" in torch.backends.quantized.supported_engines + and "qnnpack" in torch.backends.quantized.supported_engines + ), + reason="This Pytorch Build has not been built with fbgemm and qnnpack", +) +@pytest.mark.parametrize("model_name", get_available_quantizable_models()) def test_quantized_classification_model(model_name): defaults = { - 'input_shape': (1, 3, 224, 224), - 'pretrained': False, - 'quantize': True, + "input_shape": (1, 3, 224, 224), + "pretrained": False, + "quantize": True, } kwargs = {**defaults, **_model_params.get(model_name, {})} - input_shape = kwargs.pop('input_shape') + input_shape = kwargs.pop("input_shape") # First check if quantize=True provides models that can run with input data model = torchvision.models.quantization.__dict__[model_name](**kwargs) x = torch.rand(input_shape) model(x) - kwargs['quantize'] = False + kwargs["quantize"] = False for eval_mode in [True, False]: model = torchvision.models.quantization.__dict__[model_name](**kwargs) if eval_mode: @@ -717,5 +727,5 @@ def test_quantized_classification_model(model_name): raise AssertionError(f"model cannot be scripted. Traceback = {str(tb)}") from e -if __name__ == '__main__': +if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/test_models_detection_anchor_utils.py b/test/test_models_detection_anchor_utils.py index 694ef315a1f..0e99a462158 100644 --- a/test/test_models_detection_anchor_utils.py +++ b/test/test_models_detection_anchor_utils.py @@ -1,13 +1,16 @@ +import pytest import torch from common_utils import assert_equal from torchvision.models.detection.anchor_utils import AnchorGenerator, DefaultBoxGenerator from torchvision.models.detection.image_list import ImageList -import pytest class Tester: def test_incorrect_anchors(self): - incorrect_sizes = ((2, 4, 8), (32, 8), ) + incorrect_sizes = ( + (2, 4, 8), + (32, 8), + ) incorrect_aspects = (0.5, 1.0) anc = AnchorGenerator(incorrect_sizes, incorrect_aspects) image1 = torch.randn(3, 800, 800) @@ -49,15 +52,19 @@ def test_anchor_generator(self): for sizes, num_anchors_per_loc in zip(grid_sizes, model.num_anchors_per_location()): num_anchors_estimated += sizes[0] * sizes[1] * num_anchors_per_loc - anchors_output = torch.tensor([[-5., -5., 5., 5.], - [0., -5., 10., 5.], - [5., -5., 15., 5.], - [-5., 0., 5., 10.], - [0., 0., 10., 10.], - [5., 0., 15., 10.], - [-5., 5., 5., 15.], - [0., 5., 10., 15.], - [5., 5., 15., 15.]]) + anchors_output = torch.tensor( + [ + [-5.0, -5.0, 5.0, 5.0], + [0.0, -5.0, 10.0, 5.0], + [5.0, -5.0, 15.0, 5.0], + [-5.0, 0.0, 5.0, 10.0], + [0.0, 0.0, 10.0, 10.0], + [5.0, 0.0, 15.0, 10.0], + [-5.0, 5.0, 5.0, 15.0], + [0.0, 5.0, 10.0, 15.0], + [5.0, 5.0, 15.0, 15.0], + ] + ) assert num_anchors_estimated == 9 assert len(anchors) == 2 @@ -76,12 +83,14 @@ def test_defaultbox_generator(self): model.eval() dboxes = model(images, features) - dboxes_output = torch.tensor([ - [6.3750, 6.3750, 8.6250, 8.6250], - [4.7443, 4.7443, 10.2557, 10.2557], - [5.9090, 6.7045, 9.0910, 8.2955], - [6.7045, 5.9090, 8.2955, 9.0910] - ]) + dboxes_output = torch.tensor( + [ + [6.3750, 6.3750, 8.6250, 8.6250], + [4.7443, 4.7443, 10.2557, 10.2557], + [5.9090, 6.7045, 9.0910, 8.2955], + [6.7045, 5.9090, 8.2955, 9.0910], + ] + ) assert len(dboxes) == 2 assert tuple(dboxes[0].shape) == (4, 4) diff --git a/test/test_models_detection_negative_samples.py b/test/test_models_detection_negative_samples.py index a4b7064b338..921fd0c7ba2 100644 --- a/test/test_models_detection_negative_samples.py +++ b/test/test_models_detection_negative_samples.py @@ -1,25 +1,24 @@ +import pytest import torch - import torchvision.models -from torchvision.ops import MultiScaleRoIAlign -from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionProposalNetwork -from torchvision.models.detection.roi_heads import RoIHeads -from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead - -import pytest from common_utils import assert_equal +from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead +from torchvision.models.detection.roi_heads import RoIHeads +from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionProposalNetwork +from torchvision.ops import MultiScaleRoIAlign class TestModelsDetectionNegativeSamples: - def _make_empty_sample(self, add_masks=False, add_keypoints=False): images = [torch.rand((3, 100, 100), dtype=torch.float32)] boxes = torch.zeros((0, 4), dtype=torch.float32) - negative_target = {"boxes": boxes, - "labels": torch.zeros(0, dtype=torch.int64), - "image_id": 4, - "area": (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]), - "iscrowd": torch.zeros((0,), dtype=torch.int64)} + negative_target = { + "boxes": boxes, + "labels": torch.zeros(0, dtype=torch.int64), + "image_id": 4, + "area": (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]), + "iscrowd": torch.zeros((0,), dtype=torch.int64), + } if add_masks: negative_target["masks"] = torch.zeros(0, 100, 100, dtype=torch.uint8) @@ -36,16 +35,10 @@ def test_targets_to_anchors(self): anchor_sizes = ((32,), (64,), (128,), (256,), (512,)) aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) - rpn_anchor_generator = AnchorGenerator( - anchor_sizes, aspect_ratios - ) + rpn_anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios) rpn_head = RPNHead(4, rpn_anchor_generator.num_anchors_per_location()[0]) - head = RegionProposalNetwork( - rpn_anchor_generator, rpn_head, - 0.5, 0.3, - 256, 0.5, - 2000, 2000, 0.7, 0.05) + head = RegionProposalNetwork(rpn_anchor_generator, rpn_head, 0.5, 0.3, 256, 0.5, 2000, 2000, 0.7, 0.05) labels, matched_gt_boxes = head.assign_targets_to_anchors(anchors, targets) @@ -63,29 +56,29 @@ def test_assign_targets_to_proposals(self): gt_boxes = [torch.zeros((0, 4), dtype=torch.float32)] gt_labels = [torch.tensor([[0]], dtype=torch.int64)] - box_roi_pool = MultiScaleRoIAlign( - featmap_names=['0', '1', '2', '3'], - output_size=7, - sampling_ratio=2) + box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2) resolution = box_roi_pool.output_size[0] representation_size = 1024 - box_head = TwoMLPHead( - 4 * resolution ** 2, - representation_size) + box_head = TwoMLPHead(4 * resolution ** 2, representation_size) representation_size = 1024 - box_predictor = FastRCNNPredictor( - representation_size, - 2) + box_predictor = FastRCNNPredictor(representation_size, 2) roi_heads = RoIHeads( # Box - box_roi_pool, box_head, box_predictor, - 0.5, 0.5, - 512, 0.25, + box_roi_pool, + box_head, + box_predictor, + 0.5, + 0.5, + 512, + 0.25, None, - 0.05, 0.5, 100) + 0.05, + 0.5, + 100, + ) matched_idxs, labels = roi_heads.assign_targets_to_proposals(proposals, gt_boxes, gt_labels) @@ -97,61 +90,61 @@ def test_assign_targets_to_proposals(self): assert labels[0].shape == torch.Size([proposals[0].shape[0]]) assert labels[0].dtype == torch.int64 - @pytest.mark.parametrize('name', [ - "fasterrcnn_resnet50_fpn", - "fasterrcnn_mobilenet_v3_large_fpn", - "fasterrcnn_mobilenet_v3_large_320_fpn", - ]) + @pytest.mark.parametrize( + "name", + [ + "fasterrcnn_resnet50_fpn", + "fasterrcnn_mobilenet_v3_large_fpn", + "fasterrcnn_mobilenet_v3_large_320_fpn", + ], + ) def test_forward_negative_sample_frcnn(self, name): - model = torchvision.models.detection.__dict__[name]( - num_classes=2, min_size=100, max_size=100) + model = torchvision.models.detection.__dict__[name](num_classes=2, min_size=100, max_size=100) images, targets = self._make_empty_sample() loss_dict = model(images, targets) - assert_equal(loss_dict["loss_box_reg"], torch.tensor(0.)) - assert_equal(loss_dict["loss_rpn_box_reg"], torch.tensor(0.)) + assert_equal(loss_dict["loss_box_reg"], torch.tensor(0.0)) + assert_equal(loss_dict["loss_rpn_box_reg"], torch.tensor(0.0)) def test_forward_negative_sample_mrcnn(self): - model = torchvision.models.detection.maskrcnn_resnet50_fpn( - num_classes=2, min_size=100, max_size=100) + model = torchvision.models.detection.maskrcnn_resnet50_fpn(num_classes=2, min_size=100, max_size=100) images, targets = self._make_empty_sample(add_masks=True) loss_dict = model(images, targets) - assert_equal(loss_dict["loss_box_reg"], torch.tensor(0.)) - assert_equal(loss_dict["loss_rpn_box_reg"], torch.tensor(0.)) - assert_equal(loss_dict["loss_mask"], torch.tensor(0.)) + assert_equal(loss_dict["loss_box_reg"], torch.tensor(0.0)) + assert_equal(loss_dict["loss_rpn_box_reg"], torch.tensor(0.0)) + assert_equal(loss_dict["loss_mask"], torch.tensor(0.0)) def test_forward_negative_sample_krcnn(self): - model = torchvision.models.detection.keypointrcnn_resnet50_fpn( - num_classes=2, min_size=100, max_size=100) + model = torchvision.models.detection.keypointrcnn_resnet50_fpn(num_classes=2, min_size=100, max_size=100) images, targets = self._make_empty_sample(add_keypoints=True) loss_dict = model(images, targets) - assert_equal(loss_dict["loss_box_reg"], torch.tensor(0.)) - assert_equal(loss_dict["loss_rpn_box_reg"], torch.tensor(0.)) - assert_equal(loss_dict["loss_keypoint"], torch.tensor(0.)) + assert_equal(loss_dict["loss_box_reg"], torch.tensor(0.0)) + assert_equal(loss_dict["loss_rpn_box_reg"], torch.tensor(0.0)) + assert_equal(loss_dict["loss_keypoint"], torch.tensor(0.0)) def test_forward_negative_sample_retinanet(self): model = torchvision.models.detection.retinanet_resnet50_fpn( - num_classes=2, min_size=100, max_size=100, pretrained_backbone=False) + num_classes=2, min_size=100, max_size=100, pretrained_backbone=False + ) images, targets = self._make_empty_sample() loss_dict = model(images, targets) - assert_equal(loss_dict["bbox_regression"], torch.tensor(0.)) + assert_equal(loss_dict["bbox_regression"], torch.tensor(0.0)) def test_forward_negative_sample_ssd(self): - model = torchvision.models.detection.ssd300_vgg16( - num_classes=2, pretrained_backbone=False) + model = torchvision.models.detection.ssd300_vgg16(num_classes=2, pretrained_backbone=False) images, targets = self._make_empty_sample() loss_dict = model(images, targets) - assert_equal(loss_dict["bbox_regression"], torch.tensor(0.)) + assert_equal(loss_dict["bbox_regression"], torch.tensor(0.0)) -if __name__ == '__main__': +if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/test_models_detection_utils.py b/test/test_models_detection_utils.py index b599bbeaea1..8d686023b1d 100644 --- a/test/test_models_detection_utils.py +++ b/test/test_models_detection_utils.py @@ -1,14 +1,14 @@ import copy + +import pytest import torch +from common_utils import assert_equal from torchvision.models.detection import _utils -from torchvision.models.detection.transform import GeneralizedRCNNTransform -import pytest from torchvision.models.detection import backbone_utils -from common_utils import assert_equal +from torchvision.models.detection.transform import GeneralizedRCNNTransform class TestModelsDetectionUtils: - def test_balanced_positive_negative_sampler(self): sampler = _utils.BalancedPositiveNegativeSampler(4, 0.25) # keep all 6 negatives first, then add 3 positives, last two are ignore @@ -22,16 +22,13 @@ def test_balanced_positive_negative_sampler(self): assert neg[0].sum() == 3 assert neg[0][0:6].sum() == 3 - @pytest.mark.parametrize('train_layers, exp_froz_params', [ - (0, 53), (1, 43), (2, 24), (3, 11), (4, 1), (5, 0) - ]) + @pytest.mark.parametrize("train_layers, exp_froz_params", [(0, 53), (1, 43), (2, 24), (3, 11), (4, 1), (5, 0)]) def test_resnet_fpn_backbone_frozen_layers(self, train_layers, exp_froz_params): # we know how many initial layers and parameters of the network should # be frozen for each trainable_backbone_layers parameter value # i.e all 53 params are frozen if trainable_backbone_layers=0 # ad first 24 params are frozen if trainable_backbone_layers=2 - model = backbone_utils.resnet_fpn_backbone( - 'resnet50', pretrained=False, trainable_layers=train_layers) + model = backbone_utils.resnet_fpn_backbone("resnet50", pretrained=False, trainable_layers=train_layers) # boolean list that is true if the param at that index is frozen is_frozen = [not parameter.requires_grad for _, parameter in model.named_parameters()] # check that expected initial number of layers are frozen @@ -40,34 +37,37 @@ def test_resnet_fpn_backbone_frozen_layers(self, train_layers, exp_froz_params): def test_validate_resnet_inputs_detection(self): # default number of backbone layers to train ret = backbone_utils._validate_trainable_layers( - pretrained=True, trainable_backbone_layers=None, max_value=5, default_value=3) + pretrained=True, trainable_backbone_layers=None, max_value=5, default_value=3 + ) assert ret == 3 # can't go beyond 5 with pytest.raises(AssertionError): ret = backbone_utils._validate_trainable_layers( - pretrained=True, trainable_backbone_layers=6, max_value=5, default_value=3) + pretrained=True, trainable_backbone_layers=6, max_value=5, default_value=3 + ) # if not pretrained, should use all trainable layers and warn with pytest.warns(UserWarning): ret = backbone_utils._validate_trainable_layers( - pretrained=False, trainable_backbone_layers=0, max_value=5, default_value=3) + pretrained=False, trainable_backbone_layers=0, max_value=5, default_value=3 + ) assert ret == 5 def test_transform_copy_targets(self): transform = GeneralizedRCNNTransform(300, 500, torch.zeros(3), torch.ones(3)) image = [torch.rand(3, 200, 300), torch.rand(3, 200, 200)] - targets = [{'boxes': torch.rand(3, 4)}, {'boxes': torch.rand(2, 4)}] + targets = [{"boxes": torch.rand(3, 4)}, {"boxes": torch.rand(2, 4)}] targets_copy = copy.deepcopy(targets) out = transform(image, targets) # noqa: F841 - assert_equal(targets[0]['boxes'], targets_copy[0]['boxes']) - assert_equal(targets[1]['boxes'], targets_copy[1]['boxes']) + assert_equal(targets[0]["boxes"], targets_copy[0]["boxes"]) + assert_equal(targets[1]["boxes"], targets_copy[1]["boxes"]) def test_not_float_normalize(self): transform = GeneralizedRCNNTransform(300, 500, torch.zeros(3), torch.ones(3)) image = [torch.randint(0, 255, (3, 200, 300), dtype=torch.uint8)] - targets = [{'boxes': torch.rand(3, 4)}] + targets = [{"boxes": torch.rand(3, 4)}] with pytest.raises(TypeError): out = transform(image, targets) # noqa: F841 -if __name__ == '__main__': +if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/test_onnx.py b/test/test_onnx.py index cd3239cef16..c81d490a882 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -1,18 +1,17 @@ -from common_utils import set_rng_seed, assert_equal import io +from collections import OrderedDict +from typing import List, Tuple + import pytest import torch -from torchvision import ops +from common_utils import set_rng_seed, assert_equal from torchvision import models +from torchvision import ops +from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead from torchvision.models.detection.image_list import ImageList -from torchvision.models.detection.transform import GeneralizedRCNNTransform -from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionProposalNetwork from torchvision.models.detection.roi_heads import RoIHeads -from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, TwoMLPHead - -from collections import OrderedDict -from typing import List, Tuple - +from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionProposalNetwork +from torchvision.models.detection.transform import GeneralizedRCNNTransform from torchvision.ops._register_onnx_ops import _onnx_opset_version # In environments without onnxruntime we prefer to @@ -25,8 +24,16 @@ class TestONNXExporter: def setup_class(cls): torch.manual_seed(123) - def run_model(self, model, inputs_list, tolerate_small_mismatch=False, do_constant_folding=True, dynamic_axes=None, - output_names=None, input_names=None): + def run_model( + self, + model, + inputs_list, + tolerate_small_mismatch=False, + do_constant_folding=True, + dynamic_axes=None, + output_names=None, + input_names=None, + ): model.eval() onnx_io = io.BytesIO() @@ -35,14 +42,20 @@ def run_model(self, model, inputs_list, tolerate_small_mismatch=False, do_consta else: torch_onnx_input = inputs_list[0] # export to onnx with the first input - torch.onnx.export(model, torch_onnx_input, onnx_io, - do_constant_folding=do_constant_folding, opset_version=_onnx_opset_version, - dynamic_axes=dynamic_axes, input_names=input_names, output_names=output_names) + torch.onnx.export( + model, + torch_onnx_input, + onnx_io, + do_constant_folding=do_constant_folding, + opset_version=_onnx_opset_version, + dynamic_axes=dynamic_axes, + input_names=input_names, + output_names=output_names, + ) # validate the exported model with onnx runtime for test_inputs in inputs_list: with torch.no_grad(): - if isinstance(test_inputs, torch.Tensor) or \ - isinstance(test_inputs, list): + if isinstance(test_inputs, torch.Tensor) or isinstance(test_inputs, list): test_inputs = (test_inputs,) test_ouputs = model(*test_inputs) if isinstance(test_ouputs, torch.Tensor): @@ -113,9 +126,9 @@ class Module(torch.nn.Module): def forward(self, boxes, size): return ops.boxes.clip_boxes_to_image(boxes, size.shape) - self.run_model(Module(), [(boxes, size), (boxes, size_2)], - input_names=["boxes", "size"], - dynamic_axes={"size": [0, 1]}) + self.run_model( + Module(), [(boxes, size), (boxes, size_2)], input_names=["boxes", "size"], dynamic_axes={"size": [0, 1]} + ) def test_roi_align(self): x = torch.rand(1, 1, 10, 10, dtype=torch.float32) @@ -180,11 +193,11 @@ def forward(self_module, images): input = torch.rand(3, 10, 20) input_test = torch.rand(3, 100, 150) - self.run_model(TransformModule(), [(input,), (input_test,)], - input_names=["input1"], dynamic_axes={"input1": [0, 1, 2]}) + self.run_model( + TransformModule(), [(input,), (input_test,)], input_names=["input1"], dynamic_axes={"input1": [0, 1, 2]} + ) def test_transform_images(self): - class TransformModule(torch.nn.Module): def __init__(self_module): super(TransformModule, self_module).__init__() @@ -221,11 +234,17 @@ def _init_test_rpn(self): rpn_score_thresh = 0.0 rpn = RegionProposalNetwork( - rpn_anchor_generator, rpn_head, - rpn_fg_iou_thresh, rpn_bg_iou_thresh, - rpn_batch_size_per_image, rpn_positive_fraction, - rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh, - score_thresh=rpn_score_thresh) + rpn_anchor_generator, + rpn_head, + rpn_fg_iou_thresh, + rpn_bg_iou_thresh, + rpn_batch_size_per_image, + rpn_positive_fraction, + rpn_pre_nms_top_n, + rpn_post_nms_top_n, + rpn_nms_thresh, + score_thresh=rpn_score_thresh, + ) return rpn def _init_test_roi_heads_faster_rcnn(self): @@ -241,38 +260,38 @@ def _init_test_roi_heads_faster_rcnn(self): box_nms_thresh = 0.5 box_detections_per_img = 100 - box_roi_pool = ops.MultiScaleRoIAlign( - featmap_names=['0', '1', '2', '3'], - output_size=7, - sampling_ratio=2) + box_roi_pool = ops.MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2) resolution = box_roi_pool.output_size[0] representation_size = 1024 - box_head = TwoMLPHead( - out_channels * resolution ** 2, - representation_size) + box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size) representation_size = 1024 - box_predictor = FastRCNNPredictor( - representation_size, - num_classes) + box_predictor = FastRCNNPredictor(representation_size, num_classes) roi_heads = RoIHeads( - box_roi_pool, box_head, box_predictor, - box_fg_iou_thresh, box_bg_iou_thresh, - box_batch_size_per_image, box_positive_fraction, + box_roi_pool, + box_head, + box_predictor, + box_fg_iou_thresh, + box_bg_iou_thresh, + box_batch_size_per_image, + box_positive_fraction, bbox_reg_weights, - box_score_thresh, box_nms_thresh, box_detections_per_img) + box_score_thresh, + box_nms_thresh, + box_detections_per_img, + ) return roi_heads def get_features(self, images): s0, s1 = images.shape[-2:] features = [ - ('0', torch.rand(2, 256, s0 // 4, s1 // 4)), - ('1', torch.rand(2, 256, s0 // 8, s1 // 8)), - ('2', torch.rand(2, 256, s0 // 16, s1 // 16)), - ('3', torch.rand(2, 256, s0 // 32, s1 // 32)), - ('4', torch.rand(2, 256, s0 // 64, s1 // 64)), + ("0", torch.rand(2, 256, s0 // 4, s1 // 4)), + ("1", torch.rand(2, 256, s0 // 8, s1 // 8)), + ("2", torch.rand(2, 256, s0 // 16, s1 // 16)), + ("3", torch.rand(2, 256, s0 // 32, s1 // 32)), + ("4", torch.rand(2, 256, s0 // 64, s1 // 64)), ] features = OrderedDict(features) return features @@ -298,36 +317,56 @@ def forward(self_module, images, features): model.eval() model(images, features) - self.run_model(model, [(images, features), (images2, test_features)], tolerate_small_mismatch=True, - input_names=["input1", "input2", "input3", "input4", "input5", "input6"], - dynamic_axes={"input1": [0, 1, 2, 3], "input2": [0, 1, 2, 3], - "input3": [0, 1, 2, 3], "input4": [0, 1, 2, 3], - "input5": [0, 1, 2, 3], "input6": [0, 1, 2, 3]}) + self.run_model( + model, + [(images, features), (images2, test_features)], + tolerate_small_mismatch=True, + input_names=["input1", "input2", "input3", "input4", "input5", "input6"], + dynamic_axes={ + "input1": [0, 1, 2, 3], + "input2": [0, 1, 2, 3], + "input3": [0, 1, 2, 3], + "input4": [0, 1, 2, 3], + "input5": [0, 1, 2, 3], + "input6": [0, 1, 2, 3], + }, + ) def test_multi_scale_roi_align(self): - class TransformModule(torch.nn.Module): def __init__(self): super(TransformModule, self).__init__() - self.model = ops.MultiScaleRoIAlign(['feat1', 'feat2'], 3, 2) + self.model = ops.MultiScaleRoIAlign(["feat1", "feat2"], 3, 2) self.image_sizes = [(512, 512)] def forward(self, input, boxes): return self.model(input, boxes, self.image_sizes) i = OrderedDict() - i['feat1'] = torch.rand(1, 5, 64, 64) - i['feat2'] = torch.rand(1, 5, 16, 16) + i["feat1"] = torch.rand(1, 5, 64, 64) + i["feat2"] = torch.rand(1, 5, 16, 16) boxes = torch.rand(6, 4) * 256 boxes[:, 2:] += boxes[:, :2] i1 = OrderedDict() - i1['feat1'] = torch.rand(1, 5, 64, 64) - i1['feat2'] = torch.rand(1, 5, 16, 16) + i1["feat1"] = torch.rand(1, 5, 64, 64) + i1["feat2"] = torch.rand(1, 5, 16, 16) boxes1 = torch.rand(6, 4) * 256 boxes1[:, 2:] += boxes1[:, :2] - self.run_model(TransformModule(), [(i, [boxes],), (i1, [boxes1],)]) + self.run_model( + TransformModule(), + [ + ( + i, + [boxes], + ), + ( + i1, + [boxes1], + ), + ], + ) def test_roi_heads(self): class RoiHeadsModule(torch.nn.Module): @@ -342,9 +381,7 @@ def forward(self_module, images, features): images = ImageList(images, [i.shape[-2:] for i in images]) proposals, _ = self_module.rpn(images, features) detections, _ = self_module.roi_heads(features, proposals, images.image_sizes) - detections = self_module.transform.postprocess(detections, - images.image_sizes, - original_image_sizes) + detections = self_module.transform.postprocess(detections, images.image_sizes, original_image_sizes) return detections images = torch.rand(2, 3, 100, 100) @@ -356,13 +393,24 @@ def forward(self_module, images, features): model.eval() model(images, features) - self.run_model(model, [(images, features), (images2, test_features)], tolerate_small_mismatch=True, - input_names=["input1", "input2", "input3", "input4", "input5", "input6"], - dynamic_axes={"input1": [0, 1, 2, 3], "input2": [0, 1, 2, 3], "input3": [0, 1, 2, 3], - "input4": [0, 1, 2, 3], "input5": [0, 1, 2, 3], "input6": [0, 1, 2, 3]}) + self.run_model( + model, + [(images, features), (images2, test_features)], + tolerate_small_mismatch=True, + input_names=["input1", "input2", "input3", "input4", "input5", "input6"], + dynamic_axes={ + "input1": [0, 1, 2, 3], + "input2": [0, 1, 2, 3], + "input3": [0, 1, 2, 3], + "input4": [0, 1, 2, 3], + "input5": [0, 1, 2, 3], + "input6": [0, 1, 2, 3], + }, + ) def get_image(self, rel_path: str, size: Tuple[int, int]) -> torch.Tensor: import os + from PIL import Image from torchvision import transforms @@ -373,8 +421,10 @@ def get_image(self, rel_path: str, size: Tuple[int, int]) -> torch.Tensor: return transforms.ToTensor()(image) def get_test_images(self) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: - return ([self.get_image("encode_jpeg/grace_hopper_517x606.jpg", (100, 320))], - [self.get_image("fakedata/logos/rgb_pytorch.png", (250, 380))]) + return ( + [self.get_image("encode_jpeg/grace_hopper_517x606.jpg", (100, 320))], + [self.get_image("fakedata/logos/rgb_pytorch.png", (250, 380))], + ) def test_faster_rcnn(self): images, test_images = self.get_test_images() @@ -383,15 +433,23 @@ def test_faster_rcnn(self): model.eval() model(images) # Test exported model on images of different size, or dummy input - self.run_model(model, [(images,), (test_images,), (dummy_image,)], input_names=["images_tensors"], - output_names=["outputs"], - dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]}, - tolerate_small_mismatch=True) + self.run_model( + model, + [(images,), (test_images,), (dummy_image,)], + input_names=["images_tensors"], + output_names=["outputs"], + dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]}, + tolerate_small_mismatch=True, + ) # Test exported model for an image with no detections on other images - self.run_model(model, [(dummy_image,), (images,)], input_names=["images_tensors"], - output_names=["outputs"], - dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]}, - tolerate_small_mismatch=True) + self.run_model( + model, + [(dummy_image,), (images,)], + input_names=["images_tensors"], + output_names=["outputs"], + dynamic_axes={"images_tensors": [0, 1, 2], "outputs": [0, 1, 2]}, + tolerate_small_mismatch=True, + ) # Verify that paste_mask_in_image beahves the same in tracing. # This test also compares both paste_masks_in_image and _onnx_paste_masks_in_image @@ -403,11 +461,11 @@ def test_paste_mask_in_image(self): boxes *= 50 o_im_s = (100, 100) from torchvision.models.detection.roi_heads import paste_masks_in_image + out = paste_masks_in_image(masks, boxes, o_im_s) - jit_trace = torch.jit.trace(paste_masks_in_image, - (masks, boxes, - [torch.tensor(o_im_s[0]), - torch.tensor(o_im_s[1])])) + jit_trace = torch.jit.trace( + paste_masks_in_image, (masks, boxes, [torch.tensor(o_im_s[0]), torch.tensor(o_im_s[1])]) + ) out_trace = jit_trace(masks, boxes, [torch.tensor(o_im_s[0]), torch.tensor(o_im_s[1])]) assert torch.all(out.eq(out_trace)) @@ -418,6 +476,7 @@ def test_paste_mask_in_image(self): boxes2 *= 100 o_im_s2 = (200, 200) from torchvision.models.detection.roi_heads import paste_masks_in_image + out2 = paste_masks_in_image(masks2, boxes2, o_im_s2) out_trace2 = jit_trace(masks2, boxes2, [torch.tensor(o_im_s2[0]), torch.tensor(o_im_s2[1])]) @@ -430,19 +489,35 @@ def test_mask_rcnn(self): model.eval() model(images) # Test exported model on images of different size, or dummy input - self.run_model(model, [(images,), (test_images,), (dummy_image,)], - input_names=["images_tensors"], - output_names=["boxes", "labels", "scores", "masks"], - dynamic_axes={"images_tensors": [0, 1, 2], "boxes": [0, 1], "labels": [0], - "scores": [0], "masks": [0, 1, 2]}, - tolerate_small_mismatch=True) + self.run_model( + model, + [(images,), (test_images,), (dummy_image,)], + input_names=["images_tensors"], + output_names=["boxes", "labels", "scores", "masks"], + dynamic_axes={ + "images_tensors": [0, 1, 2], + "boxes": [0, 1], + "labels": [0], + "scores": [0], + "masks": [0, 1, 2], + }, + tolerate_small_mismatch=True, + ) # Test exported model for an image with no detections on other images - self.run_model(model, [(dummy_image,), (images,)], - input_names=["images_tensors"], - output_names=["boxes", "labels", "scores", "masks"], - dynamic_axes={"images_tensors": [0, 1, 2], "boxes": [0, 1], "labels": [0], - "scores": [0], "masks": [0, 1, 2]}, - tolerate_small_mismatch=True) + self.run_model( + model, + [(dummy_image,), (images,)], + input_names=["images_tensors"], + output_names=["boxes", "labels", "scores", "masks"], + dynamic_axes={ + "images_tensors": [0, 1, 2], + "boxes": [0, 1], + "labels": [0], + "scores": [0], + "masks": [0, 1, 2], + }, + tolerate_small_mismatch=True, + ) # Verify that heatmaps_to_keypoints behaves the same in tracing. # This test also compares both heatmaps_to_keypoints and _onnx_heatmaps_to_keypoints @@ -451,6 +526,7 @@ def test_heatmaps_to_keypoints(self): maps = torch.rand(10, 1, 26, 26) rois = torch.rand(10, 4) from torchvision.models.detection.roi_heads import heatmaps_to_keypoints + out = heatmaps_to_keypoints(maps, rois) jit_trace = torch.jit.trace(heatmaps_to_keypoints, (maps, rois)) out_trace = jit_trace(maps, rois) @@ -461,6 +537,7 @@ def test_heatmaps_to_keypoints(self): maps2 = torch.rand(20, 2, 21, 21) rois2 = torch.rand(20, 4) from torchvision.models.detection.roi_heads import heatmaps_to_keypoints + out2 = heatmaps_to_keypoints(maps2, rois2) out_trace2 = jit_trace(maps2, rois2) @@ -473,29 +550,38 @@ def test_keypoint_rcnn(self): model = models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300) model.eval() model(images) - self.run_model(model, [(images,), (test_images,), (dummy_images,)], - input_names=["images_tensors"], - output_names=["outputs1", "outputs2", "outputs3", "outputs4"], - dynamic_axes={"images_tensors": [0, 1, 2]}, - tolerate_small_mismatch=True) - - self.run_model(model, [(dummy_images,), (test_images,)], - input_names=["images_tensors"], - output_names=["outputs1", "outputs2", "outputs3", "outputs4"], - dynamic_axes={"images_tensors": [0, 1, 2]}, - tolerate_small_mismatch=True) + self.run_model( + model, + [(images,), (test_images,), (dummy_images,)], + input_names=["images_tensors"], + output_names=["outputs1", "outputs2", "outputs3", "outputs4"], + dynamic_axes={"images_tensors": [0, 1, 2]}, + tolerate_small_mismatch=True, + ) + + self.run_model( + model, + [(dummy_images,), (test_images,)], + input_names=["images_tensors"], + output_names=["outputs1", "outputs2", "outputs3", "outputs4"], + dynamic_axes={"images_tensors": [0, 1, 2]}, + tolerate_small_mismatch=True, + ) def test_shufflenet_v2_dynamic_axes(self): model = models.shufflenet_v2_x0_5(pretrained=True) dummy_input = torch.randn(1, 3, 224, 224, requires_grad=True) test_inputs = torch.cat([dummy_input, dummy_input, dummy_input], 0) - self.run_model(model, [(dummy_input,), (test_inputs,)], - input_names=["input_images"], - output_names=["output"], - dynamic_axes={"input_images": {0: 'batch_size'}, "output": {0: 'batch_size'}}, - tolerate_small_mismatch=True) + self.run_model( + model, + [(dummy_input,), (test_inputs,)], + input_names=["input_images"], + output_names=["output"], + dynamic_axes={"input_images": {0: "batch_size"}, "output": {0: "batch_size"}}, + tolerate_small_mismatch=True, + ) -if __name__ == '__main__': +if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/test_ops.py b/test/test_ops.py index 9823cb1a81a..64329936b72 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1,26 +1,25 @@ -from common_utils import needs_cuda, cpu_and_gpu, assert_equal import math +import os from abc import ABC, abstractmethod -import pytest +from functools import lru_cache +from typing import Tuple import numpy as np -import os - -from PIL import Image +import pytest import torch -from functools import lru_cache +from common_utils import needs_cuda, cpu_and_gpu, assert_equal +from PIL import Image from torch import Tensor from torch.autograd import gradcheck from torch.nn.modules.utils import _pair from torchvision import ops -from typing import Tuple class RoIOpTester(ABC): dtype = torch.float64 - @pytest.mark.parametrize('device', cpu_and_gpu()) - @pytest.mark.parametrize('contiguous', (True, False)) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("contiguous", (True, False)) def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, **kwargs): x_dtype = self.dtype if x_dtype is None else x_dtype rois_dtype = self.dtype if rois_dtype is None else rois_dtype @@ -30,33 +29,33 @@ def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, **kwar x = torch.rand(2, n_channels, 10, 10, dtype=x_dtype, device=device) if not contiguous: x = x.permute(0, 1, 3, 2) - rois = torch.tensor([[0, 0, 0, 9, 9], # format is (xyxy) - [0, 0, 5, 4, 9], - [0, 5, 5, 9, 9], - [1, 0, 0, 9, 9]], - dtype=rois_dtype, device=device) + rois = torch.tensor( + [[0, 0, 0, 9, 9], [0, 0, 5, 4, 9], [0, 5, 5, 9, 9], [1, 0, 0, 9, 9]], # format is (xyxy) + dtype=rois_dtype, + device=device, + ) pool_h, pool_w = pool_size, pool_size y = self.fn(x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs) # the following should be true whether we're running an autocast test or not. assert y.dtype == x.dtype - gt_y = self.expected_fn(x, rois, pool_h, pool_w, spatial_scale=1, - sampling_ratio=-1, device=device, dtype=self.dtype, **kwargs) + gt_y = self.expected_fn( + x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=device, dtype=self.dtype, **kwargs + ) tol = 1e-3 if (x_dtype is torch.half or rois_dtype is torch.half) else 1e-5 torch.testing.assert_close(gt_y.to(y), y, rtol=tol, atol=tol) - @pytest.mark.parametrize('device', cpu_and_gpu()) - @pytest.mark.parametrize('contiguous', (True, False)) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("contiguous", (True, False)) def test_backward(self, device, contiguous): pool_size = 2 x = torch.rand(1, 2 * (pool_size ** 2), 5, 5, dtype=self.dtype, device=device, requires_grad=True) if not contiguous: x = x.permute(0, 1, 3, 2) - rois = torch.tensor([[0, 0, 0, 4, 4], # format is (xyxy) - [0, 0, 2, 3, 4], - [0, 2, 2, 4, 4]], - dtype=self.dtype, device=device) + rois = torch.tensor( + [[0, 0, 0, 4, 4], [0, 0, 2, 3, 4], [0, 2, 2, 4, 4]], dtype=self.dtype, device=device # format is (xyxy) + ) def func(z): return self.fn(z, rois, pool_size, pool_size, spatial_scale=1, sampling_ratio=1) @@ -67,8 +66,8 @@ def func(z): gradcheck(script_func, (x,)) @needs_cuda - @pytest.mark.parametrize('x_dtype', (torch.float, torch.half)) - @pytest.mark.parametrize('rois_dtype', (torch.float, torch.half)) + @pytest.mark.parametrize("x_dtype", (torch.float, torch.half)) + @pytest.mark.parametrize("rois_dtype", (torch.float, torch.half)) def test_autocast(self, x_dtype, rois_dtype): with torch.cuda.amp.autocast(): self.test_forward(torch.device("cuda"), contiguous=False, x_dtype=x_dtype, rois_dtype=rois_dtype) @@ -107,8 +106,9 @@ def get_script_fn(self, rois, pool_size): scriped = torch.jit.script(ops.roi_pool) return lambda x: scriped(x, rois, pool_size) - def expected_fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, - device=None, dtype=torch.float64): + def expected_fn( + self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=None, dtype=torch.float64 + ): if device is None: device = torch.device("cpu") @@ -121,7 +121,7 @@ def get_slice(k, block): for roi_idx, roi in enumerate(rois): batch_idx = int(roi[0]) j_begin, i_begin, j_end, i_end = (int(round(x.item() * spatial_scale)) for x in roi[1:]) - roi_x = x[batch_idx, :, i_begin:i_end + 1, j_begin:j_end + 1] + roi_x = x[batch_idx, :, i_begin : i_end + 1, j_begin : j_end + 1] roi_h, roi_w = roi_x.shape[-2:] bin_h = roi_h / pool_h @@ -146,8 +146,9 @@ def get_script_fn(self, rois, pool_size): scriped = torch.jit.script(ops.ps_roi_pool) return lambda x: scriped(x, rois, pool_size) - def expected_fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, - device=None, dtype=torch.float64): + def expected_fn( + self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, device=None, dtype=torch.float64 + ): if device is None: device = torch.device("cpu") n_input_channels = x.size(1) @@ -161,7 +162,7 @@ def get_slice(k, block): for roi_idx, roi in enumerate(rois): batch_idx = int(roi[0]) j_begin, i_begin, j_end, i_end = (int(round(x.item() * spatial_scale)) for x in roi[1:]) - roi_x = x[batch_idx, :, i_begin:i_end + 1, j_begin:j_end + 1] + roi_x = x[batch_idx, :, i_begin : i_end + 1, j_begin : j_end + 1] roi_height = max(i_end - i_begin, 1) roi_width = max(j_end - j_begin, 1) @@ -216,21 +217,32 @@ def bilinear_interpolate(data, y, x, snap_border=False): class TestRoIAlign(RoIOpTester): def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, aligned=False, **kwargs): - return ops.RoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, - sampling_ratio=sampling_ratio, aligned=aligned)(x, rois) + return ops.RoIAlign( + (pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio, aligned=aligned + )(x, rois) def get_script_fn(self, rois, pool_size): scriped = torch.jit.script(ops.roi_align) return lambda x: scriped(x, rois, pool_size) - def expected_fn(self, in_data, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, aligned=False, - device=None, dtype=torch.float64): + def expected_fn( + self, + in_data, + rois, + pool_h, + pool_w, + spatial_scale=1, + sampling_ratio=-1, + aligned=False, + device=None, + dtype=torch.float64, + ): if device is None: device = torch.device("cpu") n_channels = in_data.size(1) out_data = torch.zeros(rois.size(0), n_channels, pool_h, pool_w, dtype=dtype, device=device) - offset = 0.5 if aligned else 0. + offset = 0.5 if aligned else 0.0 for r, roi in enumerate(rois): batch_idx = int(roi[0]) @@ -264,21 +276,23 @@ def expected_fn(self, in_data, rois, pool_h, pool_w, spatial_scale=1, sampling_r def test_boxes_shape(self): self._helper_boxes_shape(ops.roi_align) - @pytest.mark.parametrize('aligned', (True, False)) - @pytest.mark.parametrize('device', cpu_and_gpu()) - @pytest.mark.parametrize('contiguous', (True, False)) + @pytest.mark.parametrize("aligned", (True, False)) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("contiguous", (True, False)) def test_forward(self, device, contiguous, aligned, x_dtype=None, rois_dtype=None): - super().test_forward(device=device, contiguous=contiguous, x_dtype=x_dtype, rois_dtype=rois_dtype, - aligned=aligned) + super().test_forward( + device=device, contiguous=contiguous, x_dtype=x_dtype, rois_dtype=rois_dtype, aligned=aligned + ) @needs_cuda - @pytest.mark.parametrize('aligned', (True, False)) - @pytest.mark.parametrize('x_dtype', (torch.float, torch.half)) - @pytest.mark.parametrize('rois_dtype', (torch.float, torch.half)) + @pytest.mark.parametrize("aligned", (True, False)) + @pytest.mark.parametrize("x_dtype", (torch.float, torch.half)) + @pytest.mark.parametrize("rois_dtype", (torch.float, torch.half)) def test_autocast(self, aligned, x_dtype, rois_dtype): with torch.cuda.amp.autocast(): - self.test_forward(torch.device("cuda"), contiguous=False, aligned=aligned, x_dtype=x_dtype, - rois_dtype=rois_dtype) + self.test_forward( + torch.device("cuda"), contiguous=False, aligned=aligned, x_dtype=x_dtype, rois_dtype=rois_dtype + ) def _make_rois(self, img_size, num_imgs, dtype, num_rois=1000): rois = torch.randint(0, img_size // 2, size=(num_rois, 5)).to(dtype) @@ -286,9 +300,9 @@ def _make_rois(self, img_size, num_imgs, dtype, num_rois=1000): rois[:, 3:] += rois[:, 1:3] # make sure boxes aren't degenerate return rois - @pytest.mark.parametrize('aligned', (True, False)) - @pytest.mark.parametrize('scale, zero_point', ((1, 0), (2, 10), (0.1, 50))) - @pytest.mark.parametrize('qdtype', (torch.qint8, torch.quint8, torch.qint32)) + @pytest.mark.parametrize("aligned", (True, False)) + @pytest.mark.parametrize("scale, zero_point", ((1, 0), (2, 10), (0.1, 50))) + @pytest.mark.parametrize("qdtype", (torch.qint8, torch.quint8, torch.qint32)) def test_qroialign(self, aligned, scale, zero_point, qdtype): """Make sure quantized version of RoIAlign is close to float version""" pool_size = 5 @@ -338,7 +352,7 @@ def test_qroialign(self, aligned, scale, zero_point, qdtype): # - any difference between qy and quantized_float_y is == scale diff_idx = torch.where(qy != quantized_float_y) num_diff = diff_idx[0].numel() - assert num_diff / qy.numel() < .05 + assert num_diff / qy.numel() < 0.05 abs_diff = torch.abs(qy[diff_idx].dequantize() - quantized_float_y[diff_idx].dequantize()) t_scale = torch.full_like(abs_diff, fill_value=scale) @@ -356,15 +370,15 @@ def test_qroi_align_multiple_images(self): class TestPSRoIAlign(RoIOpTester): def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs): - return ops.PSRoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, - sampling_ratio=sampling_ratio)(x, rois) + return ops.PSRoIAlign((pool_h, pool_w), spatial_scale=spatial_scale, sampling_ratio=sampling_ratio)(x, rois) def get_script_fn(self, rois, pool_size): scriped = torch.jit.script(ops.ps_roi_align) return lambda x: scriped(x, rois, pool_size) - def expected_fn(self, in_data, rois, pool_h, pool_w, device, spatial_scale=1, - sampling_ratio=-1, dtype=torch.float64): + def expected_fn( + self, in_data, rois, pool_h, pool_w, device, spatial_scale=1, sampling_ratio=-1, dtype=torch.float64 + ): if device is None: device = torch.device("cpu") n_input_channels = in_data.size(1) @@ -407,15 +421,17 @@ def test_boxes_shape(self): class TestMultiScaleRoIAlign: def test_msroialign_repr(self): - fmap_names = ['0'] + fmap_names = ["0"] output_size = (7, 7) sampling_ratio = 2 # Pass mock feature map names t = ops.poolers.MultiScaleRoIAlign(fmap_names, output_size, sampling_ratio) # Check integrity of object __repr__ attribute - expected_string = (f"MultiScaleRoIAlign(featmap_names={fmap_names}, output_size={output_size}, " - f"sampling_ratio={sampling_ratio})") + expected_string = ( + f"MultiScaleRoIAlign(featmap_names={fmap_names}, output_size={output_size}, " + f"sampling_ratio={sampling_ratio})" + ) assert repr(t) == expected_string @@ -460,9 +476,9 @@ def _create_tensors_with_iou(self, N, iou_thresh): scores = torch.rand(N) return boxes, scores - @pytest.mark.parametrize("iou", (.2, .5, .8)) + @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) def test_nms_ref(self, iou): - err_msg = 'NMS incompatible between CPU and reference implementation for IoU={}' + err_msg = "NMS incompatible between CPU and reference implementation for IoU={}" boxes, scores = self._create_tensors_with_iou(1000, iou) keep_ref = self._reference_nms(boxes, scores, iou) keep = ops.nms(boxes, scores, iou) @@ -478,13 +494,13 @@ def test_nms_input_errors(self): with pytest.raises(RuntimeError): ops.nms(torch.rand(3, 4), torch.rand(4), 0.5) - @pytest.mark.parametrize("iou", (.2, .5, .8)) + @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) @pytest.mark.parametrize("scale, zero_point", ((1, 0), (2, 50), (3, 10))) def test_qnms(self, iou, scale, zero_point): # Note: we compare qnms vs nms instead of qnms vs reference implementation. # This is because with the int convertion, the trick used in _create_tensors_with_iou # doesn't really work (in fact, nms vs reference implem will also fail with ints) - err_msg = 'NMS and QNMS give different results for IoU={}' + err_msg = "NMS and QNMS give different results for IoU={}" boxes, scores = self._create_tensors_with_iou(1000, iou) scores *= 100 # otherwise most scores would be 0 or 1 after int convertion @@ -500,10 +516,10 @@ def test_qnms(self, iou, scale, zero_point): assert torch.allclose(qkeep, keep), err_msg.format(iou) @needs_cuda - @pytest.mark.parametrize("iou", (.2, .5, .8)) + @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) def test_nms_cuda(self, iou, dtype=torch.float64): tol = 1e-3 if dtype is torch.half else 1e-5 - err_msg = 'NMS incompatible between CPU and CUDA for IoU={}' + err_msg = "NMS incompatible between CPU and CUDA for IoU={}" boxes, scores = self._create_tensors_with_iou(1000, iou) r_cpu = ops.nms(boxes, scores, iou) @@ -517,7 +533,7 @@ def test_nms_cuda(self, iou, dtype=torch.float64): assert is_eq, err_msg.format(iou) @needs_cuda - @pytest.mark.parametrize("iou", (.2, .5, .8)) + @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) @pytest.mark.parametrize("dtype", (torch.float, torch.half)) def test_autocast(self, iou, dtype): with torch.cuda.amp.autocast(): @@ -525,9 +541,13 @@ def test_autocast(self, iou, dtype): @needs_cuda def test_nms_cuda_float16(self): - boxes = torch.tensor([[285.3538, 185.5758, 1193.5110, 851.4551], - [285.1472, 188.7374, 1192.4984, 851.0669], - [279.2440, 197.9812, 1189.4746, 849.2019]]).cuda() + boxes = torch.tensor( + [ + [285.3538, 185.5758, 1193.5110, 851.4551], + [285.1472, 188.7374, 1192.4984, 851.0669], + [279.2440, 197.9812, 1189.4746, 849.2019], + ] + ).cuda() scores = torch.tensor([0.6370, 0.7569, 0.3966]).cuda() iou_thres = 0.2 @@ -539,7 +559,7 @@ def test_batched_nms_implementations(self): """Make sure that both implementations of batched_nms yield identical results""" num_boxes = 1000 - iou_threshold = .9 + iou_threshold = 0.9 boxes = torch.cat((torch.rand(num_boxes, 2), torch.rand(num_boxes, 2) + 10), dim=1) assert max(boxes[:, 0]) < min(boxes[:, 2]) # x1 < x2 @@ -603,8 +623,11 @@ def expected_fn(self, x, weight, offset, mask, bias, stride=1, padding=0, dilati if mask is not None: mask_value = mask[b, mask_idx, i, j] - out[b, c_out, i, j] += (mask_value * weight[c_out, c, di, dj] * - bilinear_interpolate(x[b, c_in, :, :], pi, pj)) + out[b, c_out, i, j] += ( + mask_value + * weight[c_out, c, di, dj] + * bilinear_interpolate(x[b, c_in, :, :], pi, pj) + ) out += bias.view(1, n_out_channels, 1, 1) return out @@ -630,14 +653,29 @@ def get_fn_args(self, device, contiguous, batch_sz, dtype): x = torch.rand(batch_sz, n_in_channels, in_h, in_w, device=device, dtype=dtype, requires_grad=True) - offset = torch.randn(batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w, - device=device, dtype=dtype, requires_grad=True) + offset = torch.randn( + batch_sz, + n_offset_grps * 2 * weight_h * weight_w, + out_h, + out_w, + device=device, + dtype=dtype, + requires_grad=True, + ) - mask = torch.randn(batch_sz, n_offset_grps * weight_h * weight_w, out_h, out_w, - device=device, dtype=dtype, requires_grad=True) + mask = torch.randn( + batch_sz, n_offset_grps * weight_h * weight_w, out_h, out_w, device=device, dtype=dtype, requires_grad=True + ) - weight = torch.randn(n_out_channels, n_in_channels // n_weight_grps, weight_h, weight_w, - device=device, dtype=dtype, requires_grad=True) + weight = torch.randn( + n_out_channels, + n_in_channels // n_weight_grps, + weight_h, + weight_w, + device=device, + dtype=dtype, + requires_grad=True, + ) bias = torch.randn(n_out_channels, device=device, dtype=dtype, requires_grad=True) @@ -649,9 +687,9 @@ def get_fn_args(self, device, contiguous, batch_sz, dtype): return x, weight, offset, mask, bias, stride, pad, dilation - @pytest.mark.parametrize('device', cpu_and_gpu()) - @pytest.mark.parametrize('contiguous', (True, False)) - @pytest.mark.parametrize('batch_sz', (0, 33)) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("contiguous", (True, False)) + @pytest.mark.parametrize("batch_sz", (0, 33)) def test_forward(self, device, contiguous, batch_sz, dtype=None): dtype = dtype or self.dtype x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, dtype) @@ -661,8 +699,9 @@ def test_forward(self, device, contiguous, batch_sz, dtype=None): groups = 2 tol = 2e-3 if dtype is torch.half else 1e-5 - layer = ops.DeformConv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, - dilation=dilation, groups=groups).to(device=x.device, dtype=dtype) + layer = ops.DeformConv2d( + in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups + ).to(device=x.device, dtype=dtype) res = layer(x, offset, mask) weight = layer.weight.data @@ -670,7 +709,7 @@ def test_forward(self, device, contiguous, batch_sz, dtype=None): expected = self.expected_fn(x, weight, offset, mask, bias, stride=stride, padding=padding, dilation=dilation) torch.testing.assert_close( - res.to(expected), expected, rtol=tol, atol=tol, msg='\nres:\n{}\nexpected:\n{}'.format(res, expected) + res.to(expected), expected, rtol=tol, atol=tol, msg="\nres:\n{}\nexpected:\n{}".format(res, expected) ) # no modulation test @@ -678,7 +717,7 @@ def test_forward(self, device, contiguous, batch_sz, dtype=None): expected = self.expected_fn(x, weight, offset, None, bias, stride=stride, padding=padding, dilation=dilation) torch.testing.assert_close( - res.to(expected), expected, rtol=tol, atol=tol, msg='\nres:\n{}\nexpected:\n{}'.format(res, expected) + res.to(expected), expected, rtol=tol, atol=tol, msg="\nres:\n{}\nexpected:\n{}".format(res, expected) ) def test_wrong_sizes(self): @@ -686,57 +725,72 @@ def test_wrong_sizes(self): out_channels = 2 kernel_size = (3, 2) groups = 2 - x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args('cpu', contiguous=True, - batch_sz=10, dtype=self.dtype) - layer = ops.DeformConv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, - dilation=dilation, groups=groups) + x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args( + "cpu", contiguous=True, batch_sz=10, dtype=self.dtype + ) + layer = ops.DeformConv2d( + in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups + ) with pytest.raises(RuntimeError, match="the shape of the offset"): wrong_offset = torch.rand_like(offset[:, :2]) layer(x, wrong_offset) - with pytest.raises(RuntimeError, match=r'mask.shape\[1\] is not valid'): + with pytest.raises(RuntimeError, match=r"mask.shape\[1\] is not valid"): wrong_mask = torch.rand_like(mask[:, :2]) layer(x, offset, wrong_mask) - @pytest.mark.parametrize('device', cpu_and_gpu()) - @pytest.mark.parametrize('contiguous', (True, False)) - @pytest.mark.parametrize('batch_sz', (0, 33)) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("contiguous", (True, False)) + @pytest.mark.parametrize("batch_sz", (0, 33)) def test_backward(self, device, contiguous, batch_sz): - x, weight, offset, mask, bias, stride, padding, dilation = self.get_fn_args(device, contiguous, - batch_sz, self.dtype) + x, weight, offset, mask, bias, stride, padding, dilation = self.get_fn_args( + device, contiguous, batch_sz, self.dtype + ) def func(x_, offset_, mask_, weight_, bias_): - return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride, - padding=padding, dilation=dilation, mask=mask_) + return ops.deform_conv2d( + x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation, mask=mask_ + ) gradcheck(func, (x, offset, mask, weight, bias), nondet_tol=1e-5, fast_mode=True) def func_no_mask(x_, offset_, weight_, bias_): - return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride, - padding=padding, dilation=dilation, mask=None) + return ops.deform_conv2d( + x_, offset_, weight_, bias_, stride=stride, padding=padding, dilation=dilation, mask=None + ) gradcheck(func_no_mask, (x, offset, weight, bias), nondet_tol=1e-5, fast_mode=True) @torch.jit.script def script_func(x_, offset_, mask_, weight_, bias_, stride_, pad_, dilation_): # type:(Tensor, Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int])->Tensor - return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride_, - padding=pad_, dilation=dilation_, mask=mask_) - - gradcheck(lambda z, off, msk, wei, bi: script_func(z, off, msk, wei, bi, stride, padding, dilation), - (x, offset, mask, weight, bias), nondet_tol=1e-5, fast_mode=True) + return ops.deform_conv2d( + x_, offset_, weight_, bias_, stride=stride_, padding=pad_, dilation=dilation_, mask=mask_ + ) + + gradcheck( + lambda z, off, msk, wei, bi: script_func(z, off, msk, wei, bi, stride, padding, dilation), + (x, offset, mask, weight, bias), + nondet_tol=1e-5, + fast_mode=True, + ) @torch.jit.script def script_func_no_mask(x_, offset_, weight_, bias_, stride_, pad_, dilation_): # type:(Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int])->Tensor - return ops.deform_conv2d(x_, offset_, weight_, bias_, stride=stride_, - padding=pad_, dilation=dilation_, mask=None) - - gradcheck(lambda z, off, wei, bi: script_func_no_mask(z, off, wei, bi, stride, padding, dilation), - (x, offset, weight, bias), nondet_tol=1e-5, fast_mode=True) + return ops.deform_conv2d( + x_, offset_, weight_, bias_, stride=stride_, padding=pad_, dilation=dilation_, mask=None + ) + + gradcheck( + lambda z, off, wei, bi: script_func_no_mask(z, off, wei, bi, stride, padding, dilation), + (x, offset, weight, bias), + nondet_tol=1e-5, + fast_mode=True, + ) @needs_cuda - @pytest.mark.parametrize('contiguous', (True, False)) + @pytest.mark.parametrize("contiguous", (True, False)) def test_compare_cpu_cuda_grads(self, contiguous): # Test from https://github.com/pytorch/vision/issues/2598 # Run on CUDA only @@ -770,8 +824,8 @@ def test_compare_cpu_cuda_grads(self, contiguous): torch.testing.assert_close(true_cpu_grads, res_grads) @needs_cuda - @pytest.mark.parametrize('batch_sz', (0, 33)) - @pytest.mark.parametrize('dtype', (torch.float, torch.half)) + @pytest.mark.parametrize("batch_sz", (0, 33)) + @pytest.mark.parametrize("dtype", (torch.float, torch.half)) def test_autocast(self, batch_sz, dtype): with torch.cuda.amp.autocast(): self.test_forward(torch.device("cuda"), contiguous=False, batch_sz=batch_sz, dtype=dtype) @@ -794,11 +848,13 @@ def test_frozenbatchnorm2d_repr(self): def test_frozenbatchnorm2d_eps(self): sample_size = (4, 32, 28, 28) x = torch.rand(sample_size) - state_dict = dict(weight=torch.rand(sample_size[1]), - bias=torch.rand(sample_size[1]), - running_mean=torch.rand(sample_size[1]), - running_var=torch.rand(sample_size[1]), - num_batches_tracked=torch.tensor(100)) + state_dict = dict( + weight=torch.rand(sample_size[1]), + bias=torch.rand(sample_size[1]), + running_mean=torch.rand(sample_size[1]), + running_var=torch.rand(sample_size[1]), + num_batches_tracked=torch.tensor(100), + ) # Check that default eps is equal to the one of BN fbn = ops.misc.FrozenBatchNorm2d(sample_size[1]) @@ -826,17 +882,19 @@ class TestBoxConversion: def _get_box_sequences(): # Define here the argument type of `boxes` supported by region pooling operations box_tensor = torch.tensor([[0, 0, 0, 100, 100], [1, 0, 0, 100, 100]], dtype=torch.float) - box_list = [torch.tensor([[0, 0, 100, 100]], dtype=torch.float), - torch.tensor([[0, 0, 100, 100]], dtype=torch.float)] + box_list = [ + torch.tensor([[0, 0, 100, 100]], dtype=torch.float), + torch.tensor([[0, 0, 100, 100]], dtype=torch.float), + ] box_tuple = tuple(box_list) return box_tensor, box_list, box_tuple - @pytest.mark.parametrize('box_sequence', _get_box_sequences()) + @pytest.mark.parametrize("box_sequence", _get_box_sequences()) def test_check_roi_boxes_shape(self, box_sequence): # Ensure common sequences of tensors are supported ops._utils.check_roi_boxes_shape(box_sequence) - @pytest.mark.parametrize('box_sequence', _get_box_sequences()) + @pytest.mark.parametrize("box_sequence", _get_box_sequences()) def test_convert_boxes_to_roi_format(self, box_sequence): # Ensure common sequences of tensors yield the same result ref_tensor = None @@ -848,11 +906,11 @@ def test_convert_boxes_to_roi_format(self, box_sequence): class TestBox: def test_bbox_same(self): - box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], - [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) + box_tensor = torch.tensor( + [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float + ) - exp_xyxy = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], - [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) + exp_xyxy = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) assert exp_xyxy.size() == torch.Size([4, 4]) assert_equal(ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xyxy"), exp_xyxy) @@ -862,10 +920,10 @@ def test_bbox_same(self): def test_bbox_xyxy_xywh(self): # Simple test convert boxes to xywh and back. Make sure they are same. # box_tensor is in x1 y1 x2 y2 format. - box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], - [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) - exp_xywh = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], - [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float) + box_tensor = torch.tensor( + [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float + ) + exp_xywh = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float) assert exp_xywh.size() == torch.Size([4, 4]) box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh") @@ -878,10 +936,12 @@ def test_bbox_xyxy_xywh(self): def test_bbox_xyxy_cxcywh(self): # Simple test convert boxes to xywh and back. Make sure they are same. # box_tensor is in x1 y1 x2 y2 format. - box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], - [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) - exp_cxcywh = torch.tensor([[50, 50, 100, 100], [0, 0, 0, 0], - [20, 25, 20, 20], [58, 65, 70, 60]], dtype=torch.float) + box_tensor = torch.tensor( + [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float + ) + exp_cxcywh = torch.tensor( + [[50, 50, 100, 100], [0, 0, 0, 0], [20, 25, 20, 20], [58, 65, 70, 60]], dtype=torch.float + ) assert exp_cxcywh.size() == torch.Size([4, 4]) box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh") @@ -892,12 +952,14 @@ def test_bbox_xyxy_cxcywh(self): assert_equal(box_xyxy, box_tensor) def test_bbox_xywh_cxcywh(self): - box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], - [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float) + box_tensor = torch.tensor( + [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float + ) # This is wrong - exp_cxcywh = torch.tensor([[50, 50, 100, 100], [0, 0, 0, 0], - [20, 25, 20, 20], [58, 65, 70, 60]], dtype=torch.float) + exp_cxcywh = torch.tensor( + [[50, 50, 100, 100], [0, 0, 0, 0], [20, 25, 20, 20], [58, 65, 70, 60]], dtype=torch.float + ) assert exp_cxcywh.size() == torch.Size([4, 4]) box_cxcywh = ops.box_convert(box_tensor, in_fmt="xywh", out_fmt="cxcywh") @@ -907,28 +969,30 @@ def test_bbox_xywh_cxcywh(self): box_xywh = ops.box_convert(box_cxcywh, in_fmt="cxcywh", out_fmt="xywh") assert_equal(box_xywh, box_tensor) - @pytest.mark.parametrize('inv_infmt', ["xwyh", "cxwyh"]) - @pytest.mark.parametrize('inv_outfmt', ["xwcx", "xhwcy"]) + @pytest.mark.parametrize("inv_infmt", ["xwyh", "cxwyh"]) + @pytest.mark.parametrize("inv_outfmt", ["xwcx", "xhwcy"]) def test_bbox_invalid(self, inv_infmt, inv_outfmt): - box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], - [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float) + box_tensor = torch.tensor( + [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float + ) with pytest.raises(ValueError): ops.box_convert(box_tensor, inv_infmt, inv_outfmt) def test_bbox_convert_jit(self): - box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0], - [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) + box_tensor = torch.tensor( + [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float + ) scripted_fn = torch.jit.script(ops.box_convert) TOLERANCE = 1e-3 box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh") - scripted_xywh = scripted_fn(box_tensor, 'xyxy', 'xywh') + scripted_xywh = scripted_fn(box_tensor, "xyxy", "xywh") torch.testing.assert_close(scripted_xywh, box_xywh, rtol=0.0, atol=TOLERANCE) box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh") - scripted_cxcywh = scripted_fn(box_tensor, 'xyxy', 'cxcywh') + scripted_cxcywh = scripted_fn(box_tensor, "xyxy", "cxcywh") torch.testing.assert_close(scripted_cxcywh, box_cxcywh, rtol=0.0, atol=TOLERANCE) @@ -946,16 +1010,22 @@ def area_check(box, expected, tolerance=1e-4): # Check for float32 and float64 boxes for dtype in [torch.float32, torch.float64]: - box_tensor = torch.tensor([[285.3538, 185.5758, 1193.5110, 851.4551], - [285.1472, 188.7374, 1192.4984, 851.0669], - [279.2440, 197.9812, 1189.4746, 849.2019]], dtype=dtype) + box_tensor = torch.tensor( + [ + [285.3538, 185.5758, 1193.5110, 851.4551], + [285.1472, 188.7374, 1192.4984, 851.0669], + [279.2440, 197.9812, 1189.4746, 849.2019], + ], + dtype=dtype, + ) expected = torch.tensor([604723.0806, 600965.4666, 592761.0085], dtype=torch.float64) area_check(box_tensor, expected, tolerance=0.05) # Check for float16 box - box_tensor = torch.tensor([[285.25, 185.625, 1194.0, 851.5], - [285.25, 188.75, 1192.0, 851.0], - [279.25, 198.0, 1189.0, 849.0]], dtype=torch.float16) + box_tensor = torch.tensor( + [[285.25, 185.625, 1194.0, 851.5], [285.25, 188.75, 1192.0, 851.0], [279.25, 198.0, 1189.0, 849.0]], + dtype=torch.float16, + ) expected = torch.tensor([605113.875, 600495.1875, 592247.25]) area_check(box_tensor, expected) @@ -982,9 +1052,14 @@ def iou_check(box, expected, tolerance=1e-4): # Check for float boxes for dtype in [torch.float16, torch.float32, torch.float64]: - box_tensor = torch.tensor([[285.3538, 185.5758, 1193.5110, 851.4551], - [285.1472, 188.7374, 1192.4984, 851.0669], - [279.2440, 197.9812, 1189.4746, 849.2019]], dtype=dtype) + box_tensor = torch.tensor( + [ + [285.3538, 185.5758, 1193.5110, 851.4551], + [285.1472, 188.7374, 1192.4984, 851.0669], + [279.2440, 197.9812, 1189.4746, 849.2019], + ], + dtype=dtype, + ) expected = torch.tensor([[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]) iou_check(box_tensor, expected, tolerance=0.002 if dtype == torch.float16 else 1e-4) @@ -1011,9 +1086,14 @@ def gen_iou_check(box, expected, tolerance=1e-4): # Check for float boxes for dtype in [torch.float16, torch.float32, torch.float64]: - box_tensor = torch.tensor([[285.3538, 185.5758, 1193.5110, 851.4551], - [285.1472, 188.7374, 1192.4984, 851.0669], - [279.2440, 197.9812, 1189.4746, 849.2019]], dtype=dtype) + box_tensor = torch.tensor( + [ + [285.3538, 185.5758, 1193.5110, 851.4551], + [285.1472, 188.7374, 1192.4984, 851.0669], + [279.2440, 197.9812, 1189.4746, 849.2019], + ], + dtype=dtype, + ) expected = torch.tensor([[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]) gen_iou_check(box_tensor, expected, tolerance=0.002 if dtype == torch.float16 else 1e-3) @@ -1048,8 +1128,18 @@ def _create_masks(image, masks): return masks - expected = torch.tensor([[127, 2, 165, 40], [2, 50, 44, 92], [56, 63, 98, 100], [139, 68, 175, 104], - [160, 112, 198, 145], [49, 138, 99, 182], [108, 148, 152, 213]], dtype=torch.float) + expected = torch.tensor( + [ + [127, 2, 165, 40], + [2, 50, 44, 92], + [56, 63, 98, 100], + [139, 68, 175, 104], + [160, 112, 198, 145], + [49, 138, 99, 182], + [108, 148, 152, 213], + ], + dtype=torch.float, + ) image = _get_image() for dtype in [torch.float16, torch.float32, torch.float64]: @@ -1059,8 +1149,8 @@ def _create_masks(image, masks): class TestStochasticDepth: - @pytest.mark.parametrize('p', [0.2, 0.5, 0.8]) - @pytest.mark.parametrize('mode', ["batch", "row"]) + @pytest.mark.parametrize("p", [0.2, 0.5, 0.8]) + @pytest.mark.parametrize("mode", ["batch", "row"]) def test_stochastic_depth(self, mode, p): stats = pytest.importorskip("scipy.stats") batch_size = 5 @@ -1086,5 +1176,5 @@ def test_stochastic_depth(self, mode, p): assert p_value > 0.0001 -if __name__ == '__main__': +if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/test_transforms.py b/test/test_transforms.py index 72821446f3a..3712e592cc4 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1,14 +1,16 @@ +import math import os +import random + +import numpy as np +import pytest import torch import torchvision.transforms as transforms import torchvision.transforms.functional as F import torchvision.transforms.functional_tensor as F_t -from torch._utils_internal import get_file_path_2 -import math -import random -import numpy as np -import pytest from PIL import Image +from torch._utils_internal import get_file_path_2 + try: import accimage except ImportError: @@ -23,17 +25,18 @@ GRACE_HOPPER = get_file_path_2( - os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg', 'grace_hopper_517x606.jpg') + os.path.dirname(os.path.abspath(__file__)), "assets", "encode_jpeg", "grace_hopper_517x606.jpg" +) def _get_grayscale_test_image(img, fill=None): - img = img.convert('L') - fill = (fill[0], ) if isinstance(fill, tuple) else fill + img = img.convert("L") + fill = (fill[0],) if isinstance(fill, tuple) else fill return img, fill class TestConvertImageDtype: - @pytest.mark.parametrize('input_dtype, output_dtype', cycle_over(float_dtypes())) + @pytest.mark.parametrize("input_dtype, output_dtype", cycle_over(float_dtypes())) def test_float_to_float(self, input_dtype, output_dtype): input_image = torch.tensor((0.0, 1.0), dtype=input_dtype) transform = transforms.ConvertImageDtype(output_dtype) @@ -50,15 +53,15 @@ def test_float_to_float(self, input_dtype, output_dtype): assert abs(actual_min - desired_min) < 1e-7 assert abs(actual_max - desired_max) < 1e-7 - @pytest.mark.parametrize('input_dtype', float_dtypes()) - @pytest.mark.parametrize('output_dtype', int_dtypes()) + @pytest.mark.parametrize("input_dtype", float_dtypes()) + @pytest.mark.parametrize("output_dtype", int_dtypes()) def test_float_to_int(self, input_dtype, output_dtype): input_image = torch.tensor((0.0, 1.0), dtype=input_dtype) transform = transforms.ConvertImageDtype(output_dtype) transform_script = torch.jit.script(F.convert_image_dtype) if (input_dtype == torch.float32 and output_dtype in (torch.int32, torch.int64)) or ( - input_dtype == torch.float64 and output_dtype == torch.int64 + input_dtype == torch.float64 and output_dtype == torch.int64 ): with pytest.raises(RuntimeError): transform(input_image) @@ -74,8 +77,8 @@ def test_float_to_int(self, input_dtype, output_dtype): assert actual_min == desired_min assert actual_max == desired_max - @pytest.mark.parametrize('input_dtype', int_dtypes()) - @pytest.mark.parametrize('output_dtype', float_dtypes()) + @pytest.mark.parametrize("input_dtype", int_dtypes()) + @pytest.mark.parametrize("output_dtype", float_dtypes()) def test_int_to_float(self, input_dtype, output_dtype): input_image = torch.tensor((0, torch.iinfo(input_dtype).max), dtype=input_dtype) transform = transforms.ConvertImageDtype(output_dtype) @@ -94,7 +97,7 @@ def test_int_to_float(self, input_dtype, output_dtype): assert abs(actual_max - desired_max) < 1e-7 assert actual_max <= desired_max - @pytest.mark.parametrize('input_dtype, output_dtype', cycle_over(int_dtypes())) + @pytest.mark.parametrize("input_dtype, output_dtype", cycle_over(int_dtypes())) def test_dtype_int_to_int(self, input_dtype, output_dtype): input_max = torch.iinfo(input_dtype).max input_image = torch.tensor((0, input_max), dtype=input_dtype) @@ -126,7 +129,7 @@ def test_dtype_int_to_int(self, input_dtype, output_dtype): assert actual_min == desired_min assert actual_max == (desired_max + error_term) - @pytest.mark.parametrize('input_dtype, output_dtype', cycle_over(int_dtypes())) + @pytest.mark.parametrize("input_dtype, output_dtype", cycle_over(int_dtypes())) def test_int_to_int_consistency(self, input_dtype, output_dtype): input_max = torch.iinfo(input_dtype).max input_image = torch.tensor((0, input_max), dtype=input_dtype) @@ -148,11 +151,10 @@ def test_int_to_int_consistency(self, input_dtype, output_dtype): @pytest.mark.skipif(accimage is None, reason="accimage not available") class TestAccImage: - def test_accimage_to_tensor(self): trans = transforms.ToTensor() - expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB')) + expected_output = trans(Image.open(GRACE_HOPPER).convert("RGB")) output = trans(accimage.Image(GRACE_HOPPER)) torch.testing.assert_close(output, expected_output) @@ -160,22 +162,24 @@ def test_accimage_to_tensor(self): def test_accimage_pil_to_tensor(self): trans = transforms.PILToTensor() - expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB')) + expected_output = trans(Image.open(GRACE_HOPPER).convert("RGB")) output = trans(accimage.Image(GRACE_HOPPER)) assert expected_output.size() == output.size() torch.testing.assert_close(output, expected_output) def test_accimage_resize(self): - trans = transforms.Compose([ - transforms.Resize(256, interpolation=Image.LINEAR), - transforms.ToTensor(), - ]) + trans = transforms.Compose( + [ + transforms.Resize(256, interpolation=Image.LINEAR), + transforms.ToTensor(), + ] + ) # Checking if Compose, Resize and ToTensor can be printed as string trans.__repr__() - expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB')) + expected_output = trans(Image.open(GRACE_HOPPER).convert("RGB")) output = trans(accimage.Image(GRACE_HOPPER)) assert expected_output.size() == output.size() @@ -185,15 +189,17 @@ def test_accimage_resize(self): torch.testing.assert_close(output.numpy(), expected_output.numpy(), rtol=1e-5, atol=5e-2) def test_accimage_crop(self): - trans = transforms.Compose([ - transforms.CenterCrop(256), - transforms.ToTensor(), - ]) + trans = transforms.Compose( + [ + transforms.CenterCrop(256), + transforms.ToTensor(), + ] + ) # Checking if Compose, CenterCrop and ToTensor can be printed as string trans.__repr__() - expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB')) + expected_output = trans(Image.open(GRACE_HOPPER).convert("RGB")) output = trans(accimage.Image(GRACE_HOPPER)) assert expected_output.size() == output.size() @@ -201,8 +207,7 @@ def test_accimage_crop(self): class TestToTensor: - - @pytest.mark.parametrize('channels', [1, 3, 4]) + @pytest.mark.parametrize("channels", [1, 3, 4]) def test_to_tensor(self, channels): height, width = 4, 4 trans = transforms.ToTensor() @@ -225,7 +230,7 @@ def test_to_tensor(self, channels): # separate test for mode '1' PIL images input_data = torch.ByteTensor(1, height, width).bernoulli_() - img = transforms.ToPILImage()(input_data.mul(255)).convert('1') + img = transforms.ToPILImage()(input_data.mul(255)).convert("1") output = trans(img) torch.testing.assert_close(input_data, output, check_dtype=False) @@ -243,7 +248,7 @@ def test_to_tensor_errors(self): with pytest.raises(ValueError): trans(np_rng.rand(1, 1, height, width)) - @pytest.mark.parametrize('dtype', [torch.float16, torch.float, torch.double]) + @pytest.mark.parametrize("dtype", [torch.float16, torch.float, torch.double]) def test_to_tensor_with_other_default_dtypes(self, dtype): np_rng = np.random.RandomState(0) current_def_dtype = torch.get_default_dtype() @@ -258,7 +263,7 @@ def test_to_tensor_with_other_default_dtypes(self, dtype): torch.set_default_dtype(current_def_dtype) - @pytest.mark.parametrize('channels', [1, 3, 4]) + @pytest.mark.parametrize("channels", [1, 3, 4]) def test_pil_to_tensor(self, channels): height, width = 4, 4 trans = transforms.PILToTensor() @@ -283,7 +288,7 @@ def test_pil_to_tensor(self, channels): # separate test for mode '1' PIL images input_data = torch.ByteTensor(1, height, width).bernoulli_() - img = transforms.ToPILImage()(input_data.mul(255)).convert('1') + img = transforms.ToPILImage()(input_data.mul(255)).convert("1") output = trans(img).view(torch.uint8).bool().to(torch.uint8) torch.testing.assert_close(input_data, output) @@ -316,34 +321,47 @@ def test_randomresized_params(): randresizecrop = transforms.RandomResizedCrop(size, scale_range, aspect_ratio_range) i, j, h, w = randresizecrop.get_params(img, scale_range, aspect_ratio_range) aspect_ratio_obtained = w / h - assert((min(aspect_ratio_range) - epsilon <= aspect_ratio_obtained and - aspect_ratio_obtained <= max(aspect_ratio_range) + epsilon) or - aspect_ratio_obtained == 1.0) + assert ( + min(aspect_ratio_range) - epsilon <= aspect_ratio_obtained + and aspect_ratio_obtained <= max(aspect_ratio_range) + epsilon + ) or aspect_ratio_obtained == 1.0 assert isinstance(i, int) assert isinstance(j, int) assert isinstance(h, int) assert isinstance(w, int) -@pytest.mark.parametrize('height, width', [ - # height, width - # square image - (28, 28), - (27, 27), - # rectangular image: h < w - (28, 34), - (29, 35), - # rectangular image: h > w - (34, 28), - (35, 29), -]) -@pytest.mark.parametrize('osize', [ - # single integer - 22, 27, 28, 36, - # single integer in tuple/list - [22, ], (27, ), -]) -@pytest.mark.parametrize('max_size', (None, 37, 1000)) +@pytest.mark.parametrize( + "height, width", + [ + # height, width + # square image + (28, 28), + (27, 27), + # rectangular image: h < w + (28, 34), + (29, 35), + # rectangular image: h > w + (34, 28), + (35, 29), + ], +) +@pytest.mark.parametrize( + "osize", + [ + # single integer + 22, + 27, + 28, + 36, + # single integer in tuple/list + [ + 22, + ], + (27,), + ], +) +@pytest.mark.parametrize("max_size", (None, 37, 1000)) def test_resize(height, width, osize, max_size): img = Image.new("RGB", size=(width, height), color=127) @@ -371,24 +389,36 @@ def test_resize(height, width, osize, max_size): assert result.size == (exp_w, exp_h), msg -@pytest.mark.parametrize('height, width', [ - # height, width - # square image - (28, 28), - (27, 27), - # rectangular image: h < w - (28, 34), - (29, 35), - # rectangular image: h > w - (34, 28), - (35, 29), -]) -@pytest.mark.parametrize('osize', [ - # two integers sequence output - [22, 22], [22, 28], [22, 36], - [27, 22], [36, 22], [28, 28], - [28, 37], [37, 27], [37, 37] -]) +@pytest.mark.parametrize( + "height, width", + [ + # height, width + # square image + (28, 28), + (27, 27), + # rectangular image: h < w + (28, 34), + (29, 35), + # rectangular image: h > w + (34, 28), + (35, 29), + ], +) +@pytest.mark.parametrize( + "osize", + [ + # two integers sequence output + [22, 22], + [22, 28], + [22, 36], + [27, 22], + [36, 22], + [28, 28], + [28, 37], + [37, 27], + [37, 37], + ], +) def test_resize_sequence_output(height, width, osize): img = Image.new("RGB", size=(width, height), color=127) oheight, owidth = osize @@ -409,18 +439,19 @@ def test_resize_antialias_error(): class TestPad: - def test_pad(self): height = random.randint(10, 32) * 2 width = random.randint(10, 32) * 2 img = torch.ones(3, height, width) padding = random.randint(1, 20) fill = random.randint(1, 50) - result = transforms.Compose([ - transforms.ToPILImage(), - transforms.Pad(padding, fill=fill), - transforms.ToTensor(), - ])(img) + result = transforms.Compose( + [ + transforms.ToPILImage(), + transforms.Pad(padding, fill=fill), + transforms.ToTensor(), + ] + )(img) assert result.size(1) == height + 2 * padding assert result.size(2) == width + 2 * padding # check that all elements in the padded region correspond @@ -429,14 +460,9 @@ def test_pad(self): eps = 1e-5 h_padded = result[:, :padding, :] w_padded = result[:, :, :padding] - torch.testing.assert_close( - h_padded, torch.full_like(h_padded, fill_value=fill_v), rtol=0.0, atol=eps - ) - torch.testing.assert_close( - w_padded, torch.full_like(w_padded, fill_value=fill_v), rtol=0.0, atol=eps - ) - pytest.raises(ValueError, transforms.Pad(padding, fill=(1, 2)), - transforms.ToPILImage()(img)) + torch.testing.assert_close(h_padded, torch.full_like(h_padded, fill_value=fill_v), rtol=0.0, atol=eps) + torch.testing.assert_close(w_padded, torch.full_like(w_padded, fill_value=fill_v), rtol=0.0, atol=eps) + pytest.raises(ValueError, transforms.Pad(padding, fill=(1, 2)), transforms.ToPILImage()(img)) def test_pad_with_tuple_of_pad_values(self): height = random.randint(10, 32) * 2 @@ -463,7 +489,7 @@ def test_pad_with_non_constant_padding_modes(self): img = F.pad(img, 1, (200, 200, 200)) # pad 3 to all sidess - edge_padded_img = F.pad(img, 3, padding_mode='edge') + edge_padded_img = F.pad(img, 3, padding_mode="edge") # First 6 elements of leftmost edge in the middle of the image, values are in order: # edge_pad, edge_pad, edge_pad, constant_pad, constant value added to leftmost edge, 0 edge_middle_slice = np.asarray(edge_padded_img).transpose(2, 0, 1)[0][17][:6] @@ -471,7 +497,7 @@ def test_pad_with_non_constant_padding_modes(self): assert transforms.ToTensor()(edge_padded_img).size() == (3, 35, 35) # Pad 3 to left/right, 2 to top/bottom - reflect_padded_img = F.pad(img, (3, 2), padding_mode='reflect') + reflect_padded_img = F.pad(img, (3, 2), padding_mode="reflect") # First 6 elements of leftmost edge in the middle of the image, values are in order: # reflect_pad, reflect_pad, reflect_pad, constant_pad, constant value added to leftmost edge, 0 reflect_middle_slice = np.asarray(reflect_padded_img).transpose(2, 0, 1)[0][17][:6] @@ -479,7 +505,7 @@ def test_pad_with_non_constant_padding_modes(self): assert transforms.ToTensor()(reflect_padded_img).size() == (3, 33, 35) # Pad 3 to left, 2 to top, 2 to right, 1 to bottom - symmetric_padded_img = F.pad(img, (3, 2, 2, 1), padding_mode='symmetric') + symmetric_padded_img = F.pad(img, (3, 2, 2, 1), padding_mode="symmetric") # First 6 elements of leftmost edge in the middle of the image, values are in order: # sym_pad, sym_pad, sym_pad, constant_pad, constant value added to leftmost edge, 0 symmetric_middle_slice = np.asarray(symmetric_padded_img).transpose(2, 0, 1)[0][17][:6] @@ -489,7 +515,7 @@ def test_pad_with_non_constant_padding_modes(self): # Check negative padding explicitly for symmetric case, since it is not # implemented for tensor case to compare to # Crop 1 to left, pad 2 to top, pad 3 to right, crop 3 to bottom - symmetric_padded_img_neg = F.pad(img, (-1, 2, 3, -3), padding_mode='symmetric') + symmetric_padded_img_neg = F.pad(img, (-1, 2, 3, -3), padding_mode="symmetric") symmetric_neg_middle_left = np.asarray(symmetric_padded_img_neg).transpose(2, 0, 1)[0][17][:3] symmetric_neg_middle_right = np.asarray(symmetric_padded_img_neg).transpose(2, 0, 1)[0][17][-4:] assert_equal(symmetric_neg_middle_left, np.asarray([1, 0, 0], dtype=np.uint8)) @@ -516,14 +542,18 @@ def test_pad_with_mode_F_images(self): @pytest.mark.skipif(stats is None, reason="scipy.stats not available") -@pytest.mark.parametrize('fn, trans, config', [ - (F.invert, transforms.RandomInvert, {}), - (F.posterize, transforms.RandomPosterize, {"bits": 4}), - (F.solarize, transforms.RandomSolarize, {"threshold": 192}), - (F.adjust_sharpness, transforms.RandomAdjustSharpness, {"sharpness_factor": 2.0}), - (F.autocontrast, transforms.RandomAutocontrast, {}), - (F.equalize, transforms.RandomEqualize, {})]) -@pytest.mark.parametrize('p', (.5, .7)) +@pytest.mark.parametrize( + "fn, trans, config", + [ + (F.invert, transforms.RandomInvert, {}), + (F.posterize, transforms.RandomPosterize, {"bits": 4}), + (F.solarize, transforms.RandomSolarize, {"threshold": 192}), + (F.adjust_sharpness, transforms.RandomAdjustSharpness, {"sharpness_factor": 2.0}), + (F.autocontrast, transforms.RandomAutocontrast, {}), + (F.equalize, transforms.RandomEqualize, {}), + ], +) +@pytest.mark.parametrize("p", (0.5, 0.7)) def test_randomness(fn, trans, config, p): random_state = random.getstate() random.seed(42) @@ -546,43 +576,42 @@ def test_randomness(fn, trans, config, p): class TestToPil: - def _get_1_channel_tensor_various_types(): img_data_float = torch.Tensor(1, 4, 4).uniform_() expected_output = img_data_float.mul(255).int().float().div(255).numpy() - yield img_data_float, expected_output, 'L' + yield img_data_float, expected_output, "L" img_data_byte = torch.ByteTensor(1, 4, 4).random_(0, 255) expected_output = img_data_byte.float().div(255.0).numpy() - yield img_data_byte, expected_output, 'L' + yield img_data_byte, expected_output, "L" img_data_short = torch.ShortTensor(1, 4, 4).random_() expected_output = img_data_short.numpy() - yield img_data_short, expected_output, 'I;16' + yield img_data_short, expected_output, "I;16" img_data_int = torch.IntTensor(1, 4, 4).random_() expected_output = img_data_int.numpy() - yield img_data_int, expected_output, 'I' + yield img_data_int, expected_output, "I" def _get_2d_tensor_various_types(): img_data_float = torch.Tensor(4, 4).uniform_() expected_output = img_data_float.mul(255).int().float().div(255).numpy() - yield img_data_float, expected_output, 'L' + yield img_data_float, expected_output, "L" img_data_byte = torch.ByteTensor(4, 4).random_(0, 255) expected_output = img_data_byte.float().div(255.0).numpy() - yield img_data_byte, expected_output, 'L' + yield img_data_byte, expected_output, "L" img_data_short = torch.ShortTensor(4, 4).random_() expected_output = img_data_short.numpy() - yield img_data_short, expected_output, 'I;16' + yield img_data_short, expected_output, "I;16" img_data_int = torch.IntTensor(4, 4).random_() expected_output = img_data_int.numpy() - yield img_data_int, expected_output, 'I' + yield img_data_int, expected_output, "I" - @pytest.mark.parametrize('with_mode', [False, True]) - @pytest.mark.parametrize('img_data, expected_output, expected_mode', _get_1_channel_tensor_various_types()) + @pytest.mark.parametrize("with_mode", [False, True]) + @pytest.mark.parametrize("img_data, expected_output, expected_mode", _get_1_channel_tensor_various_types()) def test_1_channel_tensor_to_pil_image(self, with_mode, img_data, expected_output, expected_mode): transform = transforms.ToPILImage(mode=expected_mode) if with_mode else transforms.ToPILImage() to_tensor = transforms.ToTensor() @@ -594,19 +623,22 @@ def test_1_channel_tensor_to_pil_image(self, with_mode, img_data, expected_outpu def test_1_channel_float_tensor_to_pil_image(self): img_data = torch.Tensor(1, 4, 4).uniform_() # 'F' mode for torch.FloatTensor - img_F_mode = transforms.ToPILImage(mode='F')(img_data) - assert img_F_mode.mode == 'F' + img_F_mode = transforms.ToPILImage(mode="F")(img_data) + assert img_F_mode.mode == "F" torch.testing.assert_close( - np.array(Image.fromarray(img_data.squeeze(0).numpy(), mode='F')), np.array(img_F_mode) + np.array(Image.fromarray(img_data.squeeze(0).numpy(), mode="F")), np.array(img_F_mode) ) - @pytest.mark.parametrize('with_mode', [False, True]) - @pytest.mark.parametrize('img_data, expected_mode', [ - (torch.Tensor(4, 4, 1).uniform_().numpy(), 'F'), - (torch.ByteTensor(4, 4, 1).random_(0, 255).numpy(), 'L'), - (torch.ShortTensor(4, 4, 1).random_().numpy(), 'I;16'), - (torch.IntTensor(4, 4, 1).random_().numpy(), 'I'), - ]) + @pytest.mark.parametrize("with_mode", [False, True]) + @pytest.mark.parametrize( + "img_data, expected_mode", + [ + (torch.Tensor(4, 4, 1).uniform_().numpy(), "F"), + (torch.ByteTensor(4, 4, 1).random_(0, 255).numpy(), "L"), + (torch.ShortTensor(4, 4, 1).random_().numpy(), "I;16"), + (torch.IntTensor(4, 4, 1).random_().numpy(), "I"), + ], + ) def test_1_channel_ndarray_to_pil_image(self, with_mode, img_data, expected_mode): transform = transforms.ToPILImage(mode=expected_mode) if with_mode else transforms.ToPILImage() img = transform(img_data) @@ -615,13 +647,13 @@ def test_1_channel_ndarray_to_pil_image(self, with_mode, img_data, expected_mode # and otherwise assert_close wouldn't be able to construct a tensor from the uint16 array torch.testing.assert_close(img_data[:, :, 0], np.asarray(img).astype(img_data.dtype)) - @pytest.mark.parametrize('expected_mode', [None, 'LA']) + @pytest.mark.parametrize("expected_mode", [None, "LA"]) def test_2_channel_ndarray_to_pil_image(self, expected_mode): img_data = torch.ByteTensor(4, 4, 2).random_(0, 255).numpy() if expected_mode is None: img = transforms.ToPILImage()(img_data) - assert img.mode == 'LA' # default should assume LA + assert img.mode == "LA" # default should assume LA else: img = transforms.ToPILImage(mode=expected_mode)(img_data) assert img.mode == expected_mode @@ -635,19 +667,19 @@ def test_2_channel_ndarray_to_pil_image_error(self): # should raise if we try a mode for 4 or 1 or 3 channel images with pytest.raises(ValueError, match=r"Only modes \['LA'\] are supported for 2D inputs"): - transforms.ToPILImage(mode='RGBA')(img_data) + transforms.ToPILImage(mode="RGBA")(img_data) with pytest.raises(ValueError, match=r"Only modes \['LA'\] are supported for 2D inputs"): - transforms.ToPILImage(mode='P')(img_data) + transforms.ToPILImage(mode="P")(img_data) with pytest.raises(ValueError, match=r"Only modes \['LA'\] are supported for 2D inputs"): - transforms.ToPILImage(mode='RGB')(img_data) + transforms.ToPILImage(mode="RGB")(img_data) - @pytest.mark.parametrize('expected_mode', [None, 'LA']) + @pytest.mark.parametrize("expected_mode", [None, "LA"]) def test_2_channel_tensor_to_pil_image(self, expected_mode): img_data = torch.Tensor(2, 4, 4).uniform_() expected_output = img_data.mul(255).int().float().div(255) if expected_mode is None: img = transforms.ToPILImage()(img_data) - assert img.mode == 'LA' # default should assume LA + assert img.mode == "LA" # default should assume LA else: img = transforms.ToPILImage(mode=expected_mode)(img_data) assert img.mode == expected_mode @@ -661,14 +693,14 @@ def test_2_channel_tensor_to_pil_image_error(self): # should raise if we try a mode for 4 or 1 or 3 channel images with pytest.raises(ValueError, match=r"Only modes \['LA'\] are supported for 2D inputs"): - transforms.ToPILImage(mode='RGBA')(img_data) + transforms.ToPILImage(mode="RGBA")(img_data) with pytest.raises(ValueError, match=r"Only modes \['LA'\] are supported for 2D inputs"): - transforms.ToPILImage(mode='P')(img_data) + transforms.ToPILImage(mode="P")(img_data) with pytest.raises(ValueError, match=r"Only modes \['LA'\] are supported for 2D inputs"): - transforms.ToPILImage(mode='RGB')(img_data) + transforms.ToPILImage(mode="RGB")(img_data) - @pytest.mark.parametrize('with_mode', [False, True]) - @pytest.mark.parametrize('img_data, expected_output, expected_mode', _get_2d_tensor_various_types()) + @pytest.mark.parametrize("with_mode", [False, True]) + @pytest.mark.parametrize("img_data, expected_output, expected_mode", _get_2d_tensor_various_types()) def test_2d_tensor_to_pil_image(self, with_mode, img_data, expected_output, expected_mode): transform = transforms.ToPILImage(mode=expected_mode) if with_mode else transforms.ToPILImage() to_tensor = transforms.ToTensor() @@ -677,27 +709,30 @@ def test_2d_tensor_to_pil_image(self, with_mode, img_data, expected_output, expe assert img.mode == expected_mode torch.testing.assert_close(expected_output, to_tensor(img).numpy()[0]) - @pytest.mark.parametrize('with_mode', [False, True]) - @pytest.mark.parametrize('img_data, expected_mode', [ - (torch.Tensor(4, 4).uniform_().numpy(), 'F'), - (torch.ByteTensor(4, 4).random_(0, 255).numpy(), 'L'), - (torch.ShortTensor(4, 4).random_().numpy(), 'I;16'), - (torch.IntTensor(4, 4).random_().numpy(), 'I'), - ]) + @pytest.mark.parametrize("with_mode", [False, True]) + @pytest.mark.parametrize( + "img_data, expected_mode", + [ + (torch.Tensor(4, 4).uniform_().numpy(), "F"), + (torch.ByteTensor(4, 4).random_(0, 255).numpy(), "L"), + (torch.ShortTensor(4, 4).random_().numpy(), "I;16"), + (torch.IntTensor(4, 4).random_().numpy(), "I"), + ], + ) def test_2d_ndarray_to_pil_image(self, with_mode, img_data, expected_mode): transform = transforms.ToPILImage(mode=expected_mode) if with_mode else transforms.ToPILImage() img = transform(img_data) assert img.mode == expected_mode np.testing.assert_allclose(img_data, img) - @pytest.mark.parametrize('expected_mode', [None, 'RGB', 'HSV', 'YCbCr']) + @pytest.mark.parametrize("expected_mode", [None, "RGB", "HSV", "YCbCr"]) def test_3_channel_tensor_to_pil_image(self, expected_mode): img_data = torch.Tensor(3, 4, 4).uniform_() expected_output = img_data.mul(255).int().float().div(255) if expected_mode is None: img = transforms.ToPILImage()(img_data) - assert img.mode == 'RGB' # default should assume RGB + assert img.mode == "RGB" # default should assume RGB else: img = transforms.ToPILImage(mode=expected_mode)(img_data) assert img.mode == expected_mode @@ -710,22 +745,22 @@ def test_3_channel_tensor_to_pil_image_error(self): error_message_3d = r"Only modes \['RGB', 'YCbCr', 'HSV'\] are supported for 3D inputs" # should raise if we try a mode for 4 or 1 or 2 channel images with pytest.raises(ValueError, match=error_message_3d): - transforms.ToPILImage(mode='RGBA')(img_data) + transforms.ToPILImage(mode="RGBA")(img_data) with pytest.raises(ValueError, match=error_message_3d): - transforms.ToPILImage(mode='P')(img_data) + transforms.ToPILImage(mode="P")(img_data) with pytest.raises(ValueError, match=error_message_3d): - transforms.ToPILImage(mode='LA')(img_data) + transforms.ToPILImage(mode="LA")(img_data) - with pytest.raises(ValueError, match=r'pic should be 2/3 dimensional. Got \d+ dimensions.'): + with pytest.raises(ValueError, match=r"pic should be 2/3 dimensional. Got \d+ dimensions."): transforms.ToPILImage()(torch.Tensor(1, 3, 4, 4).uniform_()) - @pytest.mark.parametrize('expected_mode', [None, 'RGB', 'HSV', 'YCbCr']) + @pytest.mark.parametrize("expected_mode", [None, "RGB", "HSV", "YCbCr"]) def test_3_channel_ndarray_to_pil_image(self, expected_mode): img_data = torch.ByteTensor(4, 4, 3).random_(0, 255).numpy() if expected_mode is None: img = transforms.ToPILImage()(img_data) - assert img.mode == 'RGB' # default should assume RGB + assert img.mode == "RGB" # default should assume RGB else: img = transforms.ToPILImage(mode=expected_mode)(img_data) assert img.mode == expected_mode @@ -742,20 +777,20 @@ def test_3_channel_ndarray_to_pil_image_error(self): error_message_3d = r"Only modes \['RGB', 'YCbCr', 'HSV'\] are supported for 3D inputs" # should raise if we try a mode for 4 or 1 or 2 channel images with pytest.raises(ValueError, match=error_message_3d): - transforms.ToPILImage(mode='RGBA')(img_data) + transforms.ToPILImage(mode="RGBA")(img_data) with pytest.raises(ValueError, match=error_message_3d): - transforms.ToPILImage(mode='P')(img_data) + transforms.ToPILImage(mode="P")(img_data) with pytest.raises(ValueError, match=error_message_3d): - transforms.ToPILImage(mode='LA')(img_data) + transforms.ToPILImage(mode="LA")(img_data) - @pytest.mark.parametrize('expected_mode', [None, 'RGBA', 'CMYK', 'RGBX']) + @pytest.mark.parametrize("expected_mode", [None, "RGBA", "CMYK", "RGBX"]) def test_4_channel_tensor_to_pil_image(self, expected_mode): img_data = torch.Tensor(4, 4, 4).uniform_() expected_output = img_data.mul(255).int().float().div(255) if expected_mode is None: img = transforms.ToPILImage()(img_data) - assert img.mode == 'RGBA' # default should assume RGBA + assert img.mode == "RGBA" # default should assume RGBA else: img = transforms.ToPILImage(mode=expected_mode)(img_data) assert img.mode == expected_mode @@ -770,19 +805,19 @@ def test_4_channel_tensor_to_pil_image_error(self): error_message_4d = r"Only modes \['RGBA', 'CMYK', 'RGBX'\] are supported for 4D inputs" # should raise if we try a mode for 3 or 1 or 2 channel images with pytest.raises(ValueError, match=error_message_4d): - transforms.ToPILImage(mode='RGB')(img_data) + transforms.ToPILImage(mode="RGB")(img_data) with pytest.raises(ValueError, match=error_message_4d): - transforms.ToPILImage(mode='P')(img_data) + transforms.ToPILImage(mode="P")(img_data) with pytest.raises(ValueError, match=error_message_4d): - transforms.ToPILImage(mode='LA')(img_data) + transforms.ToPILImage(mode="LA")(img_data) - @pytest.mark.parametrize('expected_mode', [None, 'RGBA', 'CMYK', 'RGBX']) + @pytest.mark.parametrize("expected_mode", [None, "RGBA", "CMYK", "RGBX"]) def test_4_channel_ndarray_to_pil_image(self, expected_mode): img_data = torch.ByteTensor(4, 4, 4).random_(0, 255).numpy() if expected_mode is None: img = transforms.ToPILImage()(img_data) - assert img.mode == 'RGBA' # default should assume RGBA + assert img.mode == "RGBA" # default should assume RGBA else: img = transforms.ToPILImage(mode=expected_mode)(img_data) assert img.mode == expected_mode @@ -796,15 +831,15 @@ def test_4_channel_ndarray_to_pil_image_error(self): error_message_4d = r"Only modes \['RGBA', 'CMYK', 'RGBX'\] are supported for 4D inputs" # should raise if we try a mode for 3 or 1 or 2 channel images with pytest.raises(ValueError, match=error_message_4d): - transforms.ToPILImage(mode='RGB')(img_data) + transforms.ToPILImage(mode="RGB")(img_data) with pytest.raises(ValueError, match=error_message_4d): - transforms.ToPILImage(mode='P')(img_data) + transforms.ToPILImage(mode="P")(img_data) with pytest.raises(ValueError, match=error_message_4d): - transforms.ToPILImage(mode='LA')(img_data) + transforms.ToPILImage(mode="LA")(img_data) def test_ndarray_bad_types_to_pil_image(self): trans = transforms.ToPILImage() - reg_msg = r'Input type \w+ is not supported' + reg_msg = r"Input type \w+ is not supported" with pytest.raises(TypeError, match=reg_msg): trans(np.ones([4, 4, 1], np.int64)) with pytest.raises(TypeError, match=reg_msg): @@ -814,15 +849,15 @@ def test_ndarray_bad_types_to_pil_image(self): with pytest.raises(TypeError, match=reg_msg): trans(np.ones([4, 4, 1], np.float64)) - with pytest.raises(ValueError, match=r'pic should be 2/3 dimensional. Got \d+ dimensions.'): + with pytest.raises(ValueError, match=r"pic should be 2/3 dimensional. Got \d+ dimensions."): transforms.ToPILImage()(np.ones([1, 4, 4, 3])) - with pytest.raises(ValueError, match=r'pic should not have > 4 channels. Got \d+ channels.'): + with pytest.raises(ValueError, match=r"pic should not have > 4 channels. Got \d+ channels."): transforms.ToPILImage()(np.ones([4, 4, 6])) def test_tensor_bad_types_to_pil_image(self): - with pytest.raises(ValueError, match=r'pic should be 2/3 dimensional. Got \d+ dimensions.'): + with pytest.raises(ValueError, match=r"pic should be 2/3 dimensional. Got \d+ dimensions."): transforms.ToPILImage()(torch.ones(1, 3, 4, 4)) - with pytest.raises(ValueError, match=r'pic should not have > 4 channels. Got \d+ channels.'): + with pytest.raises(ValueError, match=r"pic should not have > 4 channels. Got \d+ channels."): transforms.ToPILImage()(torch.ones(6, 4, 4)) @@ -830,7 +865,7 @@ def test_adjust_brightness(): x_shape = [2, 2, 3] x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) - x_pil = Image.fromarray(x_np, mode='RGB') + x_pil = Image.fromarray(x_np, mode="RGB") # test 0 y_pil = F.adjust_brightness(x_pil, 1) @@ -856,7 +891,7 @@ def test_adjust_contrast(): x_shape = [2, 2, 3] x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) - x_pil = Image.fromarray(x_np, mode='RGB') + x_pil = Image.fromarray(x_np, mode="RGB") # test 0 y_pil = F.adjust_contrast(x_pil, 1) @@ -878,12 +913,12 @@ def test_adjust_contrast(): torch.testing.assert_close(y_np, y_ans) -@pytest.mark.skipif(Image.__version__ >= '7', reason="Temporarily disabled") +@pytest.mark.skipif(Image.__version__ >= "7", reason="Temporarily disabled") def test_adjust_saturation(): x_shape = [2, 2, 3] x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) - x_pil = Image.fromarray(x_np, mode='RGB') + x_pil = Image.fromarray(x_np, mode="RGB") # test 0 y_pil = F.adjust_saturation(x_pil, 1) @@ -909,7 +944,7 @@ def test_adjust_hue(): x_shape = [2, 2, 3] x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) - x_pil = Image.fromarray(x_np, mode='RGB') + x_pil = Image.fromarray(x_np, mode="RGB") with pytest.raises(ValueError): F.adjust_hue(x_pil, -0.7) @@ -940,11 +975,58 @@ def test_adjust_hue(): def test_adjust_sharpness(): x_shape = [4, 4, 3] - x_data = [75, 121, 114, 105, 97, 107, 105, 32, 66, 111, 117, 114, 99, 104, 97, 0, - 0, 65, 108, 101, 120, 97, 110, 100, 101, 114, 32, 86, 114, 121, 110, 105, - 111, 116, 105, 115, 0, 0, 73, 32, 108, 111, 118, 101, 32, 121, 111, 117] + x_data = [ + 75, + 121, + 114, + 105, + 97, + 107, + 105, + 32, + 66, + 111, + 117, + 114, + 99, + 104, + 97, + 0, + 0, + 65, + 108, + 101, + 120, + 97, + 110, + 100, + 101, + 114, + 32, + 86, + 114, + 121, + 110, + 105, + 111, + 116, + 105, + 115, + 0, + 0, + 73, + 32, + 108, + 111, + 118, + 101, + 32, + 121, + 111, + 117, + ] x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) - x_pil = Image.fromarray(x_np, mode='RGB') + x_pil = Image.fromarray(x_np, mode="RGB") # test 0 y_pil = F.adjust_sharpness(x_pil, 1) @@ -954,18 +1036,112 @@ def test_adjust_sharpness(): # test 1 y_pil = F.adjust_sharpness(x_pil, 0.5) y_np = np.array(y_pil) - y_ans = [75, 121, 114, 105, 97, 107, 105, 32, 66, 111, 117, 114, 99, 104, 97, 30, - 30, 74, 103, 96, 114, 97, 110, 100, 101, 114, 32, 81, 103, 108, 102, 101, - 107, 116, 105, 115, 0, 0, 73, 32, 108, 111, 118, 101, 32, 121, 111, 117] + y_ans = [ + 75, + 121, + 114, + 105, + 97, + 107, + 105, + 32, + 66, + 111, + 117, + 114, + 99, + 104, + 97, + 30, + 30, + 74, + 103, + 96, + 114, + 97, + 110, + 100, + 101, + 114, + 32, + 81, + 103, + 108, + 102, + 101, + 107, + 116, + 105, + 115, + 0, + 0, + 73, + 32, + 108, + 111, + 118, + 101, + 32, + 121, + 111, + 117, + ] y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) torch.testing.assert_close(y_np, y_ans) # test 2 y_pil = F.adjust_sharpness(x_pil, 2) y_np = np.array(y_pil) - y_ans = [75, 121, 114, 105, 97, 107, 105, 32, 66, 111, 117, 114, 99, 104, 97, 0, - 0, 46, 118, 111, 132, 97, 110, 100, 101, 114, 32, 95, 135, 146, 126, 112, - 119, 116, 105, 115, 0, 0, 73, 32, 108, 111, 118, 101, 32, 121, 111, 117] + y_ans = [ + 75, + 121, + 114, + 105, + 97, + 107, + 105, + 32, + 66, + 111, + 117, + 114, + 99, + 104, + 97, + 0, + 0, + 46, + 118, + 111, + 132, + 97, + 110, + 100, + 101, + 114, + 32, + 95, + 135, + 146, + 126, + 112, + 119, + 116, + 105, + 115, + 0, + 0, + 73, + 32, + 108, + 111, + 118, + 101, + 32, + 121, + 111, + 117, + ] y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) torch.testing.assert_close(y_np, y_ans) @@ -973,7 +1149,7 @@ def test_adjust_sharpness(): x_shape = [2, 2, 3] x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) - x_pil = Image.fromarray(x_np, mode='RGB') + x_pil = Image.fromarray(x_np, mode="RGB") x_th = torch.tensor(x_np.transpose(2, 0, 1)) y_pil = F.adjust_sharpness(x_pil, 2) y_np = np.array(y_pil).transpose(2, 0, 1) @@ -985,7 +1161,7 @@ def test_adjust_gamma(): x_shape = [2, 2, 3] x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) - x_pil = Image.fromarray(x_np, mode='RGB') + x_pil = Image.fromarray(x_np, mode="RGB") # test 0 y_pil = F.adjust_gamma(x_pil, 1) @@ -1011,15 +1187,15 @@ def test_adjusts_L_mode(): x_shape = [2, 2, 3] x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) - x_rgb = Image.fromarray(x_np, mode='RGB') + x_rgb = Image.fromarray(x_np, mode="RGB") - x_l = x_rgb.convert('L') - assert F.adjust_brightness(x_l, 2).mode == 'L' - assert F.adjust_saturation(x_l, 2).mode == 'L' - assert F.adjust_contrast(x_l, 2).mode == 'L' - assert F.adjust_hue(x_l, 0.4).mode == 'L' - assert F.adjust_sharpness(x_l, 2).mode == 'L' - assert F.adjust_gamma(x_l, 0.5).mode == 'L' + x_l = x_rgb.convert("L") + assert F.adjust_brightness(x_l, 2).mode == "L" + assert F.adjust_saturation(x_l, 2).mode == "L" + assert F.adjust_contrast(x_l, 2).mode == "L" + assert F.adjust_hue(x_l, 0.4).mode == "L" + assert F.adjust_sharpness(x_l, 2).mode == "L" + assert F.adjust_gamma(x_l, 0.5).mode == "L" def test_rotate(): @@ -1058,7 +1234,7 @@ def test_rotate(): assert_equal(np.array(result_a), np.array(result_b)) -@pytest.mark.parametrize('mode', ["L", "RGB", "F"]) +@pytest.mark.parametrize("mode", ["L", "RGB", "F"]) def test_rotate_fill(mode): img = F.to_pil_image(np.ones((100, 100, 3), dtype=np.uint8) * 255, "RGB") @@ -1141,8 +1317,8 @@ def test_to_grayscale(): x_shape = [2, 2, 3] x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) - x_pil = Image.fromarray(x_np, mode='RGB') - x_pil_2 = x_pil.convert('L') + x_pil = Image.fromarray(x_np, mode="RGB") + x_pil_2 = x_pil.convert("L") gray_np = np.array(x_pil_2) # Test Set: Grayscale an image with desired number of output channels @@ -1150,16 +1326,16 @@ def test_to_grayscale(): trans1 = transforms.Grayscale(num_output_channels=1) gray_pil_1 = trans1(x_pil) gray_np_1 = np.array(gray_pil_1) - assert gray_pil_1.mode == 'L', 'mode should be L' - assert gray_np_1.shape == tuple(x_shape[0:2]), 'should be 1 channel' + assert gray_pil_1.mode == "L", "mode should be L" + assert gray_np_1.shape == tuple(x_shape[0:2]), "should be 1 channel" assert_equal(gray_np, gray_np_1) # Case 2: RGB -> 3 channel grayscale trans2 = transforms.Grayscale(num_output_channels=3) gray_pil_2 = trans2(x_pil) gray_np_2 = np.array(gray_pil_2) - assert gray_pil_2.mode == 'RGB', 'mode should be RGB' - assert gray_np_2.shape == tuple(x_shape), 'should be 3 channel' + assert gray_pil_2.mode == "RGB", "mode should be RGB" + assert gray_np_2.shape == tuple(x_shape), "should be 3 channel" assert_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1]) assert_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2]) assert_equal(gray_np, gray_np_2[:, :, 0]) @@ -1168,16 +1344,16 @@ def test_to_grayscale(): trans3 = transforms.Grayscale(num_output_channels=1) gray_pil_3 = trans3(x_pil_2) gray_np_3 = np.array(gray_pil_3) - assert gray_pil_3.mode == 'L', 'mode should be L' - assert gray_np_3.shape == tuple(x_shape[0:2]), 'should be 1 channel' + assert gray_pil_3.mode == "L", "mode should be L" + assert gray_np_3.shape == tuple(x_shape[0:2]), "should be 1 channel" assert_equal(gray_np, gray_np_3) # Case 4: 1 channel grayscale -> 3 channel grayscale trans4 = transforms.Grayscale(num_output_channels=3) gray_pil_4 = trans4(x_pil_2) gray_np_4 = np.array(gray_pil_4) - assert gray_pil_4.mode == 'RGB', 'mode should be RGB' - assert gray_np_4.shape == tuple(x_shape), 'should be 3 channel' + assert gray_pil_4.mode == "RGB", "mode should be RGB" + assert gray_np_4.shape == tuple(x_shape), "should be 3 channel" assert_equal(gray_np_4[:, :, 0], gray_np_4[:, :, 1]) assert_equal(gray_np_4[:, :, 1], gray_np_4[:, :, 2]) assert_equal(gray_np, gray_np_4[:, :, 0]) @@ -1196,8 +1372,8 @@ def test_random_grayscale(): random.seed(42) x_shape = [2, 2, 3] x_np = np_rng.randint(0, 256, x_shape, np.uint8) - x_pil = Image.fromarray(x_np, mode='RGB') - x_pil_2 = x_pil.convert('L') + x_pil = Image.fromarray(x_np, mode="RGB") + x_pil_2 = x_pil.convert("L") gray_np = np.array(x_pil_2) num_samples = 250 @@ -1205,9 +1381,11 @@ def test_random_grayscale(): for _ in range(num_samples): gray_pil_2 = transforms.RandomGrayscale(p=0.5)(x_pil) gray_np_2 = np.array(gray_pil_2) - if np.array_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1]) and \ - np.array_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2]) and \ - np.array_equal(gray_np, gray_np_2[:, :, 0]): + if ( + np.array_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1]) + and np.array_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2]) + and np.array_equal(gray_np, gray_np_2[:, :, 0]) + ): num_gray = num_gray + 1 p_value = stats.binom_test(num_gray, num_samples, p=0.5) @@ -1219,8 +1397,8 @@ def test_random_grayscale(): random.seed(42) x_shape = [2, 2, 3] x_np = np_rng.randint(0, 256, x_shape, np.uint8) - x_pil = Image.fromarray(x_np, mode='RGB') - x_pil_2 = x_pil.convert('L') + x_pil = Image.fromarray(x_np, mode="RGB") + x_pil_2 = x_pil.convert("L") gray_np = np.array(x_pil_2) num_samples = 250 @@ -1239,16 +1417,16 @@ def test_random_grayscale(): x_shape = [2, 2, 3] x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) - x_pil = Image.fromarray(x_np, mode='RGB') - x_pil_2 = x_pil.convert('L') + x_pil = Image.fromarray(x_np, mode="RGB") + x_pil_2 = x_pil.convert("L") gray_np = np.array(x_pil_2) # Case 3a: RGB -> 3 channel grayscale (grayscaled) trans2 = transforms.RandomGrayscale(p=1.0) gray_pil_2 = trans2(x_pil) gray_np_2 = np.array(gray_pil_2) - assert gray_pil_2.mode == 'RGB', 'mode should be RGB' - assert gray_np_2.shape == tuple(x_shape), 'should be 3 channel' + assert gray_pil_2.mode == "RGB", "mode should be RGB" + assert gray_np_2.shape == tuple(x_shape), "should be 3 channel" assert_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1]) assert_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2]) assert_equal(gray_np, gray_np_2[:, :, 0]) @@ -1257,31 +1435,31 @@ def test_random_grayscale(): trans2 = transforms.RandomGrayscale(p=0.0) gray_pil_2 = trans2(x_pil) gray_np_2 = np.array(gray_pil_2) - assert gray_pil_2.mode == 'RGB', 'mode should be RGB' - assert gray_np_2.shape == tuple(x_shape), 'should be 3 channel' + assert gray_pil_2.mode == "RGB", "mode should be RGB" + assert gray_np_2.shape == tuple(x_shape), "should be 3 channel" assert_equal(x_np, gray_np_2) # Case 3c: 1 channel grayscale -> 1 channel grayscale (grayscaled) trans3 = transforms.RandomGrayscale(p=1.0) gray_pil_3 = trans3(x_pil_2) gray_np_3 = np.array(gray_pil_3) - assert gray_pil_3.mode == 'L', 'mode should be L' - assert gray_np_3.shape == tuple(x_shape[0:2]), 'should be 1 channel' + assert gray_pil_3.mode == "L", "mode should be L" + assert gray_np_3.shape == tuple(x_shape[0:2]), "should be 1 channel" assert_equal(gray_np, gray_np_3) # Case 3d: 1 channel grayscale -> 1 channel grayscale (unchanged) trans3 = transforms.RandomGrayscale(p=0.0) gray_pil_3 = trans3(x_pil_2) gray_np_3 = np.array(gray_pil_3) - assert gray_pil_3.mode == 'L', 'mode should be L' - assert gray_np_3.shape == tuple(x_shape[0:2]), 'should be 1 channel' + assert gray_pil_3.mode == "L", "mode should be L" + assert gray_np_3.shape == tuple(x_shape[0:2]), "should be 1 channel" assert_equal(gray_np, gray_np_3) # Checking if RandomGrayscale can be printed as string trans3.__repr__() -@pytest.mark.skipif(stats is None, reason='scipy.stats not available') +@pytest.mark.skipif(stats is None, reason="scipy.stats not available") def test_random_apply(): random_state = random.getstate() random.seed(42) @@ -1290,7 +1468,8 @@ def test_random_apply(): transforms.RandomRotation((-45, 45)), transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), - ], p=0.75 + ], + p=0.75, ) img = transforms.ToPILImage()(torch.rand(3, 10, 10)) num_samples = 250 @@ -1308,17 +1487,12 @@ def test_random_apply(): random_apply_transform.__repr__() -@pytest.mark.skipif(stats is None, reason='scipy.stats not available') +@pytest.mark.skipif(stats is None, reason="scipy.stats not available") def test_random_choice(): random_state = random.getstate() random.seed(42) random_choice_transform = transforms.RandomChoice( - [ - transforms.Resize(15), - transforms.Resize(20), - transforms.CenterCrop(10) - ], - [1 / 3, 1 / 3, 1 / 3] + [transforms.Resize(15), transforms.Resize(20), transforms.CenterCrop(10)], [1 / 3, 1 / 3, 1 / 3] ) img = transforms.ToPILImage()(torch.rand(3, 25, 25)) num_samples = 250 @@ -1346,16 +1520,11 @@ def test_random_choice(): random_choice_transform.__repr__() -@pytest.mark.skipif(stats is None, reason='scipy.stats not available') +@pytest.mark.skipif(stats is None, reason="scipy.stats not available") def test_random_order(): random_state = random.getstate() random.seed(42) - random_order_transform = transforms.RandomOrder( - [ - transforms.Resize(20), - transforms.CenterCrop(10) - ] - ) + random_order_transform = transforms.RandomOrder([transforms.Resize(20), transforms.CenterCrop(10)]) img = transforms.ToPILImage()(torch.rand(3, 25, 25)) num_samples = 250 num_normal_order = 0 @@ -1381,10 +1550,10 @@ def test_linear_transformation(): sigma = torch.mm(flat_x.t(), flat_x) / flat_x.size(0) u, s, _ = np.linalg.svd(sigma.numpy()) zca_epsilon = 1e-10 # avoid division by 0 - d = torch.Tensor(np.diag(1. / np.sqrt(s + zca_epsilon))) + d = torch.Tensor(np.diag(1.0 / np.sqrt(s + zca_epsilon))) u = torch.Tensor(u) principal_components = torch.mm(torch.mm(u, d), u.t()) - mean_vector = (torch.sum(flat_x, dim=0) / flat_x.size(0)) + mean_vector = torch.sum(flat_x, dim=0) / flat_x.size(0) # initialize whitening matrix whitening = transforms.LinearTransformation(principal_components, mean_vector) # estimate covariance and mean using weak law of large number @@ -1397,16 +1566,18 @@ def test_linear_transformation(): cov += np.dot(xwhite, xwhite.T) / num_features mean += np.sum(xwhite) / num_features # if rtol for std = 1e-3 then rtol for cov = 2e-3 as std**2 = cov - torch.testing.assert_close(cov / num_samples, np.identity(1), rtol=2e-3, atol=1e-8, check_dtype=False, - msg="cov not close to 1") - torch.testing.assert_close(mean / num_samples, 0, rtol=1e-3, atol=1e-8, check_dtype=False, - msg="mean not close to 0") + torch.testing.assert_close( + cov / num_samples, np.identity(1), rtol=2e-3, atol=1e-8, check_dtype=False, msg="cov not close to 1" + ) + torch.testing.assert_close( + mean / num_samples, 0, rtol=1e-3, atol=1e-8, check_dtype=False, msg="mean not close to 0" + ) # Checking if LinearTransformation can be printed as string whitening.__repr__() -@pytest.mark.parametrize('dtype', int_dtypes()) +@pytest.mark.parametrize("dtype", int_dtypes()) def test_max_value(dtype): assert F_t._max_value(dtype) == torch.iinfo(dtype).max @@ -1416,8 +1587,8 @@ def test_max_value(dtype): # self.assertGreater(F_t._max_value(dtype), torch.finfo(dtype).max) -@pytest.mark.parametrize('should_vflip', [True, False]) -@pytest.mark.parametrize('single_dim', [True, False]) +@pytest.mark.parametrize("should_vflip", [True, False]) +@pytest.mark.parametrize("single_dim", [True, False]) def test_ten_crop(should_vflip, single_dim): to_pil_image = transforms.ToPILImage() h = random.randint(5, 25) @@ -1427,12 +1598,10 @@ def test_ten_crop(should_vflip, single_dim): if single_dim: crop_h = min(crop_h, crop_w) crop_w = crop_h - transform = transforms.TenCrop(crop_h, - vertical_flip=should_vflip) + transform = transforms.TenCrop(crop_h, vertical_flip=should_vflip) five_crop = transforms.FiveCrop(crop_h) else: - transform = transforms.TenCrop((crop_h, crop_w), - vertical_flip=should_vflip) + transform = transforms.TenCrop((crop_h, crop_w), vertical_flip=should_vflip) five_crop = transforms.FiveCrop((crop_h, crop_w)) img = to_pil_image(torch.FloatTensor(3, h, w).uniform_()) @@ -1454,7 +1623,7 @@ def test_ten_crop(should_vflip, single_dim): assert results == expected_output -@pytest.mark.parametrize('single_dim', [True, False]) +@pytest.mark.parametrize("single_dim", [True, False]) def test_five_crop(single_dim): to_pil_image = transforms.ToPILImage() h = random.randint(5, 25) @@ -1478,17 +1647,17 @@ def test_five_crop(single_dim): to_pil_image = transforms.ToPILImage() tl = to_pil_image(img[:, 0:crop_h, 0:crop_w]) - tr = to_pil_image(img[:, 0:crop_h, w - crop_w:]) - bl = to_pil_image(img[:, h - crop_h:, 0:crop_w]) - br = to_pil_image(img[:, h - crop_h:, w - crop_w:]) + tr = to_pil_image(img[:, 0:crop_h, w - crop_w :]) + bl = to_pil_image(img[:, h - crop_h :, 0:crop_w]) + br = to_pil_image(img[:, h - crop_h :, w - crop_w :]) center = transforms.CenterCrop((crop_h, crop_w))(to_pil_image(img)) expected_output = (tl, tr, bl, br, center) assert results == expected_output -@pytest.mark.parametrize('policy', transforms.AutoAugmentPolicy) -@pytest.mark.parametrize('fill', [None, 85, (128, 128, 128)]) -@pytest.mark.parametrize('grayscale', [True, False]) +@pytest.mark.parametrize("policy", transforms.AutoAugmentPolicy) +@pytest.mark.parametrize("fill", [None, 85, (128, 128, 128)]) +@pytest.mark.parametrize("grayscale", [True, False]) def test_autoaugment(policy, fill, grayscale): random.seed(42) img = Image.open(GRACE_HOPPER) @@ -1500,10 +1669,10 @@ def test_autoaugment(policy, fill, grayscale): transform.__repr__() -@pytest.mark.parametrize('num_ops', [1, 2, 3]) -@pytest.mark.parametrize('magnitude', [7, 9, 11]) -@pytest.mark.parametrize('fill', [None, 85, (128, 128, 128)]) -@pytest.mark.parametrize('grayscale', [True, False]) +@pytest.mark.parametrize("num_ops", [1, 2, 3]) +@pytest.mark.parametrize("magnitude", [7, 9, 11]) +@pytest.mark.parametrize("fill", [None, 85, (128, 128, 128)]) +@pytest.mark.parametrize("grayscale", [True, False]) def test_randaugment(num_ops, magnitude, fill, grayscale): random.seed(42) img = Image.open(GRACE_HOPPER) @@ -1515,9 +1684,9 @@ def test_randaugment(num_ops, magnitude, fill, grayscale): transform.__repr__() -@pytest.mark.parametrize('fill', [None, 85, (128, 128, 128)]) -@pytest.mark.parametrize('num_magnitude_bins', [10, 13, 30]) -@pytest.mark.parametrize('grayscale', [True, False]) +@pytest.mark.parametrize("fill", [None, 85, (128, 128, 128)]) +@pytest.mark.parametrize("num_magnitude_bins", [10, 13, 30]) +@pytest.mark.parametrize("grayscale", [True, False]) def test_trivialaugmentwide(fill, num_magnitude_bins, grayscale): random.seed(42) img = Image.open(GRACE_HOPPER) @@ -1535,37 +1704,41 @@ def test_random_crop(): oheight = random.randint(5, (height - 2) / 2) * 2 owidth = random.randint(5, (width - 2) / 2) * 2 img = torch.ones(3, height, width) - result = transforms.Compose([ - transforms.ToPILImage(), - transforms.RandomCrop((oheight, owidth)), - transforms.ToTensor(), - ])(img) + result = transforms.Compose( + [ + transforms.ToPILImage(), + transforms.RandomCrop((oheight, owidth)), + transforms.ToTensor(), + ] + )(img) assert result.size(1) == oheight assert result.size(2) == owidth padding = random.randint(1, 20) - result = transforms.Compose([ - transforms.ToPILImage(), - transforms.RandomCrop((oheight, owidth), padding=padding), - transforms.ToTensor(), - ])(img) + result = transforms.Compose( + [ + transforms.ToPILImage(), + transforms.RandomCrop((oheight, owidth), padding=padding), + transforms.ToTensor(), + ] + )(img) assert result.size(1) == oheight assert result.size(2) == owidth - result = transforms.Compose([ - transforms.ToPILImage(), - transforms.RandomCrop((height, width)), - transforms.ToTensor() - ])(img) + result = transforms.Compose( + [transforms.ToPILImage(), transforms.RandomCrop((height, width)), transforms.ToTensor()] + )(img) assert result.size(1) == height assert result.size(2) == width torch.testing.assert_close(result, img) - result = transforms.Compose([ - transforms.ToPILImage(), - transforms.RandomCrop((height + 1, width + 1), pad_if_needed=True), - transforms.ToTensor(), - ])(img) + result = transforms.Compose( + [ + transforms.ToPILImage(), + transforms.RandomCrop((height + 1, width + 1), pad_if_needed=True), + transforms.ToTensor(), + ] + )(img) assert result.size(1) == height + 1 assert result.size(2) == width + 1 @@ -1584,41 +1757,47 @@ def test_center_crop(): img = torch.ones(3, height, width) oh1 = (height - oheight) // 2 ow1 = (width - owidth) // 2 - imgnarrow = img[:, oh1:oh1 + oheight, ow1:ow1 + owidth] + imgnarrow = img[:, oh1 : oh1 + oheight, ow1 : ow1 + owidth] imgnarrow.fill_(0) - result = transforms.Compose([ - transforms.ToPILImage(), - transforms.CenterCrop((oheight, owidth)), - transforms.ToTensor(), - ])(img) + result = transforms.Compose( + [ + transforms.ToPILImage(), + transforms.CenterCrop((oheight, owidth)), + transforms.ToTensor(), + ] + )(img) assert result.sum() == 0 oheight += 1 owidth += 1 - result = transforms.Compose([ - transforms.ToPILImage(), - transforms.CenterCrop((oheight, owidth)), - transforms.ToTensor(), - ])(img) + result = transforms.Compose( + [ + transforms.ToPILImage(), + transforms.CenterCrop((oheight, owidth)), + transforms.ToTensor(), + ] + )(img) sum1 = result.sum() assert sum1 > 1 oheight += 1 owidth += 1 - result = transforms.Compose([ - transforms.ToPILImage(), - transforms.CenterCrop((oheight, owidth)), - transforms.ToTensor(), - ])(img) + result = transforms.Compose( + [ + transforms.ToPILImage(), + transforms.CenterCrop((oheight, owidth)), + transforms.ToTensor(), + ] + )(img) sum2 = result.sum() assert sum2 > 0 assert sum2 > sum1 -@pytest.mark.parametrize('odd_image_size', (True, False)) -@pytest.mark.parametrize('delta', (1, 3, 5)) -@pytest.mark.parametrize('delta_width', (-2, -1, 0, 1, 2)) -@pytest.mark.parametrize('delta_height', (-2, -1, 0, 1, 2)) +@pytest.mark.parametrize("odd_image_size", (True, False)) +@pytest.mark.parametrize("delta", (1, 3, 5)) +@pytest.mark.parametrize("delta_width", (-2, -1, 0, 1, 2)) +@pytest.mark.parametrize("delta_height", (-2, -1, 0, 1, 2)) def test_center_crop_2(odd_image_size, delta, delta_width, delta_height): - """ Tests when center crop size is larger than image size, along any dimension""" + """Tests when center crop size is larger than image size, along any dimension""" # Since height is independent of width, we can ignore images with odd height and even width and vice-versa. input_image_size = (random.randint(10, 32) * 2, random.randint(10, 32) * 2) @@ -1632,10 +1811,8 @@ def test_center_crop_2(odd_image_size, delta, delta_width, delta_height): crop_size = (input_image_size[0] + delta_height, input_image_size[1] + delta_width) # Test both transforms, one with PIL input and one with tensor - output_pil = transforms.Compose([ - transforms.ToPILImage(), - transforms.CenterCrop(crop_size), - transforms.ToTensor()], + output_pil = transforms.Compose( + [transforms.ToPILImage(), transforms.CenterCrop(crop_size), transforms.ToTensor()], )(img) assert output_pil.size()[1:3] == crop_size @@ -1660,14 +1837,14 @@ def test_center_crop_2(odd_image_size, delta, delta_width, delta_height): output_center = output_pil[ :, - crop_center_tl[0]:crop_center_tl[0] + center_size[0], - crop_center_tl[1]:crop_center_tl[1] + center_size[1] + crop_center_tl[0] : crop_center_tl[0] + center_size[0], + crop_center_tl[1] : crop_center_tl[1] + center_size[1], ] img_center = img[ :, - input_center_tl[0]:input_center_tl[0] + center_size[0], - input_center_tl[1]:input_center_tl[1] + center_size[1] + input_center_tl[0] : input_center_tl[0] + center_size[0], + input_center_tl[1] : input_center_tl[1] + center_size[1], ] assert_equal(output_center, img_center) @@ -1679,8 +1856,8 @@ def test_color_jitter(): x_shape = [2, 2, 3] x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) - x_pil = Image.fromarray(x_np, mode='RGB') - x_pil_2 = x_pil.convert('L') + x_pil = Image.fromarray(x_np, mode="RGB") + x_pil_2 = x_pil.convert("L") for _ in range(10): y_pil = color_jitter(x_pil) @@ -1697,18 +1874,32 @@ def test_color_jitter(): def test_random_erasing(): img = torch.ones(3, 128, 128) - t = transforms.RandomErasing(scale=(0.1, 0.1), ratio=(1 / 3, 3.)) - y, x, h, w, v = t.get_params(img, t.scale, t.ratio, [t.value, ]) + t = transforms.RandomErasing(scale=(0.1, 0.1), ratio=(1 / 3, 3.0)) + y, x, h, w, v = t.get_params( + img, + t.scale, + t.ratio, + [ + t.value, + ], + ) aspect_ratio = h / w # Add some tolerance due to the rounding and int conversion used in the transform tol = 0.05 - assert (1 / 3 - tol <= aspect_ratio <= 3 + tol) + assert 1 / 3 - tol <= aspect_ratio <= 3 + tol aspect_ratios = [] random.seed(42) trial = 1000 for _ in range(trial): - y, x, h, w, v = t.get_params(img, t.scale, t.ratio, [t.value, ]) + y, x, h, w, v = t.get_params( + img, + t.scale, + t.ratio, + [ + t.value, + ], + ) aspect_ratios.append(h / w) count_bigger_then_ones = len([1 for aspect_ratio in aspect_ratios if aspect_ratio > 1]) @@ -1735,11 +1926,11 @@ def test_random_rotation(): t = transforms.RandomRotation(10) angle = t.get_params(t.degrees) - assert (angle > -10 and angle < 10) + assert angle > -10 and angle < 10 t = transforms.RandomRotation((-10, 10)) angle = t.get_params(t.degrees) - assert (-10 < angle < 10) + assert -10 < angle < 10 # Checking if RandomRotation can be printed as string t.__repr__() @@ -1775,11 +1966,12 @@ def test_randomperspective(): tr_img = F.to_tensor(tr_img) assert img.size[0] == width assert img.size[1] == height - assert (torch.nn.functional.mse_loss(tr_img, F.to_tensor(img)) + 0.3 > - torch.nn.functional.mse_loss(tr_img2, F.to_tensor(img))) + assert torch.nn.functional.mse_loss(tr_img, F.to_tensor(img)) + 0.3 > torch.nn.functional.mse_loss( + tr_img2, F.to_tensor(img) + ) -@pytest.mark.parametrize('mode', ["L", "RGB", "F"]) +@pytest.mark.parametrize("mode", ["L", "RGB", "F"]) def test_randomperspective_fill(mode): # assert fill being either a Sequence or a Number @@ -1819,7 +2011,7 @@ def test_randomperspective_fill(mode): F.perspective(img_conv, startpoints, endpoints, fill=tuple([fill] * wrong_num_bands)) -@pytest.mark.skipif(stats is None, reason='scipy.stats not available') +@pytest.mark.skipif(stats is None, reason="scipy.stats not available") def test_random_vertical_flip(): random_state = random.getstate() random.seed(42) @@ -1852,7 +2044,7 @@ def test_random_vertical_flip(): transforms.RandomVerticalFlip().__repr__() -@pytest.mark.skipif(stats is None, reason='scipy.stats not available') +@pytest.mark.skipif(stats is None, reason="scipy.stats not available") def test_random_horizontal_flip(): random_state = random.getstate() random.seed(42) @@ -1885,10 +2077,10 @@ def test_random_horizontal_flip(): transforms.RandomHorizontalFlip().__repr__() -@pytest.mark.skipif(stats is None, reason='scipy.stats not available') +@pytest.mark.skipif(stats is None, reason="scipy.stats not available") def test_normalize(): def samples_from_standard_normal(tensor): - p_value = stats.kstest(list(tensor.view(-1)), 'norm', args=(0, 1)).pvalue + p_value = stats.kstest(list(tensor.view(-1)), "norm", args=(0, 1)).pvalue return p_value > 0.0001 random_state = random.getstate() @@ -1910,8 +2102,8 @@ def samples_from_standard_normal(tensor): assert_equal(tensor, tensor_inplace) -@pytest.mark.parametrize('dtype1', [torch.float32, torch.float64]) -@pytest.mark.parametrize('dtype2', [torch.int64, torch.float32, torch.float64]) +@pytest.mark.parametrize("dtype1", [torch.float32, torch.float64]) +@pytest.mark.parametrize("dtype2", [torch.int64, torch.float32, torch.float64]) def test_normalize_different_dtype(dtype1, dtype2): img = torch.rand(3, 10, 10, dtype=dtype1) mean = torch.tensor([1, 2, 3], dtype=dtype2) @@ -1932,15 +2124,15 @@ def test_normalize_3d_tensor(): mean_unsqueezed = mean.view(-1, 1, 1) std_unsqueezed = std.view(-1, 1, 1) result1 = F.normalize(img, mean_unsqueezed, std_unsqueezed) - result2 = F.normalize(img, mean_unsqueezed.repeat(1, img_size, img_size), - std_unsqueezed.repeat(1, img_size, img_size)) + result2 = F.normalize( + img, mean_unsqueezed.repeat(1, img_size, img_size), std_unsqueezed.repeat(1, img_size, img_size) + ) torch.testing.assert_close(target, result1) torch.testing.assert_close(target, result2) class TestAffine: - - @pytest.fixture(scope='class') + @pytest.fixture(scope="class") def input_img(self): input_img = np.zeros((40, 40, 3), dtype=np.uint8) for pt in [(16, 16), (20, 16), (20, 20)]: @@ -1953,7 +2145,7 @@ def test_affine_translate_seq(self, input_img): with pytest.raises(TypeError, match=r"Argument translate should be a sequence"): F.affine(input_img, 10, translate=0, scale=1, shear=1) - @pytest.fixture(scope='class') + @pytest.fixture(scope="class") def pil_image(self, input_img): return F.to_pil_image(input_img) @@ -1974,33 +2166,29 @@ def _test_transformation(self, angle, translate, scale, shear, pil_image, input_ rot = a_rad # 1) Check transformation matrix: - C = np.array([[1, 0, cx], - [0, 1, cy], - [0, 0, 1]]) - T = np.array([[1, 0, tx], - [0, 1, ty], - [0, 0, 1]]) + C = np.array([[1, 0, cx], [0, 1, cy], [0, 0, 1]]) + T = np.array([[1, 0, tx], [0, 1, ty], [0, 0, 1]]) Cinv = np.linalg.inv(C) RS = np.array( - [[scale * math.cos(rot), -scale * math.sin(rot), 0], - [scale * math.sin(rot), scale * math.cos(rot), 0], - [0, 0, 1]]) + [ + [scale * math.cos(rot), -scale * math.sin(rot), 0], + [scale * math.sin(rot), scale * math.cos(rot), 0], + [0, 0, 1], + ] + ) - SHx = np.array([[1, -math.tan(sx), 0], - [0, 1, 0], - [0, 0, 1]]) + SHx = np.array([[1, -math.tan(sx), 0], [0, 1, 0], [0, 0, 1]]) - SHy = np.array([[1, 0, 0], - [-math.tan(sy), 1, 0], - [0, 0, 1]]) + SHy = np.array([[1, 0, 0], [-math.tan(sy), 1, 0], [0, 0, 1]]) RSS = np.matmul(RS, np.matmul(SHy, SHx)) true_matrix = np.matmul(T, np.matmul(C, np.matmul(RSS, Cinv))) - result_matrix = self._to_3x3_inv(F._get_inverse_affine_matrix(center=cnt, angle=angle, - translate=translate, scale=scale, shear=shear)) + result_matrix = self._to_3x3_inv( + F._get_inverse_affine_matrix(center=cnt, angle=angle, translate=translate, scale=scale, shear=shear) + ) assert np.sum(np.abs(true_matrix - result_matrix)) < 1e-10 # 2) Perform inverse mapping: true_result = np.zeros((40, 40, 3), dtype=np.uint8) @@ -2022,38 +2210,49 @@ def _test_transformation(self, angle, translate, scale, shear, pil_image, input_ np_result = np.array(result) n_diff_pixels = np.sum(np_result != true_result) / 3 # Accept 3 wrong pixels - error_msg = ("angle={}, translate={}, scale={}, shear={}\n".format(angle, translate, scale, shear) + - "n diff pixels={}\n".format(n_diff_pixels)) + error_msg = "angle={}, translate={}, scale={}, shear={}\n".format( + angle, translate, scale, shear + ) + "n diff pixels={}\n".format(n_diff_pixels) assert n_diff_pixels < 3, error_msg def test_transformation_discrete(self, pil_image, input_img): # Test rotation angle = 45 - self._test_transformation(angle=angle, translate=(0, 0), scale=1.0, - shear=(0.0, 0.0), pil_image=pil_image, input_img=input_img) + self._test_transformation( + angle=angle, translate=(0, 0), scale=1.0, shear=(0.0, 0.0), pil_image=pil_image, input_img=input_img + ) # Test translation translate = [10, 15] - self._test_transformation(angle=0.0, translate=translate, scale=1.0, - shear=(0.0, 0.0), pil_image=pil_image, input_img=input_img) + self._test_transformation( + angle=0.0, translate=translate, scale=1.0, shear=(0.0, 0.0), pil_image=pil_image, input_img=input_img + ) # Test scale scale = 1.2 - self._test_transformation(angle=0.0, translate=(0.0, 0.0), scale=scale, - shear=(0.0, 0.0), pil_image=pil_image, input_img=input_img) + self._test_transformation( + angle=0.0, translate=(0.0, 0.0), scale=scale, shear=(0.0, 0.0), pil_image=pil_image, input_img=input_img + ) # Test shear shear = [45.0, 25.0] - self._test_transformation(angle=0.0, translate=(0.0, 0.0), scale=1.0, - shear=shear, pil_image=pil_image, input_img=input_img) + self._test_transformation( + angle=0.0, translate=(0.0, 0.0), scale=1.0, shear=shear, pil_image=pil_image, input_img=input_img + ) @pytest.mark.parametrize("angle", range(-90, 90, 36)) @pytest.mark.parametrize("translate", range(-10, 10, 5)) @pytest.mark.parametrize("scale", [0.77, 1.0, 1.27]) @pytest.mark.parametrize("shear", range(-15, 15, 5)) def test_transformation_range(self, angle, translate, scale, shear, pil_image, input_img): - self._test_transformation(angle=angle, translate=(translate, translate), scale=scale, - shear=(shear, shear), pil_image=pil_image, input_img=input_img) + self._test_transformation( + angle=angle, + translate=(translate, translate), + scale=scale, + shear=(shear, shear), + pil_image=pil_image, + input_img=input_img, + ) def test_random_affine(): @@ -2101,13 +2300,14 @@ def test_random_affine(): t = transforms.RandomAffine(10, translate=[0.5, 0.3], scale=[0.7, 1.3], shear=[-10, 10, 20, 40]) for _ in range(100): - angle, translations, scale, shear = t.get_params(t.degrees, t.translate, t.scale, t.shear, - img_size=img.size) + angle, translations, scale, shear = t.get_params(t.degrees, t.translate, t.scale, t.shear, img_size=img.size) assert -10 < angle < 10 - assert -img.size[0] * 0.5 <= translations[0] <= img.size[0] * 0.5, ("{} vs {}" - .format(translations[0], img.size[0] * 0.5)) - assert -img.size[1] * 0.5 <= translations[1] <= img.size[1] * 0.5, ("{} vs {}" - .format(translations[1], img.size[1] * 0.5)) + assert -img.size[0] * 0.5 <= translations[0] <= img.size[0] * 0.5, "{} vs {}".format( + translations[0], img.size[0] * 0.5 + ) + assert -img.size[1] * 0.5 <= translations[1] <= img.size[1] * 0.5, "{} vs {}".format( + translations[1], img.size[1] * 0.5 + ) assert 0.7 < scale < 1.3 assert -10 < shear[0] < 10 assert -20 < shear[1] < 40 @@ -2133,5 +2333,5 @@ def test_random_affine(): assert t.interpolation == transforms.InterpolationMode.BILINEAR -if __name__ == '__main__': +if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 8aa398bd006..07b98f60999 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -1,14 +1,9 @@ import os -import torch -from torchvision import transforms as T -from torchvision.transforms import functional as F -from torchvision.transforms import InterpolationMode +from typing import Sequence import numpy as np import pytest - -from typing import Sequence - +import torch from common_utils import ( get_tmp_dir, int_dtypes, @@ -20,6 +15,9 @@ cpu_and_gpu, assert_equal, ) +from torchvision import transforms as T +from torchvision.transforms import InterpolationMode +from torchvision.transforms import functional as F NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC @@ -94,110 +92,137 @@ def _test_op(func, method, device, channels=3, fn_kwargs=None, meth_kwargs=None, _test_class_op(method, device, channels, meth_kwargs, test_exact_match=test_exact_match, **match_kwargs) -@pytest.mark.parametrize('device', cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize( - 'func,method,fn_kwargs,match_kwargs', [ + "func,method,fn_kwargs,match_kwargs", + [ (F.hflip, T.RandomHorizontalFlip, None, {}), (F.vflip, T.RandomVerticalFlip, None, {}), (F.invert, T.RandomInvert, None, {}), (F.posterize, T.RandomPosterize, {"bits": 4}, {}), (F.solarize, T.RandomSolarize, {"threshold": 192.0}, {}), (F.adjust_sharpness, T.RandomAdjustSharpness, {"sharpness_factor": 2.0}, {}), - (F.autocontrast, T.RandomAutocontrast, None, {'test_exact_match': False, - 'agg_method': 'max', 'tol': (1 + 1e-5), - 'allowed_percentage_diff': .05}), - (F.equalize, T.RandomEqualize, None, {}) - ] + ( + F.autocontrast, + T.RandomAutocontrast, + None, + {"test_exact_match": False, "agg_method": "max", "tol": (1 + 1e-5), "allowed_percentage_diff": 0.05}, + ), + (F.equalize, T.RandomEqualize, None, {}), + ], ) -@pytest.mark.parametrize('channels', [1, 3]) +@pytest.mark.parametrize("channels", [1, 3]) def test_random(func, method, device, channels, fn_kwargs, match_kwargs): _test_op(func, method, device, channels, fn_kwargs, fn_kwargs, **match_kwargs) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('channels', [1, 3]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("channels", [1, 3]) class TestColorJitter: - - @pytest.mark.parametrize('brightness', [0.1, 0.5, 1.0, 1.34, (0.3, 0.7), [0.4, 0.5]]) + @pytest.mark.parametrize("brightness", [0.1, 0.5, 1.0, 1.34, (0.3, 0.7), [0.4, 0.5]]) def test_color_jitter_brightness(self, brightness, device, channels): tol = 1.0 + 1e-10 meth_kwargs = {"brightness": brightness} _test_class_op( - T.ColorJitter, meth_kwargs=meth_kwargs, test_exact_match=False, device=device, - tol=tol, agg_method="max", channels=channels, + T.ColorJitter, + meth_kwargs=meth_kwargs, + test_exact_match=False, + device=device, + tol=tol, + agg_method="max", + channels=channels, ) - @pytest.mark.parametrize('contrast', [0.2, 0.5, 1.0, 1.5, (0.3, 0.7), [0.4, 0.5]]) + @pytest.mark.parametrize("contrast", [0.2, 0.5, 1.0, 1.5, (0.3, 0.7), [0.4, 0.5]]) def test_color_jitter_contrast(self, contrast, device, channels): tol = 1.0 + 1e-10 meth_kwargs = {"contrast": contrast} _test_class_op( - T.ColorJitter, meth_kwargs=meth_kwargs, test_exact_match=False, device=device, - tol=tol, agg_method="max", channels=channels + T.ColorJitter, + meth_kwargs=meth_kwargs, + test_exact_match=False, + device=device, + tol=tol, + agg_method="max", + channels=channels, ) - @pytest.mark.parametrize('saturation', [0.5, 0.75, 1.0, 1.25, (0.3, 0.7), [0.3, 0.4]]) + @pytest.mark.parametrize("saturation", [0.5, 0.75, 1.0, 1.25, (0.3, 0.7), [0.3, 0.4]]) def test_color_jitter_saturation(self, saturation, device, channels): tol = 1.0 + 1e-10 meth_kwargs = {"saturation": saturation} _test_class_op( - T.ColorJitter, meth_kwargs=meth_kwargs, test_exact_match=False, device=device, - tol=tol, agg_method="max", channels=channels + T.ColorJitter, + meth_kwargs=meth_kwargs, + test_exact_match=False, + device=device, + tol=tol, + agg_method="max", + channels=channels, ) - @pytest.mark.parametrize('hue', [0.2, 0.5, (-0.2, 0.3), [-0.4, 0.5]]) + @pytest.mark.parametrize("hue", [0.2, 0.5, (-0.2, 0.3), [-0.4, 0.5]]) def test_color_jitter_hue(self, hue, device, channels): meth_kwargs = {"hue": hue} _test_class_op( - T.ColorJitter, meth_kwargs=meth_kwargs, test_exact_match=False, device=device, - tol=16.1, agg_method="max", channels=channels + T.ColorJitter, + meth_kwargs=meth_kwargs, + test_exact_match=False, + device=device, + tol=16.1, + agg_method="max", + channels=channels, ) def test_color_jitter_all(self, device, channels): # All 4 parameters together meth_kwargs = {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.2} _test_class_op( - T.ColorJitter, meth_kwargs=meth_kwargs, test_exact_match=False, device=device, - tol=12.1, agg_method="max", channels=channels + T.ColorJitter, + meth_kwargs=meth_kwargs, + test_exact_match=False, + device=device, + tol=12.1, + agg_method="max", + channels=channels, ) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('m', ["constant", "edge", "reflect", "symmetric"]) -@pytest.mark.parametrize('mul', [1, -1]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("m", ["constant", "edge", "reflect", "symmetric"]) +@pytest.mark.parametrize("mul", [1, -1]) def test_pad(m, mul, device): fill = 127 if m == "constant" else 0 # Test functional.pad (PIL and Tensor) with padding as single int - _test_functional_op( - F.pad, fn_kwargs={"padding": mul * 2, "fill": fill, "padding_mode": m}, - device=device - ) + _test_functional_op(F.pad, fn_kwargs={"padding": mul * 2, "fill": fill, "padding_mode": m}, device=device) # Test functional.pad and transforms.Pad with padding as [int, ] - fn_kwargs = meth_kwargs = {"padding": [mul * 2, ], "fill": fill, "padding_mode": m} - _test_op( - F.pad, T.Pad, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs - ) + fn_kwargs = meth_kwargs = { + "padding": [ + mul * 2, + ], + "fill": fill, + "padding_mode": m, + } + _test_op(F.pad, T.Pad, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs) # Test functional.pad and transforms.Pad with padding as list fn_kwargs = meth_kwargs = {"padding": [mul * 4, 4], "fill": fill, "padding_mode": m} - _test_op( - F.pad, T.Pad, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs - ) + _test_op(F.pad, T.Pad, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs) # Test functional.pad and transforms.Pad with padding as tuple fn_kwargs = meth_kwargs = {"padding": (mul * 2, 2, 2, mul * 2), "fill": fill, "padding_mode": m} - _test_op( - F.pad, T.Pad, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs - ) + _test_op(F.pad, T.Pad, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs) -@pytest.mark.parametrize('device', cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_gpu()) def test_crop(device): fn_kwargs = {"top": 2, "left": 3, "height": 4, "width": 5} # Test transforms.RandomCrop with size and padding as tuple - meth_kwargs = {"size": (4, 5), "padding": (4, 4), "pad_if_needed": True, } - _test_op( - F.crop, T.RandomCrop, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs - ) + meth_kwargs = { + "size": (4, 5), + "padding": (4, 4), + "pad_if_needed": True, + } + _test_op(F.crop, T.RandomCrop, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs) # Test transforms.functional.crop including outside the image area fn_kwargs = {"top": -2, "left": 3, "height": 4, "width": 5} # top @@ -216,35 +241,43 @@ def test_crop(device): _test_functional_op(F.crop, fn_kwargs=fn_kwargs, device=device) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('padding_config', [ - {"padding_mode": "constant", "fill": 0}, - {"padding_mode": "constant", "fill": 10}, - {"padding_mode": "constant", "fill": 20}, - {"padding_mode": "edge"}, - {"padding_mode": "reflect"} -]) -@pytest.mark.parametrize('size', [5, [5, ], [6, 6]]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize( + "padding_config", + [ + {"padding_mode": "constant", "fill": 0}, + {"padding_mode": "constant", "fill": 10}, + {"padding_mode": "constant", "fill": 20}, + {"padding_mode": "edge"}, + {"padding_mode": "reflect"}, + ], +) +@pytest.mark.parametrize( + "size", + [ + 5, + [ + 5, + ], + [6, 6], + ], +) def test_crop_pad(size, padding_config, device): config = dict(padding_config) config["size"] = size _test_class_op(T.RandomCrop, device, meth_kwargs=config) -@pytest.mark.parametrize('device', cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_gpu()) def test_center_crop(device, tmpdir): fn_kwargs = {"output_size": (4, 5)} - meth_kwargs = {"size": (4, 5), } - _test_op( - F.center_crop, T.CenterCrop, device=device, fn_kwargs=fn_kwargs, - meth_kwargs=meth_kwargs - ) + meth_kwargs = { + "size": (4, 5), + } + _test_op(F.center_crop, T.CenterCrop, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs) fn_kwargs = {"output_size": (5,)} meth_kwargs = {"size": (5,)} - _test_op( - F.center_crop, T.CenterCrop, device=device, fn_kwargs=fn_kwargs, - meth_kwargs=meth_kwargs - ) + _test_op(F.center_crop, T.CenterCrop, device=device, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs) tensor = torch.randint(0, 256, (3, 10, 10), dtype=torch.uint8, device=device) # Test torchscript of transforms.CenterCrop with size as int f = T.CenterCrop(size=5) @@ -252,7 +285,11 @@ def test_center_crop(device, tmpdir): scripted_fn(tensor) # Test torchscript of transforms.CenterCrop with size as [int, ] - f = T.CenterCrop(size=[5, ]) + f = T.CenterCrop( + size=[ + 5, + ] + ) scripted_fn = torch.jit.script(f) scripted_fn(tensor) @@ -264,16 +301,29 @@ def test_center_crop(device, tmpdir): scripted_fn.save(os.path.join(tmpdir, "t_center_crop.pt")) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('fn, method, out_length', [ - # test_five_crop - (F.five_crop, T.FiveCrop, 5), - # test_ten_crop - (F.ten_crop, T.TenCrop, 10) -]) -@pytest.mark.parametrize('size', [(5,), [5, ], (4, 5), [4, 5]]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize( + "fn, method, out_length", + [ + # test_five_crop + (F.five_crop, T.FiveCrop, 5), + # test_ten_crop + (F.ten_crop, T.TenCrop, 10), + ], +) +@pytest.mark.parametrize( + "size", + [ + (5,), + [ + 5, + ], + (4, 5), + [4, 5], + ], +) def test_x_crop(fn, method, out_length, size, device): - meth_kwargs = fn_kwargs = {'size': size} + meth_kwargs = fn_kwargs = {"size": size} scripted_fn = torch.jit.script(fn) tensor, pil_img = _create_data(height=20, width=20, device=device) @@ -309,15 +359,19 @@ def test_x_crop(fn, method, out_length, size, device): assert_equal(transformed_img, transformed_batch[i, ...]) -@pytest.mark.parametrize('method', ["FiveCrop", "TenCrop"]) +@pytest.mark.parametrize("method", ["FiveCrop", "TenCrop"]) def test_x_crop_save(method, tmpdir): - fn = getattr(T, method)(size=[5, ]) + fn = getattr(T, method)( + size=[ + 5, + ] + ) scripted_fn = torch.jit.script(fn) scripted_fn.save(os.path.join(tmpdir, "t_op_list_{}.pt".format(method))) class TestResize: - @pytest.mark.parametrize('size', [32, 34, 35, 36, 38]) + @pytest.mark.parametrize("size", [32, 34, 35, 36, 38]) def test_resize_int(self, size): # TODO: Minimal check for bug-fix, improve this later x = torch.rand(3, 32, 46) @@ -329,11 +383,21 @@ def test_resize_int(self, size): assert y.shape[1] == size assert y.shape[2] == int(size * 46 / 32) - @pytest.mark.parametrize('device', cpu_and_gpu()) - @pytest.mark.parametrize('dt', [None, torch.float32, torch.float64]) - @pytest.mark.parametrize('size', [[32, ], [32, 32], (32, 32), [34, 35]]) - @pytest.mark.parametrize('max_size', [None, 35, 1000]) - @pytest.mark.parametrize('interpolation', [BILINEAR, BICUBIC, NEAREST]) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("dt", [None, torch.float32, torch.float64]) + @pytest.mark.parametrize( + "size", + [ + [ + 32, + ], + [32, 32], + (32, 32), + [34, 35], + ], + ) + @pytest.mark.parametrize("max_size", [None, 35, 1000]) + @pytest.mark.parametrize("interpolation", [BILINEAR, BICUBIC, NEAREST]) def test_resize_scripted(self, dt, size, max_size, interpolation, device): tensor, _ = _create_data(height=34, width=36, device=device) batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device) @@ -350,15 +414,33 @@ def test_resize_scripted(self, dt, size, max_size, interpolation, device): _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) def test_resize_save(self, tmpdir): - transform = T.Resize(size=[32, ]) + transform = T.Resize( + size=[ + 32, + ] + ) s_transform = torch.jit.script(transform) s_transform.save(os.path.join(tmpdir, "t_resize.pt")) - @pytest.mark.parametrize('device', cpu_and_gpu()) - @pytest.mark.parametrize('scale', [(0.7, 1.2), [0.7, 1.2]]) - @pytest.mark.parametrize('ratio', [(0.75, 1.333), [0.75, 1.333]]) - @pytest.mark.parametrize('size', [(32,), [44, ], [32, ], [32, 32], (32, 32), [44, 55]]) - @pytest.mark.parametrize('interpolation', [NEAREST, BILINEAR, BICUBIC]) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("scale", [(0.7, 1.2), [0.7, 1.2]]) + @pytest.mark.parametrize("ratio", [(0.75, 1.333), [0.75, 1.333]]) + @pytest.mark.parametrize( + "size", + [ + (32,), + [ + 44, + ], + [ + 32, + ], + [32, 32], + (32, 32), + [44, 55], + ], + ) + @pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR, BICUBIC]) def test_resized_crop(self, scale, ratio, size, interpolation, device): tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device) batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device) @@ -368,7 +450,11 @@ def test_resized_crop(self, scale, ratio, size, interpolation, device): _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) def test_resized_crop_save(self, tmpdir): - transform = T.RandomResizedCrop(size=[32, ]) + transform = T.RandomResizedCrop( + size=[ + 32, + ] + ) s_transform = torch.jit.script(transform) s_transform.save(os.path.join(tmpdir, "t_resized_crop.pt")) @@ -383,61 +469,83 @@ def _test_random_affine_helper(device, **kwargs): _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) -@pytest.mark.parametrize('device', cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_gpu()) def test_random_affine(device, tmpdir): transform = T.RandomAffine(degrees=45.0) s_transform = torch.jit.script(transform) s_transform.save(os.path.join(tmpdir, "t_random_affine.pt")) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('interpolation', [NEAREST, BILINEAR]) -@pytest.mark.parametrize('shear', [15, 10.0, (5.0, 10.0), [-15, 15], [-10.0, 10.0, -11.0, 11.0]]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR]) +@pytest.mark.parametrize("shear", [15, 10.0, (5.0, 10.0), [-15, 15], [-10.0, 10.0, -11.0, 11.0]]) def test_random_affine_shear(device, interpolation, shear): _test_random_affine_helper(device, degrees=0.0, interpolation=interpolation, shear=shear) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('interpolation', [NEAREST, BILINEAR]) -@pytest.mark.parametrize('scale', [(0.7, 1.2), [0.7, 1.2]]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR]) +@pytest.mark.parametrize("scale", [(0.7, 1.2), [0.7, 1.2]]) def test_random_affine_scale(device, interpolation, scale): _test_random_affine_helper(device, degrees=0.0, interpolation=interpolation, scale=scale) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('interpolation', [NEAREST, BILINEAR]) -@pytest.mark.parametrize('translate', [(0.1, 0.2), [0.2, 0.1]]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR]) +@pytest.mark.parametrize("translate", [(0.1, 0.2), [0.2, 0.1]]) def test_random_affine_translate(device, interpolation, translate): _test_random_affine_helper(device, degrees=0.0, interpolation=interpolation, translate=translate) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('interpolation', [NEAREST, BILINEAR]) -@pytest.mark.parametrize('degrees', [45, 35.0, (-45, 45), [-90.0, 90.0]]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR]) +@pytest.mark.parametrize("degrees", [45, 35.0, (-45, 45), [-90.0, 90.0]]) def test_random_affine_degrees(device, interpolation, degrees): _test_random_affine_helper(device, degrees=degrees, interpolation=interpolation) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('interpolation', [NEAREST, BILINEAR]) -@pytest.mark.parametrize('fill', [85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR]) +@pytest.mark.parametrize( + "fill", + [ + 85, + (10, -10, 10), + 0.7, + [0.0, 0.0, 0.0], + [ + 1, + ], + 1, + ], +) def test_random_affine_fill(device, interpolation, fill): _test_random_affine_helper(device, degrees=0.0, interpolation=interpolation, fill=fill) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('center', [(0, 0), [10, 10], None, (56, 44)]) -@pytest.mark.parametrize('expand', [True, False]) -@pytest.mark.parametrize('degrees', [45, 35.0, (-45, 45), [-90.0, 90.0]]) -@pytest.mark.parametrize('interpolation', [NEAREST, BILINEAR]) -@pytest.mark.parametrize('fill', [85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("center", [(0, 0), [10, 10], None, (56, 44)]) +@pytest.mark.parametrize("expand", [True, False]) +@pytest.mark.parametrize("degrees", [45, 35.0, (-45, 45), [-90.0, 90.0]]) +@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR]) +@pytest.mark.parametrize( + "fill", + [ + 85, + (10, -10, 10), + 0.7, + [0.0, 0.0, 0.0], + [ + 1, + ], + 1, + ], +) def test_random_rotate(device, center, expand, degrees, interpolation, fill): tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device) batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device) - transform = T.RandomRotation( - degrees=degrees, interpolation=interpolation, expand=expand, center=center, fill=fill - ) + transform = T.RandomRotation(degrees=degrees, interpolation=interpolation, expand=expand, center=center, fill=fill) s_transform = torch.jit.script(transform) _test_transform_vs_scripted(transform, s_transform, tensor) @@ -450,19 +558,27 @@ def test_random_rotate_save(tmpdir): s_transform.save(os.path.join(tmpdir, "t_random_rotate.pt")) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('distortion_scale', np.linspace(0.1, 1.0, num=20)) -@pytest.mark.parametrize('interpolation', [NEAREST, BILINEAR]) -@pytest.mark.parametrize('fill', [85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("distortion_scale", np.linspace(0.1, 1.0, num=20)) +@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR]) +@pytest.mark.parametrize( + "fill", + [ + 85, + (10, -10, 10), + 0.7, + [0.0, 0.0, 0.0], + [ + 1, + ], + 1, + ], +) def test_random_perspective(device, distortion_scale, interpolation, fill): tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device) batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device) - transform = T.RandomPerspective( - distortion_scale=distortion_scale, - interpolation=interpolation, - fill=fill - ) + transform = T.RandomPerspective(distortion_scale=distortion_scale, interpolation=interpolation, fill=fill) s_transform = torch.jit.script(transform) _test_transform_vs_scripted(transform, s_transform, tensor) @@ -475,23 +591,19 @@ def test_random_perspective_save(tmpdir): s_transform.save(os.path.join(tmpdir, "t_perspective.pt")) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('Klass, meth_kwargs', [ - (T.Grayscale, {"num_output_channels": 1}), - (T.Grayscale, {"num_output_channels": 3}), - (T.RandomGrayscale, {}) -]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize( + "Klass, meth_kwargs", + [(T.Grayscale, {"num_output_channels": 1}), (T.Grayscale, {"num_output_channels": 3}), (T.RandomGrayscale, {})], +) def test_to_grayscale(device, Klass, meth_kwargs): tol = 1.0 + 1e-10 - _test_class_op( - Klass, meth_kwargs=meth_kwargs, test_exact_match=False, device=device, - tol=tol, agg_method="max" - ) + _test_class_op(Klass, meth_kwargs=meth_kwargs, test_exact_match=False, device=device, tol=tol, agg_method="max") -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('in_dtype', int_dtypes() + float_dtypes()) -@pytest.mark.parametrize('out_dtype', int_dtypes() + float_dtypes()) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("in_dtype", int_dtypes() + float_dtypes()) +@pytest.mark.parametrize("out_dtype", int_dtypes() + float_dtypes()) def test_convert_image_dtype(device, in_dtype, out_dtype): tensor, _ = _create_data(26, 34, device=device) batch_tensors = torch.rand(4, 3, 44, 56, device=device) @@ -502,8 +614,9 @@ def test_convert_image_dtype(device, in_dtype, out_dtype): fn = T.ConvertImageDtype(dtype=out_dtype) scripted_fn = torch.jit.script(fn) - if (in_dtype == torch.float32 and out_dtype in (torch.int32, torch.int64)) or \ - (in_dtype == torch.float64 and out_dtype == torch.int64): + if (in_dtype == torch.float32 and out_dtype in (torch.int32, torch.int64)) or ( + in_dtype == torch.float64 and out_dtype == torch.int64 + ): with pytest.raises(RuntimeError, match=r"cannot be performed safely"): _test_transform_vs_scripted(fn, scripted_fn, in_tensor) with pytest.raises(RuntimeError, match=r"cannot be performed safely"): @@ -520,9 +633,22 @@ def test_convert_image_dtype_save(tmpdir): scripted_fn.save(os.path.join(tmpdir, "t_convert_dtype.pt")) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('policy', [policy for policy in T.AutoAugmentPolicy]) -@pytest.mark.parametrize('fill', [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("policy", [policy for policy in T.AutoAugmentPolicy]) +@pytest.mark.parametrize( + "fill", + [ + None, + 85, + (10, -10, 10), + 0.7, + [0.0, 0.0, 0.0], + [ + 1, + ], + 1, + ], +) def test_autoaugment(device, policy, fill): tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device) batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device) @@ -534,10 +660,23 @@ def test_autoaugment(device, policy, fill): _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('num_ops', [1, 2, 3]) -@pytest.mark.parametrize('magnitude', [7, 9, 11]) -@pytest.mark.parametrize('fill', [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("num_ops", [1, 2, 3]) +@pytest.mark.parametrize("magnitude", [7, 9, 11]) +@pytest.mark.parametrize( + "fill", + [ + None, + 85, + (10, -10, 10), + 0.7, + [0.0, 0.0, 0.0], + [ + 1, + ], + 1, + ], +) def test_randaugment(device, num_ops, magnitude, fill): tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device) batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device) @@ -549,8 +688,21 @@ def test_randaugment(device, num_ops, magnitude, fill): _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('fill', [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize( + "fill", + [ + None, + 85, + (10, -10, 10), + 0.7, + [0.0, 0.0, 0.0], + [ + 1, + ], + 1, + ], +) def test_trivialaugmentwide(device, fill): tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device) batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device) @@ -562,21 +714,17 @@ def test_trivialaugmentwide(device, fill): _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) -@pytest.mark.parametrize('augmentation', [T.AutoAugment, T.RandAugment, T.TrivialAugmentWide]) +@pytest.mark.parametrize("augmentation", [T.AutoAugment, T.RandAugment, T.TrivialAugmentWide]) def test_autoaugment_save(augmentation, tmpdir): transform = augmentation() s_transform = torch.jit.script(transform) s_transform.save(os.path.join(tmpdir, "t_autoaugment.pt")) -@pytest.mark.parametrize('device', cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize( - 'config', [ - {"value": 0.2}, - {"value": "random"}, - {"value": (0.2, 0.2, 0.2)}, - {"value": "random", "ratio": (0.1, 0.2)} - ] + "config", + [{"value": 0.2}, {"value": "random"}, {"value": (0.2, 0.2, 0.2)}, {"value": "random", "ratio": (0.1, 0.2)}], ) def test_random_erasing(device, config): tensor, _ = _create_data(24, 32, channels=3, device=device) @@ -602,7 +750,7 @@ def test_random_erasing_with_invalid_data(): random_erasing(img) -@pytest.mark.parametrize('device', cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_gpu()) def test_normalize(device, tmpdir): fn = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) tensor, _ = _create_data(26, 34, device=device) @@ -621,7 +769,7 @@ def test_normalize(device, tmpdir): scripted_fn.save(os.path.join(tmpdir, "t_norm.pt")) -@pytest.mark.parametrize('device', cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_gpu()) def test_linear_transformation(device, tmpdir): c, h, w = 3, 24, 32 @@ -647,14 +795,16 @@ def test_linear_transformation(device, tmpdir): scripted_fn.save(os.path.join(tmpdir, "t_norm.pt")) -@pytest.mark.parametrize('device', cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_gpu()) def test_compose(device): tensor, _ = _create_data(26, 34, device=device) tensor = tensor.to(dtype=torch.float32) / 255.0 - transforms = T.Compose([ - T.CenterCrop(10), - T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), - ]) + transforms = T.Compose( + [ + T.CenterCrop(10), + T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + ] + ) s_transforms = torch.nn.Sequential(*transforms.transforms) scripted_fn = torch.jit.script(s_transforms) @@ -664,26 +814,36 @@ def test_compose(device): transformed_tensor_script = scripted_fn(tensor) assert_equal(transformed_tensor, transformed_tensor_script, msg="{}".format(transforms)) - t = T.Compose([ - lambda x: x, - ]) + t = T.Compose( + [ + lambda x: x, + ] + ) with pytest.raises(RuntimeError, match="cannot call a value of type 'Tensor'"): torch.jit.script(t) -@pytest.mark.parametrize('device', cpu_and_gpu()) +@pytest.mark.parametrize("device", cpu_and_gpu()) def test_random_apply(device): tensor, _ = _create_data(26, 34, device=device) tensor = tensor.to(dtype=torch.float32) / 255.0 - transforms = T.RandomApply([ - T.RandomHorizontalFlip(), - T.ColorJitter(), - ], p=0.4) - s_transforms = T.RandomApply(torch.nn.ModuleList([ - T.RandomHorizontalFlip(), - T.ColorJitter(), - ]), p=0.4) + transforms = T.RandomApply( + [ + T.RandomHorizontalFlip(), + T.ColorJitter(), + ], + p=0.4, + ) + s_transforms = T.RandomApply( + torch.nn.ModuleList( + [ + T.RandomHorizontalFlip(), + T.ColorJitter(), + ] + ), + p=0.4, + ) scripted_fn = torch.jit.script(s_transforms) torch.manual_seed(12) @@ -695,27 +855,38 @@ def test_random_apply(device): if device == "cpu": # Can't check this twice, otherwise # "Can't redefine method: forward on class: __torch__.torchvision.transforms.transforms.RandomApply" - transforms = T.RandomApply([ - T.ColorJitter(), - ], p=0.3) + transforms = T.RandomApply( + [ + T.ColorJitter(), + ], + p=0.3, + ) with pytest.raises(RuntimeError, match="Module 'RandomApply' has no attribute 'transforms'"): torch.jit.script(transforms) -@pytest.mark.parametrize('device', cpu_and_gpu()) -@pytest.mark.parametrize('meth_kwargs', [ - {"kernel_size": 3, "sigma": 0.75}, - {"kernel_size": 23, "sigma": [0.1, 2.0]}, - {"kernel_size": 23, "sigma": (0.1, 2.0)}, - {"kernel_size": [3, 3], "sigma": (1.0, 1.0)}, - {"kernel_size": (3, 3), "sigma": (0.1, 2.0)}, - {"kernel_size": [23], "sigma": 0.75} -]) -@pytest.mark.parametrize('channels', [1, 3]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize( + "meth_kwargs", + [ + {"kernel_size": 3, "sigma": 0.75}, + {"kernel_size": 23, "sigma": [0.1, 2.0]}, + {"kernel_size": 23, "sigma": (0.1, 2.0)}, + {"kernel_size": [3, 3], "sigma": (1.0, 1.0)}, + {"kernel_size": (3, 3), "sigma": (0.1, 2.0)}, + {"kernel_size": [23], "sigma": 0.75}, + ], +) +@pytest.mark.parametrize("channels", [1, 3]) def test_gaussian_blur(device, channels, meth_kwargs): tol = 1.0 + 1e-10 torch.manual_seed(12) _test_class_op( - T.GaussianBlur, meth_kwargs=meth_kwargs, channels=channels, - test_exact_match=False, device=device, agg_method="max", tol=tol + T.GaussianBlur, + meth_kwargs=meth_kwargs, + channels=channels, + test_exact_match=False, + device=device, + agg_method="max", + tol=tol, ) diff --git a/test/test_transforms_video.py b/test/test_transforms_video.py index 975b425f6a5..a3bd8528abf 100644 --- a/test/test_transforms_video.py +++ b/test/test_transforms_video.py @@ -1,10 +1,11 @@ -import torch -from torchvision.transforms import Compose -import pytest import random -import numpy as np import warnings + +import numpy as np +import pytest +import torch from common_utils import assert_equal +from torchvision.transforms import Compose try: from scipy import stats @@ -17,8 +18,7 @@ import torchvision.transforms._transforms_video as transforms -class TestVideoTransforms(): - +class TestVideoTransforms: def test_random_crop_video(self): numFrames = random.randint(4, 128) height = random.randint(10, 32) * 2 @@ -26,10 +26,12 @@ def test_random_crop_video(self): oheight = random.randint(5, (height - 2) / 2) * 2 owidth = random.randint(5, (width - 2) / 2) * 2 clip = torch.randint(0, 256, (numFrames, height, width, 3), dtype=torch.uint8) - result = Compose([ - transforms.ToTensorVideo(), - transforms.RandomCropVideo((oheight, owidth)), - ])(clip) + result = Compose( + [ + transforms.ToTensorVideo(), + transforms.RandomCropVideo((oheight, owidth)), + ] + )(clip) assert result.size(2) == oheight assert result.size(3) == owidth @@ -42,10 +44,12 @@ def test_random_resized_crop_video(self): oheight = random.randint(5, (height - 2) / 2) * 2 owidth = random.randint(5, (width - 2) / 2) * 2 clip = torch.randint(0, 256, (numFrames, height, width, 3), dtype=torch.uint8) - result = Compose([ - transforms.ToTensorVideo(), - transforms.RandomResizedCropVideo((oheight, owidth)), - ])(clip) + result = Compose( + [ + transforms.ToTensorVideo(), + transforms.RandomResizedCropVideo((oheight, owidth)), + ] + )(clip) assert result.size(2) == oheight assert result.size(3) == owidth @@ -61,47 +65,56 @@ def test_center_crop_video(self): clip = torch.ones((numFrames, height, width, 3), dtype=torch.uint8) * 255 oh1 = (height - oheight) // 2 ow1 = (width - owidth) // 2 - clipNarrow = clip[:, oh1:oh1 + oheight, ow1:ow1 + owidth, :] + clipNarrow = clip[:, oh1 : oh1 + oheight, ow1 : ow1 + owidth, :] clipNarrow.fill_(0) - result = Compose([ - transforms.ToTensorVideo(), - transforms.CenterCropVideo((oheight, owidth)), - ])(clip) - - msg = "height: " + str(height) + " width: " \ - + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) + result = Compose( + [ + transforms.ToTensorVideo(), + transforms.CenterCropVideo((oheight, owidth)), + ] + )(clip) + + msg = ( + "height: " + str(height) + " width: " + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) + ) assert result.sum().item() == 0, msg oheight += 1 owidth += 1 - result = Compose([ - transforms.ToTensorVideo(), - transforms.CenterCropVideo((oheight, owidth)), - ])(clip) + result = Compose( + [ + transforms.ToTensorVideo(), + transforms.CenterCropVideo((oheight, owidth)), + ] + )(clip) sum1 = result.sum() - msg = "height: " + str(height) + " width: " \ - + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) + msg = ( + "height: " + str(height) + " width: " + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) + ) assert sum1.item() > 1, msg oheight += 1 owidth += 1 - result = Compose([ - transforms.ToTensorVideo(), - transforms.CenterCropVideo((oheight, owidth)), - ])(clip) + result = Compose( + [ + transforms.ToTensorVideo(), + transforms.CenterCropVideo((oheight, owidth)), + ] + )(clip) sum2 = result.sum() - msg = "height: " + str(height) + " width: " \ - + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) + msg = ( + "height: " + str(height) + " width: " + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) + ) assert sum2.item() > 1, msg assert sum2.item() > sum1.item(), msg - @pytest.mark.skipif(stats is None, reason='scipy.stats is not available') - @pytest.mark.parametrize('channels', [1, 3]) + @pytest.mark.skipif(stats is None, reason="scipy.stats is not available") + @pytest.mark.parametrize("channels", [1, 3]) def test_normalize_video(self, channels): def samples_from_standard_normal(tensor): - p_value = stats.kstest(list(tensor.view(-1)), 'norm', args=(0, 1)).pvalue + p_value = stats.kstest(list(tensor.view(-1)), "norm", args=(0, 1)).pvalue return p_value > 0.0001 random_state = random.getstate() @@ -147,7 +160,7 @@ def test_to_tensor_video(self): trans.__repr__() - @pytest.mark.skipif(stats is None, reason='scipy.stats not available') + @pytest.mark.skipif(stats is None, reason="scipy.stats not available") def test_random_horizontal_flip_video(self): random_state = random.getstate() random.seed(42) @@ -179,5 +192,5 @@ def test_random_horizontal_flip_video(self): transforms.RandomHorizontalFlipVideo().__repr__() -if __name__ == '__main__': +if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/test_utils.py b/test/test_utils.py index 37829b906f1..5c0502d7bb5 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,21 +1,20 @@ -import pytest -import numpy as np import os import sys import tempfile -import torch -import torchvision.utils as utils - from io import BytesIO + +import numpy as np +import pytest +import torch import torchvision.transforms.functional as F -from PIL import Image, __version__ as PILLOW_VERSION, ImageColor +import torchvision.utils as utils from common_utils import assert_equal +from PIL import Image, __version__ as PILLOW_VERSION, ImageColor -PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split('.')) +PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split(".")) -boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], - [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) +boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) def test_make_grid_not_inplace(): @@ -23,13 +22,13 @@ def test_make_grid_not_inplace(): t_clone = t.clone() utils.make_grid(t, normalize=False) - assert_equal(t, t_clone, msg='make_grid modified tensor in-place') + assert_equal(t, t_clone, msg="make_grid modified tensor in-place") utils.make_grid(t, normalize=True, scale_each=False) - assert_equal(t, t_clone, msg='make_grid modified tensor in-place') + assert_equal(t, t_clone, msg="make_grid modified tensor in-place") utils.make_grid(t, normalize=True, scale_each=True) - assert_equal(t, t_clone, msg='make_grid modified tensor in-place') + assert_equal(t, t_clone, msg="make_grid modified tensor in-place") def test_normalize_in_make_grid(): @@ -46,48 +45,48 @@ def test_normalize_in_make_grid(): rounded_grid_max = torch.round(grid_max * 10 ** n_digits) / (10 ** n_digits) rounded_grid_min = torch.round(grid_min * 10 ** n_digits) / (10 ** n_digits) - assert_equal(norm_max, rounded_grid_max, msg='Normalized max is not equal to 1') - assert_equal(norm_min, rounded_grid_min, msg='Normalized min is not equal to 0') + assert_equal(norm_max, rounded_grid_max, msg="Normalized max is not equal to 1") + assert_equal(norm_min, rounded_grid_min, msg="Normalized min is not equal to 0") -@pytest.mark.skipif(sys.platform in ('win32', 'cygwin'), reason='temporarily disabled on Windows') +@pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows") def test_save_image(): - with tempfile.NamedTemporaryFile(suffix='.png') as f: + with tempfile.NamedTemporaryFile(suffix=".png") as f: t = torch.rand(2, 3, 64, 64) utils.save_image(t, f.name) - assert os.path.exists(f.name), 'The image is not present after save' + assert os.path.exists(f.name), "The image is not present after save" -@pytest.mark.skipif(sys.platform in ('win32', 'cygwin'), reason='temporarily disabled on Windows') +@pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows") def test_save_image_single_pixel(): - with tempfile.NamedTemporaryFile(suffix='.png') as f: + with tempfile.NamedTemporaryFile(suffix=".png") as f: t = torch.rand(1, 3, 1, 1) utils.save_image(t, f.name) - assert os.path.exists(f.name), 'The pixel image is not present after save' + assert os.path.exists(f.name), "The pixel image is not present after save" -@pytest.mark.skipif(sys.platform in ('win32', 'cygwin'), reason='temporarily disabled on Windows') +@pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows") def test_save_image_file_object(): - with tempfile.NamedTemporaryFile(suffix='.png') as f: + with tempfile.NamedTemporaryFile(suffix=".png") as f: t = torch.rand(2, 3, 64, 64) utils.save_image(t, f.name) img_orig = Image.open(f.name) fp = BytesIO() - utils.save_image(t, fp, format='png') + utils.save_image(t, fp, format="png") img_bytes = Image.open(fp) - assert_equal(F.to_tensor(img_orig), F.to_tensor(img_bytes), msg='Image not stored in file object') + assert_equal(F.to_tensor(img_orig), F.to_tensor(img_bytes), msg="Image not stored in file object") -@pytest.mark.skipif(sys.platform in ('win32', 'cygwin'), reason='temporarily disabled on Windows') +@pytest.mark.skipif(sys.platform in ("win32", "cygwin"), reason="temporarily disabled on Windows") def test_save_image_single_pixel_file_object(): - with tempfile.NamedTemporaryFile(suffix='.png') as f: + with tempfile.NamedTemporaryFile(suffix=".png") as f: t = torch.rand(1, 3, 1, 1) utils.save_image(t, f.name) img_orig = Image.open(f.name) fp = BytesIO() - utils.save_image(t, fp, format='png') + utils.save_image(t, fp, format="png") img_bytes = Image.open(fp) - assert_equal(F.to_tensor(img_orig), F.to_tensor(img_bytes), msg='Image not stored in file object') + assert_equal(F.to_tensor(img_orig), F.to_tensor(img_bytes), msg="Image not stored in file object") def test_draw_boxes(): @@ -113,13 +112,7 @@ def test_draw_boxes(): assert_equal(img, img_cp) -@pytest.mark.parametrize('colors', [ - None, - ['red', 'blue', '#FF00FF', (1, 34, 122)], - 'red', - '#FF00FF', - (1, 34, 122) -]) +@pytest.mark.parametrize("colors", [None, ["red", "blue", "#FF00FF", (1, 34, 122)], "red", "#FF00FF", (1, 34, 122)]) def test_draw_boxes_colors(colors): img = torch.full((3, 100, 100), 0, dtype=torch.uint8) utils.draw_bounding_boxes(img, boxes, fill=False, width=7, colors=colors) @@ -154,8 +147,7 @@ def test_draw_invalid_boxes(): img_tp = ((1, 1, 1), (1, 2, 3)) img_wrong1 = torch.full((3, 5, 5), 255, dtype=torch.float) img_wrong2 = torch.full((1, 3, 5, 5), 255, dtype=torch.uint8) - boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], - [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) + boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float) with pytest.raises(TypeError, match="Tensor expected"): utils.draw_bounding_boxes(img_tp, boxes) with pytest.raises(ValueError, match="Tensor uint8 expected"): @@ -166,12 +158,15 @@ def test_draw_invalid_boxes(): utils.draw_bounding_boxes(img_wrong2[0][:2], boxes) -@pytest.mark.parametrize('colors', [ - None, - ['red', 'blue'], - ['#FF00FF', (1, 34, 122)], -]) -@pytest.mark.parametrize('alpha', (0, .5, .7, 1)) +@pytest.mark.parametrize( + "colors", + [ + None, + ["red", "blue"], + ["#FF00FF", (1, 34, 122)], + ], +) +@pytest.mark.parametrize("alpha", (0, 0.5, 0.7, 1)) def test_draw_segmentation_masks(colors, alpha): """This test makes sure that masks draw their corresponding color where they should""" num_masks, h, w = 2, 100, 100 @@ -241,10 +236,10 @@ def test_draw_segmentation_masks_errors(): with pytest.raises(ValueError, match="There are more masks"): utils.draw_segmentation_masks(image=img, masks=masks, colors=[]) with pytest.raises(ValueError, match="colors must be a tuple or a string, or a list thereof"): - bad_colors = np.array(['red', 'blue']) # should be a list + bad_colors = np.array(["red", "blue"]) # should be a list utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors) with pytest.raises(ValueError, match="It seems that you passed a tuple of colors instead of"): - bad_colors = ('red', 'blue') # should be a list + bad_colors = ("red", "blue") # should be a list utils.draw_segmentation_masks(image=img, masks=masks, colors=bad_colors) diff --git a/test/test_video_reader.py b/test/test_video_reader.py index 41ca3e9b08a..282ce653322 100644 --- a/test/test_video_reader.py +++ b/test/test_video_reader.py @@ -2,17 +2,17 @@ import itertools import math import os -import pytest -from pytest import approx from fractions import Fraction import numpy as np +import pytest import torch import torchvision.io as io +from common_utils import assert_equal from numpy.random import randint +from pytest import approx from torchvision import set_video_backend from torchvision.io import _HAS_VIDEO_OPT -from common_utils import assert_equal try: @@ -108,18 +108,14 @@ } -DecoderResult = collections.namedtuple( - "DecoderResult", "vframes vframe_pts vtimebase aframes aframe_pts atimebase" -) +DecoderResult = collections.namedtuple("DecoderResult", "vframes vframe_pts vtimebase aframes aframe_pts atimebase") # av_seek_frame is imprecise so seek to a timestamp earlier by a margin # The unit of margin is second seek_frame_margin = 0.25 -def _read_from_stream( - container, start_pts, end_pts, stream, stream_name, buffer_size=4 -): +def _read_from_stream(container, start_pts, end_pts, stream, stream_name, buffer_size=4): """ Args: container: pyav container @@ -231,9 +227,7 @@ def _decode_frames_by_av_module( else: aframes = torch.empty((1, 0), dtype=torch.float32) - aframe_pts = torch.tensor( - [audio_frame.pts for audio_frame in audio_frames], dtype=torch.int64 - ) + aframe_pts = torch.tensor([audio_frame.pts for audio_frame in audio_frames], dtype=torch.int64) return DecoderResult( vframes=vframes, @@ -273,25 +267,28 @@ def _get_video_tensor(video_dir, video_file): @pytest.mark.skipif(_HAS_VIDEO_OPT is False, reason="Didn't compile with ffmpeg") class TestVideoReader: def check_separate_decoding_result(self, tv_result, config): - """check the decoding results from TorchVision decoder - """ - vframes, vframe_pts, vtimebase, vfps, vduration, \ - aframes, aframe_pts, atimebase, asample_rate, aduration = ( - tv_result - ) - - video_duration = vduration.item() * Fraction( - vtimebase[0].item(), vtimebase[1].item() - ) + """check the decoding results from TorchVision decoder""" + ( + vframes, + vframe_pts, + vtimebase, + vfps, + vduration, + aframes, + aframe_pts, + atimebase, + asample_rate, + aduration, + ) = tv_result + + video_duration = vduration.item() * Fraction(vtimebase[0].item(), vtimebase[1].item()) assert video_duration == approx(config.duration, abs=0.5) assert vfps.item() == approx(config.video_fps, abs=0.5) if asample_rate.numel() > 0: assert asample_rate.item() == config.audio_sample_rate - audio_duration = aduration.item() * Fraction( - atimebase[0].item(), atimebase[1].item() - ) + audio_duration = aduration.item() * Fraction(atimebase[0].item(), atimebase[1].item()) assert audio_duration == approx(config.duration, abs=0.5) # check if pts of video frames are sorted in ascending order @@ -305,16 +302,12 @@ def check_separate_decoding_result(self, tv_result, config): def check_probe_result(self, result, config): vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result - video_duration = vduration.item() * Fraction( - vtimebase[0].item(), vtimebase[1].item() - ) + video_duration = vduration.item() * Fraction(vtimebase[0].item(), vtimebase[1].item()) assert video_duration == approx(config.duration, abs=0.5) assert vfps.item() == approx(config.video_fps, abs=0.5) if asample_rate.numel() > 0: assert asample_rate.item() == config.audio_sample_rate - audio_duration = aduration.item() * Fraction( - atimebase[0].item(), atimebase[1].item() - ) + audio_duration = aduration.item() * Fraction(atimebase[0].item(), atimebase[1].item()) assert audio_duration == approx(config.duration, abs=0.5) def check_meta_result(self, result, config): @@ -333,10 +326,18 @@ def compare_decoding_result(self, tv_result, ref_result, config=all_check_config decoder or TorchVision decoder with getPtsOnly = 1 config: config of decoding results checker """ - vframes, vframe_pts, vtimebase, _vfps, _vduration, \ - aframes, aframe_pts, atimebase, _asample_rate, _aduration = ( - tv_result - ) + ( + vframes, + vframe_pts, + vtimebase, + _vfps, + _vduration, + aframes, + aframe_pts, + atimebase, + _asample_rate, + _aduration, + ) = tv_result if isinstance(ref_result, list): # the ref_result is from new video_reader decoder ref_result = DecoderResult( @@ -349,32 +350,20 @@ def compare_decoding_result(self, tv_result, ref_result, config=all_check_config ) if vframes.numel() > 0 and ref_result.vframes.numel() > 0: - mean_delta = torch.mean( - torch.abs(vframes.float() - ref_result.vframes.float()) - ) + mean_delta = torch.mean(torch.abs(vframes.float() - ref_result.vframes.float())) assert mean_delta == approx(0.0, abs=8.0) - mean_delta = torch.mean( - torch.abs(vframe_pts.float() - ref_result.vframe_pts.float()) - ) + mean_delta = torch.mean(torch.abs(vframe_pts.float() - ref_result.vframe_pts.float())) assert mean_delta == approx(0.0, abs=1.0) assert_equal(vtimebase, ref_result.vtimebase) - if ( - config.check_aframes - and aframes.numel() > 0 - and ref_result.aframes.numel() > 0 - ): + if config.check_aframes and aframes.numel() > 0 and ref_result.aframes.numel() > 0: """Audio stream is available and audio frame is required to return from decoder""" assert_equal(aframes, ref_result.aframes) - if ( - config.check_aframe_pts - and aframe_pts.numel() > 0 - and ref_result.aframe_pts.numel() > 0 - ): + if config.check_aframe_pts and aframe_pts.numel() > 0 and ref_result.aframe_pts.numel() > 0: """Audio stream is available""" assert_equal(aframe_pts, ref_result.aframe_pts) @@ -508,19 +497,25 @@ def test_read_video_from_file_read_single_stream_only(self): audio_timebase_den, ) - vframes, vframe_pts, vtimebase, vfps, vduration, \ - aframes, aframe_pts, atimebase, asample_rate, aduration = ( - tv_result - ) + ( + vframes, + vframe_pts, + vtimebase, + vfps, + vduration, + aframes, + aframe_pts, + atimebase, + asample_rate, + aduration, + ) = tv_result assert (vframes.numel() > 0) is bool(readVideoStream) assert (vframe_pts.numel() > 0) is bool(readVideoStream) assert (vtimebase.numel() > 0) is bool(readVideoStream) assert (vfps.numel() > 0) is bool(readVideoStream) - expect_audio_data = ( - readAudioStream == 1 and config.audio_sample_rate is not None - ) + expect_audio_data = readAudioStream == 1 and config.audio_sample_rate is not None assert (aframes.numel() > 0) is bool(expect_audio_data) assert (aframe_pts.numel() > 0) is bool(expect_audio_data) assert (atimebase.numel() > 0) is bool(expect_audio_data) @@ -808,19 +803,23 @@ def test_read_video_from_file_audio_resampling(self): audio_timebase_num, audio_timebase_den, ) - vframes, vframe_pts, vtimebase, vfps, vduration, \ - aframes, aframe_pts, atimebase, asample_rate, aduration = ( - tv_result - ) + ( + vframes, + vframe_pts, + vtimebase, + vfps, + vduration, + aframes, + aframe_pts, + atimebase, + asample_rate, + aduration, + ) = tv_result if aframes.numel() > 0: assert samples == asample_rate.item() assert 1 == aframes.size(1) # when audio stream is found - duration = ( - float(aframe_pts[-1]) - * float(atimebase[0]) - / float(atimebase[1]) - ) + duration = float(aframe_pts[-1]) * float(atimebase[0]) / float(atimebase[1]) assert aframes.size(0) == approx(int(duration * asample_rate.item()), abs=0.1 * asample_rate.item()) def test_compare_read_video_from_memory_and_file(self): @@ -1040,10 +1039,18 @@ def test_read_video_in_range_from_memory(self): audio_timebase_num, audio_timebase_den, ) - vframes, vframe_pts, vtimebase, vfps, vduration, \ - aframes, aframe_pts, atimebase, asample_rate, aduration = ( - tv_result - ) + ( + vframes, + vframe_pts, + vtimebase, + vfps, + vduration, + aframes, + aframe_pts, + atimebase, + asample_rate, + aduration, + ) = tv_result assert abs(config.video_fps - vfps.item()) < 0.01 for num_frames in [4, 8, 16, 32, 64, 128]: @@ -1097,41 +1104,31 @@ def test_read_video_in_range_from_memory(self): ) # pass 3: decode frames in range using PyAv - video_timebase_av, audio_timebase_av = _get_timebase_by_av_module( - full_path - ) + video_timebase_av, audio_timebase_av = _get_timebase_by_av_module(full_path) video_start_pts_av = _pts_convert( video_start_pts.item(), Fraction(video_timebase_num.item(), video_timebase_den.item()), - Fraction( - video_timebase_av.numerator, video_timebase_av.denominator - ), + Fraction(video_timebase_av.numerator, video_timebase_av.denominator), math.floor, ) video_end_pts_av = _pts_convert( video_end_pts.item(), Fraction(video_timebase_num.item(), video_timebase_den.item()), - Fraction( - video_timebase_av.numerator, video_timebase_av.denominator - ), + Fraction(video_timebase_av.numerator, video_timebase_av.denominator), math.ceil, ) if audio_timebase_av: audio_start_pts = _pts_convert( video_start_pts.item(), Fraction(video_timebase_num.item(), video_timebase_den.item()), - Fraction( - audio_timebase_av.numerator, audio_timebase_av.denominator - ), + Fraction(audio_timebase_av.numerator, audio_timebase_av.denominator), math.floor, ) audio_end_pts = _pts_convert( video_end_pts.item(), Fraction(video_timebase_num.item(), video_timebase_den.item()), - Fraction( - audio_timebase_av.numerator, audio_timebase_av.denominator - ), + Fraction(audio_timebase_av.numerator, audio_timebase_av.denominator), math.ceil, ) @@ -1218,46 +1215,42 @@ def test_read_video_from_memory_scripted(self): # FUTURE: check value of video / audio frames def test_invalid_file(self): - set_video_backend('video_reader') + set_video_backend("video_reader") with pytest.raises(RuntimeError): - io.read_video('foo.mp4') + io.read_video("foo.mp4") - set_video_backend('pyav') + set_video_backend("pyav") with pytest.raises(RuntimeError): - io.read_video('foo.mp4') + io.read_video("foo.mp4") def test_audio_present_pts(self): """Test if audio frames are returned with pts unit.""" - backends = ['video_reader', 'pyav'] + backends = ["video_reader", "pyav"] start_offsets = [0, 1000] end_offsets = [3000, None] for test_video, _ in test_videos.items(): full_path = os.path.join(VIDEO_DIR, test_video) container = av.open(full_path) if container.streams.audio: - for backend, start_offset, end_offset in itertools.product( - backends, start_offsets, end_offsets): + for backend, start_offset, end_offset in itertools.product(backends, start_offsets, end_offsets): set_video_backend(backend) - _, audio, _ = io.read_video( - full_path, start_offset, end_offset, pts_unit='pts') + _, audio, _ = io.read_video(full_path, start_offset, end_offset, pts_unit="pts") assert all([dimension > 0 for dimension in audio.shape[:2]]) def test_audio_present_sec(self): """Test if audio frames are returned with sec unit.""" - backends = ['video_reader', 'pyav'] + backends = ["video_reader", "pyav"] start_offsets = [0, 0.1] end_offsets = [0.3, None] for test_video, _ in test_videos.items(): full_path = os.path.join(VIDEO_DIR, test_video) container = av.open(full_path) if container.streams.audio: - for backend, start_offset, end_offset in itertools.product( - backends, start_offsets, end_offsets): + for backend, start_offset, end_offset in itertools.product(backends, start_offsets, end_offsets): set_video_backend(backend) - _, audio, _ = io.read_video( - full_path, start_offset, end_offset, pts_unit='sec') + _, audio, _ = io.read_video(full_path, start_offset, end_offset, pts_unit="sec") assert all([dimension > 0 for dimension in audio.shape[:2]]) -if __name__ == '__main__': +if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/test_videoapi.py b/test/test_videoapi.py index 1bcce29670d..2f933ee7f73 100644 --- a/test/test_videoapi.py +++ b/test/test_videoapi.py @@ -1,13 +1,13 @@ import collections import os -import pytest -from pytest import approx import urllib +import pytest import torch import torchvision -from torchvision.io import _HAS_VIDEO_OPT, VideoReader +from pytest import approx from torchvision.datasets.utils import download_url +from torchvision.io import _HAS_VIDEO_OPT, VideoReader try: @@ -36,30 +36,16 @@ def fate(name, path="."): test_videos = { - "RATRACE_wave_f_nm_np1_fr_goo_37.avi": GroundTruth( - duration=2.0, video_fps=30.0, audio_sample_rate=None - ), + "RATRACE_wave_f_nm_np1_fr_goo_37.avi": GroundTruth(duration=2.0, video_fps=30.0, audio_sample_rate=None), "SchoolRulesHowTheyHelpUs_wave_f_nm_np1_ba_med_0.avi": GroundTruth( duration=2.0, video_fps=30.0, audio_sample_rate=None ), - "TrumanShow_wave_f_nm_np1_fr_med_26.avi": GroundTruth( - duration=2.0, video_fps=30.0, audio_sample_rate=None - ), - "v_SoccerJuggling_g23_c01.avi": GroundTruth( - duration=8.0, video_fps=29.97, audio_sample_rate=None - ), - "v_SoccerJuggling_g24_c01.avi": GroundTruth( - duration=8.0, video_fps=29.97, audio_sample_rate=None - ), - "R6llTwEh07w.mp4": GroundTruth( - duration=10.0, video_fps=30.0, audio_sample_rate=44100 - ), - "SOX5yA1l24A.mp4": GroundTruth( - duration=11.0, video_fps=29.97, audio_sample_rate=48000 - ), - "WUzgd7C1pWA.mp4": GroundTruth( - duration=11.0, video_fps=29.97, audio_sample_rate=48000 - ), + "TrumanShow_wave_f_nm_np1_fr_med_26.avi": GroundTruth(duration=2.0, video_fps=30.0, audio_sample_rate=None), + "v_SoccerJuggling_g23_c01.avi": GroundTruth(duration=8.0, video_fps=29.97, audio_sample_rate=None), + "v_SoccerJuggling_g24_c01.avi": GroundTruth(duration=8.0, video_fps=29.97, audio_sample_rate=None), + "R6llTwEh07w.mp4": GroundTruth(duration=10.0, video_fps=30.0, audio_sample_rate=44100), + "SOX5yA1l24A.mp4": GroundTruth(duration=11.0, video_fps=29.97, audio_sample_rate=48000), + "WUzgd7C1pWA.mp4": GroundTruth(duration=11.0, video_fps=29.97, audio_sample_rate=48000), } @@ -79,13 +65,9 @@ def test_frame_reading(self): assert float(av_frame.pts * av_frame.time_base) == approx(vr_frame["pts"], abs=0.1) - av_array = torch.tensor(av_frame.to_rgb().to_ndarray()).permute( - 2, 0, 1 - ) + av_array = torch.tensor(av_frame.to_rgb().to_ndarray()).permute(2, 0, 1) vr_array = vr_frame["data"] - mean_delta = torch.mean( - torch.abs(av_array.float() - vr_array.float()) - ) + mean_delta = torch.mean(torch.abs(av_array.float() - vr_array.float())) # on average the difference is very small and caused # by decoding (around 1%) # TODO: asses empirically how to set this? atm it's 1% @@ -102,9 +84,7 @@ def test_frame_reading(self): av_array = torch.tensor(av_frame.to_ndarray()).permute(1, 0) vr_array = vr_frame["data"] - max_delta = torch.max( - torch.abs(av_array.float() - vr_array.float()) - ) + max_delta = torch.max(torch.abs(av_array.float() - vr_array.float())) # we assure that there is never more than 1% difference in signal assert max_delta.item() < 0.001 @@ -188,5 +168,5 @@ def test_fate_suite(self): os.remove(video_path) -if __name__ == '__main__': +if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/tracing/frcnn/trace_model.py b/test/tracing/frcnn/trace_model.py index 34961e8684f..8cc1d344936 100644 --- a/test/tracing/frcnn/trace_model.py +++ b/test/tracing/frcnn/trace_model.py @@ -1,4 +1,3 @@ - import os.path as osp import torch diff --git a/torchvision/__init__.py b/torchvision/__init__.py index a5a6f568151..03dc20c5c54 100644 --- a/torchvision/__init__.py +++ b/torchvision/__init__.py @@ -1,16 +1,15 @@ -import warnings import os +import warnings -from .extension import _HAS_OPS - -from torchvision import models +import torch from torchvision import datasets +from torchvision import io +from torchvision import models from torchvision import ops from torchvision import transforms from torchvision import utils -from torchvision import io -import torch +from .extension import _HAS_OPS try: from .version import __version__ # noqa: F401 @@ -18,14 +17,17 @@ pass # Check if torchvision is being imported within the root folder -if (not _HAS_OPS and os.path.dirname(os.path.realpath(__file__)) == - os.path.join(os.path.realpath(os.getcwd()), 'torchvision')): - message = ('You are importing torchvision within its own root folder ({}). ' - 'This is not expected to work and may give errors. Please exit the ' - 'torchvision project source and relaunch your python interpreter.') +if not _HAS_OPS and os.path.dirname(os.path.realpath(__file__)) == os.path.join( + os.path.realpath(os.getcwd()), "torchvision" +): + message = ( + "You are importing torchvision within its own root folder ({}). " + "This is not expected to work and may give errors. Please exit the " + "torchvision project source and relaunch your python interpreter." + ) warnings.warn(message.format(os.getcwd())) -_image_backend = 'PIL' +_image_backend = "PIL" _video_backend = "pyav" @@ -40,9 +42,8 @@ def set_image_backend(backend): generally faster than PIL, but does not support as many operations. """ global _image_backend - if backend not in ['PIL', 'accimage']: - raise ValueError("Invalid backend '{}'. Options are 'PIL' and 'accimage'" - .format(backend)) + if backend not in ["PIL", "accimage"]: + raise ValueError("Invalid backend '{}'. Options are 'PIL' and 'accimage'".format(backend)) _image_backend = backend @@ -71,14 +72,9 @@ def set_video_backend(backend): """ global _video_backend if backend not in ["pyav", "video_reader"]: - raise ValueError( - "Invalid video backend '%s'. Options are 'pyav' and 'video_reader'" % backend - ) + raise ValueError("Invalid video backend '%s'. Options are 'pyav' and 'video_reader'" % backend) if backend == "video_reader" and not io._HAS_VIDEO_OPT: - message = ( - "video_reader video backend is not available." - " Please compile torchvision from source and try again" - ) + message = "video_reader video backend is not available." " Please compile torchvision from source and try again" warnings.warn(message) else: _video_backend = backend diff --git a/torchvision/_internally_replaced_utils.py b/torchvision/_internally_replaced_utils.py index 0ab3e4e3f15..b37ba8acbc3 100644 --- a/torchvision/_internally_replaced_utils.py +++ b/torchvision/_internally_replaced_utils.py @@ -1,5 +1,5 @@ -import os import importlib.machinery +import os def _download_file_from_remote_location(fpath: str, url: str) -> None: @@ -19,13 +19,13 @@ def _is_remote_location_available() -> bool: def _get_extension_path(lib_name): lib_dir = os.path.dirname(__file__) - if os.name == 'nt': + if os.name == "nt": # Register the main torchvision library location on the default DLL path import ctypes import sys - kernel32 = ctypes.WinDLL('kernel32.dll', use_last_error=True) - with_load_library_flags = hasattr(kernel32, 'AddDllDirectory') + kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True) + with_load_library_flags = hasattr(kernel32, "AddDllDirectory") prev_error_mode = kernel32.SetErrorMode(0x0001) if with_load_library_flags: @@ -42,10 +42,7 @@ def _get_extension_path(lib_name): kernel32.SetErrorMode(prev_error_mode) - loader_details = ( - importlib.machinery.ExtensionFileLoader, - importlib.machinery.EXTENSION_SUFFIXES - ) + loader_details = (importlib.machinery.ExtensionFileLoader, importlib.machinery.EXTENSION_SUFFIXES) extfinder = importlib.machinery.FileFinder(lib_dir, loader_details) ext_specs = extfinder.find_spec(lib_name) diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index b4298486491..72a73d1d51b 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -1,41 +1,74 @@ -from .lsun import LSUN, LSUNClass -from .folder import ImageFolder, DatasetFolder -from .coco import CocoCaptions, CocoDetection +from .caltech import Caltech101, Caltech256 +from .celeba import CelebA from .cifar import CIFAR10, CIFAR100 -from .stl10 import STL10 -from .mnist import MNIST, EMNIST, FashionMNIST, KMNIST, QMNIST -from .svhn import SVHN -from .phototour import PhotoTour +from .cityscapes import Cityscapes +from .coco import CocoCaptions, CocoDetection from .fakedata import FakeData -from .semeion import SEMEION -from .omniglot import Omniglot -from .sbu import SBU from .flickr import Flickr8k, Flickr30k -from .voc import VOCSegmentation, VOCDetection -from .cityscapes import Cityscapes +from .folder import ImageFolder, DatasetFolder +from .hmdb51 import HMDB51 from .imagenet import ImageNet -from .caltech import Caltech101, Caltech256 -from .celeba import CelebA -from .widerface import WIDERFace -from .sbd import SBDataset -from .vision import VisionDataset -from .usps import USPS +from .inaturalist import INaturalist from .kinetics import Kinetics400, Kinetics -from .hmdb51 import HMDB51 -from .ucf101 import UCF101 -from .places365 import Places365 from .kitti import Kitti -from .inaturalist import INaturalist from .lfw import LFWPeople, LFWPairs +from .lsun import LSUN, LSUNClass +from .mnist import MNIST, EMNIST, FashionMNIST, KMNIST, QMNIST +from .omniglot import Omniglot +from .phototour import PhotoTour +from .places365 import Places365 +from .sbd import SBDataset +from .sbu import SBU +from .semeion import SEMEION +from .stl10 import STL10 +from .svhn import SVHN +from .ucf101 import UCF101 +from .usps import USPS +from .vision import VisionDataset +from .voc import VOCSegmentation, VOCDetection +from .widerface import WIDERFace -__all__ = ('LSUN', 'LSUNClass', - 'ImageFolder', 'DatasetFolder', 'FakeData', - 'CocoCaptions', 'CocoDetection', - 'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST', 'QMNIST', - 'MNIST', 'KMNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION', - 'Omniglot', 'SBU', 'Flickr8k', 'Flickr30k', - 'VOCSegmentation', 'VOCDetection', 'Cityscapes', 'ImageNet', - 'Caltech101', 'Caltech256', 'CelebA', 'WIDERFace', 'SBDataset', - 'VisionDataset', 'USPS', 'Kinetics400', "Kinetics", 'HMDB51', 'UCF101', - 'Places365', 'Kitti', "INaturalist", "LFWPeople", "LFWPairs" - ) +__all__ = ( + "LSUN", + "LSUNClass", + "ImageFolder", + "DatasetFolder", + "FakeData", + "CocoCaptions", + "CocoDetection", + "CIFAR10", + "CIFAR100", + "EMNIST", + "FashionMNIST", + "QMNIST", + "MNIST", + "KMNIST", + "STL10", + "SVHN", + "PhotoTour", + "SEMEION", + "Omniglot", + "SBU", + "Flickr8k", + "Flickr30k", + "VOCSegmentation", + "VOCDetection", + "Cityscapes", + "ImageNet", + "Caltech101", + "Caltech256", + "CelebA", + "WIDERFace", + "SBDataset", + "VisionDataset", + "USPS", + "Kinetics400", + "Kinetics", + "HMDB51", + "UCF101", + "Places365", + "Kitti", + "INaturalist", + "LFWPeople", + "LFWPairs", +) diff --git a/torchvision/datasets/caltech.py b/torchvision/datasets/caltech.py index a99e6fde948..64d007f4f9f 100644 --- a/torchvision/datasets/caltech.py +++ b/torchvision/datasets/caltech.py @@ -1,10 +1,11 @@ -from PIL import Image import os import os.path from typing import Any, Callable, List, Optional, Union, Tuple -from .vision import VisionDataset +from PIL import Image + from .utils import download_and_extract_archive, verify_str_arg +from .vision import VisionDataset class Caltech101(VisionDataset): @@ -32,28 +33,26 @@ class Caltech101(VisionDataset): """ def __init__( - self, - root: str, - target_type: Union[List[str], str] = "category", - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - download: bool = False, + self, + root: str, + target_type: Union[List[str], str] = "category", + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, ) -> None: - super(Caltech101, self).__init__(os.path.join(root, 'caltech101'), - transform=transform, - target_transform=target_transform) + super(Caltech101, self).__init__( + os.path.join(root, "caltech101"), transform=transform, target_transform=target_transform + ) os.makedirs(self.root, exist_ok=True) if not isinstance(target_type, list): target_type = [target_type] - self.target_type = [verify_str_arg(t, "target_type", ("category", "annotation")) - for t in target_type] + self.target_type = [verify_str_arg(t, "target_type", ("category", "annotation")) for t in target_type] if download: self.download() if not self._check_integrity(): - raise RuntimeError('Dataset not found or corrupted.' + - ' You can use download=True to download it') + raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it") self.categories = sorted(os.listdir(os.path.join(self.root, "101_ObjectCategories"))) self.categories.remove("BACKGROUND_Google") # this is not a real class @@ -61,10 +60,12 @@ def __init__( # For some reason, the category names in "101_ObjectCategories" and # "Annotations" do not always match. This is a manual map between the # two. Defaults to using same name, since most names are fine. - name_map = {"Faces": "Faces_2", - "Faces_easy": "Faces_3", - "Motorbikes": "Motorbikes_16", - "airplanes": "Airplanes_Side_2"} + name_map = { + "Faces": "Faces_2", + "Faces_easy": "Faces_3", + "Motorbikes": "Motorbikes_16", + "airplanes": "Airplanes_Side_2", + } self.annotation_categories = list(map(lambda x: name_map[x] if x in name_map else x, self.categories)) self.index: List[int] = [] @@ -84,20 +85,28 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: """ import scipy.io - img = Image.open(os.path.join(self.root, - "101_ObjectCategories", - self.categories[self.y[index]], - "image_{:04d}.jpg".format(self.index[index]))) + img = Image.open( + os.path.join( + self.root, + "101_ObjectCategories", + self.categories[self.y[index]], + "image_{:04d}.jpg".format(self.index[index]), + ) + ) target: Any = [] for t in self.target_type: if t == "category": target.append(self.y[index]) elif t == "annotation": - data = scipy.io.loadmat(os.path.join(self.root, - "Annotations", - self.annotation_categories[self.y[index]], - "annotation_{:04d}.mat".format(self.index[index]))) + data = scipy.io.loadmat( + os.path.join( + self.root, + "Annotations", + self.annotation_categories[self.y[index]], + "annotation_{:04d}.mat".format(self.index[index]), + ) + ) target.append(data["obj_contour"]) target = tuple(target) if len(target) > 1 else target[0] @@ -118,19 +127,21 @@ def __len__(self) -> int: def download(self) -> None: if self._check_integrity(): - print('Files already downloaded and verified') + print("Files already downloaded and verified") return download_and_extract_archive( "http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz", self.root, filename="101_ObjectCategories.tar.gz", - md5="b224c7392d521a49829488ab0f1120d9") + md5="b224c7392d521a49829488ab0f1120d9", + ) download_and_extract_archive( "http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar", self.root, filename="101_Annotations.tar", - md5="6f83eeb1f24d99cab4eb377263132c91") + md5="6f83eeb1f24d99cab4eb377263132c91", + ) def extra_repr(self) -> str: return "Target type: {target_type}".format(**self.__dict__) @@ -152,23 +163,22 @@ class Caltech256(VisionDataset): """ def __init__( - self, - root: str, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - download: bool = False, + self, + root: str, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, ) -> None: - super(Caltech256, self).__init__(os.path.join(root, 'caltech256'), - transform=transform, - target_transform=target_transform) + super(Caltech256, self).__init__( + os.path.join(root, "caltech256"), transform=transform, target_transform=target_transform + ) os.makedirs(self.root, exist_ok=True) if download: self.download() if not self._check_integrity(): - raise RuntimeError('Dataset not found or corrupted.' + - ' You can use download=True to download it') + raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it") self.categories = sorted(os.listdir(os.path.join(self.root, "256_ObjectCategories"))) self.index: List[int] = [] @@ -186,10 +196,14 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: Returns: tuple: (image, target) where target is index of the target class. """ - img = Image.open(os.path.join(self.root, - "256_ObjectCategories", - self.categories[self.y[index]], - "{:03d}_{:04d}.jpg".format(self.y[index] + 1, self.index[index]))) + img = Image.open( + os.path.join( + self.root, + "256_ObjectCategories", + self.categories[self.y[index]], + "{:03d}_{:04d}.jpg".format(self.y[index] + 1, self.index[index]), + ) + ) target = self.y[index] @@ -210,11 +224,12 @@ def __len__(self) -> int: def download(self) -> None: if self._check_integrity(): - print('Files already downloaded and verified') + print("Files already downloaded and verified") return download_and_extract_archive( "http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar", self.root, filename="256_ObjectCategories.tar", - md5="67b4f42ca05d46448c6bb8ecd2220f6d") + md5="67b4f42ca05d46448c6bb8ecd2220f6d", + ) diff --git a/torchvision/datasets/celeba.py b/torchvision/datasets/celeba.py index f2fcdb74dfe..91f0fc3f919 100644 --- a/torchvision/datasets/celeba.py +++ b/torchvision/datasets/celeba.py @@ -1,12 +1,14 @@ -from collections import namedtuple import csv -from functools import partial -import torch import os -import PIL +from collections import namedtuple +from functools import partial from typing import Any, Callable, List, Optional, Union, Tuple -from .vision import VisionDataset + +import PIL +import torch + from .utils import download_file_from_google_drive, check_integrity, verify_str_arg +from .vision import VisionDataset CSV = namedtuple("CSV", ["header", "index", "data"]) @@ -57,16 +59,15 @@ class CelebA(VisionDataset): ] def __init__( - self, - root: str, - split: str = "train", - target_type: Union[List[str], str] = "attr", - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - download: bool = False, + self, + root: str, + split: str = "train", + target_type: Union[List[str], str] = "attr", + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, ) -> None: - super(CelebA, self).__init__(root, transform=transform, - target_transform=target_transform) + super(CelebA, self).__init__(root, transform=transform, target_transform=target_transform) self.split = split if isinstance(target_type, list): self.target_type = target_type @@ -74,14 +75,13 @@ def __init__( self.target_type = [target_type] if not self.target_type and self.target_transform is not None: - raise RuntimeError('target_transform is specified but target_type is empty') + raise RuntimeError("target_transform is specified but target_type is empty") if download: self.download() if not self._check_integrity(): - raise RuntimeError('Dataset not found or corrupted.' + - ' You can use download=True to download it') + raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it") split_map = { "train": 0, @@ -89,8 +89,7 @@ def __init__( "test": 2, "all": None, } - split_ = split_map[verify_str_arg(split.lower(), "split", - ("train", "valid", "test", "all"))] + split_ = split_map[verify_str_arg(split.lower(), "split", ("train", "valid", "test", "all"))] splits = self._load_csv("list_eval_partition.txt") identity = self._load_csv("identity_CelebA.txt") bbox = self._load_csv("list_bbox_celeba.txt", header=1) @@ -108,7 +107,7 @@ def __init__( self.landmarks_align = landmarks_align.data[mask] self.attr = attr.data[mask] # map from {-1, 1} to {0, 1} - self.attr = torch.div(self.attr + 1, 2, rounding_mode='floor') + self.attr = torch.div(self.attr + 1, 2, rounding_mode="floor") self.attr_names = attr.header def _load_csv( @@ -120,11 +119,11 @@ def _load_csv( fn = partial(os.path.join, self.root, self.base_folder) with open(fn(filename)) as csv_file: - data = list(csv.reader(csv_file, delimiter=' ', skipinitialspace=True)) + data = list(csv.reader(csv_file, delimiter=" ", skipinitialspace=True)) if header is not None: headers = data[header] - data = data[header + 1:] + data = data[header + 1 :] indices = [row[0] for row in data] data = [row[1:] for row in data] @@ -148,7 +147,7 @@ def download(self) -> None: import zipfile if self._check_integrity(): - print('Files already downloaded and verified') + print("Files already downloaded and verified") return for (file_id, md5, filename) in self.file_list: @@ -172,7 +171,7 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: target.append(self.landmarks_align[index, :]) else: # TODO: refactor with utils.verify_str_arg - raise ValueError("Target type \"{}\" is not recognized.".format(t)) + raise ValueError('Target type "{}" is not recognized.'.format(t)) if self.transform is not None: X = self.transform(X) @@ -192,4 +191,4 @@ def __len__(self) -> int: def extra_repr(self) -> str: lines = ["Target type: {target_type}", "Split: {split}"] - return '\n'.join(lines).format(**self.__dict__) + return "\n".join(lines).format(**self.__dict__) diff --git a/torchvision/datasets/cifar.py b/torchvision/datasets/cifar.py index 17a2b5ee9cd..f6b121f268a 100644 --- a/torchvision/datasets/cifar.py +++ b/torchvision/datasets/cifar.py @@ -1,13 +1,14 @@ -from PIL import Image import os import os.path -import numpy as np import pickle -import torch from typing import Any, Callable, Optional, Tuple -from .vision import VisionDataset +import numpy as np +import torch +from PIL import Image + from .utils import check_integrity, download_and_extract_archive +from .vision import VisionDataset class CIFAR10(VisionDataset): @@ -27,38 +28,38 @@ class CIFAR10(VisionDataset): downloaded again. """ - base_folder = 'cifar-10-batches-py' + + base_folder = "cifar-10-batches-py" url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" filename = "cifar-10-python.tar.gz" - tgz_md5 = 'c58f30108f718f92721af3b95e74349a' + tgz_md5 = "c58f30108f718f92721af3b95e74349a" train_list = [ - ['data_batch_1', 'c99cafc152244af753f735de768cd75f'], - ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], - ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], - ['data_batch_4', '634d18415352ddfa80567beed471001a'], - ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'], + ["data_batch_1", "c99cafc152244af753f735de768cd75f"], + ["data_batch_2", "d4bba439e000b95fd0a9bffe97cbabec"], + ["data_batch_3", "54ebc095f3ab1f0389bbae665268c751"], + ["data_batch_4", "634d18415352ddfa80567beed471001a"], + ["data_batch_5", "482c414d41f54cd18b22e5b47cb7c3cb"], ] test_list = [ - ['test_batch', '40351d587109b95175f43aff81a1287e'], + ["test_batch", "40351d587109b95175f43aff81a1287e"], ] meta = { - 'filename': 'batches.meta', - 'key': 'label_names', - 'md5': '5ff9c542aee3614f3951f8cda6e48888', + "filename": "batches.meta", + "key": "label_names", + "md5": "5ff9c542aee3614f3951f8cda6e48888", } def __init__( - self, - root: str, - train: bool = True, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - download: bool = False, + self, + root: str, + train: bool = True, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, ) -> None: - super(CIFAR10, self).__init__(root, transform=transform, - target_transform=target_transform) + super(CIFAR10, self).__init__(root, transform=transform, target_transform=target_transform) self.train = train # training set or test set @@ -66,8 +67,7 @@ def __init__( self.download() if not self._check_integrity(): - raise RuntimeError('Dataset not found or corrupted.' + - ' You can use download=True to download it') + raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it") if self.train: downloaded_list = self.train_list @@ -80,13 +80,13 @@ def __init__( # now load the picked numpy arrays for file_name, checksum in downloaded_list: file_path = os.path.join(self.root, self.base_folder, file_name) - with open(file_path, 'rb') as f: - entry = pickle.load(f, encoding='latin1') - self.data.append(entry['data']) - if 'labels' in entry: - self.targets.extend(entry['labels']) + with open(file_path, "rb") as f: + entry = pickle.load(f, encoding="latin1") + self.data.append(entry["data"]) + if "labels" in entry: + self.targets.extend(entry["labels"]) else: - self.targets.extend(entry['fine_labels']) + self.targets.extend(entry["fine_labels"]) self.data = np.vstack(self.data).reshape(-1, 3, 32, 32) self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC @@ -94,13 +94,14 @@ def __init__( self._load_meta() def _load_meta(self) -> None: - path = os.path.join(self.root, self.base_folder, self.meta['filename']) - if not check_integrity(path, self.meta['md5']): - raise RuntimeError('Dataset metadata file not found or corrupted.' + - ' You can use download=True to download it') - with open(path, 'rb') as infile: - data = pickle.load(infile, encoding='latin1') - self.classes = data[self.meta['key']] + path = os.path.join(self.root, self.base_folder, self.meta["filename"]) + if not check_integrity(path, self.meta["md5"]): + raise RuntimeError( + "Dataset metadata file not found or corrupted." + " You can use download=True to download it" + ) + with open(path, "rb") as infile: + data = pickle.load(infile, encoding="latin1") + self.classes = data[self.meta["key"]] self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)} def __getitem__(self, index: int) -> Tuple[Any, Any]: @@ -130,7 +131,7 @@ def __len__(self) -> int: def _check_integrity(self) -> bool: root = self.root - for fentry in (self.train_list + self.test_list): + for fentry in self.train_list + self.test_list: filename, md5 = fentry[0], fentry[1] fpath = os.path.join(root, self.base_folder, filename) if not check_integrity(fpath, md5): @@ -139,7 +140,7 @@ def _check_integrity(self) -> bool: def download(self) -> None: if self._check_integrity(): - print('Files already downloaded and verified') + print("Files already downloaded and verified") return download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5) @@ -152,19 +153,20 @@ class CIFAR100(CIFAR10): This is a subclass of the `CIFAR10` Dataset. """ - base_folder = 'cifar-100-python' + + base_folder = "cifar-100-python" url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" filename = "cifar-100-python.tar.gz" - tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' + tgz_md5 = "eb9058c3a382ffc7106e4002c42a8d85" train_list = [ - ['train', '16019d7e3df5f24257cddd939b257f8d'], + ["train", "16019d7e3df5f24257cddd939b257f8d"], ] test_list = [ - ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], + ["test", "f0ef6b0ae62326f3e7ffdfab6717acfc"], ] meta = { - 'filename': 'meta', - 'key': 'fine_label_names', - 'md5': '7973b15100ade9c7d40fb424638fde48', + "filename": "meta", + "key": "fine_label_names", + "md5": "7973b15100ade9c7d40fb424638fde48", } diff --git a/torchvision/datasets/cityscapes.py b/torchvision/datasets/cityscapes.py index bed7524ac4f..cfc3e8bab71 100644 --- a/torchvision/datasets/cityscapes.py +++ b/torchvision/datasets/cityscapes.py @@ -3,9 +3,10 @@ from collections import namedtuple from typing import Any, Callable, Dict, List, Optional, Union, Tuple +from PIL import Image + from .utils import extract_archive, verify_str_arg, iterable_to_str from .vision import VisionDataset -from PIL import Image class Cityscapes(VisionDataset): @@ -57,60 +58,62 @@ class Cityscapes(VisionDataset): """ # Based on https://github.com/mcordts/cityscapesScripts - CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id', 'category', 'category_id', - 'has_instances', 'ignore_in_eval', 'color']) + CityscapesClass = namedtuple( + "CityscapesClass", + ["name", "id", "train_id", "category", "category_id", "has_instances", "ignore_in_eval", "color"], + ) classes = [ - CityscapesClass('unlabeled', 0, 255, 'void', 0, False, True, (0, 0, 0)), - CityscapesClass('ego vehicle', 1, 255, 'void', 0, False, True, (0, 0, 0)), - CityscapesClass('rectification border', 2, 255, 'void', 0, False, True, (0, 0, 0)), - CityscapesClass('out of roi', 3, 255, 'void', 0, False, True, (0, 0, 0)), - CityscapesClass('static', 4, 255, 'void', 0, False, True, (0, 0, 0)), - CityscapesClass('dynamic', 5, 255, 'void', 0, False, True, (111, 74, 0)), - CityscapesClass('ground', 6, 255, 'void', 0, False, True, (81, 0, 81)), - CityscapesClass('road', 7, 0, 'flat', 1, False, False, (128, 64, 128)), - CityscapesClass('sidewalk', 8, 1, 'flat', 1, False, False, (244, 35, 232)), - CityscapesClass('parking', 9, 255, 'flat', 1, False, True, (250, 170, 160)), - CityscapesClass('rail track', 10, 255, 'flat', 1, False, True, (230, 150, 140)), - CityscapesClass('building', 11, 2, 'construction', 2, False, False, (70, 70, 70)), - CityscapesClass('wall', 12, 3, 'construction', 2, False, False, (102, 102, 156)), - CityscapesClass('fence', 13, 4, 'construction', 2, False, False, (190, 153, 153)), - CityscapesClass('guard rail', 14, 255, 'construction', 2, False, True, (180, 165, 180)), - CityscapesClass('bridge', 15, 255, 'construction', 2, False, True, (150, 100, 100)), - CityscapesClass('tunnel', 16, 255, 'construction', 2, False, True, (150, 120, 90)), - CityscapesClass('pole', 17, 5, 'object', 3, False, False, (153, 153, 153)), - CityscapesClass('polegroup', 18, 255, 'object', 3, False, True, (153, 153, 153)), - CityscapesClass('traffic light', 19, 6, 'object', 3, False, False, (250, 170, 30)), - CityscapesClass('traffic sign', 20, 7, 'object', 3, False, False, (220, 220, 0)), - CityscapesClass('vegetation', 21, 8, 'nature', 4, False, False, (107, 142, 35)), - CityscapesClass('terrain', 22, 9, 'nature', 4, False, False, (152, 251, 152)), - CityscapesClass('sky', 23, 10, 'sky', 5, False, False, (70, 130, 180)), - CityscapesClass('person', 24, 11, 'human', 6, True, False, (220, 20, 60)), - CityscapesClass('rider', 25, 12, 'human', 6, True, False, (255, 0, 0)), - CityscapesClass('car', 26, 13, 'vehicle', 7, True, False, (0, 0, 142)), - CityscapesClass('truck', 27, 14, 'vehicle', 7, True, False, (0, 0, 70)), - CityscapesClass('bus', 28, 15, 'vehicle', 7, True, False, (0, 60, 100)), - CityscapesClass('caravan', 29, 255, 'vehicle', 7, True, True, (0, 0, 90)), - CityscapesClass('trailer', 30, 255, 'vehicle', 7, True, True, (0, 0, 110)), - CityscapesClass('train', 31, 16, 'vehicle', 7, True, False, (0, 80, 100)), - CityscapesClass('motorcycle', 32, 17, 'vehicle', 7, True, False, (0, 0, 230)), - CityscapesClass('bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)), - CityscapesClass('license plate', -1, -1, 'vehicle', 7, False, True, (0, 0, 142)), + CityscapesClass("unlabeled", 0, 255, "void", 0, False, True, (0, 0, 0)), + CityscapesClass("ego vehicle", 1, 255, "void", 0, False, True, (0, 0, 0)), + CityscapesClass("rectification border", 2, 255, "void", 0, False, True, (0, 0, 0)), + CityscapesClass("out of roi", 3, 255, "void", 0, False, True, (0, 0, 0)), + CityscapesClass("static", 4, 255, "void", 0, False, True, (0, 0, 0)), + CityscapesClass("dynamic", 5, 255, "void", 0, False, True, (111, 74, 0)), + CityscapesClass("ground", 6, 255, "void", 0, False, True, (81, 0, 81)), + CityscapesClass("road", 7, 0, "flat", 1, False, False, (128, 64, 128)), + CityscapesClass("sidewalk", 8, 1, "flat", 1, False, False, (244, 35, 232)), + CityscapesClass("parking", 9, 255, "flat", 1, False, True, (250, 170, 160)), + CityscapesClass("rail track", 10, 255, "flat", 1, False, True, (230, 150, 140)), + CityscapesClass("building", 11, 2, "construction", 2, False, False, (70, 70, 70)), + CityscapesClass("wall", 12, 3, "construction", 2, False, False, (102, 102, 156)), + CityscapesClass("fence", 13, 4, "construction", 2, False, False, (190, 153, 153)), + CityscapesClass("guard rail", 14, 255, "construction", 2, False, True, (180, 165, 180)), + CityscapesClass("bridge", 15, 255, "construction", 2, False, True, (150, 100, 100)), + CityscapesClass("tunnel", 16, 255, "construction", 2, False, True, (150, 120, 90)), + CityscapesClass("pole", 17, 5, "object", 3, False, False, (153, 153, 153)), + CityscapesClass("polegroup", 18, 255, "object", 3, False, True, (153, 153, 153)), + CityscapesClass("traffic light", 19, 6, "object", 3, False, False, (250, 170, 30)), + CityscapesClass("traffic sign", 20, 7, "object", 3, False, False, (220, 220, 0)), + CityscapesClass("vegetation", 21, 8, "nature", 4, False, False, (107, 142, 35)), + CityscapesClass("terrain", 22, 9, "nature", 4, False, False, (152, 251, 152)), + CityscapesClass("sky", 23, 10, "sky", 5, False, False, (70, 130, 180)), + CityscapesClass("person", 24, 11, "human", 6, True, False, (220, 20, 60)), + CityscapesClass("rider", 25, 12, "human", 6, True, False, (255, 0, 0)), + CityscapesClass("car", 26, 13, "vehicle", 7, True, False, (0, 0, 142)), + CityscapesClass("truck", 27, 14, "vehicle", 7, True, False, (0, 0, 70)), + CityscapesClass("bus", 28, 15, "vehicle", 7, True, False, (0, 60, 100)), + CityscapesClass("caravan", 29, 255, "vehicle", 7, True, True, (0, 0, 90)), + CityscapesClass("trailer", 30, 255, "vehicle", 7, True, True, (0, 0, 110)), + CityscapesClass("train", 31, 16, "vehicle", 7, True, False, (0, 80, 100)), + CityscapesClass("motorcycle", 32, 17, "vehicle", 7, True, False, (0, 0, 230)), + CityscapesClass("bicycle", 33, 18, "vehicle", 7, True, False, (119, 11, 32)), + CityscapesClass("license plate", -1, -1, "vehicle", 7, False, True, (0, 0, 142)), ] def __init__( - self, - root: str, - split: str = "train", - mode: str = "fine", - target_type: Union[List[str], str] = "instance", - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - transforms: Optional[Callable] = None, + self, + root: str, + split: str = "train", + mode: str = "fine", + target_type: Union[List[str], str] = "instance", + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + transforms: Optional[Callable] = None, ) -> None: super(Cityscapes, self).__init__(root, transforms, transform, target_transform) - self.mode = 'gtFine' if mode == 'fine' else 'gtCoarse' - self.images_dir = os.path.join(self.root, 'leftImg8bit', split) + self.mode = "gtFine" if mode == "fine" else "gtCoarse" + self.images_dir = os.path.join(self.root, "leftImg8bit", split) self.targets_dir = os.path.join(self.root, self.mode, split) self.target_type = target_type self.split = split @@ -122,35 +125,37 @@ def __init__( valid_modes = ("train", "test", "val") else: valid_modes = ("train", "train_extra", "val") - msg = ("Unknown value '{}' for argument split if mode is '{}'. " - "Valid values are {{{}}}.") + msg = "Unknown value '{}' for argument split if mode is '{}'. " "Valid values are {{{}}}." msg = msg.format(split, mode, iterable_to_str(valid_modes)) verify_str_arg(split, "split", valid_modes, msg) if not isinstance(target_type, list): self.target_type = [target_type] - [verify_str_arg(value, "target_type", - ("instance", "semantic", "polygon", "color")) - for value in self.target_type] + [ + verify_str_arg(value, "target_type", ("instance", "semantic", "polygon", "color")) + for value in self.target_type + ] if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir): - if split == 'train_extra': - image_dir_zip = os.path.join(self.root, 'leftImg8bit{}'.format('_trainextra.zip')) + if split == "train_extra": + image_dir_zip = os.path.join(self.root, "leftImg8bit{}".format("_trainextra.zip")) else: - image_dir_zip = os.path.join(self.root, 'leftImg8bit{}'.format('_trainvaltest.zip')) + image_dir_zip = os.path.join(self.root, "leftImg8bit{}".format("_trainvaltest.zip")) - if self.mode == 'gtFine': - target_dir_zip = os.path.join(self.root, '{}{}'.format(self.mode, '_trainvaltest.zip')) - elif self.mode == 'gtCoarse': - target_dir_zip = os.path.join(self.root, '{}{}'.format(self.mode, '.zip')) + if self.mode == "gtFine": + target_dir_zip = os.path.join(self.root, "{}{}".format(self.mode, "_trainvaltest.zip")) + elif self.mode == "gtCoarse": + target_dir_zip = os.path.join(self.root, "{}{}".format(self.mode, ".zip")) if os.path.isfile(image_dir_zip) and os.path.isfile(target_dir_zip): extract_archive(from_path=image_dir_zip, to_path=self.root) extract_archive(from_path=target_dir_zip, to_path=self.root) else: - raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the' - ' specified "split" and "mode" are inside the "root" directory') + raise RuntimeError( + "Dataset not found or incomplete. Please make sure all required folders for the" + ' specified "split" and "mode" are inside the "root" directory' + ) for city in os.listdir(self.images_dir): img_dir = os.path.join(self.images_dir, city) @@ -158,8 +163,9 @@ def __init__( for file_name in os.listdir(img_dir): target_types = [] for t in self.target_type: - target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0], - self._get_target_suffix(self.mode, t)) + target_name = "{}_{}".format( + file_name.split("_leftImg8bit")[0], self._get_target_suffix(self.mode, t) + ) target_types.append(os.path.join(target_dir, target_name)) self.images.append(os.path.join(img_dir, file_name)) @@ -174,11 +180,11 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation. """ - image = Image.open(self.images[index]).convert('RGB') + image = Image.open(self.images[index]).convert("RGB") targets: Any = [] for i, t in enumerate(self.target_type): - if t == 'polygon': + if t == "polygon": target = self._load_json(self.targets[index][i]) else: target = Image.open(self.targets[index][i]) @@ -197,19 +203,19 @@ def __len__(self) -> int: def extra_repr(self) -> str: lines = ["Split: {split}", "Mode: {mode}", "Type: {target_type}"] - return '\n'.join(lines).format(**self.__dict__) + return "\n".join(lines).format(**self.__dict__) def _load_json(self, path: str) -> Dict[str, Any]: - with open(path, 'r') as file: + with open(path, "r") as file: data = json.load(file) return data def _get_target_suffix(self, mode: str, target_type: str) -> str: - if target_type == 'instance': - return '{}_instanceIds.png'.format(mode) - elif target_type == 'semantic': - return '{}_labelIds.png'.format(mode) - elif target_type == 'color': - return '{}_color.png'.format(mode) + if target_type == "instance": + return "{}_instanceIds.png".format(mode) + elif target_type == "semantic": + return "{}_labelIds.png".format(mode) + elif target_type == "color": + return "{}_color.png".format(mode) else: - return '{}_polygons.json'.format(mode) + return "{}_polygons.json".format(mode) diff --git a/torchvision/datasets/coco.py b/torchvision/datasets/coco.py index a211f2f4f51..d65aa4dc862 100644 --- a/torchvision/datasets/coco.py +++ b/torchvision/datasets/coco.py @@ -1,9 +1,11 @@ -from .vision import VisionDataset -from PIL import Image import os import os.path from typing import Any, Callable, Optional, Tuple, List +from PIL import Image + +from .vision import VisionDataset + class CocoDetection(VisionDataset): """`MS Coco Detection `_ Dataset. diff --git a/torchvision/datasets/fakedata.py b/torchvision/datasets/fakedata.py index ddb14505275..2c95cf488c1 100644 --- a/torchvision/datasets/fakedata.py +++ b/torchvision/datasets/fakedata.py @@ -1,7 +1,9 @@ -import torch from typing import Any, Callable, Optional, Tuple -from .vision import VisionDataset + +import torch + from .. import transforms +from .vision import VisionDataset class FakeData(VisionDataset): @@ -21,16 +23,17 @@ class FakeData(VisionDataset): """ def __init__( - self, - size: int = 1000, - image_size: Tuple[int, int, int] = (3, 224, 224), - num_classes: int = 10, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - random_offset: int = 0, + self, + size: int = 1000, + image_size: Tuple[int, int, int] = (3, 224, 224), + num_classes: int = 10, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + random_offset: int = 0, ) -> None: - super(FakeData, self).__init__(None, transform=transform, # type: ignore[arg-type] - target_transform=target_transform) + super(FakeData, self).__init__( + None, transform=transform, target_transform=target_transform # type: ignore[arg-type] + ) self.size = size self.num_classes = num_classes self.image_size = image_size diff --git a/torchvision/datasets/flickr.py b/torchvision/datasets/flickr.py index a3b3e411b6e..31cb68d4937 100644 --- a/torchvision/datasets/flickr.py +++ b/torchvision/datasets/flickr.py @@ -1,10 +1,11 @@ +import glob +import os from collections import defaultdict -from PIL import Image from html.parser import HTMLParser from typing import Any, Callable, Dict, List, Optional, Tuple -import glob -import os +from PIL import Image + from .vision import VisionDataset @@ -27,26 +28,26 @@ def __init__(self, root: str) -> None: def handle_starttag(self, tag: str, attrs: List[Tuple[str, Optional[str]]]) -> None: self.current_tag = tag - if tag == 'table': + if tag == "table": self.in_table = True def handle_endtag(self, tag: str) -> None: self.current_tag = None - if tag == 'table': + if tag == "table": self.in_table = False def handle_data(self, data: str) -> None: if self.in_table: - if data == 'Image Not Found': + if data == "Image Not Found": self.current_img = None - elif self.current_tag == 'a': - img_id = data.split('/')[-2] - img_id = os.path.join(self.root, img_id + '_*.jpg') + elif self.current_tag == "a": + img_id = data.split("/")[-2] + img_id = os.path.join(self.root, img_id + "_*.jpg") img_id = glob.glob(img_id)[0] self.current_img = img_id self.annotations[img_id] = [] - elif self.current_tag == 'li' and self.current_img: + elif self.current_tag == "li" and self.current_img: img_id = self.current_img self.annotations[img_id].append(data.strip()) @@ -64,14 +65,13 @@ class Flickr8k(VisionDataset): """ def __init__( - self, - root: str, - ann_file: str, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, + self, + root: str, + ann_file: str, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, ) -> None: - super(Flickr8k, self).__init__(root, transform=transform, - target_transform=target_transform) + super(Flickr8k, self).__init__(root, transform=transform, target_transform=target_transform) self.ann_file = os.path.expanduser(ann_file) # Read annotations and store in a dict @@ -93,7 +93,7 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: img_id = self.ids[index] # Image - img = Image.open(img_id).convert('RGB') + img = Image.open(img_id).convert("RGB") if self.transform is not None: img = self.transform(img) @@ -121,21 +121,20 @@ class Flickr30k(VisionDataset): """ def __init__( - self, - root: str, - ann_file: str, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, + self, + root: str, + ann_file: str, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, ) -> None: - super(Flickr30k, self).__init__(root, transform=transform, - target_transform=target_transform) + super(Flickr30k, self).__init__(root, transform=transform, target_transform=target_transform) self.ann_file = os.path.expanduser(ann_file) # Read annotations and store in a dict self.annotations = defaultdict(list) with open(self.ann_file) as fh: for line in fh: - img_id, caption = line.strip().split('\t') + img_id, caption = line.strip().split("\t") self.annotations[img_id[:-2]].append(caption) self.ids = list(sorted(self.annotations.keys())) @@ -152,7 +151,7 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: # Image filename = os.path.join(self.root, img_id) - img = Image.open(filename).convert('RGB') + img = Image.open(filename).convert("RGB") if self.transform is not None: img = self.transform(img) diff --git a/torchvision/datasets/folder.py b/torchvision/datasets/folder.py index 1280495273c..fedf4a35539 100644 --- a/torchvision/datasets/folder.py +++ b/torchvision/datasets/folder.py @@ -1,11 +1,11 @@ -from .vision import VisionDataset - -from PIL import Image - import os import os.path from typing import Any, Callable, cast, Dict, List, Optional, Tuple +from PIL import Image + +from .vision import VisionDataset + def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool: """Checks if a file is an allowed extension. @@ -132,16 +132,15 @@ class DatasetFolder(VisionDataset): """ def __init__( - self, - root: str, - loader: Callable[[str], Any], - extensions: Optional[Tuple[str, ...]] = None, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - is_valid_file: Optional[Callable[[str], bool]] = None, + self, + root: str, + loader: Callable[[str], Any], + extensions: Optional[Tuple[str, ...]] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + is_valid_file: Optional[Callable[[str], bool]] = None, ) -> None: - super(DatasetFolder, self).__init__(root, transform=transform, - target_transform=target_transform) + super(DatasetFolder, self).__init__(root, transform=transform, target_transform=target_transform) classes, class_to_idx = self.find_classes(self.root) samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file) @@ -186,9 +185,7 @@ def make_dataset( # prevent potential bug since make_dataset() would use the class_to_idx logic of the # find_classes() function, instead of using that of the find_classes() method, which # is potentially overridden and thus could have a different logic. - raise ValueError( - "The class_to_idx parameter cannot be None." - ) + raise ValueError("The class_to_idx parameter cannot be None.") return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file) def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]: @@ -241,19 +238,20 @@ def __len__(self) -> int: return len(self.samples) -IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') +IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp") def pil_loader(path: str) -> Image.Image: # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) - with open(path, 'rb') as f: + with open(path, "rb") as f: img = Image.open(f) - return img.convert('RGB') + return img.convert("RGB") # TODO: specify the return type def accimage_loader(path: str) -> Any: import accimage + try: return accimage.Image(path) except IOError: @@ -263,7 +261,8 @@ def accimage_loader(path: str) -> Any: def default_loader(path: str) -> Any: from torchvision import get_image_backend - if get_image_backend() == 'accimage': + + if get_image_backend() == "accimage": return accimage_loader(path) else: return pil_loader(path) @@ -300,15 +299,19 @@ class ImageFolder(DatasetFolder): """ def __init__( - self, - root: str, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - loader: Callable[[str], Any] = default_loader, - is_valid_file: Optional[Callable[[str], bool]] = None, + self, + root: str, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + loader: Callable[[str], Any] = default_loader, + is_valid_file: Optional[Callable[[str], bool]] = None, ): - super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None, - transform=transform, - target_transform=target_transform, - is_valid_file=is_valid_file) + super(ImageFolder, self).__init__( + root, + loader, + IMG_EXTENSIONS if is_valid_file is None else None, + transform=transform, + target_transform=target_transform, + is_valid_file=is_valid_file, + ) self.imgs = self.samples diff --git a/torchvision/datasets/hmdb51.py b/torchvision/datasets/hmdb51.py index 8d90c916b84..fe12c0d0b47 100644 --- a/torchvision/datasets/hmdb51.py +++ b/torchvision/datasets/hmdb51.py @@ -1,6 +1,7 @@ import glob import os from typing import Optional, Callable, Tuple, Dict, Any, List + from torch import Tensor from .folder import find_classes, make_dataset @@ -49,7 +50,7 @@ class HMDB51(VisionDataset): data_url = "http://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/hmdb51_org.rar" splits = { "url": "http://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/test_train_splits.rar", - "md5": "15e67781e70dcfbdce2d7dbb9b3344b5" + "md5": "15e67781e70dcfbdce2d7dbb9b3344b5", } TRAIN_TAG = 1 TEST_TAG = 2 @@ -75,7 +76,7 @@ def __init__( if fold not in (1, 2, 3): raise ValueError("fold should be between 1 and 3, got {}".format(fold)) - extensions = ('avi',) + extensions = ("avi",) self.classes, class_to_idx = find_classes(self.root) self.samples = make_dataset( self.root, diff --git a/torchvision/datasets/imagenet.py b/torchvision/datasets/imagenet.py index 6dfc9bfebfd..624294571aa 100644 --- a/torchvision/datasets/imagenet.py +++ b/torchvision/datasets/imagenet.py @@ -1,17 +1,19 @@ -import warnings -from contextlib import contextmanager import os import shutil import tempfile +import warnings +from contextlib import contextmanager from typing import Any, Dict, List, Iterator, Optional, Tuple + import torch + from .folder import ImageFolder from .utils import check_integrity, extract_archive, verify_str_arg ARCHIVE_META = { - 'train': ('ILSVRC2012_img_train.tar', '1d675b47d978889d74fa0da5fadfb00e'), - 'val': ('ILSVRC2012_img_val.tar', '29b22e2961454d5413ddabcf34fc5622'), - 'devkit': ('ILSVRC2012_devkit_t12.tar.gz', 'fa75699e90414af021442c21a62c3abf') + "train": ("ILSVRC2012_img_train.tar", "1d675b47d978889d74fa0da5fadfb00e"), + "val": ("ILSVRC2012_img_val.tar", "29b22e2961454d5413ddabcf34fc5622"), + "devkit": ("ILSVRC2012_devkit_t12.tar.gz", "fa75699e90414af021442c21a62c3abf"), } META_FILE = "meta.bin" @@ -38,15 +40,16 @@ class ImageNet(ImageFolder): targets (list): The class_index value for each image in the dataset """ - def __init__(self, root: str, split: str = 'train', download: Optional[str] = None, **kwargs: Any) -> None: + def __init__(self, root: str, split: str = "train", download: Optional[str] = None, **kwargs: Any) -> None: if download is True: - msg = ("The dataset is no longer publicly accessible. You need to " - "download the archives externally and place them in the root " - "directory.") + msg = ( + "The dataset is no longer publicly accessible. You need to " + "download the archives externally and place them in the root " + "directory." + ) raise RuntimeError(msg) elif download is False: - msg = ("The use of the download flag is deprecated, since the dataset " - "is no longer publicly accessible.") + msg = "The use of the download flag is deprecated, since the dataset " "is no longer publicly accessible." warnings.warn(msg, RuntimeWarning) root = self.root = os.path.expanduser(root) @@ -61,18 +64,16 @@ def __init__(self, root: str, split: str = 'train', download: Optional[str] = No self.wnids = self.classes self.wnid_to_idx = self.class_to_idx self.classes = [wnid_to_classes[wnid] for wnid in self.wnids] - self.class_to_idx = {cls: idx - for idx, clss in enumerate(self.classes) - for cls in clss} + self.class_to_idx = {cls: idx for idx, clss in enumerate(self.classes) for cls in clss} def parse_archives(self) -> None: if not check_integrity(os.path.join(self.root, META_FILE)): parse_devkit_archive(self.root) if not os.path.isdir(self.split_folder): - if self.split == 'train': + if self.split == "train": parse_train_archive(self.root) - elif self.split == 'val': + elif self.split == "val": parse_val_archive(self.root) @property @@ -91,15 +92,19 @@ def load_meta_file(root: str, file: Optional[str] = None) -> Tuple[Dict[str, str if check_integrity(file): return torch.load(file) else: - msg = ("The meta file {} is not present in the root directory or is corrupted. " - "This file is automatically created by the ImageNet dataset.") + msg = ( + "The meta file {} is not present in the root directory or is corrupted. " + "This file is automatically created by the ImageNet dataset." + ) raise RuntimeError(msg.format(file, root)) def _verify_archive(root: str, file: str, md5: str) -> None: if not check_integrity(os.path.join(root, file), md5): - msg = ("The archive {} is not present in the root directory or is corrupted. " - "You need to download it externally and place it in {}.") + msg = ( + "The archive {} is not present in the root directory or is corrupted. " + "You need to download it externally and place it in {}." + ) raise RuntimeError(msg.format(file, root)) @@ -116,20 +121,18 @@ def parse_devkit_archive(root: str, file: Optional[str] = None) -> None: def parse_meta_mat(devkit_root: str) -> Tuple[Dict[int, str], Dict[str, str]]: metafile = os.path.join(devkit_root, "data", "meta.mat") - meta = sio.loadmat(metafile, squeeze_me=True)['synsets'] + meta = sio.loadmat(metafile, squeeze_me=True)["synsets"] nums_children = list(zip(*meta))[4] - meta = [meta[idx] for idx, num_children in enumerate(nums_children) - if num_children == 0] + meta = [meta[idx] for idx, num_children in enumerate(nums_children) if num_children == 0] idcs, wnids, classes = list(zip(*meta))[:3] - classes = [tuple(clss.split(', ')) for clss in classes] + classes = [tuple(clss.split(", ")) for clss in classes] idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)} wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)} return idx_to_wnid, wnid_to_classes def parse_val_groundtruth_txt(devkit_root: str) -> List[int]: - file = os.path.join(devkit_root, "data", - "ILSVRC2012_validation_ground_truth.txt") - with open(file, 'r') as txtfh: + file = os.path.join(devkit_root, "data", "ILSVRC2012_validation_ground_truth.txt") + with open(file, "r") as txtfh: val_idcs = txtfh.readlines() return [int(val_idx) for val_idx in val_idcs] diff --git a/torchvision/datasets/inaturalist.py b/torchvision/datasets/inaturalist.py index 7b9a911f823..1e2d09d39f8 100644 --- a/torchvision/datasets/inaturalist.py +++ b/torchvision/datasets/inaturalist.py @@ -1,29 +1,30 @@ -from PIL import Image import os import os.path from typing import Any, Callable, Dict, List, Optional, Union, Tuple -from .vision import VisionDataset +from PIL import Image + from .utils import download_and_extract_archive, verify_str_arg +from .vision import VisionDataset CATEGORIES_2021 = ["kingdom", "phylum", "class", "order", "family", "genus"] DATASET_URLS = { - '2017': 'https://ml-inat-competition-datasets.s3.amazonaws.com/2017/train_val_images.tar.gz', - '2018': 'https://ml-inat-competition-datasets.s3.amazonaws.com/2018/train_val2018.tar.gz', - '2019': 'https://ml-inat-competition-datasets.s3.amazonaws.com/2019/train_val2019.tar.gz', - '2021_train': 'https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train.tar.gz', - '2021_train_mini': 'https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train_mini.tar.gz', - '2021_valid': 'https://ml-inat-competition-datasets.s3.amazonaws.com/2021/val.tar.gz', + "2017": "https://ml-inat-competition-datasets.s3.amazonaws.com/2017/train_val_images.tar.gz", + "2018": "https://ml-inat-competition-datasets.s3.amazonaws.com/2018/train_val2018.tar.gz", + "2019": "https://ml-inat-competition-datasets.s3.amazonaws.com/2019/train_val2019.tar.gz", + "2021_train": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train.tar.gz", + "2021_train_mini": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train_mini.tar.gz", + "2021_valid": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/val.tar.gz", } DATASET_MD5 = { - '2017': '7c784ea5e424efaec655bd392f87301f', - '2018': 'b1c6952ce38f31868cc50ea72d066cc3', - '2019': 'c60a6e2962c9b8ccbd458d12c8582644', - '2021_train': '38a7bb733f7a09214d44293460ec0021', - '2021_train_mini': 'db6ed8330e634445efc8fec83ae81442', - '2021_valid': 'f6f6e0e242e3d4c9569ba56400938afc', + "2017": "7c784ea5e424efaec655bd392f87301f", + "2018": "b1c6952ce38f31868cc50ea72d066cc3", + "2019": "c60a6e2962c9b8ccbd458d12c8582644", + "2021_train": "38a7bb733f7a09214d44293460ec0021", + "2021_train_mini": "db6ed8330e634445efc8fec83ae81442", + "2021_valid": "f6f6e0e242e3d4c9569ba56400938afc", } @@ -63,27 +64,26 @@ class INaturalist(VisionDataset): """ def __init__( - self, - root: str, - version: str = "2021_train", - target_type: Union[List[str], str] = "full", - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - download: bool = False, + self, + root: str, + version: str = "2021_train", + target_type: Union[List[str], str] = "full", + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, ) -> None: self.version = verify_str_arg(version, "version", DATASET_URLS.keys()) - super(INaturalist, self).__init__(os.path.join(root, version), - transform=transform, - target_transform=target_transform) + super(INaturalist, self).__init__( + os.path.join(root, version), transform=transform, target_transform=target_transform + ) os.makedirs(root, exist_ok=True) if download: self.download() if not self._check_integrity(): - raise RuntimeError('Dataset not found or corrupted.' + - ' You can use download=True to download it') + raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it") self.all_categories: List[str] = [] @@ -96,12 +96,10 @@ def __init__( if not isinstance(target_type, list): target_type = [target_type] if self.version[:4] == "2021": - self.target_type = [verify_str_arg(t, "target_type", ("full", *CATEGORIES_2021)) - for t in target_type] + self.target_type = [verify_str_arg(t, "target_type", ("full", *CATEGORIES_2021)) for t in target_type] self._init_2021() else: - self.target_type = [verify_str_arg(t, "target_type", ("full", "super")) - for t in target_type] + self.target_type = [verify_str_arg(t, "target_type", ("full", "super")) for t in target_type] self._init_pre2021() # index of all files: (full category id, filename) @@ -118,16 +116,14 @@ def _init_2021(self) -> None: self.all_categories = sorted(os.listdir(self.root)) # map: category type -> name of category -> index - self.categories_index = { - k: {} for k in CATEGORIES_2021 - } + self.categories_index = {k: {} for k in CATEGORIES_2021} for dir_index, dir_name in enumerate(self.all_categories): - pieces = dir_name.split('_') + pieces = dir_name.split("_") if len(pieces) != 8: - raise RuntimeError(f'Unexpected category name {dir_name}, wrong number of pieces') - if pieces[0] != f'{dir_index:05d}': - raise RuntimeError(f'Unexpected category id {pieces[0]}, expecting {dir_index:05d}') + raise RuntimeError(f"Unexpected category name {dir_name}, wrong number of pieces") + if pieces[0] != f"{dir_index:05d}": + raise RuntimeError(f"Unexpected category id {pieces[0]}, expecting {dir_index:05d}") cat_map = {} for cat, name in zip(CATEGORIES_2021, pieces[1:7]): if name in self.categories_index[cat]: @@ -142,7 +138,7 @@ def _init_pre2021(self) -> None: """Initialize based on 2017-2019 layout""" # map: category type -> name of category -> index - self.categories_index = {'super': {}} + self.categories_index = {"super": {}} cat_index = 0 super_categories = sorted(os.listdir(self.root)) @@ -165,7 +161,7 @@ def _init_pre2021(self) -> None: self.all_categories.extend([""] * (subcat_i - old_len + 1)) if self.categories_map[subcat_i]: raise RuntimeError(f"Duplicate category {subcat}") - self.categories_map[subcat_i] = {'super': sindex} + self.categories_map[subcat_i] = {"super": sindex} self.all_categories[subcat_i] = os.path.join(scat, subcat) # validate the dictionary @@ -183,9 +179,7 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: """ cat_id, fname = self.index[index] - img = Image.open(os.path.join(self.root, - self.all_categories[cat_id], - fname)) + img = Image.open(os.path.join(self.root, self.all_categories[cat_id], fname)) target: Any = [] for t in self.target_type: @@ -239,10 +233,8 @@ def download(self) -> None: base_root = os.path.dirname(self.root) download_and_extract_archive( - DATASET_URLS[self.version], - base_root, - filename=f"{self.version}.tgz", - md5=DATASET_MD5[self.version]) + DATASET_URLS[self.version], base_root, filename=f"{self.version}.tgz", md5=DATASET_MD5[self.version] + ) orig_dir_name = os.path.join(base_root, os.path.basename(DATASET_URLS[self.version]).rstrip(".tar.gz")) if not os.path.exists(orig_dir_name): diff --git a/torchvision/datasets/kinetics.py b/torchvision/datasets/kinetics.py index 68f7470e6ab..058fcca29e6 100644 --- a/torchvision/datasets/kinetics.py +++ b/torchvision/datasets/kinetics.py @@ -1,17 +1,16 @@ -import time +import csv import os +import time import warnings - - -from os import path -import csv -from typing import Any, Callable, Dict, Optional, Tuple from functools import partial from multiprocessing import Pool +from os import path +from typing import Any, Callable, Dict, Optional, Tuple + from torch import Tensor -from .utils import download_and_extract_archive, download_url, verify_str_arg, check_integrity from .folder import find_classes, make_dataset +from .utils import download_and_extract_archive, download_url, verify_str_arg, check_integrity from .video_utils import VideoClips from .vision import VisionDataset @@ -214,18 +213,13 @@ def _make_ds_structure(self) -> None: start=int(row["time_start"]), end=int(row["time_end"]), ) - label = ( - row["label"] - .replace(" ", "_") - .replace("'", "") - .replace("(", "") - .replace(")", "") - ) + label = row["label"].replace(" ", "_").replace("'", "").replace("(", "").replace(")", "") os.makedirs(path.join(self.split_folder, label), exist_ok=True) downloaded_file = path.join(self.split_folder, f) if path.isfile(downloaded_file): os.replace( - downloaded_file, path.join(self.split_folder, label, f), + downloaded_file, + path.join(self.split_folder, label, f), ) @property @@ -303,11 +297,12 @@ def __init__( split: Any = None, download: Any = None, num_download_workers: Any = None, - **kwargs: Any + **kwargs: Any, ) -> None: warnings.warn( "Kinetics400 is deprecated and will be removed in a future release." - "It was replaced by Kinetics(..., num_classes=\"400\").") + 'It was replaced by Kinetics(..., num_classes="400").' + ) if any(value is not None for value in (num_classes, split, download, num_download_workers)): raise RuntimeError( "Usage of 'num_classes', 'split', 'download', or 'num_download_workers' is not supported in " diff --git a/torchvision/datasets/kitti.py b/torchvision/datasets/kitti.py index c26a5aee066..1c5d5217c44 100644 --- a/torchvision/datasets/kitti.py +++ b/torchvision/datasets/kitti.py @@ -73,9 +73,7 @@ def __init__( if download: self.download() if not self._check_exists(): - raise RuntimeError( - "Dataset not found. You may use download=True to download it." - ) + raise RuntimeError("Dataset not found. You may use download=True to download it.") image_dir = os.path.join(self._raw_folder, self._location, self.image_dir_name) if self.train: @@ -83,9 +81,7 @@ def __init__( for img_file in os.listdir(image_dir): self.images.append(os.path.join(image_dir, img_file)) if self.train: - self.targets.append( - os.path.join(labels_dir, f"{img_file.split('.')[0]}.txt") - ) + self.targets.append(os.path.join(labels_dir, f"{img_file.split('.')[0]}.txt")) def __getitem__(self, index: int) -> Tuple[Any, Any]: """Get item at a given index. @@ -117,16 +113,18 @@ def _parse_target(self, index: int) -> List: with open(self.targets[index]) as inp: content = csv.reader(inp, delimiter=" ") for line in content: - target.append({ - "type": line[0], - "truncated": float(line[1]), - "occluded": int(line[2]), - "alpha": float(line[3]), - "bbox": [float(x) for x in line[4:8]], - "dimensions": [float(x) for x in line[8:11]], - "location": [float(x) for x in line[11:14]], - "rotation_y": float(line[14]), - }) + target.append( + { + "type": line[0], + "truncated": float(line[1]), + "occluded": int(line[2]), + "alpha": float(line[3]), + "bbox": [float(x) for x in line[4:8]], + "dimensions": [float(x) for x in line[8:11]], + "location": [float(x) for x in line[11:14]], + "rotation_y": float(line[14]), + } + ) return target def __len__(self) -> int: @@ -141,10 +139,7 @@ def _check_exists(self) -> bool: folders = [self.image_dir_name] if self.train: folders.append(self.labels_dir_name) - return all( - os.path.isdir(os.path.join(self._raw_folder, self._location, fname)) - for fname in folders - ) + return all(os.path.isdir(os.path.join(self._raw_folder, self._location, fname)) for fname in folders) def download(self) -> None: """Download the KITTI data if it doesn't exist already.""" diff --git a/torchvision/datasets/lfw.py b/torchvision/datasets/lfw.py index ecd5c3820ed..77a2b41ba35 100644 --- a/torchvision/datasets/lfw.py +++ b/torchvision/datasets/lfw.py @@ -1,30 +1,32 @@ import os from typing import Any, Callable, List, Optional, Tuple + from PIL import Image -from .vision import VisionDataset + from .utils import check_integrity, download_and_extract_archive, download_url, verify_str_arg +from .vision import VisionDataset class _LFW(VisionDataset): - base_folder = 'lfw-py' + base_folder = "lfw-py" download_url_prefix = "http://vis-www.cs.umass.edu/lfw/" file_dict = { - 'original': ("lfw", "lfw.tgz", "a17d05bd522c52d84eca14327a23d494"), - 'funneled': ("lfw_funneled", "lfw-funneled.tgz", "1b42dfed7d15c9b2dd63d5e5840c86ad"), - 'deepfunneled': ("lfw-deepfunneled", "lfw-deepfunneled.tgz", "68331da3eb755a505a502b5aacb3c201") + "original": ("lfw", "lfw.tgz", "a17d05bd522c52d84eca14327a23d494"), + "funneled": ("lfw_funneled", "lfw-funneled.tgz", "1b42dfed7d15c9b2dd63d5e5840c86ad"), + "deepfunneled": ("lfw-deepfunneled", "lfw-deepfunneled.tgz", "68331da3eb755a505a502b5aacb3c201"), } checksums = { - 'pairs.txt': '9f1ba174e4e1c508ff7cdf10ac338a7d', - 'pairsDevTest.txt': '5132f7440eb68cf58910c8a45a2ac10b', - 'pairsDevTrain.txt': '4f27cbf15b2da4a85c1907eb4181ad21', - 'people.txt': '450f0863dd89e85e73936a6d71a3474b', - 'peopleDevTest.txt': 'e4bf5be0a43b5dcd9dc5ccfcb8fb19c5', - 'peopleDevTrain.txt': '54eaac34beb6d042ed3a7d883e247a21', - 'lfw-names.txt': 'a6d0a479bd074669f656265a6e693f6d' + "pairs.txt": "9f1ba174e4e1c508ff7cdf10ac338a7d", + "pairsDevTest.txt": "5132f7440eb68cf58910c8a45a2ac10b", + "pairsDevTrain.txt": "4f27cbf15b2da4a85c1907eb4181ad21", + "people.txt": "450f0863dd89e85e73936a6d71a3474b", + "peopleDevTest.txt": "e4bf5be0a43b5dcd9dc5ccfcb8fb19c5", + "peopleDevTrain.txt": "54eaac34beb6d042ed3a7d883e247a21", + "lfw-names.txt": "a6d0a479bd074669f656265a6e693f6d", } - annot_file = {'10fold': '', 'train': 'DevTrain', 'test': 'DevTest'} + annot_file = {"10fold": "", "train": "DevTrain", "test": "DevTest"} names = "lfw-names.txt" def __init__( @@ -37,14 +39,15 @@ def __init__( target_transform: Optional[Callable] = None, download: bool = False, ): - super(_LFW, self).__init__(os.path.join(root, self.base_folder), - transform=transform, target_transform=target_transform) + super(_LFW, self).__init__( + os.path.join(root, self.base_folder), transform=transform, target_transform=target_transform + ) - self.image_set = verify_str_arg(image_set.lower(), 'image_set', self.file_dict.keys()) + self.image_set = verify_str_arg(image_set.lower(), "image_set", self.file_dict.keys()) images_dir, self.filename, self.md5 = self.file_dict[self.image_set] - self.view = verify_str_arg(view.lower(), 'view', ['people', 'pairs']) - self.split = verify_str_arg(split.lower(), 'split', ['10fold', 'train', 'test']) + self.view = verify_str_arg(view.lower(), "view", ["people", "pairs"]) + self.split = verify_str_arg(split.lower(), "split", ["10fold", "train", "test"]) self.labels_file = f"{self.view}{self.annot_file[self.split]}.txt" self.data: List[Any] = [] @@ -52,15 +55,14 @@ def __init__( self.download() if not self._check_integrity(): - raise RuntimeError('Dataset not found or corrupted.' + - ' You can use download=True to download it') + raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it") self.images_dir = os.path.join(self.root, images_dir) def _loader(self, path: str) -> Image.Image: - with open(path, 'rb') as f: + with open(path, "rb") as f: img = Image.open(f) - return img.convert('RGB') + return img.convert("RGB") def _check_integrity(self): st1 = check_integrity(os.path.join(self.root, self.filename), self.md5) @@ -73,7 +75,7 @@ def _check_integrity(self): def download(self): if self._check_integrity(): - print('Files already downloaded and verified') + print("Files already downloaded and verified") return url = f"{self.download_url_prefix}{self.filename}" download_and_extract_archive(url, self.root, filename=self.filename, md5=self.md5) @@ -120,21 +122,20 @@ def __init__( target_transform: Optional[Callable] = None, download: bool = False, ): - super(LFWPeople, self).__init__(root, split, image_set, "people", - transform, target_transform, download) + super(LFWPeople, self).__init__(root, split, image_set, "people", transform, target_transform, download) self.class_to_idx = self._get_classes() self.data, self.targets = self._get_people() def _get_people(self): data, targets = [], [] - with open(os.path.join(self.root, self.labels_file), 'r') as f: + with open(os.path.join(self.root, self.labels_file), "r") as f: lines = f.readlines() n_folds, s = (int(lines[0]), 1) if self.split == "10fold" else (1, 0) for fold in range(n_folds): n_lines = int(lines[s]) - people = [line.strip().split("\t") for line in lines[s + 1: s + n_lines + 1]] + people = [line.strip().split("\t") for line in lines[s + 1 : s + n_lines + 1]] s += n_lines + 1 for i, (identity, num_imgs) in enumerate(people): for num in range(1, int(num_imgs) + 1): @@ -145,7 +146,7 @@ def _get_people(self): return data, targets def _get_classes(self): - with open(os.path.join(self.root, self.names), 'r') as f: + with open(os.path.join(self.root, self.names), "r") as f: lines = f.readlines() names = [line.strip().split()[0] for line in lines] class_to_idx = {name: i for i, name in enumerate(names)} @@ -203,14 +204,13 @@ def __init__( target_transform: Optional[Callable] = None, download: bool = False, ): - super(LFWPairs, self).__init__(root, split, image_set, "pairs", - transform, target_transform, download) + super(LFWPairs, self).__init__(root, split, image_set, "pairs", transform, target_transform, download) self.pair_names, self.data, self.targets = self._get_pairs(self.images_dir) def _get_pairs(self, images_dir): pair_names, data, targets = [], [], [] - with open(os.path.join(self.root, self.labels_file), 'r') as f: + with open(os.path.join(self.root, self.labels_file), "r") as f: lines = f.readlines() if self.split == "10fold": n_folds, n_pairs = lines[0].split("\t") @@ -220,9 +220,9 @@ def _get_pairs(self, images_dir): s = 1 for fold in range(n_folds): - matched_pairs = [line.strip().split("\t") for line in lines[s: s + n_pairs]] - unmatched_pairs = [line.strip().split("\t") for line in lines[s + n_pairs: s + (2 * n_pairs)]] - s += (2 * n_pairs) + matched_pairs = [line.strip().split("\t") for line in lines[s : s + n_pairs]] + unmatched_pairs = [line.strip().split("\t") for line in lines[s + n_pairs : s + (2 * n_pairs)]] + s += 2 * n_pairs for pair in matched_pairs: img1, img2, same = self._get_path(pair[0], pair[1]), self._get_path(pair[0], pair[2]), 1 pair_names.append((pair[0], pair[0])) diff --git a/torchvision/datasets/lsun.py b/torchvision/datasets/lsun.py index 592806d30d0..5d4bcf948d7 100644 --- a/torchvision/datasets/lsun.py +++ b/torchvision/datasets/lsun.py @@ -1,29 +1,29 @@ -from .vision import VisionDataset -from PIL import Image +import io import os import os.path -import io +import pickle import string from collections.abc import Iterable -import pickle from typing import Any, Callable, cast, List, Optional, Tuple, Union + +from PIL import Image + from .utils import verify_str_arg, iterable_to_str +from .vision import VisionDataset class LSUNClass(VisionDataset): def __init__( - self, root: str, transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None + self, root: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None ) -> None: import lmdb - super(LSUNClass, self).__init__(root, transform=transform, - target_transform=target_transform) - self.env = lmdb.open(root, max_readers=1, readonly=True, lock=False, - readahead=False, meminit=False) + super(LSUNClass, self).__init__(root, transform=transform, target_transform=target_transform) + + self.env = lmdb.open(root, max_readers=1, readonly=True, lock=False, readahead=False, meminit=False) with self.env.begin(write=False) as txn: - self.length = txn.stat()['entries'] - cache_file = '_cache_' + ''.join(c for c in root if c in string.ascii_letters) + self.length = txn.stat()["entries"] + cache_file = "_cache_" + "".join(c for c in root if c in string.ascii_letters) if os.path.isfile(cache_file): self.keys = pickle.load(open(cache_file, "rb")) else: @@ -40,7 +40,7 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: buf = io.BytesIO() buf.write(imgbuf) buf.seek(0) - img = Image.open(buf).convert('RGB') + img = Image.open(buf).convert("RGB") if self.transform is not None: img = self.transform(img) @@ -71,22 +71,19 @@ class LSUN(VisionDataset): """ def __init__( - self, - root: str, - classes: Union[str, List[str]] = "train", - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, + self, + root: str, + classes: Union[str, List[str]] = "train", + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, ) -> None: - super(LSUN, self).__init__(root, transform=transform, - target_transform=target_transform) + super(LSUN, self).__init__(root, transform=transform, target_transform=target_transform) self.classes = self._verify_classes(classes) # for each class, create an LSUNClassDataset self.dbs = [] for c in self.classes: - self.dbs.append(LSUNClass( - root=os.path.join(root, f"{c}_lmdb"), - transform=transform)) + self.dbs.append(LSUNClass(root=os.path.join(root, f"{c}_lmdb"), transform=transform)) self.indices = [] count = 0 @@ -97,35 +94,41 @@ def __init__( self.length = count def _verify_classes(self, classes: Union[str, List[str]]) -> List[str]: - categories = ['bedroom', 'bridge', 'church_outdoor', 'classroom', - 'conference_room', 'dining_room', 'kitchen', - 'living_room', 'restaurant', 'tower'] - dset_opts = ['train', 'val', 'test'] + categories = [ + "bedroom", + "bridge", + "church_outdoor", + "classroom", + "conference_room", + "dining_room", + "kitchen", + "living_room", + "restaurant", + "tower", + ] + dset_opts = ["train", "val", "test"] try: classes = cast(str, classes) verify_str_arg(classes, "classes", dset_opts) - if classes == 'test': + if classes == "test": classes = [classes] else: - classes = [c + '_' + classes for c in categories] + classes = [c + "_" + classes for c in categories] except ValueError: if not isinstance(classes, Iterable): - msg = ("Expected type str or Iterable for argument classes, " - "but got type {}.") + msg = "Expected type str or Iterable for argument classes, " "but got type {}." raise ValueError(msg.format(type(classes))) classes = list(classes) - msg_fmtstr_type = ("Expected type str for elements in argument classes, " - "but got type {}.") + msg_fmtstr_type = "Expected type str for elements in argument classes, " "but got type {}." for c in classes: verify_str_arg(c, custom_msg=msg_fmtstr_type.format(type(c))) - c_short = c.split('_') - category, dset_opt = '_'.join(c_short[:-1]), c_short[-1] + c_short = c.split("_") + category, dset_opt = "_".join(c_short[:-1]), c_short[-1] msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}." - msg = msg_fmtstr.format(category, "LSUN class", - iterable_to_str(categories)) + msg = msg_fmtstr.format(category, "LSUN class", iterable_to_str(categories)) verify_str_arg(category, valid_values=categories, custom_msg=msg) msg = msg_fmtstr.format(dset_opt, "postfix", iterable_to_str(dset_opts)) diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py index 237a135722f..49e20438874 100644 --- a/torchvision/datasets/mnist.py +++ b/torchvision/datasets/mnist.py @@ -1,16 +1,18 @@ -from .vision import VisionDataset -import warnings -from PIL import Image +import codecs import os import os.path -import numpy as np -import torch -import codecs +import shutil import string +import warnings from typing import Any, Callable, Dict, List, Optional, Tuple from urllib.error import URLError + +import numpy as np +import torch +from PIL import Image + from .utils import download_and_extract_archive, extract_archive, verify_str_arg, check_integrity -import shutil +from .vision import VisionDataset class MNIST(VisionDataset): @@ -31,21 +33,31 @@ class MNIST(VisionDataset): """ mirrors = [ - 'http://yann.lecun.com/exdb/mnist/', - 'https://ossci-datasets.s3.amazonaws.com/mnist/', + "http://yann.lecun.com/exdb/mnist/", + "https://ossci-datasets.s3.amazonaws.com/mnist/", ] resources = [ ("train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"), ("train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"), ("t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"), - ("t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c") + ("t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c"), ] - training_file = 'training.pt' - test_file = 'test.pt' - classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', - '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine'] + training_file = "training.pt" + test_file = "test.pt" + classes = [ + "0 - zero", + "1 - one", + "2 - two", + "3 - three", + "4 - four", + "5 - five", + "6 - six", + "7 - seven", + "8 - eight", + "9 - nine", + ] @property def train_labels(self): @@ -68,15 +80,14 @@ def test_data(self): return self.data def __init__( - self, - root: str, - train: bool = True, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - download: bool = False, + self, + root: str, + train: bool = True, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, ) -> None: - super(MNIST, self).__init__(root, transform=transform, - target_transform=target_transform) + super(MNIST, self).__init__(root, transform=transform, target_transform=target_transform) self.train = train # training set or test set if self._check_legacy_exist(): @@ -87,8 +98,7 @@ def __init__( self.download() if not self._check_exists(): - raise RuntimeError('Dataset not found.' + - ' You can use download=True to download it') + raise RuntimeError("Dataset not found." + " You can use download=True to download it") self.data, self.targets = self._load_data() @@ -128,7 +138,7 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: # doing this so that it is consistent with all other datasets # to return a PIL Image - img = Image.fromarray(img.numpy(), mode='L') + img = Image.fromarray(img.numpy(), mode="L") if self.transform is not None: img = self.transform(img) @@ -143,11 +153,11 @@ def __len__(self) -> int: @property def raw_folder(self) -> str: - return os.path.join(self.root, self.__class__.__name__, 'raw') + return os.path.join(self.root, self.__class__.__name__, "raw") @property def processed_folder(self) -> str: - return os.path.join(self.root, self.__class__.__name__, 'processed') + return os.path.join(self.root, self.__class__.__name__, "processed") @property def class_to_idx(self) -> Dict[str, int]: @@ -173,15 +183,9 @@ def download(self) -> None: url = "{}{}".format(mirror, filename) try: print("Downloading {}".format(url)) - download_and_extract_archive( - url, download_root=self.raw_folder, - filename=filename, - md5=md5 - ) + download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5) except URLError as error: - print( - "Failed to download (trying next):\n{}".format(error) - ) + print("Failed to download (trying next):\n{}".format(error)) continue finally: print() @@ -209,18 +213,16 @@ class FashionMNIST(MNIST): target_transform (callable, optional): A function/transform that takes in the target and transforms it. """ - mirrors = [ - "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/" - ] + + mirrors = ["http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/"] resources = [ ("train-images-idx3-ubyte.gz", "8d4fb7e6c68d591d4c3dfef9ec88bf0d"), ("train-labels-idx1-ubyte.gz", "25c81989df183df01b3e8a0aad5dffbe"), ("t10k-images-idx3-ubyte.gz", "bef4ecab320f06d8554ea6380940ec79"), - ("t10k-labels-idx1-ubyte.gz", "bb300cfdad3c16e7a12a480ee83cd310") + ("t10k-labels-idx1-ubyte.gz", "bb300cfdad3c16e7a12a480ee83cd310"), ] - classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', - 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'] + classes = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"] class KMNIST(MNIST): @@ -239,17 +241,16 @@ class KMNIST(MNIST): target_transform (callable, optional): A function/transform that takes in the target and transforms it. """ - mirrors = [ - "http://codh.rois.ac.jp/kmnist/dataset/kmnist/" - ] + + mirrors = ["http://codh.rois.ac.jp/kmnist/dataset/kmnist/"] resources = [ ("train-images-idx3-ubyte.gz", "bdb82020997e1d708af4cf47b453dcf7"), ("train-labels-idx1-ubyte.gz", "e144d726b3acfaa3e44228e80efcd344"), ("t10k-images-idx3-ubyte.gz", "5c965bf0a639b31b8f53240b1b52f4d7"), - ("t10k-labels-idx1-ubyte.gz", "7320c461ea6c1c855c0b718fb2a4b134") + ("t10k-labels-idx1-ubyte.gz", "7320c461ea6c1c855c0b718fb2a4b134"), ] - classes = ['o', 'ki', 'su', 'tsu', 'na', 'ha', 'ma', 'ya', 're', 'wo'] + classes = ["o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"] class EMNIST(MNIST): @@ -271,19 +272,20 @@ class EMNIST(MNIST): target_transform (callable, optional): A function/transform that takes in the target and transforms it. """ - url = 'https://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip' + + url = "https://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip" md5 = "58c8d27c78d21e728a6bc7b3cc06412e" - splits = ('byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist') + splits = ("byclass", "bymerge", "balanced", "letters", "digits", "mnist") # Merged Classes assumes Same structure for both uppercase and lowercase version - _merged_classes = {'c', 'i', 'j', 'k', 'l', 'm', 'o', 'p', 's', 'u', 'v', 'w', 'x', 'y', 'z'} + _merged_classes = {"c", "i", "j", "k", "l", "m", "o", "p", "s", "u", "v", "w", "x", "y", "z"} _all_classes = set(string.digits + string.ascii_letters) classes_split_dict = { - 'byclass': sorted(list(_all_classes)), - 'bymerge': sorted(list(_all_classes - _merged_classes)), - 'balanced': sorted(list(_all_classes - _merged_classes)), - 'letters': ['N/A'] + list(string.ascii_lowercase), - 'digits': list(string.digits), - 'mnist': list(string.digits), + "byclass": sorted(list(_all_classes)), + "bymerge": sorted(list(_all_classes - _merged_classes)), + "balanced": sorted(list(_all_classes - _merged_classes)), + "letters": ["N/A"] + list(string.ascii_lowercase), + "digits": list(string.digits), + "mnist": list(string.digits), } def __init__(self, root: str, split: str, **kwargs: Any) -> None: @@ -295,11 +297,11 @@ def __init__(self, root: str, split: str, **kwargs: Any) -> None: @staticmethod def _training_file(split) -> str: - return 'training_{}.pt'.format(split) + return "training_{}.pt".format(split) @staticmethod def _test_file(split) -> str: - return 'test_{}.pt'.format(split) + return "test_{}.pt".format(split) @property def _file_prefix(self) -> str: @@ -328,9 +330,9 @@ def download(self) -> None: os.makedirs(self.raw_folder, exist_ok=True) download_and_extract_archive(self.url, download_root=self.raw_folder, md5=self.md5) - gzip_folder = os.path.join(self.raw_folder, 'gzip') + gzip_folder = os.path.join(self.raw_folder, "gzip") for gzip_file in os.listdir(gzip_folder): - if gzip_file.endswith('.gz'): + if gzip_file.endswith(".gz"): extract_archive(os.path.join(gzip_folder, gzip_file), self.raw_folder) shutil.rmtree(gzip_folder) @@ -365,39 +367,60 @@ class QMNIST(MNIST): training set ot the testing set. Default: True. """ - subsets = { - 'train': 'train', - 'test': 'test', - 'test10k': 'test', - 'test50k': 'test', - 'nist': 'nist' - } + subsets = {"train": "train", "test": "test", "test10k": "test", "test50k": "test", "nist": "nist"} resources: Dict[str, List[Tuple[str, str]]] = { # type: ignore[assignment] - 'train': [('https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-images-idx3-ubyte.gz', - 'ed72d4157d28c017586c42bc6afe6370'), - ('https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-labels-idx2-int.gz', - '0058f8dd561b90ffdd0f734c6a30e5e4')], - 'test': [('https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-images-idx3-ubyte.gz', - '1394631089c404de565df7b7aeaf9412'), - ('https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-labels-idx2-int.gz', - '5b5b05890a5e13444e108efe57b788aa')], - 'nist': [('https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-images-idx3-ubyte.xz', - '7f124b3b8ab81486c9d8c2749c17f834'), - ('https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-labels-idx2-int.xz', - '5ed0e788978e45d4a8bd4b7caec3d79d')] + "train": [ + ( + "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-images-idx3-ubyte.gz", + "ed72d4157d28c017586c42bc6afe6370", + ), + ( + "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-labels-idx2-int.gz", + "0058f8dd561b90ffdd0f734c6a30e5e4", + ), + ], + "test": [ + ( + "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-images-idx3-ubyte.gz", + "1394631089c404de565df7b7aeaf9412", + ), + ( + "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-labels-idx2-int.gz", + "5b5b05890a5e13444e108efe57b788aa", + ), + ], + "nist": [ + ( + "https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-images-idx3-ubyte.xz", + "7f124b3b8ab81486c9d8c2749c17f834", + ), + ( + "https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-labels-idx2-int.xz", + "5ed0e788978e45d4a8bd4b7caec3d79d", + ), + ], } - classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', - '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine'] + classes = [ + "0 - zero", + "1 - one", + "2 - two", + "3 - three", + "4 - four", + "5 - five", + "6 - six", + "7 - seven", + "8 - eight", + "9 - nine", + ] def __init__( - self, root: str, what: Optional[str] = None, compat: bool = True, - train: bool = True, **kwargs: Any + self, root: str, what: Optional[str] = None, compat: bool = True, train: bool = True, **kwargs: Any ) -> None: if what is None: - what = 'train' if train else 'test' + what = "train" if train else "test" self.what = verify_str_arg(what, "what", tuple(self.subsets.keys())) self.compat = compat - self.data_file = what + '.pt' + self.data_file = what + ".pt" self.training_file = self.data_file self.test_file = self.data_file super(QMNIST, self).__init__(root, train, **kwargs) @@ -417,16 +440,16 @@ def _check_exists(self) -> bool: def _load_data(self): data = read_sn3_pascalvincent_tensor(self.images_file) - assert (data.dtype == torch.uint8) - assert (data.ndimension() == 3) + assert data.dtype == torch.uint8 + assert data.ndimension() == 3 targets = read_sn3_pascalvincent_tensor(self.labels_file).long() - assert (targets.ndimension() == 2) + assert targets.ndimension() == 2 - if self.what == 'test10k': + if self.what == "test10k": data = data[0:10000, :, :].clone() targets = targets[0:10000, :].clone() - elif self.what == 'test50k': + elif self.what == "test50k": data = data[10000:, :, :].clone() targets = targets[10000:, :].clone() @@ -434,7 +457,7 @@ def _load_data(self): def download(self) -> None: """Download the QMNIST data if it doesn't exist already. - Note that we only download what has been asked for (argument 'what'). + Note that we only download what has been asked for (argument 'what'). """ if self._check_exists(): return @@ -443,7 +466,7 @@ def download(self) -> None: split = self.resources[self.subsets[self.what]] for url, md5 in split: - filename = url.rpartition('/')[2] + filename = url.rpartition("/")[2] file_path = os.path.join(self.raw_folder, filename) if not os.path.isfile(file_path): download_and_extract_archive(url, self.raw_folder, filename=filename, md5=md5) @@ -451,7 +474,7 @@ def download(self) -> None: def __getitem__(self, index: int) -> Tuple[Any, Any]: # redefined to handle the compat flag img, target = self.data[index], self.targets[index] - img = Image.fromarray(img.numpy(), mode='L') + img = Image.fromarray(img.numpy(), mode="L") if self.transform is not None: img = self.transform(img) if self.compat: @@ -465,22 +488,22 @@ def extra_repr(self) -> str: def get_int(b: bytes) -> int: - return int(codecs.encode(b, 'hex'), 16) + return int(codecs.encode(b, "hex"), 16) SN3_PASCALVINCENT_TYPEMAP = { 8: (torch.uint8, np.uint8, np.uint8), 9: (torch.int8, np.int8, np.int8), - 11: (torch.int16, np.dtype('>i2'), 'i2'), - 12: (torch.int32, np.dtype('>i4'), 'i4'), - 13: (torch.float32, np.dtype('>f4'), 'f4'), - 14: (torch.float64, np.dtype('>f8'), 'f8') + 11: (torch.int16, np.dtype(">i2"), "i2"), + 12: (torch.int32, np.dtype(">i4"), "i4"), + 13: (torch.float32, np.dtype(">f4"), "f4"), + 14: (torch.float64, np.dtype(">f8"), "f8"), } def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tensor: """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh'). - Argument may be a filename, compressed filename, or file object. + Argument may be a filename, compressed filename, or file object. """ # read with open(path, "rb") as f: @@ -492,7 +515,7 @@ def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tenso assert 1 <= nd <= 3 assert 8 <= ty <= 14 m = SN3_PASCALVINCENT_TYPEMAP[ty] - s = [get_int(data[4 * (i + 1): 4 * (i + 2)]) for i in range(nd)] + s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)] parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1))) assert parsed.shape[0] == np.prod(s) or not strict return torch.from_numpy(parsed.astype(m[2])).view(*s) @@ -500,13 +523,13 @@ def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tenso def read_label_file(path: str) -> torch.Tensor: x = read_sn3_pascalvincent_tensor(path, strict=False) - assert(x.dtype == torch.uint8) - assert(x.ndimension() == 1) + assert x.dtype == torch.uint8 + assert x.ndimension() == 1 return x.long() def read_image_file(path: str) -> torch.Tensor: x = read_sn3_pascalvincent_tensor(path, strict=False) - assert(x.dtype == torch.uint8) - assert(x.ndimension() == 3) + assert x.dtype == torch.uint8 + assert x.ndimension() == 3 return x diff --git a/torchvision/datasets/omniglot.py b/torchvision/datasets/omniglot.py index b78bf86d16f..0a6577beaae 100644 --- a/torchvision/datasets/omniglot.py +++ b/torchvision/datasets/omniglot.py @@ -1,8 +1,10 @@ -from PIL import Image from os.path import join from typing import Any, Callable, List, Optional, Tuple -from .vision import VisionDataset + +from PIL import Image + from .utils import download_and_extract_archive, check_integrity, list_dir, list_files +from .vision import VisionDataset class Omniglot(VisionDataset): @@ -21,38 +23,40 @@ class Omniglot(VisionDataset): puts it in root directory. If the zip files are already downloaded, they are not downloaded again. """ - folder = 'omniglot-py' - download_url_prefix = 'https://raw.githubusercontent.com/brendenlake/omniglot/master/python' + + folder = "omniglot-py" + download_url_prefix = "https://raw.githubusercontent.com/brendenlake/omniglot/master/python" zips_md5 = { - 'images_background': '68d2efa1b9178cc56df9314c21c6e718', - 'images_evaluation': '6b91aef0f799c5bb55b94e3f2daec811' + "images_background": "68d2efa1b9178cc56df9314c21c6e718", + "images_evaluation": "6b91aef0f799c5bb55b94e3f2daec811", } def __init__( - self, - root: str, - background: bool = True, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - download: bool = False, + self, + root: str, + background: bool = True, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, ) -> None: - super(Omniglot, self).__init__(join(root, self.folder), transform=transform, - target_transform=target_transform) + super(Omniglot, self).__init__(join(root, self.folder), transform=transform, target_transform=target_transform) self.background = background if download: self.download() if not self._check_integrity(): - raise RuntimeError('Dataset not found or corrupted.' + - ' You can use download=True to download it') + raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it") self.target_folder = join(self.root, self._get_target_folder()) self._alphabets = list_dir(self.target_folder) - self._characters: List[str] = sum([[join(a, c) for c in list_dir(join(self.target_folder, a))] - for a in self._alphabets], []) - self._character_images = [[(image, idx) for image in list_files(join(self.target_folder, character), '.png')] - for idx, character in enumerate(self._characters)] + self._characters: List[str] = sum( + [[join(a, c) for c in list_dir(join(self.target_folder, a))] for a in self._alphabets], [] + ) + self._character_images = [ + [(image, idx) for image in list_files(join(self.target_folder, character), ".png")] + for idx, character in enumerate(self._characters) + ] self._flat_character_images: List[Tuple[str, int]] = sum(self._character_images, []) def __len__(self) -> int: @@ -68,7 +72,7 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: """ image_name, character_class = self._flat_character_images[index] image_path = join(self.target_folder, self._characters[character_class], image_name) - image = Image.open(image_path, mode='r').convert('L') + image = Image.open(image_path, mode="r").convert("L") if self.transform: image = self.transform(image) @@ -80,19 +84,19 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: def _check_integrity(self) -> bool: zip_filename = self._get_target_folder() - if not check_integrity(join(self.root, zip_filename + '.zip'), self.zips_md5[zip_filename]): + if not check_integrity(join(self.root, zip_filename + ".zip"), self.zips_md5[zip_filename]): return False return True def download(self) -> None: if self._check_integrity(): - print('Files already downloaded and verified') + print("Files already downloaded and verified") return filename = self._get_target_folder() - zip_filename = filename + '.zip' - url = self.download_url_prefix + '/' + zip_filename + zip_filename = filename + ".zip" + url = self.download_url_prefix + "/" + zip_filename download_and_extract_archive(url, self.root, filename=zip_filename, md5=self.zips_md5[filename]) def _get_target_folder(self) -> str: - return 'images_background' if self.background else 'images_evaluation' + return "images_background" if self.background else "images_evaluation" diff --git a/torchvision/datasets/phototour.py b/torchvision/datasets/phototour.py index abb89701e1e..fe8134fe2c8 100644 --- a/torchvision/datasets/phototour.py +++ b/torchvision/datasets/phototour.py @@ -1,12 +1,12 @@ import os -import numpy as np -from PIL import Image from typing import Any, Callable, List, Optional, Tuple, Union +import numpy as np import torch -from .vision import VisionDataset +from PIL import Image from .utils import download_url +from .vision import VisionDataset class PhotoTour(VisionDataset): @@ -33,56 +33,67 @@ class PhotoTour(VisionDataset): downloaded again. """ + urls = { - 'notredame_harris': [ - 'http://matthewalunbrown.com/patchdata/notredame_harris.zip', - 'notredame_harris.zip', - '69f8c90f78e171349abdf0307afefe4d' - ], - 'yosemite_harris': [ - 'http://matthewalunbrown.com/patchdata/yosemite_harris.zip', - 'yosemite_harris.zip', - 'a73253d1c6fbd3ba2613c45065c00d46' + "notredame_harris": [ + "http://matthewalunbrown.com/patchdata/notredame_harris.zip", + "notredame_harris.zip", + "69f8c90f78e171349abdf0307afefe4d", ], - 'liberty_harris': [ - 'http://matthewalunbrown.com/patchdata/liberty_harris.zip', - 'liberty_harris.zip', - 'c731fcfb3abb4091110d0ae8c7ba182c' + "yosemite_harris": [ + "http://matthewalunbrown.com/patchdata/yosemite_harris.zip", + "yosemite_harris.zip", + "a73253d1c6fbd3ba2613c45065c00d46", ], - 'notredame': [ - 'http://icvl.ee.ic.ac.uk/vbalnt/notredame.zip', - 'notredame.zip', - '509eda8535847b8c0a90bbb210c83484' + "liberty_harris": [ + "http://matthewalunbrown.com/patchdata/liberty_harris.zip", + "liberty_harris.zip", + "c731fcfb3abb4091110d0ae8c7ba182c", ], - 'yosemite': [ - 'http://icvl.ee.ic.ac.uk/vbalnt/yosemite.zip', - 'yosemite.zip', - '533b2e8eb7ede31be40abc317b2fd4f0' - ], - 'liberty': [ - 'http://icvl.ee.ic.ac.uk/vbalnt/liberty.zip', - 'liberty.zip', - 'fdd9152f138ea5ef2091746689176414' + "notredame": [ + "http://icvl.ee.ic.ac.uk/vbalnt/notredame.zip", + "notredame.zip", + "509eda8535847b8c0a90bbb210c83484", ], + "yosemite": ["http://icvl.ee.ic.ac.uk/vbalnt/yosemite.zip", "yosemite.zip", "533b2e8eb7ede31be40abc317b2fd4f0"], + "liberty": ["http://icvl.ee.ic.ac.uk/vbalnt/liberty.zip", "liberty.zip", "fdd9152f138ea5ef2091746689176414"], + } + means = { + "notredame": 0.4854, + "yosemite": 0.4844, + "liberty": 0.4437, + "notredame_harris": 0.4854, + "yosemite_harris": 0.4844, + "liberty_harris": 0.4437, + } + stds = { + "notredame": 0.1864, + "yosemite": 0.1818, + "liberty": 0.2019, + "notredame_harris": 0.1864, + "yosemite_harris": 0.1818, + "liberty_harris": 0.2019, + } + lens = { + "notredame": 468159, + "yosemite": 633587, + "liberty": 450092, + "liberty_harris": 379587, + "yosemite_harris": 450912, + "notredame_harris": 325295, } - means = {'notredame': 0.4854, 'yosemite': 0.4844, 'liberty': 0.4437, - 'notredame_harris': 0.4854, 'yosemite_harris': 0.4844, 'liberty_harris': 0.4437} - stds = {'notredame': 0.1864, 'yosemite': 0.1818, 'liberty': 0.2019, - 'notredame_harris': 0.1864, 'yosemite_harris': 0.1818, 'liberty_harris': 0.2019} - lens = {'notredame': 468159, 'yosemite': 633587, 'liberty': 450092, - 'liberty_harris': 379587, 'yosemite_harris': 450912, 'notredame_harris': 325295} - image_ext = 'bmp' - info_file = 'info.txt' - matches_files = 'm50_100000_100000_0.txt' + image_ext = "bmp" + info_file = "info.txt" + matches_files = "m50_100000_100000_0.txt" def __init__( - self, root: str, name: str, train: bool = True, transform: Optional[Callable] = None, download: bool = False + self, root: str, name: str, train: bool = True, transform: Optional[Callable] = None, download: bool = False ) -> None: super(PhotoTour, self).__init__(root, transform=transform) self.name = name self.data_dir = os.path.join(self.root, name) - self.data_down = os.path.join(self.root, '{}.zip'.format(name)) - self.data_file = os.path.join(self.root, '{}.pt'.format(name)) + self.data_down = os.path.join(self.root, "{}.zip".format(name)) + self.data_file = os.path.join(self.root, "{}.pt".format(name)) self.train = train self.mean = self.means[name] @@ -128,7 +139,7 @@ def _check_downloaded(self) -> bool: def download(self) -> None: if self._check_datafile_exists(): - print('# Found cached data {}'.format(self.data_file)) + print("# Found cached data {}".format(self.data_file)) return if not self._check_downloaded(): @@ -140,25 +151,26 @@ def download(self) -> None: download_url(url, self.root, filename, md5) - print('# Extracting data {}\n'.format(self.data_down)) + print("# Extracting data {}\n".format(self.data_down)) import zipfile - with zipfile.ZipFile(fpath, 'r') as z: + + with zipfile.ZipFile(fpath, "r") as z: z.extractall(self.data_dir) os.unlink(fpath) def cache(self) -> None: # process and save as torch files - print('# Caching data {}'.format(self.data_file)) + print("# Caching data {}".format(self.data_file)) dataset = ( read_image_file(self.data_dir, self.image_ext, self.lens[self.name]), read_info_file(self.data_dir, self.info_file), - read_matches_files(self.data_dir, self.matches_files) + read_matches_files(self.data_dir, self.matches_files), ) - with open(self.data_file, 'wb') as f: + with open(self.data_file, "wb") as f: torch.save(dataset, f) def extra_repr(self) -> str: @@ -166,17 +178,14 @@ def extra_repr(self) -> str: def read_image_file(data_dir: str, image_ext: str, n: int) -> torch.Tensor: - """Return a Tensor containing the patches - """ + """Return a Tensor containing the patches""" def PIL2array(_img: Image.Image) -> np.ndarray: - """Convert PIL image type to numpy 2D array - """ + """Convert PIL image type to numpy 2D array""" return np.array(_img.getdata(), dtype=np.uint8).reshape(64, 64) def find_files(_data_dir: str, _image_ext: str) -> List[str]: - """Return a list with the file names of the images containing the patches - """ + """Return a list with the file names of the images containing the patches""" files = [] # find those files with the specified extension for file_dir in os.listdir(_data_dir): @@ -198,22 +207,21 @@ def find_files(_data_dir: str, _image_ext: str) -> List[str]: def read_info_file(data_dir: str, info_file: str) -> torch.Tensor: """Return a Tensor containing the list of labels - Read the file and keep only the ID of the 3D point. + Read the file and keep only the ID of the 3D point. """ - with open(os.path.join(data_dir, info_file), 'r') as f: + with open(os.path.join(data_dir, info_file), "r") as f: labels = [int(line.split()[0]) for line in f] return torch.LongTensor(labels) def read_matches_files(data_dir: str, matches_file: str) -> torch.Tensor: """Return a Tensor containing the ground truth matches - Read the file and keep only 3D point ID. - Matches are represented with a 1, non matches with a 0. + Read the file and keep only 3D point ID. + Matches are represented with a 1, non matches with a 0. """ matches = [] - with open(os.path.join(data_dir, matches_file), 'r') as f: + with open(os.path.join(data_dir, matches_file), "r") as f: for line in f: line_split = line.split() - matches.append([int(line_split[0]), int(line_split[3]), - int(line_split[1] == line_split[4])]) + matches.append([int(line_split[0]), int(line_split[3]), int(line_split[1] == line_split[4])]) return torch.LongTensor(matches) diff --git a/torchvision/datasets/samplers/__init__.py b/torchvision/datasets/samplers/__init__.py index 870322d39b4..861a029a9ec 100644 --- a/torchvision/datasets/samplers/__init__.py +++ b/torchvision/datasets/samplers/__init__.py @@ -1,3 +1,3 @@ from .clip_sampler import DistributedSampler, UniformClipSampler, RandomClipSampler -__all__ = ('DistributedSampler', 'UniformClipSampler', 'RandomClipSampler') +__all__ = ("DistributedSampler", "UniformClipSampler", "RandomClipSampler") diff --git a/torchvision/datasets/samplers/clip_sampler.py b/torchvision/datasets/samplers/clip_sampler.py index 0f90e3ad1b0..259621bf91f 100644 --- a/torchvision/datasets/samplers/clip_sampler.py +++ b/torchvision/datasets/samplers/clip_sampler.py @@ -1,9 +1,10 @@ import math +from typing import Optional, List, Iterator, Sized, Union, cast + import torch -from torch.utils.data import Sampler import torch.distributed as dist +from torch.utils.data import Sampler from torchvision.datasets.video_utils import VideoClips -from typing import Optional, List, Iterator, Sized, Union, cast class DistributedSampler(Sampler): @@ -36,12 +37,12 @@ class DistributedSampler(Sampler): """ def __init__( - self, - dataset: Sized, - num_replicas: Optional[int] = None, - rank: Optional[int] = None, - shuffle: bool = False, - group_size: int = 1, + self, + dataset: Sized, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = False, + group_size: int = 1, ) -> None: if num_replicas is None: if not dist.is_available(): @@ -51,9 +52,11 @@ def __init__( if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") rank = dist.get_rank() - assert len(dataset) % group_size == 0, ( - "dataset length must be a multiplier of group size" - "dataset length: %d, group size: %d" % (len(dataset), group_size) + assert ( + len(dataset) % group_size == 0 + ), "dataset length must be a multiplier of group size" "dataset length: %d, group size: %d" % ( + len(dataset), + group_size, ) self.dataset = dataset self.group_size = group_size @@ -61,9 +64,7 @@ def __init__( self.rank = rank self.epoch = 0 dataset_group_length = len(dataset) // group_size - self.num_group_samples = int( - math.ceil(dataset_group_length * 1.0 / self.num_replicas) - ) + self.num_group_samples = int(math.ceil(dataset_group_length * 1.0 / self.num_replicas)) self.num_samples = self.num_group_samples * group_size self.total_size = self.num_samples * self.num_replicas self.shuffle = shuffle @@ -79,16 +80,14 @@ def __iter__(self) -> Iterator[int]: indices = list(range(len(self.dataset))) # add extra samples to make it evenly divisible - indices += indices[:(self.total_size - len(indices))] + indices += indices[: (self.total_size - len(indices))] assert len(indices) == self.total_size total_group_size = self.total_size // self.group_size - indices = torch.reshape( - torch.LongTensor(indices), (total_group_size, self.group_size) - ) + indices = torch.reshape(torch.LongTensor(indices), (total_group_size, self.group_size)) # subsample - indices = indices[self.rank:total_group_size:self.num_replicas, :] + indices = indices[self.rank : total_group_size : self.num_replicas, :] indices = torch.reshape(indices, (-1,)).tolist() assert len(indices) == self.num_samples @@ -115,10 +114,10 @@ class UniformClipSampler(Sampler): video_clips (VideoClips): video clips to sample from num_clips_per_video (int): number of clips to be sampled per video """ + def __init__(self, video_clips: VideoClips, num_clips_per_video: int) -> None: if not isinstance(video_clips, VideoClips): - raise TypeError("Expected video_clips to be an instance of VideoClips, " - "got {}".format(type(video_clips))) + raise TypeError("Expected video_clips to be an instance of VideoClips, " "got {}".format(type(video_clips))) self.video_clips = video_clips self.num_clips_per_video = num_clips_per_video @@ -132,19 +131,13 @@ def __iter__(self) -> Iterator[int]: # corner case where video decoding fails continue - sampled = ( - torch.linspace(s, s + length - 1, steps=self.num_clips_per_video) - .floor() - .to(torch.int64) - ) + sampled = torch.linspace(s, s + length - 1, steps=self.num_clips_per_video).floor().to(torch.int64) s += length idxs.append(sampled) return iter(cast(List[int], torch.cat(idxs).tolist())) def __len__(self) -> int: - return sum( - self.num_clips_per_video for c in self.video_clips.clips if len(c) > 0 - ) + return sum(self.num_clips_per_video for c in self.video_clips.clips if len(c) > 0) class RandomClipSampler(Sampler): @@ -155,10 +148,10 @@ class RandomClipSampler(Sampler): video_clips (VideoClips): video clips to sample from max_clips_per_video (int): maximum number of clips to be sampled per video """ + def __init__(self, video_clips: VideoClips, max_clips_per_video: int) -> None: if not isinstance(video_clips, VideoClips): - raise TypeError("Expected video_clips to be an instance of VideoClips, " - "got {}".format(type(video_clips))) + raise TypeError("Expected video_clips to be an instance of VideoClips, " "got {}".format(type(video_clips))) self.video_clips = video_clips self.max_clips_per_video = max_clips_per_video diff --git a/torchvision/datasets/sbd.py b/torchvision/datasets/sbd.py index e47c9493858..889dfc3a0be 100644 --- a/torchvision/datasets/sbd.py +++ b/torchvision/datasets/sbd.py @@ -1,12 +1,12 @@ import os import shutil -from .vision import VisionDataset from typing import Any, Callable, Optional, Tuple import numpy as np - from PIL import Image + from .utils import download_url, verify_str_arg, download_and_extract_archive +from .vision import VisionDataset class SBDataset(VisionDataset): @@ -50,30 +50,29 @@ class SBDataset(VisionDataset): voc_split_md5 = "79bff800c5f0b1ec6b21080a3c066722" def __init__( - self, - root: str, - image_set: str = "train", - mode: str = "boundaries", - download: bool = False, - transforms: Optional[Callable] = None, + self, + root: str, + image_set: str = "train", + mode: str = "boundaries", + download: bool = False, + transforms: Optional[Callable] = None, ) -> None: try: from scipy.io import loadmat + self._loadmat = loadmat except ImportError: - raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: " - "pip install scipy") + raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: " "pip install scipy") super(SBDataset, self).__init__(root, transforms) - self.image_set = verify_str_arg(image_set, "image_set", - ("train", "val", "train_noval")) + self.image_set = verify_str_arg(image_set, "image_set", ("train", "val", "train_noval")) self.mode = verify_str_arg(mode, "mode", ("segmentation", "boundaries")) self.num_classes = 20 sbd_root = self.root - image_dir = os.path.join(sbd_root, 'img') - mask_dir = os.path.join(sbd_root, 'cls') + image_dir = os.path.join(sbd_root, "img") + mask_dir = os.path.join(sbd_root, "cls") if download: download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.md5) @@ -81,36 +80,35 @@ def __init__( for f in ["cls", "img", "inst", "train.txt", "val.txt"]: old_path = os.path.join(extracted_ds_root, f) shutil.move(old_path, sbd_root) - download_url(self.voc_train_url, sbd_root, self.voc_split_filename, - self.voc_split_md5) + download_url(self.voc_train_url, sbd_root, self.voc_split_filename, self.voc_split_md5) if not os.path.isdir(sbd_root): - raise RuntimeError('Dataset not found or corrupted.' + - ' You can use download=True to download it') + raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it") - split_f = os.path.join(sbd_root, image_set.rstrip('\n') + '.txt') + split_f = os.path.join(sbd_root, image_set.rstrip("\n") + ".txt") with open(os.path.join(split_f), "r") as fh: file_names = [x.strip() for x in fh.readlines()] self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names] self.masks = [os.path.join(mask_dir, x + ".mat") for x in file_names] - assert (len(self.images) == len(self.masks)) + assert len(self.images) == len(self.masks) - self._get_target = self._get_segmentation_target \ - if self.mode == "segmentation" else self._get_boundaries_target + self._get_target = self._get_segmentation_target if self.mode == "segmentation" else self._get_boundaries_target def _get_segmentation_target(self, filepath: str) -> Image.Image: mat = self._loadmat(filepath) - return Image.fromarray(mat['GTcls'][0]['Segmentation'][0]) + return Image.fromarray(mat["GTcls"][0]["Segmentation"][0]) def _get_boundaries_target(self, filepath: str) -> np.ndarray: mat = self._loadmat(filepath) - return np.concatenate([np.expand_dims(mat['GTcls'][0]['Boundaries'][0][i][0].toarray(), axis=0) - for i in range(self.num_classes)], axis=0) + return np.concatenate( + [np.expand_dims(mat["GTcls"][0]["Boundaries"][0][i][0].toarray(), axis=0) for i in range(self.num_classes)], + axis=0, + ) def __getitem__(self, index: int) -> Tuple[Any, Any]: - img = Image.open(self.images[index]).convert('RGB') + img = Image.open(self.images[index]).convert("RGB") target = self._get_target(self.masks[index]) if self.transforms is not None: @@ -123,4 +121,4 @@ def __len__(self) -> int: def extra_repr(self) -> str: lines = ["Image set: {image_set}", "Mode: {mode}"] - return '\n'.join(lines).format(**self.__dict__) + return "\n".join(lines).format(**self.__dict__) diff --git a/torchvision/datasets/sbu.py b/torchvision/datasets/sbu.py index 6c8ad15686b..53c6218a7de 100644 --- a/torchvision/datasets/sbu.py +++ b/torchvision/datasets/sbu.py @@ -1,8 +1,9 @@ -from PIL import Image -from .utils import download_url, check_integrity +import os from typing import Any, Callable, Optional, Tuple -import os +from PIL import Image + +from .utils import download_url, check_integrity from .vision import VisionDataset @@ -20,38 +21,37 @@ class SBU(VisionDataset): puts it in root directory. If dataset is already downloaded, it is not downloaded again. """ + url = "http://www.cs.virginia.edu/~vicente/sbucaptions/SBUCaptionedPhotoDataset.tar.gz" filename = "SBUCaptionedPhotoDataset.tar.gz" - md5_checksum = '9aec147b3488753cf758b4d493422285' + md5_checksum = "9aec147b3488753cf758b4d493422285" def __init__( - self, - root: str, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - download: bool = True, + self, + root: str, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = True, ) -> None: - super(SBU, self).__init__(root, transform=transform, - target_transform=target_transform) + super(SBU, self).__init__(root, transform=transform, target_transform=target_transform) if download: self.download() if not self._check_integrity(): - raise RuntimeError('Dataset not found or corrupted.' + - ' You can use download=True to download it') + raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it") # Read the caption for each photo self.photos = [] self.captions = [] - file1 = os.path.join(self.root, 'dataset', 'SBU_captioned_photo_dataset_urls.txt') - file2 = os.path.join(self.root, 'dataset', 'SBU_captioned_photo_dataset_captions.txt') + file1 = os.path.join(self.root, "dataset", "SBU_captioned_photo_dataset_urls.txt") + file2 = os.path.join(self.root, "dataset", "SBU_captioned_photo_dataset_captions.txt") for line1, line2 in zip(open(file1), open(file2)): url = line1.rstrip() photo = os.path.basename(url) - filename = os.path.join(self.root, 'dataset', photo) + filename = os.path.join(self.root, "dataset", photo) if os.path.exists(filename): caption = line2.rstrip() self.photos.append(photo) @@ -65,8 +65,8 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: Returns: tuple: (image, target) where target is a caption for the photo. """ - filename = os.path.join(self.root, 'dataset', self.photos[index]) - img = Image.open(filename).convert('RGB') + filename = os.path.join(self.root, "dataset", self.photos[index]) + img = Image.open(filename).convert("RGB") if self.transform is not None: img = self.transform(img) @@ -93,21 +93,21 @@ def download(self) -> None: import tarfile if self._check_integrity(): - print('Files already downloaded and verified') + print("Files already downloaded and verified") return download_url(self.url, self.root, self.filename, self.md5_checksum) # Extract file - with tarfile.open(os.path.join(self.root, self.filename), 'r:gz') as tar: + with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar: tar.extractall(path=self.root) # Download individual photos - with open(os.path.join(self.root, 'dataset', 'SBU_captioned_photo_dataset_urls.txt')) as fh: + with open(os.path.join(self.root, "dataset", "SBU_captioned_photo_dataset_urls.txt")) as fh: for line in fh: url = line.rstrip() try: - download_url(url, os.path.join(self.root, 'dataset')) + download_url(url, os.path.join(self.root, "dataset")) except OSError: # The images point to public images on Flickr. # Note: Images might be removed by users at anytime. diff --git a/torchvision/datasets/semeion.py b/torchvision/datasets/semeion.py index 20ce4e5f5d5..b97918a6292 100644 --- a/torchvision/datasets/semeion.py +++ b/torchvision/datasets/semeion.py @@ -1,10 +1,12 @@ -from PIL import Image import os import os.path -import numpy as np from typing import Any, Callable, Optional, Tuple -from .vision import VisionDataset + +import numpy as np +from PIL import Image + from .utils import download_url, check_integrity +from .vision import VisionDataset class SEMEION(VisionDataset): @@ -24,30 +26,28 @@ class SEMEION(VisionDataset): """ url = "http://archive.ics.uci.edu/ml/machine-learning-databases/semeion/semeion.data" filename = "semeion.data" - md5_checksum = 'cb545d371d2ce14ec121470795a77432' + md5_checksum = "cb545d371d2ce14ec121470795a77432" def __init__( - self, - root: str, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - download: bool = True, + self, + root: str, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = True, ) -> None: - super(SEMEION, self).__init__(root, transform=transform, - target_transform=target_transform) + super(SEMEION, self).__init__(root, transform=transform, target_transform=target_transform) if download: self.download() if not self._check_integrity(): - raise RuntimeError('Dataset not found or corrupted.' + - ' You can use download=True to download it') + raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it") fp = os.path.join(self.root, self.filename) data = np.loadtxt(fp) # convert value to 8 bit unsigned integer # color (white #255) the pixels - self.data = (data[:, :256] * 255).astype('uint8') + self.data = (data[:, :256] * 255).astype("uint8") self.data = np.reshape(self.data, (-1, 16, 16)) self.labels = np.nonzero(data[:, 256:])[1] @@ -63,7 +63,7 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: # doing this so that it is consistent with all other datasets # to return a PIL Image - img = Image.fromarray(img, mode='L') + img = Image.fromarray(img, mode="L") if self.transform is not None: img = self.transform(img) @@ -85,7 +85,7 @@ def _check_integrity(self) -> bool: def download(self) -> None: if self._check_integrity(): - print('Files already downloaded and verified') + print("Files already downloaded and verified") return root = self.root diff --git a/torchvision/datasets/stl10.py b/torchvision/datasets/stl10.py index 50e9af882bc..20ebbc3b0ee 100644 --- a/torchvision/datasets/stl10.py +++ b/torchvision/datasets/stl10.py @@ -1,11 +1,12 @@ -from PIL import Image import os import os.path -import numpy as np from typing import Any, Callable, Optional, Tuple -from .vision import VisionDataset +import numpy as np +from PIL import Image + from .utils import check_integrity, download_and_extract_archive, verify_str_arg +from .vision import VisionDataset class STL10(VisionDataset): @@ -27,70 +28,60 @@ class STL10(VisionDataset): puts it in root directory. If dataset is already downloaded, it is not downloaded again. """ - base_folder = 'stl10_binary' + + base_folder = "stl10_binary" url = "http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz" filename = "stl10_binary.tar.gz" - tgz_md5 = '91f7769df0f17e558f3565bffb0c7dfb' - class_names_file = 'class_names.txt' - folds_list_file = 'fold_indices.txt' + tgz_md5 = "91f7769df0f17e558f3565bffb0c7dfb" + class_names_file = "class_names.txt" + folds_list_file = "fold_indices.txt" train_list = [ - ['train_X.bin', '918c2871b30a85fa023e0c44e0bee87f'], - ['train_y.bin', '5a34089d4802c674881badbb80307741'], - ['unlabeled_X.bin', '5242ba1fed5e4be9e1e742405eb56ca4'] + ["train_X.bin", "918c2871b30a85fa023e0c44e0bee87f"], + ["train_y.bin", "5a34089d4802c674881badbb80307741"], + ["unlabeled_X.bin", "5242ba1fed5e4be9e1e742405eb56ca4"], ] - test_list = [ - ['test_X.bin', '7f263ba9f9e0b06b93213547f721ac82'], - ['test_y.bin', '36f9794fa4beb8a2c72628de14fa638e'] - ] - splits = ('train', 'train+unlabeled', 'unlabeled', 'test') + test_list = [["test_X.bin", "7f263ba9f9e0b06b93213547f721ac82"], ["test_y.bin", "36f9794fa4beb8a2c72628de14fa638e"]] + splits = ("train", "train+unlabeled", "unlabeled", "test") def __init__( - self, - root: str, - split: str = "train", - folds: Optional[int] = None, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - download: bool = False, + self, + root: str, + split: str = "train", + folds: Optional[int] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, ) -> None: - super(STL10, self).__init__(root, transform=transform, - target_transform=target_transform) + super(STL10, self).__init__(root, transform=transform, target_transform=target_transform) self.split = verify_str_arg(split, "split", self.splits) self.folds = self._verify_folds(folds) if download: self.download() elif not self._check_integrity(): - raise RuntimeError( - 'Dataset not found or corrupted. ' - 'You can use download=True to download it') + raise RuntimeError("Dataset not found or corrupted. " "You can use download=True to download it") # now load the picked numpy arrays self.labels: Optional[np.ndarray] - if self.split == 'train': - self.data, self.labels = self.__loadfile( - self.train_list[0][0], self.train_list[1][0]) + if self.split == "train": + self.data, self.labels = self.__loadfile(self.train_list[0][0], self.train_list[1][0]) self.__load_folds(folds) - elif self.split == 'train+unlabeled': - self.data, self.labels = self.__loadfile( - self.train_list[0][0], self.train_list[1][0]) + elif self.split == "train+unlabeled": + self.data, self.labels = self.__loadfile(self.train_list[0][0], self.train_list[1][0]) self.__load_folds(folds) unlabeled_data, _ = self.__loadfile(self.train_list[2][0]) self.data = np.concatenate((self.data, unlabeled_data)) - self.labels = np.concatenate( - (self.labels, np.asarray([-1] * unlabeled_data.shape[0]))) + self.labels = np.concatenate((self.labels, np.asarray([-1] * unlabeled_data.shape[0]))) - elif self.split == 'unlabeled': + elif self.split == "unlabeled": self.data, _ = self.__loadfile(self.train_list[2][0]) self.labels = np.asarray([-1] * self.data.shape[0]) else: # self.split == 'test': - self.data, self.labels = self.__loadfile( - self.test_list[0][0], self.test_list[1][0]) + self.data, self.labels = self.__loadfile(self.test_list[0][0], self.test_list[1][0]) - class_file = os.path.join( - self.root, self.base_folder, self.class_names_file) + class_file = os.path.join(self.root, self.base_folder, self.class_names_file) if os.path.isfile(class_file): with open(class_file) as f: self.classes = f.read().splitlines() @@ -101,8 +92,7 @@ def _verify_folds(self, folds: Optional[int]) -> Optional[int]: elif isinstance(folds, int): if folds in range(10): return folds - msg = ("Value for argument folds should be in the range [0, 10), " - "but got {}.") + msg = "Value for argument folds should be in the range [0, 10), " "but got {}." raise ValueError(msg.format(folds)) else: msg = "Expected type None or int for argument folds, but got type {}." @@ -140,13 +130,12 @@ def __len__(self) -> int: def __loadfile(self, data_file: str, labels_file: Optional[str] = None) -> Tuple[np.ndarray, Optional[np.ndarray]]: labels = None if labels_file: - path_to_labels = os.path.join( - self.root, self.base_folder, labels_file) - with open(path_to_labels, 'rb') as f: + path_to_labels = os.path.join(self.root, self.base_folder, labels_file) + with open(path_to_labels, "rb") as f: labels = np.fromfile(f, dtype=np.uint8) - 1 # 0-based path_to_data = os.path.join(self.root, self.base_folder, data_file) - with open(path_to_data, 'rb') as f: + with open(path_to_data, "rb") as f: # read whole file in uint8 chunks everything = np.fromfile(f, dtype=np.uint8) images = np.reshape(everything, (-1, 3, 96, 96)) @@ -156,7 +145,7 @@ def __loadfile(self, data_file: str, labels_file: Optional[str] = None) -> Tuple def _check_integrity(self) -> bool: root = self.root - for fentry in (self.train_list + self.test_list): + for fentry in self.train_list + self.test_list: filename, md5 = fentry[0], fentry[1] fpath = os.path.join(root, self.base_folder, filename) if not check_integrity(fpath, md5): @@ -165,7 +154,7 @@ def _check_integrity(self) -> bool: def download(self) -> None: if self._check_integrity(): - print('Files already downloaded and verified') + print("Files already downloaded and verified") return download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5) self._check_integrity() @@ -177,11 +166,10 @@ def __load_folds(self, folds: Optional[int]) -> None: # loads one of the folds if specified if folds is None: return - path_to_folds = os.path.join( - self.root, self.base_folder, self.folds_list_file) - with open(path_to_folds, 'r') as f: + path_to_folds = os.path.join(self.root, self.base_folder, self.folds_list_file) + with open(path_to_folds, "r") as f: str_idx = f.read().splitlines()[folds] - list_idx = np.fromstring(str_idx, dtype=np.int64, sep=' ') + list_idx = np.fromstring(str_idx, dtype=np.int64, sep=" ") self.data = self.data[list_idx, :, :, :] if self.labels is not None: self.labels = self.labels[list_idx] diff --git a/torchvision/datasets/svhn.py b/torchvision/datasets/svhn.py index f1adee687eb..f5c6087b778 100644 --- a/torchvision/datasets/svhn.py +++ b/torchvision/datasets/svhn.py @@ -1,10 +1,12 @@ -from .vision import VisionDataset -from PIL import Image import os import os.path -import numpy as np from typing import Any, Callable, Optional, Tuple + +import numpy as np +from PIL import Image + from .utils import download_url, check_integrity, verify_str_arg +from .vision import VisionDataset class SVHN(VisionDataset): @@ -33,23 +35,32 @@ class SVHN(VisionDataset): """ split_list = { - 'train': ["http://ufldl.stanford.edu/housenumbers/train_32x32.mat", - "train_32x32.mat", "e26dedcc434d2e4c54c9b2d4a06d8373"], - 'test': ["http://ufldl.stanford.edu/housenumbers/test_32x32.mat", - "test_32x32.mat", "eb5a983be6a315427106f1b164d9cef3"], - 'extra': ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat", - "extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"]} + "train": [ + "http://ufldl.stanford.edu/housenumbers/train_32x32.mat", + "train_32x32.mat", + "e26dedcc434d2e4c54c9b2d4a06d8373", + ], + "test": [ + "http://ufldl.stanford.edu/housenumbers/test_32x32.mat", + "test_32x32.mat", + "eb5a983be6a315427106f1b164d9cef3", + ], + "extra": [ + "http://ufldl.stanford.edu/housenumbers/extra_32x32.mat", + "extra_32x32.mat", + "a93ce644f1a588dc4d68dda5feec44a7", + ], + } def __init__( - self, - root: str, - split: str = "train", - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - download: bool = False, + self, + root: str, + split: str = "train", + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, ) -> None: - super(SVHN, self).__init__(root, transform=transform, - target_transform=target_transform) + super(SVHN, self).__init__(root, transform=transform, target_transform=target_transform) self.split = verify_str_arg(split, "split", tuple(self.split_list.keys())) self.url = self.split_list[split][0] self.filename = self.split_list[split][1] @@ -59,8 +70,7 @@ def __init__( self.download() if not self._check_integrity(): - raise RuntimeError('Dataset not found or corrupted.' + - ' You can use download=True to download it') + raise RuntimeError("Dataset not found or corrupted." + " You can use download=True to download it") # import here rather than at top of file because this is # an optional dependency for torchvision @@ -69,12 +79,12 @@ def __init__( # reading(loading) mat file as array loaded_mat = sio.loadmat(os.path.join(self.root, self.filename)) - self.data = loaded_mat['X'] + self.data = loaded_mat["X"] # loading from the .mat file gives an np array of type np.uint8 # converting to np.int64, so that we have a LongTensor after # the conversion from the numpy array # the squeeze is needed to obtain a 1D tensor - self.labels = loaded_mat['y'].astype(np.int64).squeeze() + self.labels = loaded_mat["y"].astype(np.int64).squeeze() # the svhn dataset assigns the class label "10" to the digit 0 # this makes it inconsistent with several loss functions diff --git a/torchvision/datasets/ucf101.py b/torchvision/datasets/ucf101.py index ca8963efd75..dbe9b22e603 100644 --- a/torchvision/datasets/ucf101.py +++ b/torchvision/datasets/ucf101.py @@ -1,5 +1,6 @@ import os from typing import Any, Dict, List, Tuple, Optional, Callable + from torch import Tensor from .folder import find_classes, make_dataset @@ -62,13 +63,13 @@ def __init__( _video_width: int = 0, _video_height: int = 0, _video_min_dimension: int = 0, - _audio_samples: int = 0 + _audio_samples: int = 0, ) -> None: super(UCF101, self).__init__(root) if not 1 <= fold <= 3: raise ValueError("fold should be between 1 and 3, got {}".format(fold)) - extensions = ('avi',) + extensions = ("avi",) self.fold = fold self.train = train diff --git a/torchvision/datasets/usps.py b/torchvision/datasets/usps.py index c315b8d3111..c90ebfa7e6f 100644 --- a/torchvision/datasets/usps.py +++ b/torchvision/datasets/usps.py @@ -1,8 +1,9 @@ -from PIL import Image import os -import numpy as np from typing import Any, Callable, cast, Optional, Tuple +import numpy as np +from PIL import Image + from .utils import download_url from .vision import VisionDataset @@ -26,28 +27,30 @@ class USPS(VisionDataset): downloaded again. """ + split_list = { - 'train': [ + "train": [ "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.bz2", - "usps.bz2", 'ec16c51db3855ca6c91edd34d0e9b197' + "usps.bz2", + "ec16c51db3855ca6c91edd34d0e9b197", ], - 'test': [ + "test": [ "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.t.bz2", - "usps.t.bz2", '8ea070ee2aca1ac39742fdd1ef5ed118' + "usps.t.bz2", + "8ea070ee2aca1ac39742fdd1ef5ed118", ], } def __init__( - self, - root: str, - train: bool = True, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - download: bool = False, + self, + root: str, + train: bool = True, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, ) -> None: - super(USPS, self).__init__(root, transform=transform, - target_transform=target_transform) - split = 'train' if train else 'test' + super(USPS, self).__init__(root, transform=transform, target_transform=target_transform) + split = "train" if train else "test" url, filename, checksum = self.split_list[split] full_path = os.path.join(self.root, filename) @@ -55,9 +58,10 @@ def __init__( download_url(url, self.root, filename, md5=checksum) import bz2 + with bz2.open(full_path) as fp: raw_data = [line.decode().split() for line in fp.readlines()] - tmp_list = [[x.split(':')[-1] for x in data[1:]] for data in raw_data] + tmp_list = [[x.split(":")[-1] for x in data[1:]] for data in raw_data] imgs = np.asarray(tmp_list, dtype=np.float32).reshape((-1, 16, 16)) imgs = ((cast(np.ndarray, imgs) + 1) / 2 * 255).astype(dtype=np.uint8) targets = [int(d[0]) - 1 for d in raw_data] @@ -77,7 +81,7 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]: # doing this so that it is consistent with all other datasets # to return a PIL Image - img = Image.fromarray(img, mode='L') + img = Image.fromarray(img, mode="L") if self.transform is not None: img = self.transform(img) diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index 9ae726edd8f..feb8d28bcce 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -1,19 +1,19 @@ import bz2 +import gzip +import hashlib +import itertools +import lzma import os import os.path -import hashlib -import gzip +import pathlib import re import tarfile -from typing import Any, Callable, List, Iterable, Optional, TypeVar, Dict, IO, Tuple, Iterator -from urllib.parse import urlparse -import zipfile -import lzma import urllib -import urllib.request import urllib.error -import pathlib -import itertools +import urllib.request +import zipfile +from typing import Any, Callable, List, Iterable, Optional, TypeVar, Dict, IO, Tuple, Iterator +from urllib.parse import urlparse import torch from torch.utils.model_zoo import tqdm @@ -52,8 +52,8 @@ def bar_update(count, block_size, total_size): def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str: md5 = hashlib.md5() - with open(fpath, 'rb') as f: - for chunk in iter(lambda: f.read(chunk_size), b''): + with open(fpath, "rb") as f: + for chunk in iter(lambda: f.read(chunk_size), b""): md5.update(chunk) return md5.hexdigest() @@ -120,7 +120,7 @@ def download_url( # check if file is already present locally if check_integrity(fpath, md5): - print('Using downloaded and verified file: ' + fpath) + print("Using downloaded and verified file: " + fpath) return if _is_remote_location_available(): @@ -136,13 +136,12 @@ def download_url( # download the file try: - print('Downloading ' + url + ' to ' + fpath) + print("Downloading " + url + " to " + fpath) _urlretrieve(url, fpath) except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined] - if url[:5] == 'https': - url = url.replace('https:', 'http:') - print('Failed download. Trying https -> http instead.' - ' Downloading ' + url + ' to ' + fpath) + if url[:5] == "https": + url = url.replace("https:", "http:") + print("Failed download. Trying https -> http instead." " Downloading " + url + " to " + fpath) _urlretrieve(url, fpath) else: raise e @@ -202,6 +201,7 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[ """ # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url import requests + url = "https://docs.google.com/uc?export=download" root = os.path.expanduser(root) @@ -212,15 +212,15 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[ os.makedirs(root, exist_ok=True) if os.path.isfile(fpath) and check_integrity(fpath, md5): - print('Using downloaded and verified file: ' + fpath) + print("Using downloaded and verified file: " + fpath) else: session = requests.Session() - response = session.get(url, params={'id': file_id}, stream=True) + response = session.get(url, params={"id": file_id}, stream=True) token = _get_confirm_token(response) if token: - params = {'id': file_id, 'confirm': token} + params = {"id": file_id, "confirm": token} response = session.get(url, params=params, stream=True) # Ideally, one would use response.status_code to check for quota limits, but google drive is not consistent @@ -240,20 +240,21 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[ ) raise RuntimeError(msg) - _save_response_content(itertools.chain((first_chunk, ), response_content_generator), fpath) + _save_response_content(itertools.chain((first_chunk,), response_content_generator), fpath) response.close() def _get_confirm_token(response: "requests.models.Response") -> Optional[str]: # type: ignore[name-defined] for key, value in response.cookies.items(): - if key.startswith('download_warning'): + if key.startswith("download_warning"): return value return None def _save_response_content( - response_gen: Iterator[bytes], destination: str, # type: ignore[name-defined] + response_gen: Iterator[bytes], + destination: str, # type: ignore[name-defined] ) -> None: with open(destination, "wb") as f: pbar = tqdm(total=None) @@ -439,7 +440,10 @@ def iterable_to_str(iterable: Iterable) -> str: def verify_str_arg( - value: T, arg: Optional[str] = None, valid_values: Iterable[T] = None, custom_msg: Optional[str] = None, + value: T, + arg: Optional[str] = None, + valid_values: Iterable[T] = None, + custom_msg: Optional[str] = None, ) -> T: if not isinstance(value, torch._six.string_classes): if arg is None: @@ -456,10 +460,8 @@ def verify_str_arg( if custom_msg is not None: msg = custom_msg else: - msg = ("Unknown value '{value}' for argument {arg}. " - "Valid values are {{{valid_values}}}.") - msg = msg.format(value=value, arg=arg, - valid_values=iterable_to_str(valid_values)) + msg = "Unknown value '{value}' for argument {arg}. " "Valid values are {{{valid_values}}}." + msg = msg.format(value=value, arg=arg, valid_values=iterable_to_str(valid_values)) raise ValueError(msg) return value diff --git a/torchvision/datasets/video_utils.py b/torchvision/datasets/video_utils.py index 987270c4cd4..8d427c9d80c 100644 --- a/torchvision/datasets/video_utils.py +++ b/torchvision/datasets/video_utils.py @@ -206,14 +206,14 @@ def compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate): if frame_rate is None: frame_rate = fps total_frames = len(video_pts) * (float(frame_rate) / fps) - idxs = VideoClips._resample_video_idx( - int(math.floor(total_frames)), fps, frame_rate - ) + idxs = VideoClips._resample_video_idx(int(math.floor(total_frames)), fps, frame_rate) video_pts = video_pts[idxs] clips = unfold(video_pts, num_frames, step) if not clips.numel(): - warnings.warn("There aren't enough frames in the current video to get a clip for the given clip length and " - "frames between clips. The video (and potentially others) will be skipped.") + warnings.warn( + "There aren't enough frames in the current video to get a clip for the given clip length and " + "frames between clips. The video (and potentially others) will be skipped." + ) if isinstance(idxs, slice): idxs = [idxs] * len(clips) else: @@ -237,9 +237,7 @@ def compute_clips(self, num_frames, step, frame_rate=None): self.clips = [] self.resampling_idxs = [] for video_pts, fps in zip(self.video_pts, self.video_fps): - clips, idxs = self.compute_clips_for_video( - video_pts, num_frames, step, fps, frame_rate - ) + clips, idxs = self.compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate) self.clips.append(clips) self.resampling_idxs.append(idxs) clip_lengths = torch.as_tensor([len(v) for v in self.clips]) @@ -295,10 +293,7 @@ def get_clip(self, idx): video_idx (int): index of the video in `video_paths` """ if idx >= self.num_clips(): - raise IndexError( - "Index {} out of range " - "({} number of clips)".format(idx, self.num_clips()) - ) + raise IndexError("Index {} out of range " "({} number of clips)".format(idx, self.num_clips())) video_idx, clip_idx = self.get_clip_location(idx) video_path = self.video_paths[video_idx] clip_pts = self.clips[video_idx][clip_idx] @@ -314,13 +309,9 @@ def get_clip(self, idx): if self._video_height != 0: raise ValueError("pyav backend doesn't support _video_height != 0") if self._video_min_dimension != 0: - raise ValueError( - "pyav backend doesn't support _video_min_dimension != 0" - ) + raise ValueError("pyav backend doesn't support _video_min_dimension != 0") if self._video_max_dimension != 0: - raise ValueError( - "pyav backend doesn't support _video_max_dimension != 0" - ) + raise ValueError("pyav backend doesn't support _video_max_dimension != 0") if self._audio_samples != 0: raise ValueError("pyav backend doesn't support _audio_samples != 0") @@ -338,19 +329,11 @@ def get_clip(self, idx): audio_start_pts, audio_end_pts = 0, -1 audio_timebase = Fraction(0, 1) - video_timebase = Fraction( - info.video_timebase.numerator, info.video_timebase.denominator - ) + video_timebase = Fraction(info.video_timebase.numerator, info.video_timebase.denominator) if info.has_audio: - audio_timebase = Fraction( - info.audio_timebase.numerator, info.audio_timebase.denominator - ) - audio_start_pts = pts_convert( - video_start_pts, video_timebase, audio_timebase, math.floor - ) - audio_end_pts = pts_convert( - video_end_pts, video_timebase, audio_timebase, math.ceil - ) + audio_timebase = Fraction(info.audio_timebase.numerator, info.audio_timebase.denominator) + audio_start_pts = pts_convert(video_start_pts, video_timebase, audio_timebase, math.floor) + audio_end_pts = pts_convert(video_end_pts, video_timebase, audio_timebase, math.ceil) audio_fps = info.audio_sample_rate video, audio, info = _read_video_from_file( video_path, @@ -376,9 +359,7 @@ def get_clip(self, idx): resampling_idx = resampling_idx - resampling_idx[0] video = video[resampling_idx] info["video_fps"] = self.frame_rate - assert len(video) == self.num_frames, "{} x {}".format( - video.shape, self.num_frames - ) + assert len(video) == self.num_frames, "{} x {}".format(video.shape, self.num_frames) return video, audio, info, video_idx def __getstate__(self): diff --git a/torchvision/datasets/vision.py b/torchvision/datasets/vision.py index db44a8b1ba0..591c0bb0c0c 100644 --- a/torchvision/datasets/vision.py +++ b/torchvision/datasets/vision.py @@ -1,7 +1,8 @@ import os +from typing import Any, Callable, List, Optional, Tuple + import torch import torch.utils.data as data -from typing import Any, Callable, List, Optional, Tuple class VisionDataset(data.Dataset): @@ -22,14 +23,15 @@ class VisionDataset(data.Dataset): :attr:`transforms` and the combination of :attr:`transform` and :attr:`target_transform` are mutually exclusive. """ + _repr_indent = 4 def __init__( - self, - root: str, - transforms: Optional[Callable] = None, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, + self, + root: str, + transforms: Optional[Callable] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, ) -> None: torch._C._log_api_usage_once(f"torchvision.datasets.{self.__class__.__name__}") if isinstance(root, torch._six.string_classes): @@ -39,8 +41,7 @@ def __init__( has_transforms = transforms is not None has_separate_transform = transform is not None or target_transform is not None if has_transforms and has_separate_transform: - raise ValueError("Only transforms or transform/target_transform can " - "be passed as argument") + raise ValueError("Only transforms or transform/target_transform can " "be passed as argument") # for backwards-compatibility self.transform = transform @@ -72,12 +73,11 @@ def __repr__(self) -> str: if hasattr(self, "transforms") and self.transforms is not None: body += [repr(self.transforms)] lines = [head] + [" " * self._repr_indent + line for line in body] - return '\n'.join(lines) + return "\n".join(lines) def _format_transform_repr(self, transform: Callable, head: str) -> List[str]: lines = transform.__repr__().splitlines() - return (["{}{}".format(head, lines[0])] + - ["{}{}".format(" " * len(head), line) for line in lines[1:]]) + return ["{}{}".format(head, lines[0])] + ["{}{}".format(" " * len(head), line) for line in lines[1:]] def extra_repr(self) -> str: return "" @@ -97,16 +97,13 @@ def __call__(self, input: Any, target: Any) -> Tuple[Any, Any]: def _format_transform_repr(self, transform: Callable, head: str) -> List[str]: lines = transform.__repr__().splitlines() - return (["{}{}".format(head, lines[0])] + - ["{}{}".format(" " * len(head), line) for line in lines[1:]]) + return ["{}{}".format(head, lines[0])] + ["{}{}".format(" " * len(head), line) for line in lines[1:]] def __repr__(self) -> str: body = [self.__class__.__name__] if self.transform is not None: - body += self._format_transform_repr(self.transform, - "Transform: ") + body += self._format_transform_repr(self.transform, "Transform: ") if self.target_transform is not None: - body += self._format_transform_repr(self.target_transform, - "Target transform: ") + body += self._format_transform_repr(self.target_transform, "Target transform: ") - return '\n'.join(body) + return "\n".join(body) diff --git a/torchvision/datasets/voc.py b/torchvision/datasets/voc.py index 56bd92c7972..a089d43125d 100644 --- a/torchvision/datasets/voc.py +++ b/torchvision/datasets/voc.py @@ -1,59 +1,63 @@ -import os import collections -from .vision import VisionDataset +import os from xml.etree.ElementTree import Element as ET_Element + +from .vision import VisionDataset + try: from defusedxml.ElementTree import parse as ET_parse except ImportError: from xml.etree.ElementTree import parse as ET_parse -from PIL import Image +import warnings from typing import Any, Callable, Dict, Optional, Tuple, List + +from PIL import Image + from .utils import download_and_extract_archive, verify_str_arg -import warnings DATASET_YEAR_DICT = { - '2012': { - 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar', - 'filename': 'VOCtrainval_11-May-2012.tar', - 'md5': '6cd6e144f989b92b3379bac3b3de84fd', - 'base_dir': os.path.join('VOCdevkit', 'VOC2012') + "2012": { + "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar", + "filename": "VOCtrainval_11-May-2012.tar", + "md5": "6cd6e144f989b92b3379bac3b3de84fd", + "base_dir": os.path.join("VOCdevkit", "VOC2012"), + }, + "2011": { + "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar", + "filename": "VOCtrainval_25-May-2011.tar", + "md5": "6c3384ef61512963050cb5d687e5bf1e", + "base_dir": os.path.join("TrainVal", "VOCdevkit", "VOC2011"), }, - '2011': { - 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar', - 'filename': 'VOCtrainval_25-May-2011.tar', - 'md5': '6c3384ef61512963050cb5d687e5bf1e', - 'base_dir': os.path.join('TrainVal', 'VOCdevkit', 'VOC2011') + "2010": { + "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar", + "filename": "VOCtrainval_03-May-2010.tar", + "md5": "da459979d0c395079b5c75ee67908abb", + "base_dir": os.path.join("VOCdevkit", "VOC2010"), }, - '2010': { - 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar', - 'filename': 'VOCtrainval_03-May-2010.tar', - 'md5': 'da459979d0c395079b5c75ee67908abb', - 'base_dir': os.path.join('VOCdevkit', 'VOC2010') + "2009": { + "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar", + "filename": "VOCtrainval_11-May-2009.tar", + "md5": "59065e4b188729180974ef6572f6a212", + "base_dir": os.path.join("VOCdevkit", "VOC2009"), }, - '2009': { - 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar', - 'filename': 'VOCtrainval_11-May-2009.tar', - 'md5': '59065e4b188729180974ef6572f6a212', - 'base_dir': os.path.join('VOCdevkit', 'VOC2009') + "2008": { + "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar", + "filename": "VOCtrainval_11-May-2012.tar", + "md5": "2629fa636546599198acfcfbfcf1904a", + "base_dir": os.path.join("VOCdevkit", "VOC2008"), }, - '2008': { - 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar', - 'filename': 'VOCtrainval_11-May-2012.tar', - 'md5': '2629fa636546599198acfcfbfcf1904a', - 'base_dir': os.path.join('VOCdevkit', 'VOC2008') + "2007": { + "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar", + "filename": "VOCtrainval_06-Nov-2007.tar", + "md5": "c52e279531787c972589f7e41ab4ae64", + "base_dir": os.path.join("VOCdevkit", "VOC2007"), }, - '2007': { - 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar', - 'filename': 'VOCtrainval_06-Nov-2007.tar', - 'md5': 'c52e279531787c972589f7e41ab4ae64', - 'base_dir': os.path.join('VOCdevkit', 'VOC2007') + "2007-test": { + "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar", + "filename": "VOCtest_06-Nov-2007.tar", + "md5": "b6e924de25625d8de591ea690078ad9f", + "base_dir": os.path.join("VOCdevkit", "VOC2007"), }, - '2007-test': { - 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar', - 'filename': 'VOCtest_06-Nov-2007.tar', - 'md5': 'b6e924de25625d8de591ea690078ad9f', - 'base_dir': os.path.join('VOCdevkit', 'VOC2007') - } } diff --git a/torchvision/datasets/widerface.py b/torchvision/datasets/widerface.py index c1775309b29..dd054556b46 100644 --- a/torchvision/datasets/widerface.py +++ b/torchvision/datasets/widerface.py @@ -1,10 +1,17 @@ -from PIL import Image import os from os.path import abspath, expanduser -import torch from typing import Any, Callable, List, Dict, Optional, Tuple, Union -from .utils import check_integrity, download_file_from_google_drive, \ - download_and_extract_archive, extract_archive, verify_str_arg + +import torch +from PIL import Image + +from .utils import ( + check_integrity, + download_file_from_google_drive, + download_and_extract_archive, + extract_archive, + verify_str_arg, +) from .vision import VisionDataset @@ -40,25 +47,25 @@ class WIDERFace(VisionDataset): # File ID MD5 Hash Filename ("0B6eKvaijfFUDQUUwd21EckhUbWs", "3fedf70df600953d25982bcd13d91ba2", "WIDER_train.zip"), ("0B6eKvaijfFUDd3dIRmpvSk8tLUk", "dfa7d7e790efa35df3788964cf0bbaea", "WIDER_val.zip"), - ("0B6eKvaijfFUDbW4tdGpaYjgzZkU", "e5d8f4248ed24c334bbd12f49c29dd40", "WIDER_test.zip") + ("0B6eKvaijfFUDbW4tdGpaYjgzZkU", "e5d8f4248ed24c334bbd12f49c29dd40", "WIDER_test.zip"), ] ANNOTATIONS_FILE = ( "http://mmlab.ie.cuhk.edu.hk/projects/WIDERFace/support/bbx_annotation/wider_face_split.zip", "0e3767bcf0e326556d407bf5bff5d27c", - "wider_face_split.zip" + "wider_face_split.zip", ) def __init__( - self, - root: str, - split: str = "train", - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - download: bool = False, + self, + root: str, + split: str = "train", + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, ) -> None: - super(WIDERFace, self).__init__(root=os.path.join(root, self.BASE_FOLDER), - transform=transform, - target_transform=target_transform) + super(WIDERFace, self).__init__( + root=os.path.join(root, self.BASE_FOLDER), transform=transform, target_transform=target_transform + ) # check arguments self.split = verify_str_arg(split, "split", ("train", "val", "test")) @@ -66,8 +73,9 @@ def __init__( self.download() if not self._check_integrity(): - raise RuntimeError("Dataset not found or corrupted. " + - "You can use download=True to download and prepare it") + raise RuntimeError( + "Dataset not found or corrupted. " + "You can use download=True to download and prepare it" + ) self.img_info: List[Dict[str, Union[str, Dict[str, torch.Tensor]]]] = [] if self.split in ("train", "val"): @@ -102,7 +110,7 @@ def __len__(self) -> int: def extra_repr(self) -> str: lines = ["Split: {split}"] - return '\n'.join(lines).format(**self.__dict__) + return "\n".join(lines).format(**self.__dict__) def parse_train_val_annotations_file(self) -> None: filename = "wider_face_train_bbx_gt.txt" if self.split == "train" else "wider_face_val_bbx_gt.txt" @@ -133,16 +141,20 @@ def parse_train_val_annotations_file(self) -> None: box_annotation_line = False file_name_line = True labels_tensor = torch.tensor(labels) - self.img_info.append({ - "img_path": img_path, - "annotations": {"bbox": labels_tensor[:, 0:4], # x, y, width, height - "blur": labels_tensor[:, 4], - "expression": labels_tensor[:, 5], - "illumination": labels_tensor[:, 6], - "occlusion": labels_tensor[:, 7], - "pose": labels_tensor[:, 8], - "invalid": labels_tensor[:, 9]} - }) + self.img_info.append( + { + "img_path": img_path, + "annotations": { + "bbox": labels_tensor[:, 0:4], # x, y, width, height + "blur": labels_tensor[:, 4], + "expression": labels_tensor[:, 5], + "illumination": labels_tensor[:, 6], + "occlusion": labels_tensor[:, 7], + "pose": labels_tensor[:, 8], + "invalid": labels_tensor[:, 9], + }, + } + ) box_counter = 0 labels.clear() else: @@ -172,7 +184,7 @@ def _check_integrity(self) -> bool: def download(self) -> None: if self._check_integrity(): - print('Files already downloaded and verified') + print("Files already downloaded and verified") return # download and extract image data @@ -182,6 +194,6 @@ def download(self) -> None: extract_archive(filepath) # download and extract annotation files - download_and_extract_archive(url=self.ANNOTATIONS_FILE[0], - download_root=self.root, - md5=self.ANNOTATIONS_FILE[1]) + download_and_extract_archive( + url=self.ANNOTATIONS_FILE[0], download_root=self.root, md5=self.ANNOTATIONS_FILE[1] + ) diff --git a/torchvision/extension.py b/torchvision/extension.py index bea6db33636..b59bed94dff 100644 --- a/torchvision/extension.py +++ b/torchvision/extension.py @@ -11,12 +11,14 @@ def _has_ops(): try: - lib_path = _get_extension_path('_C') + lib_path = _get_extension_path("_C") torch.ops.load_library(lib_path) _HAS_OPS = True def _has_ops(): # noqa: F811 return True + + except (ImportError, OSError): pass @@ -41,6 +43,7 @@ def _check_cuda_version(): if not _HAS_OPS: return -1 import torch + _version = torch.ops.torchvision._cuda_version() if _version != -1 and torch.version.cuda is not None: tv_version = str(_version) @@ -51,14 +54,17 @@ def _check_cuda_version(): tv_major = int(tv_version[0:2]) tv_minor = int(tv_version[3]) t_version = torch.version.cuda - t_version = t_version.split('.') + t_version = t_version.split(".") t_major = int(t_version[0]) t_minor = int(t_version[1]) if t_major != tv_major or t_minor != tv_minor: - raise RuntimeError("Detected that PyTorch and torchvision were compiled with different CUDA versions. " - "PyTorch has CUDA Version={}.{} and torchvision has CUDA Version={}.{}. " - "Please reinstall the torchvision that matches your PyTorch install." - .format(t_major, t_minor, tv_major, tv_minor)) + raise RuntimeError( + "Detected that PyTorch and torchvision were compiled with different CUDA versions. " + "PyTorch has CUDA Version={}.{} and torchvision has CUDA Version={}.{}. " + "Please reinstall the torchvision that matches your PyTorch install.".format( + t_major, t_minor, tv_major, tv_minor + ) + ) return _version diff --git a/torchvision/io/__init__.py b/torchvision/io/__init__.py index d0ec1b406f3..c443bee4bb2 100644 --- a/torchvision/io/__init__.py +++ b/torchvision/io/__init__.py @@ -1,6 +1,7 @@ -import torch from typing import Any, Dict, Iterator +import torch + from ._video_opt import ( Timebase, VideoMetaData, @@ -12,11 +13,6 @@ _read_video_timestamps_from_file, _read_video_timestamps_from_memory, ) -from .video import ( - read_video, - read_video_timestamps, - write_video, -) from .image import ( ImageReadMode, decode_image, @@ -30,6 +26,11 @@ write_jpeg, write_png, ) +from .video import ( + read_video, + read_video_timestamps, + write_video, +) if _HAS_VIDEO_OPT: @@ -127,10 +128,10 @@ def __next__(self) -> Dict[str, Any]: raise StopIteration return {"data": frame, "pts": pts} - def __iter__(self) -> Iterator['VideoReader']: + def __iter__(self) -> Iterator["VideoReader"]: return self - def seek(self, time_s: float) -> 'VideoReader': + def seek(self, time_s: float) -> "VideoReader": """Seek within current stream. Args: diff --git a/torchvision/io/_video_opt.py b/torchvision/io/_video_opt.py index a4a811dec4b..b94e25ea698 100644 --- a/torchvision/io/_video_opt.py +++ b/torchvision/io/_video_opt.py @@ -1,4 +1,3 @@ - import math import os import warnings @@ -12,7 +11,7 @@ try: - lib_path = _get_extension_path('video_reader') + lib_path = _get_extension_path("video_reader") torch.ops.load_library(lib_path) _HAS_VIDEO_OPT = True except (ImportError, OSError): @@ -90,9 +89,7 @@ def _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration): """ meta = VideoMetaData() if vtimebase.numel() > 0: - meta.video_timebase = Timebase( - int(vtimebase[0].item()), int(vtimebase[1].item()) - ) + meta.video_timebase = Timebase(int(vtimebase[0].item()), int(vtimebase[1].item())) timebase = vtimebase[0].item() / float(vtimebase[1].item()) if vduration.numel() > 0: meta.has_video = True @@ -100,9 +97,7 @@ def _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration): if vfps.numel() > 0: meta.video_fps = float(vfps.item()) if atimebase.numel() > 0: - meta.audio_timebase = Timebase( - int(atimebase[0].item()), int(atimebase[1].item()) - ) + meta.audio_timebase = Timebase(int(atimebase[0].item()), int(atimebase[1].item())) timebase = atimebase[0].item() / float(atimebase[1].item()) if aduration.numel() > 0: meta.has_audio = True @@ -216,10 +211,7 @@ def _read_video_from_file( audio_timebase.numerator, audio_timebase.denominator, ) - vframes, _vframe_pts, vtimebase, vfps, vduration, \ - aframes, aframe_pts, atimebase, asample_rate, aduration = ( - result - ) + vframes, _vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, atimebase, asample_rate, aduration = result info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration) if aframes.numel() > 0: # when audio stream is found @@ -254,8 +246,7 @@ def _read_video_timestamps_from_file(filename): 0, # audio_timebase_num 1, # audio_timebase_den ) - _vframes, vframe_pts, vtimebase, vfps, vduration, \ - _aframes, aframe_pts, atimebase, asample_rate, aduration = result + _vframes, vframe_pts, vtimebase, vfps, vduration, _aframes, aframe_pts, atimebase, asample_rate, aduration = result info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration) vframe_pts = vframe_pts.numpy().tolist() @@ -372,10 +363,7 @@ def _read_video_from_memory( audio_timebase_denominator, ) - vframes, _vframe_pts, vtimebase, vfps, vduration, \ - aframes, aframe_pts, atimebase, asample_rate, aduration = ( - result - ) + vframes, _vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, atimebase, asample_rate, aduration = result if aframes.numel() > 0: # when audio stream is found @@ -413,10 +401,7 @@ def _read_video_timestamps_from_memory(video_data): 0, # audio_timebase_num 1, # audio_timebase_den ) - _vframes, vframe_pts, vtimebase, vfps, vduration, \ - _aframes, aframe_pts, atimebase, asample_rate, aduration = ( - result - ) + _vframes, vframe_pts, vtimebase, vfps, vduration, _aframes, aframe_pts, atimebase, asample_rate, aduration = result info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration) vframe_pts = vframe_pts.numpy().tolist() @@ -439,10 +424,10 @@ def _probe_video_from_memory(video_data): def _convert_to_sec(start_pts, end_pts, pts_unit, time_base): - if pts_unit == 'pts': + if pts_unit == "pts": start_pts = float(start_pts * time_base) end_pts = float(end_pts * time_base) - pts_unit = 'sec' + pts_unit = "sec" return start_pts, end_pts, pts_unit @@ -467,20 +452,15 @@ def _read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"): time_base = default_timebase if has_video: - video_timebase = Fraction( - info.video_timebase.numerator, info.video_timebase.denominator - ) + video_timebase = Fraction(info.video_timebase.numerator, info.video_timebase.denominator) time_base = video_timebase if has_audio: - audio_timebase = Fraction( - info.audio_timebase.numerator, info.audio_timebase.denominator - ) + audio_timebase = Fraction(info.audio_timebase.numerator, info.audio_timebase.denominator) time_base = time_base if time_base else audio_timebase # video_timebase is the default time_base - start_pts_sec, end_pts_sec, pts_unit = _convert_to_sec( - start_pts, end_pts, pts_unit, time_base) + start_pts_sec, end_pts_sec, pts_unit = _convert_to_sec(start_pts, end_pts, pts_unit, time_base) def get_pts(time_base): start_offset = start_pts_sec @@ -527,9 +507,7 @@ def _read_video_timestamps(filename, pts_unit="pts"): pts, _, info = _read_video_timestamps_from_file(filename) if pts_unit == "sec": - video_time_base = Fraction( - info.video_timebase.numerator, info.video_timebase.denominator - ) + video_time_base = Fraction(info.video_timebase.numerator, info.video_timebase.denominator) pts = [x * video_time_base for x in pts] video_fps = info.video_fps if info.has_video else None diff --git a/torchvision/io/image.py b/torchvision/io/image.py index 343c0b3a33d..2ba1e9eddd9 100644 --- a/torchvision/io/image.py +++ b/torchvision/io/image.py @@ -1,11 +1,12 @@ -import torch from enum import Enum +import torch + from .._internally_replaced_utils import _get_extension_path try: - lib_path = _get_extension_path('image') + lib_path = _get_extension_path("image") torch.ops.load_library(lib_path) except (ImportError, OSError): pass @@ -21,6 +22,7 @@ class ImageReadMode(Enum): ``ImageReadMode.RGB`` for RGB and ``ImageReadMode.RGB_ALPHA`` for RGB with transparency. """ + UNCHANGED = 0 GRAY = 1 GRAY_ALPHA = 2 @@ -111,8 +113,9 @@ def write_png(input: torch.Tensor, filename: str, compression_level: int = 6): write_file(filename, output) -def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED, - device: str = 'cpu') -> torch.Tensor: +def decode_jpeg( + input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED, device: str = "cpu" +) -> torch.Tensor: """ Decodes a JPEG image into a 3 dimensional RGB Tensor. Optionally converts the image to the desired format. @@ -135,7 +138,7 @@ def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANG output (Tensor[image_channels, image_height, image_width]) """ device = torch.device(device) - if device.type == 'cuda': + if device.type == "cuda": output = torch.ops.image.decode_jpeg_cuda(input, mode.value, device) else: output = torch.ops.image.decode_jpeg(input, mode.value) @@ -158,8 +161,7 @@ def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor: JPEG file. """ if quality < 1 or quality > 100: - raise ValueError('Image quality should be a positive number ' - 'between 1 and 100') + raise ValueError("Image quality should be a positive number " "between 1 and 100") output = torch.ops.image.encode_jpeg(input, quality) return output diff --git a/torchvision/io/video.py b/torchvision/io/video.py index e39b0dae301..e5648459113 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -94,16 +94,16 @@ def write_video( if audio_array is not None: audio_format_dtypes = { - 'dbl': ' 0 and start_offset > 0 and start_offset not in frames: # if there is no frame that exactly matches the pts of start_offset # add the last frame smaller than start_offset, to guarantee that @@ -264,7 +260,7 @@ def read_video( from torchvision import get_video_backend if not os.path.exists(filename): - raise RuntimeError(f'File not found: {filename}') + raise RuntimeError(f"File not found: {filename}") if get_video_backend() != "pyav": return _video_opt._read_video(filename, start_pts, end_pts, pts_unit) @@ -276,8 +272,7 @@ def read_video( if end_pts < start_pts: raise ValueError( - "end_pts should be larger than start_pts, got " - "start_pts={} and end_pts={}".format(start_pts, end_pts) + "end_pts should be larger than start_pts, got " "start_pts={} and end_pts={}".format(start_pts, end_pts) ) info = {} @@ -295,8 +290,7 @@ def read_video( elif container.streams.audio: time_base = container.streams.audio[0].time_base # video_timebase is the default time_base - start_pts, end_pts, pts_unit = _video_opt._convert_to_sec( - start_pts, end_pts, pts_unit, time_base) + start_pts, end_pts, pts_unit = _video_opt._convert_to_sec(start_pts, end_pts, pts_unit, time_base) if container.streams.video: video_frames = _read_from_stream( container, @@ -337,7 +331,7 @@ def read_video( if aframes_list: aframes = np.concatenate(aframes_list, 1) aframes = torch.as_tensor(aframes) - if pts_unit == 'sec': + if pts_unit == "sec": start_pts = int(math.floor(start_pts * (1 / audio_timebase))) if end_pts != float("inf"): end_pts = int(math.ceil(end_pts * (1 / audio_timebase))) diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py index 07ccf8de7f5..516e47feb19 100644 --- a/torchvision/models/__init__.py +++ b/torchvision/models/__init__.py @@ -10,8 +10,8 @@ from .shufflenetv2 import * from .efficientnet import * from .regnet import * -from . import segmentation from . import detection -from . import video -from . import quantization from . import feature_extraction +from . import quantization +from . import segmentation +from . import video diff --git a/torchvision/models/_utils.py b/torchvision/models/_utils.py index ec2e287d974..2a7a1bbaa0f 100644 --- a/torchvision/models/_utils.py +++ b/torchvision/models/_utils.py @@ -1,7 +1,7 @@ from collections import OrderedDict +from typing import Dict, Optional from torch import nn -from typing import Dict, Optional class IntermediateLayerGetter(nn.ModuleDict): @@ -35,6 +35,7 @@ class IntermediateLayerGetter(nn.ModuleDict): >>> [('feat1', torch.Size([1, 64, 56, 56])), >>> ('feat2', torch.Size([1, 256, 14, 14]))] """ + _version = 2 __annotations__ = { "return_layers": Dict[str, str], diff --git a/torchvision/models/alexnet.py b/torchvision/models/alexnet.py index 156a453c3cc..ca255843e79 100644 --- a/torchvision/models/alexnet.py +++ b/torchvision/models/alexnet.py @@ -1,19 +1,20 @@ +from typing import Any + import torch import torch.nn as nn + from .._internally_replaced_utils import load_state_dict_from_url -from typing import Any -__all__ = ['AlexNet', 'alexnet'] +__all__ = ["AlexNet", "alexnet"] model_urls = { - 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-7be5be79.pth', + "alexnet": "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth", } class AlexNet(nn.Module): - def __init__(self, num_classes: int = 1000) -> None: super(AlexNet, self).__init__() self.features = nn.Sequential( @@ -61,7 +62,6 @@ def alexnet(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> A """ model = AlexNet(**kwargs) if pretrained: - state_dict = load_state_dict_from_url(model_urls['alexnet'], - progress=progress) + state_dict = load_state_dict_from_url(model_urls["alexnet"], progress=progress) model.load_state_dict(state_dict) return model diff --git a/torchvision/models/densenet.py b/torchvision/models/densenet.py index aef7977773b..17acb556d6b 100644 --- a/torchvision/models/densenet.py +++ b/torchvision/models/densenet.py @@ -1,50 +1,47 @@ import re +from collections import OrderedDict +from typing import Any, List, Tuple + import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as cp -from collections import OrderedDict -from .._internally_replaced_utils import load_state_dict_from_url from torch import Tensor -from typing import Any, List, Tuple + +from .._internally_replaced_utils import load_state_dict_from_url -__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] +__all__ = ["DenseNet", "densenet121", "densenet169", "densenet201", "densenet161"] model_urls = { - 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', - 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', - 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', - 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', + "densenet121": "https://download.pytorch.org/models/densenet121-a639ec97.pth", + "densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth", + "densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth", + "densenet161": "https://download.pytorch.org/models/densenet161-8d451a50.pth", } class _DenseLayer(nn.Module): def __init__( - self, - num_input_features: int, - growth_rate: int, - bn_size: int, - drop_rate: float, - memory_efficient: bool = False + self, num_input_features: int, growth_rate: int, bn_size: int, drop_rate: float, memory_efficient: bool = False ) -> None: super(_DenseLayer, self).__init__() self.norm1: nn.BatchNorm2d - self.add_module('norm1', nn.BatchNorm2d(num_input_features)) + self.add_module("norm1", nn.BatchNorm2d(num_input_features)) self.relu1: nn.ReLU - self.add_module('relu1', nn.ReLU(inplace=True)) + self.add_module("relu1", nn.ReLU(inplace=True)) self.conv1: nn.Conv2d - self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * - growth_rate, kernel_size=1, stride=1, - bias=False)) + self.add_module( + "conv1", nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False) + ) self.norm2: nn.BatchNorm2d - self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)) + self.add_module("norm2", nn.BatchNorm2d(bn_size * growth_rate)) self.relu2: nn.ReLU - self.add_module('relu2', nn.ReLU(inplace=True)) + self.add_module("relu2", nn.ReLU(inplace=True)) self.conv2: nn.Conv2d - self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, - kernel_size=3, stride=1, padding=1, - bias=False)) + self.add_module( + "conv2", nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False) + ) self.drop_rate = float(drop_rate) self.memory_efficient = memory_efficient @@ -93,8 +90,7 @@ def forward(self, input: Tensor) -> Tensor: # noqa: F811 new_features = self.conv2(self.relu2(self.norm2(bottleneck_output))) if self.drop_rate > 0: - new_features = F.dropout(new_features, p=self.drop_rate, - training=self.training) + new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) return new_features @@ -108,7 +104,7 @@ def __init__( bn_size: int, growth_rate: int, drop_rate: float, - memory_efficient: bool = False + memory_efficient: bool = False, ) -> None: super(_DenseBlock, self).__init__() for i in range(num_layers): @@ -119,7 +115,7 @@ def __init__( drop_rate=drop_rate, memory_efficient=memory_efficient, ) - self.add_module('denselayer%d' % (i + 1), layer) + self.add_module("denselayer%d" % (i + 1), layer) def forward(self, init_features: Tensor) -> Tensor: features = [init_features] @@ -132,11 +128,10 @@ def forward(self, init_features: Tensor) -> Tensor: class _Transition(nn.Sequential): def __init__(self, num_input_features: int, num_output_features: int) -> None: super(_Transition, self).__init__() - self.add_module('norm', nn.BatchNorm2d(num_input_features)) - self.add_module('relu', nn.ReLU(inplace=True)) - self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, - kernel_size=1, stride=1, bias=False)) - self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) + self.add_module("norm", nn.BatchNorm2d(num_input_features)) + self.add_module("relu", nn.ReLU(inplace=True)) + self.add_module("conv", nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False)) + self.add_module("pool", nn.AvgPool2d(kernel_size=2, stride=2)) class DenseNet(nn.Module): @@ -163,19 +158,22 @@ def __init__( bn_size: int = 4, drop_rate: float = 0, num_classes: int = 1000, - memory_efficient: bool = False + memory_efficient: bool = False, ) -> None: super(DenseNet, self).__init__() # First convolution - self.features = nn.Sequential(OrderedDict([ - ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, - padding=3, bias=False)), - ('norm0', nn.BatchNorm2d(num_init_features)), - ('relu0', nn.ReLU(inplace=True)), - ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), - ])) + self.features = nn.Sequential( + OrderedDict( + [ + ("conv0", nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), + ("norm0", nn.BatchNorm2d(num_init_features)), + ("relu0", nn.ReLU(inplace=True)), + ("pool0", nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), + ] + ) + ) # Each denseblock num_features = num_init_features @@ -186,18 +184,17 @@ def __init__( bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate, - memory_efficient=memory_efficient + memory_efficient=memory_efficient, ) - self.features.add_module('denseblock%d' % (i + 1), block) + self.features.add_module("denseblock%d" % (i + 1), block) num_features = num_features + num_layers * growth_rate if i != len(block_config) - 1: - trans = _Transition(num_input_features=num_features, - num_output_features=num_features // 2) - self.features.add_module('transition%d' % (i + 1), trans) + trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2) + self.features.add_module("transition%d" % (i + 1), trans) num_features = num_features // 2 # Final batch norm - self.features.add_module('norm5', nn.BatchNorm2d(num_features)) + self.features.add_module("norm5", nn.BatchNorm2d(num_features)) # Linear layer self.classifier = nn.Linear(num_features, num_classes) @@ -227,7 +224,8 @@ def _load_state_dict(model: nn.Module, model_url: str, progress: bool) -> None: # They are also in the checkpoints in model_urls. This pattern is used # to find such keys. pattern = re.compile( - r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') + r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$" + ) state_dict = load_state_dict_from_url(model_url, progress=progress) for key in list(state_dict.keys()): @@ -246,7 +244,7 @@ def _densenet( num_init_features: int, pretrained: bool, progress: bool, - **kwargs: Any + **kwargs: Any, ) -> DenseNet: model = DenseNet(growth_rate, block_config, num_init_features, **kwargs) if pretrained: @@ -265,8 +263,7 @@ def densenet121(pretrained: bool = False, progress: bool = True, **kwargs: Any) memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, but slower. Default: *False*. See `"paper" `_. """ - return _densenet('densenet121', 32, (6, 12, 24, 16), 64, pretrained, progress, - **kwargs) + return _densenet("densenet121", 32, (6, 12, 24, 16), 64, pretrained, progress, **kwargs) def densenet161(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet: @@ -280,8 +277,7 @@ def densenet161(pretrained: bool = False, progress: bool = True, **kwargs: Any) memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, but slower. Default: *False*. See `"paper" `_. """ - return _densenet('densenet161', 48, (6, 12, 36, 24), 96, pretrained, progress, - **kwargs) + return _densenet("densenet161", 48, (6, 12, 36, 24), 96, pretrained, progress, **kwargs) def densenet169(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet: @@ -295,8 +291,7 @@ def densenet169(pretrained: bool = False, progress: bool = True, **kwargs: Any) memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, but slower. Default: *False*. See `"paper" `_. """ - return _densenet('densenet169', 32, (6, 12, 32, 32), 64, pretrained, progress, - **kwargs) + return _densenet("densenet169", 32, (6, 12, 32, 32), 64, pretrained, progress, **kwargs) def densenet201(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet: @@ -310,5 +305,4 @@ def densenet201(pretrained: bool = False, progress: bool = True, **kwargs: Any) memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, but slower. Default: *False*. See `"paper" `_. """ - return _densenet('densenet201', 32, (6, 12, 48, 32), 64, pretrained, progress, - **kwargs) + return _densenet("densenet201", 32, (6, 12, 48, 32), 64, pretrained, progress, **kwargs) diff --git a/torchvision/models/detection/_utils.py b/torchvision/models/detection/_utils.py index 1d3bcdba7fe..130b445874a 100644 --- a/torchvision/models/detection/_utils.py +++ b/torchvision/models/detection/_utils.py @@ -1,10 +1,9 @@ import math -import torch - from collections import OrderedDict -from torch import Tensor from typing import List, Tuple +import torch +from torch import Tensor from torchvision.ops.misc import FrozenBatchNorm2d @@ -61,12 +60,8 @@ def __call__(self, matched_idxs): neg_idx_per_image = negative[perm2] # create binary mask from indices - pos_idx_per_image_mask = torch.zeros_like( - matched_idxs_per_image, dtype=torch.uint8 - ) - neg_idx_per_image_mask = torch.zeros_like( - matched_idxs_per_image, dtype=torch.uint8 - ) + pos_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8) + neg_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8) pos_idx_per_image_mask[pos_idx_per_image] = 1 neg_idx_per_image_mask[neg_idx_per_image] = 1 @@ -132,7 +127,7 @@ class BoxCoder(object): the representation used for training the regressors. """ - def __init__(self, weights, bbox_xform_clip=math.log(1000. / 16)): + def __init__(self, weights, bbox_xform_clip=math.log(1000.0 / 16)): # type: (Tuple[float, float, float, float], float) -> None """ Args: @@ -177,9 +172,7 @@ def decode(self, rel_codes, boxes): box_sum += val if box_sum > 0: rel_codes = rel_codes.reshape(box_sum, -1) - pred_boxes = self.decode_single( - rel_codes, concat_boxes - ) + pred_boxes = self.decode_single(rel_codes, concat_boxes) if box_sum > 0: pred_boxes = pred_boxes.reshape(box_sum, -1, 4) return pred_boxes @@ -247,8 +240,8 @@ class Matcher(object): BETWEEN_THRESHOLDS = -2 __annotations__ = { - 'BELOW_LOW_THRESHOLD': int, - 'BETWEEN_THRESHOLDS': int, + "BELOW_LOW_THRESHOLD": int, + "BETWEEN_THRESHOLDS": int, } def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=False): @@ -287,13 +280,9 @@ def __call__(self, match_quality_matrix): if match_quality_matrix.numel() == 0: # empty targets or proposals not supported during training if match_quality_matrix.shape[0] == 0: - raise ValueError( - "No ground-truth boxes available for one of the images " - "during training") + raise ValueError("No ground-truth boxes available for one of the images " "during training") else: - raise ValueError( - "No proposal boxes available for one of the images " - "during training") + raise ValueError("No proposal boxes available for one of the images " "during training") # match_quality_matrix is M (gt) x N (predicted) # Max over gt elements (dim 0) to find best gt candidate for each prediction @@ -305,9 +294,7 @@ def __call__(self, match_quality_matrix): # Assign candidate matches with low quality to negative (unassigned) values below_low_threshold = matched_vals < self.low_threshold - between_thresholds = (matched_vals >= self.low_threshold) & ( - matched_vals < self.high_threshold - ) + between_thresholds = (matched_vals >= self.low_threshold) & (matched_vals < self.high_threshold) matches[below_low_threshold] = self.BELOW_LOW_THRESHOLD matches[between_thresholds] = self.BETWEEN_THRESHOLDS @@ -328,9 +315,7 @@ def set_low_quality_matches_(self, matches, all_matches, match_quality_matrix): # For each gt, find the prediction with which it has highest quality highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1) # Find highest quality match available, even if it is low, including ties - gt_pred_pairs_of_highest_quality = torch.where( - match_quality_matrix == highest_quality_foreach_gt[:, None] - ) + gt_pred_pairs_of_highest_quality = torch.where(match_quality_matrix == highest_quality_foreach_gt[:, None]) # Example gt_pred_pairs_of_highest_quality: # tensor([[ 0, 39796], # [ 1, 32055], @@ -350,7 +335,6 @@ def set_low_quality_matches_(self, matches, all_matches, match_quality_matrix): class SSDMatcher(Matcher): - def __init__(self, threshold): super().__init__(threshold, threshold, allow_low_quality_matches=False) @@ -359,9 +343,9 @@ def __call__(self, match_quality_matrix): # For each gt, find the prediction with which it has the highest quality _, highest_quality_pred_foreach_gt = match_quality_matrix.max(dim=1) - matches[highest_quality_pred_foreach_gt] = torch.arange(highest_quality_pred_foreach_gt.size(0), - dtype=torch.int64, - device=highest_quality_pred_foreach_gt.device) + matches[highest_quality_pred_foreach_gt] = torch.arange( + highest_quality_pred_foreach_gt.size(0), dtype=torch.int64, device=highest_quality_pred_foreach_gt.device + ) return matches @@ -405,7 +389,7 @@ def retrieve_out_channels(model, size): tmp_img = torch.zeros((1, 3, size[1], size[0]), device=device) features = model(tmp_img) if isinstance(features, torch.Tensor): - features = OrderedDict([('0', features)]) + features = OrderedDict([("0", features)]) out_channels = [x.size(1) for x in features.values()] if in_training: diff --git a/torchvision/models/detection/anchor_utils.py b/torchvision/models/detection/anchor_utils.py index 0057da45e24..2e433958715 100644 --- a/torchvision/models/detection/anchor_utils.py +++ b/torchvision/models/detection/anchor_utils.py @@ -1,8 +1,9 @@ import math +from typing import List, Optional + import torch from torch import nn, Tensor -from typing import List, Optional from .image_list import ImageList @@ -48,15 +49,21 @@ def __init__( self.sizes = sizes self.aspect_ratios = aspect_ratios - self.cell_anchors = [self.generate_anchors(size, aspect_ratio) - for size, aspect_ratio in zip(sizes, aspect_ratios)] + self.cell_anchors = [ + self.generate_anchors(size, aspect_ratio) for size, aspect_ratio in zip(sizes, aspect_ratios) + ] # TODO: https://github.com/pytorch/pytorch/issues/26792 # For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values. # (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios) # This method assumes aspect ratio = height / width for an anchor. - def generate_anchors(self, scales: List[int], aspect_ratios: List[float], dtype: torch.dtype = torch.float32, - device: torch.device = torch.device("cpu")): + def generate_anchors( + self, + scales: List[int], + aspect_ratios: List[float], + dtype: torch.dtype = torch.float32, + device: torch.device = torch.device("cpu"), + ): scales = torch.as_tensor(scales, dtype=dtype, device=device) aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device) h_ratios = torch.sqrt(aspect_ratios) @@ -69,8 +76,7 @@ def generate_anchors(self, scales: List[int], aspect_ratios: List[float], dtype: return base_anchors.round() def set_cell_anchors(self, dtype: torch.dtype, device: torch.device): - self.cell_anchors = [cell_anchor.to(dtype=dtype, device=device) - for cell_anchor in self.cell_anchors] + self.cell_anchors = [cell_anchor.to(dtype=dtype, device=device) for cell_anchor in self.cell_anchors] def num_anchors_per_location(self): return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)] @@ -83,25 +89,21 @@ def grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]]) assert cell_anchors is not None if not (len(grid_sizes) == len(strides) == len(cell_anchors)): - raise ValueError("Anchors should be Tuple[Tuple[int]] because each feature " - "map could potentially have different sizes and aspect ratios. " - "There needs to be a match between the number of " - "feature maps passed and the number of sizes / aspect ratios specified.") - - for size, stride, base_anchors in zip( - grid_sizes, strides, cell_anchors - ): + raise ValueError( + "Anchors should be Tuple[Tuple[int]] because each feature " + "map could potentially have different sizes and aspect ratios. " + "There needs to be a match between the number of " + "feature maps passed and the number of sizes / aspect ratios specified." + ) + + for size, stride, base_anchors in zip(grid_sizes, strides, cell_anchors): grid_height, grid_width = size stride_height, stride_width = stride device = base_anchors.device # For output anchor, compute [x_center, y_center, x_center, y_center] - shifts_x = torch.arange( - 0, grid_width, dtype=torch.int32, device=device - ) * stride_width - shifts_y = torch.arange( - 0, grid_height, dtype=torch.int32, device=device - ) * stride_height + shifts_x = torch.arange(0, grid_width, dtype=torch.int32, device=device) * stride_width + shifts_y = torch.arange(0, grid_height, dtype=torch.int32, device=device) * stride_height shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) shift_x = shift_x.reshape(-1) shift_y = shift_y.reshape(-1) @@ -109,9 +111,7 @@ def grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]]) # For every (base anchor, output anchor) pair, # offset each zero-centered base anchor by the center of the output anchor. - anchors.append( - (shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4) - ) + anchors.append((shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4)) return anchors @@ -119,8 +119,13 @@ def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Ten grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps] image_size = image_list.tensors.shape[-2:] dtype, device = feature_maps[0].dtype, feature_maps[0].device - strides = [[torch.tensor(image_size[0] // g[0], dtype=torch.int64, device=device), - torch.tensor(image_size[1] // g[1], dtype=torch.int64, device=device)] for g in grid_sizes] + strides = [ + [ + torch.tensor(image_size[0] // g[0], dtype=torch.int64, device=device), + torch.tensor(image_size[1] // g[1], dtype=torch.int64, device=device), + ] + for g in grid_sizes + ] self.set_cell_anchors(dtype, device) anchors_over_all_feature_maps = self.grid_anchors(grid_sizes, strides) anchors: List[List[torch.Tensor]] = [] @@ -149,8 +154,15 @@ class DefaultBoxGenerator(nn.Module): is applied while the boxes are encoded in format ``(cx, cy, w, h)``. """ - def __init__(self, aspect_ratios: List[List[int]], min_ratio: float = 0.15, max_ratio: float = 0.9, - scales: Optional[List[float]] = None, steps: Optional[List[int]] = None, clip: bool = True): + def __init__( + self, + aspect_ratios: List[List[int]], + min_ratio: float = 0.15, + max_ratio: float = 0.9, + scales: Optional[List[float]] = None, + steps: Optional[List[int]] = None, + clip: bool = True, + ): super().__init__() if steps is not None: assert len(aspect_ratios) == len(steps) @@ -172,8 +184,9 @@ def __init__(self, aspect_ratios: List[List[int]], min_ratio: float = 0.15, max_ self._wh_pairs = self._generate_wh_pairs(num_outputs) - def _generate_wh_pairs(self, num_outputs: int, dtype: torch.dtype = torch.float32, - device: torch.device = torch.device("cpu")) -> List[Tensor]: + def _generate_wh_pairs( + self, num_outputs: int, dtype: torch.dtype = torch.float32, device: torch.device = torch.device("cpu") + ) -> List[Tensor]: _wh_pairs: List[Tensor] = [] for k in range(num_outputs): # Adding the 2 default width-height pairs for aspect ratio 1 and scale s'k @@ -196,8 +209,9 @@ def num_anchors_per_location(self): return [2 + 2 * len(r) for r in self.aspect_ratios] # Default Boxes calculation based on page 6 of SSD paper - def _grid_default_boxes(self, grid_sizes: List[List[int]], image_size: List[int], - dtype: torch.dtype = torch.float32) -> Tensor: + def _grid_default_boxes( + self, grid_sizes: List[List[int]], image_size: List[int], dtype: torch.dtype = torch.float32 + ) -> Tensor: default_boxes = [] for k, f_k in enumerate(grid_sizes): # Now add the default boxes for each width-height pair @@ -224,12 +238,12 @@ def _grid_default_boxes(self, grid_sizes: List[List[int]], image_size: List[int] return torch.cat(default_boxes, dim=0) def __repr__(self) -> str: - s = self.__class__.__name__ + '(' - s += 'aspect_ratios={aspect_ratios}' - s += ', clip={clip}' - s += ', scales={scales}' - s += ', steps={steps}' - s += ')' + s = self.__class__.__name__ + "(" + s += "aspect_ratios={aspect_ratios}" + s += ", clip={clip}" + s += ", scales={scales}" + s += ", steps={steps}" + s += ")" return s.format(**self.__dict__) def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Tensor]: @@ -242,8 +256,13 @@ def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Ten dboxes = [] for _ in image_list.image_sizes: dboxes_in_image = default_boxes - dboxes_in_image = torch.cat([dboxes_in_image[:, :2] - 0.5 * dboxes_in_image[:, 2:], - dboxes_in_image[:, :2] + 0.5 * dboxes_in_image[:, 2:]], -1) + dboxes_in_image = torch.cat( + [ + dboxes_in_image[:, :2] - 0.5 * dboxes_in_image[:, 2:], + dboxes_in_image[:, :2] + 0.5 * dboxes_in_image[:, 2:], + ], + -1, + ) dboxes_in_image[:, 0::2] *= image_size[1] dboxes_in_image[:, 1::2] *= image_size[0] dboxes.append(dboxes_in_image) diff --git a/torchvision/models/detection/backbone_utils.py b/torchvision/models/detection/backbone_utils.py index b4699d0ee12..70a7b40bd50 100644 --- a/torchvision/models/detection/backbone_utils.py +++ b/torchvision/models/detection/backbone_utils.py @@ -1,11 +1,12 @@ import warnings + from torch import nn +from torchvision.ops import misc as misc_nn_ops from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool -from torchvision.ops import misc as misc_nn_ops -from .._utils import IntermediateLayerGetter from .. import mobilenet from .. import resnet +from .._utils import IntermediateLayerGetter class BackboneWithFPN(nn.Module): @@ -26,6 +27,7 @@ class BackboneWithFPN(nn.Module): Attributes: out_channels (int): the number of channels in the FPN """ + def __init__(self, backbone, return_layers, in_channels_list, out_channels, extra_blocks=None): super(BackboneWithFPN, self).__init__() @@ -52,7 +54,7 @@ def resnet_fpn_backbone( norm_layer=misc_nn_ops.FrozenBatchNorm2d, trainable_layers=3, returned_layers=None, - extra_blocks=None + extra_blocks=None, ): """ Constructs a specified ResNet backbone with FPN on top. Freezes the specified number of layers in the backbone. @@ -89,15 +91,13 @@ def resnet_fpn_backbone( a new list of feature maps and their corresponding names. By default a ``LastLevelMaxPool`` is used. """ - backbone = resnet.__dict__[backbone_name]( - pretrained=pretrained, - norm_layer=norm_layer) + backbone = resnet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer) # select layers that wont be frozen assert 0 <= trainable_layers <= 5 - layers_to_train = ['layer4', 'layer3', 'layer2', 'layer1', 'conv1'][:trainable_layers] + layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers] if trainable_layers == 5: - layers_to_train.append('bn1') + layers_to_train.append("bn1") for name, parameter in backbone.named_parameters(): if all([not name.startswith(layer) for layer in layers_to_train]): parameter.requires_grad_(False) @@ -108,7 +108,7 @@ def resnet_fpn_backbone( if returned_layers is None: returned_layers = [1, 2, 3, 4] assert min(returned_layers) > 0 and max(returned_layers) < 5 - return_layers = {f'layer{k}': str(v) for v, k in enumerate(returned_layers)} + return_layers = {f"layer{k}": str(v) for v, k in enumerate(returned_layers)} in_channels_stage2 = backbone.inplanes // 8 in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers] @@ -123,7 +123,8 @@ def _validate_trainable_layers(pretrained, trainable_backbone_layers, max_value, warnings.warn( "Changing trainable_backbone_layers has not effect if " "neither pretrained nor pretrained_backbone have been set to True, " - "falling back to trainable_backbone_layers={} so that all layers are trainable".format(max_value)) + "falling back to trainable_backbone_layers={} so that all layers are trainable".format(max_value) + ) trainable_backbone_layers = max_value # by default freeze first blocks @@ -140,7 +141,7 @@ def mobilenet_backbone( norm_layer=misc_nn_ops.FrozenBatchNorm2d, trainable_layers=2, returned_layers=None, - extra_blocks=None + extra_blocks=None, ): backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer).features @@ -165,7 +166,7 @@ def mobilenet_backbone( if returned_layers is None: returned_layers = [num_stages - 2, num_stages - 1] assert min(returned_layers) >= 0 and max(returned_layers) < num_stages - return_layers = {f'{stage_indices[k]}': str(v) for v, k in enumerate(returned_layers)} + return_layers = {f"{stage_indices[k]}": str(v) for v, k in enumerate(returned_layers)} in_channels_list = [backbone[stage_indices[i]].out_channels for i in returned_layers] return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks) diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index c3ec8db0a19..02da39e8c73 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -1,22 +1,22 @@ -from torch import nn import torch.nn.functional as F - +from torch import nn from torchvision.ops import MultiScaleRoIAlign -from ._utils import overwrite_eps from ..._internally_replaced_utils import load_state_dict_from_url - +from ._utils import overwrite_eps from .anchor_utils import AnchorGenerator +from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers, mobilenet_backbone from .generalized_rcnn import GeneralizedRCNN -from .rpn import RPNHead, RegionProposalNetwork from .roi_heads import RoIHeads +from .rpn import RPNHead, RegionProposalNetwork from .transform import GeneralizedRCNNTransform -from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers, mobilenet_backbone __all__ = [ - "FasterRCNN", "fasterrcnn_resnet50_fpn", "fasterrcnn_mobilenet_v3_large_320_fpn", - "fasterrcnn_mobilenet_v3_large_fpn" + "FasterRCNN", + "fasterrcnn_resnet50_fpn", + "fasterrcnn_mobilenet_v3_large_320_fpn", + "fasterrcnn_mobilenet_v3_large_fpn", ] @@ -141,30 +141,48 @@ class FasterRCNN(GeneralizedRCNN): >>> predictions = model(x) """ - def __init__(self, backbone, num_classes=None, - # transform parameters - min_size=800, max_size=1333, - image_mean=None, image_std=None, - # RPN parameters - rpn_anchor_generator=None, rpn_head=None, - rpn_pre_nms_top_n_train=2000, rpn_pre_nms_top_n_test=1000, - rpn_post_nms_top_n_train=2000, rpn_post_nms_top_n_test=1000, - rpn_nms_thresh=0.7, - rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3, - rpn_batch_size_per_image=256, rpn_positive_fraction=0.5, - rpn_score_thresh=0.0, - # Box parameters - box_roi_pool=None, box_head=None, box_predictor=None, - box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100, - box_fg_iou_thresh=0.5, box_bg_iou_thresh=0.5, - box_batch_size_per_image=512, box_positive_fraction=0.25, - bbox_reg_weights=None): + def __init__( + self, + backbone, + num_classes=None, + # transform parameters + min_size=800, + max_size=1333, + image_mean=None, + image_std=None, + # RPN parameters + rpn_anchor_generator=None, + rpn_head=None, + rpn_pre_nms_top_n_train=2000, + rpn_pre_nms_top_n_test=1000, + rpn_post_nms_top_n_train=2000, + rpn_post_nms_top_n_test=1000, + rpn_nms_thresh=0.7, + rpn_fg_iou_thresh=0.7, + rpn_bg_iou_thresh=0.3, + rpn_batch_size_per_image=256, + rpn_positive_fraction=0.5, + rpn_score_thresh=0.0, + # Box parameters + box_roi_pool=None, + box_head=None, + box_predictor=None, + box_score_thresh=0.05, + box_nms_thresh=0.5, + box_detections_per_img=100, + box_fg_iou_thresh=0.5, + box_bg_iou_thresh=0.5, + box_batch_size_per_image=512, + box_positive_fraction=0.25, + bbox_reg_weights=None, + ): if not hasattr(backbone, "out_channels"): raise ValueError( "backbone should contain an attribute out_channels " "specifying the number of output channels (assumed to be the " - "same for all the levels)") + "same for all the levels)" + ) assert isinstance(rpn_anchor_generator, (AnchorGenerator, type(None))) assert isinstance(box_roi_pool, (MultiScaleRoIAlign, type(None))) @@ -174,58 +192,59 @@ def __init__(self, backbone, num_classes=None, raise ValueError("num_classes should be None when box_predictor is specified") else: if box_predictor is None: - raise ValueError("num_classes should not be None when box_predictor " - "is not specified") + raise ValueError("num_classes should not be None when box_predictor " "is not specified") out_channels = backbone.out_channels if rpn_anchor_generator is None: anchor_sizes = ((32,), (64,), (128,), (256,), (512,)) aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) - rpn_anchor_generator = AnchorGenerator( - anchor_sizes, aspect_ratios - ) + rpn_anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios) if rpn_head is None: - rpn_head = RPNHead( - out_channels, rpn_anchor_generator.num_anchors_per_location()[0] - ) + rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0]) rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test) rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test) rpn = RegionProposalNetwork( - rpn_anchor_generator, rpn_head, - rpn_fg_iou_thresh, rpn_bg_iou_thresh, - rpn_batch_size_per_image, rpn_positive_fraction, - rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh, - score_thresh=rpn_score_thresh) + rpn_anchor_generator, + rpn_head, + rpn_fg_iou_thresh, + rpn_bg_iou_thresh, + rpn_batch_size_per_image, + rpn_positive_fraction, + rpn_pre_nms_top_n, + rpn_post_nms_top_n, + rpn_nms_thresh, + score_thresh=rpn_score_thresh, + ) if box_roi_pool is None: - box_roi_pool = MultiScaleRoIAlign( - featmap_names=['0', '1', '2', '3'], - output_size=7, - sampling_ratio=2) + box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2) if box_head is None: resolution = box_roi_pool.output_size[0] representation_size = 1024 - box_head = TwoMLPHead( - out_channels * resolution ** 2, - representation_size) + box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size) if box_predictor is None: representation_size = 1024 - box_predictor = FastRCNNPredictor( - representation_size, - num_classes) + box_predictor = FastRCNNPredictor(representation_size, num_classes) roi_heads = RoIHeads( # Box - box_roi_pool, box_head, box_predictor, - box_fg_iou_thresh, box_bg_iou_thresh, - box_batch_size_per_image, box_positive_fraction, + box_roi_pool, + box_head, + box_predictor, + box_fg_iou_thresh, + box_bg_iou_thresh, + box_batch_size_per_image, + box_positive_fraction, bbox_reg_weights, - box_score_thresh, box_nms_thresh, box_detections_per_img) + box_score_thresh, + box_nms_thresh, + box_detections_per_img, + ) if image_mean is None: image_mean = [0.485, 0.456, 0.406] @@ -286,17 +305,15 @@ def forward(self, x): model_urls = { - 'fasterrcnn_resnet50_fpn_coco': - 'https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth', - 'fasterrcnn_mobilenet_v3_large_320_fpn_coco': - 'https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth', - 'fasterrcnn_mobilenet_v3_large_fpn_coco': - 'https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth' + "fasterrcnn_resnet50_fpn_coco": "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth", + "fasterrcnn_mobilenet_v3_large_320_fpn_coco": "https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth", + "fasterrcnn_mobilenet_v3_large_fpn_coco": "https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth", } -def fasterrcnn_resnet50_fpn(pretrained=False, progress=True, - num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs): +def fasterrcnn_resnet50_fpn( + pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs +): """ Constructs a Faster R-CNN model with a ResNet-50-FPN backbone. @@ -362,36 +379,54 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True, Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. """ trainable_backbone_layers = _validate_trainable_layers( - pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3) + pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3 + ) if pretrained: # no need to download the backbone if pretrained is set pretrained_backbone = False - backbone = resnet_fpn_backbone('resnet50', pretrained_backbone, trainable_layers=trainable_backbone_layers) + backbone = resnet_fpn_backbone("resnet50", pretrained_backbone, trainable_layers=trainable_backbone_layers) model = FasterRCNN(backbone, num_classes, **kwargs) if pretrained: - state_dict = load_state_dict_from_url(model_urls['fasterrcnn_resnet50_fpn_coco'], - progress=progress) + state_dict = load_state_dict_from_url(model_urls["fasterrcnn_resnet50_fpn_coco"], progress=progress) model.load_state_dict(state_dict) overwrite_eps(model, 0.0) return model -def _fasterrcnn_mobilenet_v3_large_fpn(weights_name, pretrained=False, progress=True, num_classes=91, - pretrained_backbone=True, trainable_backbone_layers=None, **kwargs): +def _fasterrcnn_mobilenet_v3_large_fpn( + weights_name, + pretrained=False, + progress=True, + num_classes=91, + pretrained_backbone=True, + trainable_backbone_layers=None, + **kwargs, +): trainable_backbone_layers = _validate_trainable_layers( - pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3) + pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3 + ) if pretrained: pretrained_backbone = False - backbone = mobilenet_backbone("mobilenet_v3_large", pretrained_backbone, True, - trainable_layers=trainable_backbone_layers) - - anchor_sizes = ((32, 64, 128, 256, 512, ), ) * 3 + backbone = mobilenet_backbone( + "mobilenet_v3_large", pretrained_backbone, True, trainable_layers=trainable_backbone_layers + ) + + anchor_sizes = ( + ( + 32, + 64, + 128, + 256, + 512, + ), + ) * 3 aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) - model = FasterRCNN(backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), - **kwargs) + model = FasterRCNN( + backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs + ) if pretrained: if model_urls.get(weights_name, None) is None: raise ValueError("No checkpoint is available for model {}".format(weights_name)) @@ -400,8 +435,9 @@ def _fasterrcnn_mobilenet_v3_large_fpn(weights_name, pretrained=False, progress= return model -def fasterrcnn_mobilenet_v3_large_320_fpn(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, - trainable_backbone_layers=None, **kwargs): +def fasterrcnn_mobilenet_v3_large_320_fpn( + pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs +): """ Constructs a low resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone tunned for mobile use-cases. It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See @@ -433,13 +469,20 @@ def fasterrcnn_mobilenet_v3_large_320_fpn(pretrained=False, progress=True, num_c } kwargs = {**defaults, **kwargs} - return _fasterrcnn_mobilenet_v3_large_fpn(weights_name, pretrained=pretrained, progress=progress, - num_classes=num_classes, pretrained_backbone=pretrained_backbone, - trainable_backbone_layers=trainable_backbone_layers, **kwargs) - - -def fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, - trainable_backbone_layers=None, **kwargs): + return _fasterrcnn_mobilenet_v3_large_fpn( + weights_name, + pretrained=pretrained, + progress=progress, + num_classes=num_classes, + pretrained_backbone=pretrained_backbone, + trainable_backbone_layers=trainable_backbone_layers, + **kwargs, + ) + + +def fasterrcnn_mobilenet_v3_large_fpn( + pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs +): """ Constructs a high resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone. It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See @@ -467,6 +510,12 @@ def fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_class } kwargs = {**defaults, **kwargs} - return _fasterrcnn_mobilenet_v3_large_fpn(weights_name, pretrained=pretrained, progress=progress, - num_classes=num_classes, pretrained_backbone=pretrained_backbone, - trainable_backbone_layers=trainable_backbone_layers, **kwargs) + return _fasterrcnn_mobilenet_v3_large_fpn( + weights_name, + pretrained=pretrained, + progress=progress, + num_classes=num_classes, + pretrained_backbone=pretrained_backbone, + trainable_backbone_layers=trainable_backbone_layers, + **kwargs, + ) diff --git a/torchvision/models/detection/generalized_rcnn.py b/torchvision/models/detection/generalized_rcnn.py index 1d3979caa3f..c77c892e63e 100644 --- a/torchvision/models/detection/generalized_rcnn.py +++ b/torchvision/models/detection/generalized_rcnn.py @@ -2,11 +2,12 @@ Implements the Generalized R-CNN framework """ +import warnings from collections import OrderedDict +from typing import Tuple, List, Dict, Optional, Union + import torch from torch import nn, Tensor -import warnings -from typing import Tuple, List, Dict, Optional, Union class GeneralizedRCNN(nn.Module): @@ -61,12 +62,11 @@ def forward(self, images, targets=None): boxes = target["boxes"] if isinstance(boxes, torch.Tensor): if len(boxes.shape) != 2 or boxes.shape[-1] != 4: - raise ValueError("Expected target boxes to be a tensor" - "of shape [N, 4], got {:}.".format( - boxes.shape)) + raise ValueError( + "Expected target boxes to be a tensor" "of shape [N, 4], got {:}.".format(boxes.shape) + ) else: - raise ValueError("Expected target boxes to be of type " - "Tensor, got {:}.".format(type(boxes))) + raise ValueError("Expected target boxes to be of type " "Tensor, got {:}.".format(type(boxes))) original_image_sizes: List[Tuple[int, int]] = [] for img in images: @@ -86,13 +86,14 @@ def forward(self, images, targets=None): # print the first degenerate box bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0] degen_bb: List[float] = boxes[bb_idx].tolist() - raise ValueError("All bounding boxes should have positive height and width." - " Found invalid box {} for target at index {}." - .format(degen_bb, target_idx)) + raise ValueError( + "All bounding boxes should have positive height and width." + " Found invalid box {} for target at index {}.".format(degen_bb, target_idx) + ) features = self.backbone(images.tensors) if isinstance(features, torch.Tensor): - features = OrderedDict([('0', features)]) + features = OrderedDict([("0", features)]) proposals, proposal_losses = self.rpn(images, features, targets) detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets) detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) diff --git a/torchvision/models/detection/image_list.py b/torchvision/models/detection/image_list.py index a389b3c3ce1..333d3b569f2 100644 --- a/torchvision/models/detection/image_list.py +++ b/torchvision/models/detection/image_list.py @@ -1,6 +1,7 @@ +from typing import List, Tuple + import torch from torch import Tensor -from typing import List, Tuple class ImageList(object): @@ -20,6 +21,6 @@ def __init__(self, tensors: Tensor, image_sizes: List[Tuple[int, int]]): self.tensors = tensors self.image_sizes = image_sizes - def to(self, device: torch.device) -> 'ImageList': + def to(self, device: torch.device) -> "ImageList": cast_tensor = self.tensors.to(device) return ImageList(cast_tensor, self.image_sizes) diff --git a/torchvision/models/detection/keypoint_rcnn.py b/torchvision/models/detection/keypoint_rcnn.py index 79df3b450c4..7cd975ea6a0 100644 --- a/torchvision/models/detection/keypoint_rcnn.py +++ b/torchvision/models/detection/keypoint_rcnn.py @@ -1,18 +1,14 @@ import torch from torch import nn - from torchvision.ops import MultiScaleRoIAlign -from ._utils import overwrite_eps from ..._internally_replaced_utils import load_state_dict_from_url - -from .faster_rcnn import FasterRCNN +from ._utils import overwrite_eps from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers +from .faster_rcnn import FasterRCNN -__all__ = [ - "KeypointRCNN", "keypointrcnn_resnet50_fpn" -] +__all__ = ["KeypointRCNN", "keypointrcnn_resnet50_fpn"] class KeypointRCNN(FasterRCNN): @@ -151,27 +147,47 @@ class KeypointRCNN(FasterRCNN): >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] >>> predictions = model(x) """ - def __init__(self, backbone, num_classes=None, - # transform parameters - min_size=None, max_size=1333, - image_mean=None, image_std=None, - # RPN parameters - rpn_anchor_generator=None, rpn_head=None, - rpn_pre_nms_top_n_train=2000, rpn_pre_nms_top_n_test=1000, - rpn_post_nms_top_n_train=2000, rpn_post_nms_top_n_test=1000, - rpn_nms_thresh=0.7, - rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3, - rpn_batch_size_per_image=256, rpn_positive_fraction=0.5, - rpn_score_thresh=0.0, - # Box parameters - box_roi_pool=None, box_head=None, box_predictor=None, - box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100, - box_fg_iou_thresh=0.5, box_bg_iou_thresh=0.5, - box_batch_size_per_image=512, box_positive_fraction=0.25, - bbox_reg_weights=None, - # keypoint parameters - keypoint_roi_pool=None, keypoint_head=None, keypoint_predictor=None, - num_keypoints=17): + + def __init__( + self, + backbone, + num_classes=None, + # transform parameters + min_size=None, + max_size=1333, + image_mean=None, + image_std=None, + # RPN parameters + rpn_anchor_generator=None, + rpn_head=None, + rpn_pre_nms_top_n_train=2000, + rpn_pre_nms_top_n_test=1000, + rpn_post_nms_top_n_train=2000, + rpn_post_nms_top_n_test=1000, + rpn_nms_thresh=0.7, + rpn_fg_iou_thresh=0.7, + rpn_bg_iou_thresh=0.3, + rpn_batch_size_per_image=256, + rpn_positive_fraction=0.5, + rpn_score_thresh=0.0, + # Box parameters + box_roi_pool=None, + box_head=None, + box_predictor=None, + box_score_thresh=0.05, + box_nms_thresh=0.5, + box_detections_per_img=100, + box_fg_iou_thresh=0.5, + box_bg_iou_thresh=0.5, + box_batch_size_per_image=512, + box_positive_fraction=0.25, + bbox_reg_weights=None, + # keypoint parameters + keypoint_roi_pool=None, + keypoint_head=None, + keypoint_predictor=None, + num_keypoints=17, + ): assert isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None))) if min_size is None: @@ -184,10 +200,7 @@ def __init__(self, backbone, num_classes=None, out_channels = backbone.out_channels if keypoint_roi_pool is None: - keypoint_roi_pool = MultiScaleRoIAlign( - featmap_names=['0', '1', '2', '3'], - output_size=14, - sampling_ratio=2) + keypoint_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=14, sampling_ratio=2) if keypoint_head is None: keypoint_layers = tuple(512 for _ in range(8)) @@ -198,24 +211,39 @@ def __init__(self, backbone, num_classes=None, keypoint_predictor = KeypointRCNNPredictor(keypoint_dim_reduced, num_keypoints) super(KeypointRCNN, self).__init__( - backbone, num_classes, + backbone, + num_classes, # transform parameters - min_size, max_size, - image_mean, image_std, + min_size, + max_size, + image_mean, + image_std, # RPN-specific parameters - rpn_anchor_generator, rpn_head, - rpn_pre_nms_top_n_train, rpn_pre_nms_top_n_test, - rpn_post_nms_top_n_train, rpn_post_nms_top_n_test, + rpn_anchor_generator, + rpn_head, + rpn_pre_nms_top_n_train, + rpn_pre_nms_top_n_test, + rpn_post_nms_top_n_train, + rpn_post_nms_top_n_test, rpn_nms_thresh, - rpn_fg_iou_thresh, rpn_bg_iou_thresh, - rpn_batch_size_per_image, rpn_positive_fraction, + rpn_fg_iou_thresh, + rpn_bg_iou_thresh, + rpn_batch_size_per_image, + rpn_positive_fraction, rpn_score_thresh, # Box parameters - box_roi_pool, box_head, box_predictor, - box_score_thresh, box_nms_thresh, box_detections_per_img, - box_fg_iou_thresh, box_bg_iou_thresh, - box_batch_size_per_image, box_positive_fraction, - bbox_reg_weights) + box_roi_pool, + box_head, + box_predictor, + box_score_thresh, + box_nms_thresh, + box_detections_per_img, + box_fg_iou_thresh, + box_bg_iou_thresh, + box_batch_size_per_image, + box_positive_fraction, + bbox_reg_weights, + ) self.roi_heads.keypoint_roi_pool = keypoint_roi_pool self.roi_heads.keypoint_head = keypoint_head @@ -249,9 +277,7 @@ def __init__(self, in_channels, num_keypoints): stride=2, padding=deconv_kernel // 2 - 1, ) - nn.init.kaiming_normal_( - self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu" - ) + nn.init.kaiming_normal_(self.kps_score_lowres.weight, mode="fan_out", nonlinearity="relu") nn.init.constant_(self.kps_score_lowres.bias, 0) self.up_scale = 2 self.out_channels = num_keypoints @@ -265,16 +291,20 @@ def forward(self, x): model_urls = { # legacy model for BC reasons, see https://github.com/pytorch/vision/issues/1606 - 'keypointrcnn_resnet50_fpn_coco_legacy': - 'https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth', - 'keypointrcnn_resnet50_fpn_coco': - 'https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth', + "keypointrcnn_resnet50_fpn_coco_legacy": "https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth", + "keypointrcnn_resnet50_fpn_coco": "https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth", } -def keypointrcnn_resnet50_fpn(pretrained=False, progress=True, - num_classes=2, num_keypoints=17, - pretrained_backbone=True, trainable_backbone_layers=None, **kwargs): +def keypointrcnn_resnet50_fpn( + pretrained=False, + progress=True, + num_classes=2, + num_keypoints=17, + pretrained_backbone=True, + trainable_backbone_layers=None, + **kwargs, +): """ Constructs a Keypoint R-CNN model with a ResNet-50-FPN backbone. @@ -331,19 +361,19 @@ def keypointrcnn_resnet50_fpn(pretrained=False, progress=True, Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. """ trainable_backbone_layers = _validate_trainable_layers( - pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3) + pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3 + ) if pretrained: # no need to download the backbone if pretrained is set pretrained_backbone = False - backbone = resnet_fpn_backbone('resnet50', pretrained_backbone, trainable_layers=trainable_backbone_layers) + backbone = resnet_fpn_backbone("resnet50", pretrained_backbone, trainable_layers=trainable_backbone_layers) model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs) if pretrained: - key = 'keypointrcnn_resnet50_fpn_coco' - if pretrained == 'legacy': - key += '_legacy' - state_dict = load_state_dict_from_url(model_urls[key], - progress=progress) + key = "keypointrcnn_resnet50_fpn_coco" + if pretrained == "legacy": + key += "_legacy" + state_dict = load_state_dict_from_url(model_urls[key], progress=progress) model.load_state_dict(state_dict) overwrite_eps(model, 0.0) return model diff --git a/torchvision/models/detection/mask_rcnn.py b/torchvision/models/detection/mask_rcnn.py index 06b36d573ab..6b8208b19d8 100644 --- a/torchvision/models/detection/mask_rcnn.py +++ b/torchvision/models/detection/mask_rcnn.py @@ -1,17 +1,16 @@ from collections import OrderedDict from torch import nn - from torchvision.ops import MultiScaleRoIAlign -from ._utils import overwrite_eps from ..._internally_replaced_utils import load_state_dict_from_url - -from .faster_rcnn import FasterRCNN +from ._utils import overwrite_eps from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers +from .faster_rcnn import FasterRCNN __all__ = [ - "MaskRCNN", "maskrcnn_resnet50_fpn", + "MaskRCNN", + "maskrcnn_resnet50_fpn", ] @@ -149,26 +148,46 @@ class MaskRCNN(FasterRCNN): >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] >>> predictions = model(x) """ - def __init__(self, backbone, num_classes=None, - # transform parameters - min_size=800, max_size=1333, - image_mean=None, image_std=None, - # RPN parameters - rpn_anchor_generator=None, rpn_head=None, - rpn_pre_nms_top_n_train=2000, rpn_pre_nms_top_n_test=1000, - rpn_post_nms_top_n_train=2000, rpn_post_nms_top_n_test=1000, - rpn_nms_thresh=0.7, - rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3, - rpn_batch_size_per_image=256, rpn_positive_fraction=0.5, - rpn_score_thresh=0.0, - # Box parameters - box_roi_pool=None, box_head=None, box_predictor=None, - box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100, - box_fg_iou_thresh=0.5, box_bg_iou_thresh=0.5, - box_batch_size_per_image=512, box_positive_fraction=0.25, - bbox_reg_weights=None, - # Mask parameters - mask_roi_pool=None, mask_head=None, mask_predictor=None): + + def __init__( + self, + backbone, + num_classes=None, + # transform parameters + min_size=800, + max_size=1333, + image_mean=None, + image_std=None, + # RPN parameters + rpn_anchor_generator=None, + rpn_head=None, + rpn_pre_nms_top_n_train=2000, + rpn_pre_nms_top_n_test=1000, + rpn_post_nms_top_n_train=2000, + rpn_post_nms_top_n_test=1000, + rpn_nms_thresh=0.7, + rpn_fg_iou_thresh=0.7, + rpn_bg_iou_thresh=0.3, + rpn_batch_size_per_image=256, + rpn_positive_fraction=0.5, + rpn_score_thresh=0.0, + # Box parameters + box_roi_pool=None, + box_head=None, + box_predictor=None, + box_score_thresh=0.05, + box_nms_thresh=0.5, + box_detections_per_img=100, + box_fg_iou_thresh=0.5, + box_bg_iou_thresh=0.5, + box_batch_size_per_image=512, + box_positive_fraction=0.25, + bbox_reg_weights=None, + # Mask parameters + mask_roi_pool=None, + mask_head=None, + mask_predictor=None, + ): assert isinstance(mask_roi_pool, (MultiScaleRoIAlign, type(None))) @@ -179,10 +198,7 @@ def __init__(self, backbone, num_classes=None, out_channels = backbone.out_channels if mask_roi_pool is None: - mask_roi_pool = MultiScaleRoIAlign( - featmap_names=['0', '1', '2', '3'], - output_size=14, - sampling_ratio=2) + mask_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=14, sampling_ratio=2) if mask_head is None: mask_layers = (256, 256, 256, 256) @@ -192,28 +208,42 @@ def __init__(self, backbone, num_classes=None, if mask_predictor is None: mask_predictor_in_channels = 256 # == mask_layers[-1] mask_dim_reduced = 256 - mask_predictor = MaskRCNNPredictor(mask_predictor_in_channels, - mask_dim_reduced, num_classes) + mask_predictor = MaskRCNNPredictor(mask_predictor_in_channels, mask_dim_reduced, num_classes) super(MaskRCNN, self).__init__( - backbone, num_classes, + backbone, + num_classes, # transform parameters - min_size, max_size, - image_mean, image_std, + min_size, + max_size, + image_mean, + image_std, # RPN-specific parameters - rpn_anchor_generator, rpn_head, - rpn_pre_nms_top_n_train, rpn_pre_nms_top_n_test, - rpn_post_nms_top_n_train, rpn_post_nms_top_n_test, + rpn_anchor_generator, + rpn_head, + rpn_pre_nms_top_n_train, + rpn_pre_nms_top_n_test, + rpn_post_nms_top_n_train, + rpn_post_nms_top_n_test, rpn_nms_thresh, - rpn_fg_iou_thresh, rpn_bg_iou_thresh, - rpn_batch_size_per_image, rpn_positive_fraction, + rpn_fg_iou_thresh, + rpn_bg_iou_thresh, + rpn_batch_size_per_image, + rpn_positive_fraction, rpn_score_thresh, # Box parameters - box_roi_pool, box_head, box_predictor, - box_score_thresh, box_nms_thresh, box_detections_per_img, - box_fg_iou_thresh, box_bg_iou_thresh, - box_batch_size_per_image, box_positive_fraction, - bbox_reg_weights) + box_roi_pool, + box_head, + box_predictor, + box_score_thresh, + box_nms_thresh, + box_detections_per_img, + box_fg_iou_thresh, + box_bg_iou_thresh, + box_batch_size_per_image, + box_positive_fraction, + bbox_reg_weights, + ) self.roi_heads.mask_roi_pool = mask_roi_pool self.roi_heads.mask_head = mask_head @@ -232,8 +262,8 @@ def __init__(self, in_channels, layers, dilation): next_feature = in_channels for layer_idx, layer_features in enumerate(layers, 1): d["mask_fcn{}".format(layer_idx)] = nn.Conv2d( - next_feature, layer_features, kernel_size=3, - stride=1, padding=dilation, dilation=dilation) + next_feature, layer_features, kernel_size=3, stride=1, padding=dilation, dilation=dilation + ) d["relu{}".format(layer_idx)] = nn.ReLU(inplace=True) next_feature = layer_features @@ -247,11 +277,15 @@ def __init__(self, in_channels, layers, dilation): class MaskRCNNPredictor(nn.Sequential): def __init__(self, in_channels, dim_reduced, num_classes): - super(MaskRCNNPredictor, self).__init__(OrderedDict([ - ("conv5_mask", nn.ConvTranspose2d(in_channels, dim_reduced, 2, 2, 0)), - ("relu", nn.ReLU(inplace=True)), - ("mask_fcn_logits", nn.Conv2d(dim_reduced, num_classes, 1, 1, 0)), - ])) + super(MaskRCNNPredictor, self).__init__( + OrderedDict( + [ + ("conv5_mask", nn.ConvTranspose2d(in_channels, dim_reduced, 2, 2, 0)), + ("relu", nn.ReLU(inplace=True)), + ("mask_fcn_logits", nn.Conv2d(dim_reduced, num_classes, 1, 1, 0)), + ] + ) + ) for name, param in self.named_parameters(): if "weight" in name: @@ -261,13 +295,13 @@ def __init__(self, in_channels, dim_reduced, num_classes): model_urls = { - 'maskrcnn_resnet50_fpn_coco': - 'https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth', + "maskrcnn_resnet50_fpn_coco": "https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth", } -def maskrcnn_resnet50_fpn(pretrained=False, progress=True, - num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs): +def maskrcnn_resnet50_fpn( + pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs +): """ Constructs a Mask R-CNN model with a ResNet-50-FPN backbone. @@ -324,16 +358,16 @@ def maskrcnn_resnet50_fpn(pretrained=False, progress=True, Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. """ trainable_backbone_layers = _validate_trainable_layers( - pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3) + pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3 + ) if pretrained: # no need to download the backbone if pretrained is set pretrained_backbone = False - backbone = resnet_fpn_backbone('resnet50', pretrained_backbone, trainable_layers=trainable_backbone_layers) + backbone = resnet_fpn_backbone("resnet50", pretrained_backbone, trainable_layers=trainable_backbone_layers) model = MaskRCNN(backbone, num_classes, **kwargs) if pretrained: - state_dict = load_state_dict_from_url(model_urls['maskrcnn_resnet50_fpn_coco'], - progress=progress) + state_dict = load_state_dict_from_url(model_urls["maskrcnn_resnet50_fpn_coco"], progress=progress) model.load_state_dict(state_dict) overwrite_eps(model, 0.0) return model diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index c6e301c268c..eb05144cb0c 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -1,26 +1,23 @@ import math -from collections import OrderedDict import warnings +from collections import OrderedDict +from typing import Dict, List, Tuple, Optional import torch from torch import nn, Tensor -from typing import Dict, List, Tuple, Optional -from ._utils import overwrite_eps from ..._internally_replaced_utils import load_state_dict_from_url - +from ...ops import sigmoid_focal_loss +from ...ops import boxes as box_ops +from ...ops.feature_pyramid_network import LastLevelP6P7 from . import _utils as det_utils +from ._utils import overwrite_eps from .anchor_utils import AnchorGenerator -from .transform import GeneralizedRCNNTransform from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers -from ...ops.feature_pyramid_network import LastLevelP6P7 -from ...ops import sigmoid_focal_loss -from ...ops import boxes as box_ops +from .transform import GeneralizedRCNNTransform -__all__ = [ - "RetinaNet", "retinanet_resnet50_fpn" -] +__all__ = ["RetinaNet", "retinanet_resnet50_fpn"] def _sum(x: List[Tensor]) -> Tensor: @@ -48,16 +45,13 @@ def __init__(self, in_channels, num_anchors, num_classes): def compute_loss(self, targets, head_outputs, anchors, matched_idxs): # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Dict[str, Tensor] return { - 'classification': self.classification_head.compute_loss(targets, head_outputs, matched_idxs), - 'bbox_regression': self.regression_head.compute_loss(targets, head_outputs, anchors, matched_idxs), + "classification": self.classification_head.compute_loss(targets, head_outputs, matched_idxs), + "bbox_regression": self.regression_head.compute_loss(targets, head_outputs, anchors, matched_idxs), } def forward(self, x): # type: (List[Tensor]) -> Dict[str, Tensor] - return { - 'cls_logits': self.classification_head(x), - 'bbox_regression': self.regression_head(x) - } + return {"cls_logits": self.classification_head(x), "bbox_regression": self.regression_head(x)} class RetinaNetClassificationHead(nn.Module): @@ -100,7 +94,7 @@ def compute_loss(self, targets, head_outputs, matched_idxs): # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Tensor losses = [] - cls_logits = head_outputs['cls_logits'] + cls_logits = head_outputs["cls_logits"] for targets_per_image, cls_logits_per_image, matched_idxs_per_image in zip(targets, cls_logits, matched_idxs): # determine only the foreground @@ -111,18 +105,21 @@ def compute_loss(self, targets, head_outputs, matched_idxs): gt_classes_target = torch.zeros_like(cls_logits_per_image) gt_classes_target[ foreground_idxs_per_image, - targets_per_image['labels'][matched_idxs_per_image[foreground_idxs_per_image]] + targets_per_image["labels"][matched_idxs_per_image[foreground_idxs_per_image]], ] = 1.0 # find indices for which anchors should be ignored valid_idxs_per_image = matched_idxs_per_image != self.BETWEEN_THRESHOLDS # compute the classification loss - losses.append(sigmoid_focal_loss( - cls_logits_per_image[valid_idxs_per_image], - gt_classes_target[valid_idxs_per_image], - reduction='sum', - ) / max(1, num_foreground)) + losses.append( + sigmoid_focal_loss( + cls_logits_per_image[valid_idxs_per_image], + gt_classes_target[valid_idxs_per_image], + reduction="sum", + ) + / max(1, num_foreground) + ) return _sum(losses) / len(targets) @@ -153,8 +150,9 @@ class RetinaNetRegressionHead(nn.Module): in_channels (int): number of channels of the input feature num_anchors (int): number of anchors to be predicted """ + __annotations__ = { - 'box_coder': det_utils.BoxCoder, + "box_coder": det_utils.BoxCoder, } def __init__(self, in_channels, num_anchors): @@ -181,16 +179,17 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs): # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor], List[Tensor]) -> Tensor losses = [] - bbox_regression = head_outputs['bbox_regression'] + bbox_regression = head_outputs["bbox_regression"] - for targets_per_image, bbox_regression_per_image, anchors_per_image, matched_idxs_per_image in \ - zip(targets, bbox_regression, anchors, matched_idxs): + for targets_per_image, bbox_regression_per_image, anchors_per_image, matched_idxs_per_image in zip( + targets, bbox_regression, anchors, matched_idxs + ): # determine only the foreground indices, ignore the rest foreground_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0] num_foreground = foreground_idxs_per_image.numel() # select only the foreground boxes - matched_gt_boxes_per_image = targets_per_image['boxes'][matched_idxs_per_image[foreground_idxs_per_image]] + matched_gt_boxes_per_image = targets_per_image["boxes"][matched_idxs_per_image[foreground_idxs_per_image]] bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :] anchors_per_image = anchors_per_image[foreground_idxs_per_image, :] @@ -198,11 +197,10 @@ def compute_loss(self, targets, head_outputs, anchors, matched_idxs): target_regression = self.box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image) # compute the loss - losses.append(torch.nn.functional.l1_loss( - bbox_regression_per_image, - target_regression, - reduction='sum' - ) / max(1, num_foreground)) + losses.append( + torch.nn.functional.l1_loss(bbox_regression_per_image, target_regression, reduction="sum") + / max(1, num_foreground) + ) return _sum(losses) / max(1, len(targets)) @@ -309,30 +307,40 @@ class RetinaNet(nn.Module): >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] >>> predictions = model(x) """ + __annotations__ = { - 'box_coder': det_utils.BoxCoder, - 'proposal_matcher': det_utils.Matcher, + "box_coder": det_utils.BoxCoder, + "proposal_matcher": det_utils.Matcher, } - def __init__(self, backbone, num_classes, - # transform parameters - min_size=800, max_size=1333, - image_mean=None, image_std=None, - # Anchor parameters - anchor_generator=None, head=None, - proposal_matcher=None, - score_thresh=0.05, - nms_thresh=0.5, - detections_per_img=300, - fg_iou_thresh=0.5, bg_iou_thresh=0.4, - topk_candidates=1000): + def __init__( + self, + backbone, + num_classes, + # transform parameters + min_size=800, + max_size=1333, + image_mean=None, + image_std=None, + # Anchor parameters + anchor_generator=None, + head=None, + proposal_matcher=None, + score_thresh=0.05, + nms_thresh=0.5, + detections_per_img=300, + fg_iou_thresh=0.5, + bg_iou_thresh=0.4, + topk_candidates=1000, + ): super().__init__() if not hasattr(backbone, "out_channels"): raise ValueError( "backbone should contain an attribute out_channels " "specifying the number of output channels (assumed to be the " - "same for all the levels)") + "same for all the levels)" + ) self.backbone = backbone assert isinstance(anchor_generator, (AnchorGenerator, type(None))) @@ -340,9 +348,7 @@ def __init__(self, backbone, num_classes, if anchor_generator is None: anchor_sizes = tuple((x, int(x * 2 ** (1.0 / 3)), int(x * 2 ** (2.0 / 3))) for x in [32, 64, 128, 256, 512]) aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) - anchor_generator = AnchorGenerator( - anchor_sizes, aspect_ratios - ) + anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios) self.anchor_generator = anchor_generator if head is None: @@ -385,20 +391,21 @@ def compute_loss(self, targets, head_outputs, anchors): # type: (List[Dict[str, Tensor]], Dict[str, Tensor], List[Tensor]) -> Dict[str, Tensor] matched_idxs = [] for anchors_per_image, targets_per_image in zip(anchors, targets): - if targets_per_image['boxes'].numel() == 0: - matched_idxs.append(torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64, - device=anchors_per_image.device)) + if targets_per_image["boxes"].numel() == 0: + matched_idxs.append( + torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64, device=anchors_per_image.device) + ) continue - match_quality_matrix = box_ops.box_iou(targets_per_image['boxes'], anchors_per_image) + match_quality_matrix = box_ops.box_iou(targets_per_image["boxes"], anchors_per_image) matched_idxs.append(self.proposal_matcher(match_quality_matrix)) return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs) def postprocess_detections(self, head_outputs, anchors, image_shapes): # type: (Dict[str, List[Tensor]], List[List[Tensor]], List[Tuple[int, int]]) -> List[Dict[str, Tensor]] - class_logits = head_outputs['cls_logits'] - box_regression = head_outputs['bbox_regression'] + class_logits = head_outputs["cls_logits"] + box_regression = head_outputs["bbox_regression"] num_images = len(image_shapes) @@ -413,8 +420,9 @@ def postprocess_detections(self, head_outputs, anchors, image_shapes): image_scores = [] image_labels = [] - for box_regression_per_level, logits_per_level, anchors_per_level in \ - zip(box_regression_per_image, logits_per_image, anchors_per_image): + for box_regression_per_level, logits_per_level, anchors_per_level in zip( + box_regression_per_image, logits_per_image, anchors_per_image + ): num_classes = logits_per_level.shape[-1] # remove low scoring boxes @@ -428,11 +436,12 @@ def postprocess_detections(self, head_outputs, anchors, image_shapes): scores_per_level, idxs = scores_per_level.topk(num_topk) topk_idxs = topk_idxs[idxs] - anchor_idxs = torch.div(topk_idxs, num_classes, rounding_mode='floor') + anchor_idxs = torch.div(topk_idxs, num_classes, rounding_mode="floor") labels_per_level = topk_idxs % num_classes - boxes_per_level = self.box_coder.decode_single(box_regression_per_level[anchor_idxs], - anchors_per_level[anchor_idxs]) + boxes_per_level = self.box_coder.decode_single( + box_regression_per_level[anchor_idxs], anchors_per_level[anchor_idxs] + ) boxes_per_level = box_ops.clip_boxes_to_image(boxes_per_level, image_shape) image_boxes.append(boxes_per_level) @@ -445,13 +454,15 @@ def postprocess_detections(self, head_outputs, anchors, image_shapes): # non-maximum suppression keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh) - keep = keep[:self.detections_per_img] - - detections.append({ - 'boxes': image_boxes[keep], - 'scores': image_scores[keep], - 'labels': image_labels[keep], - }) + keep = keep[: self.detections_per_img] + + detections.append( + { + "boxes": image_boxes[keep], + "scores": image_scores[keep], + "labels": image_labels[keep], + } + ) return detections @@ -478,12 +489,11 @@ def forward(self, images, targets=None): boxes = target["boxes"] if isinstance(boxes, torch.Tensor): if len(boxes.shape) != 2 or boxes.shape[-1] != 4: - raise ValueError("Expected target boxes to be a tensor" - "of shape [N, 4], got {:}.".format( - boxes.shape)) + raise ValueError( + "Expected target boxes to be a tensor" "of shape [N, 4], got {:}.".format(boxes.shape) + ) else: - raise ValueError("Expected target boxes to be of type " - "Tensor, got {:}.".format(type(boxes))) + raise ValueError("Expected target boxes to be of type " "Tensor, got {:}.".format(type(boxes))) # get the original image sizes original_image_sizes: List[Tuple[int, int]] = [] @@ -505,14 +515,15 @@ def forward(self, images, targets=None): # print the first degenerate box bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0] degen_bb: List[float] = boxes[bb_idx].tolist() - raise ValueError("All bounding boxes should have positive height and width." - " Found invalid box {} for target at index {}." - .format(degen_bb, target_idx)) + raise ValueError( + "All bounding boxes should have positive height and width." + " Found invalid box {} for target at index {}.".format(degen_bb, target_idx) + ) # get the features from the backbone features = self.backbone(images.tensors) if isinstance(features, torch.Tensor): - features = OrderedDict([('0', features)]) + features = OrderedDict([("0", features)]) # TODO: Do we want a list or a dict? features = list(features.values()) @@ -536,7 +547,7 @@ def forward(self, images, targets=None): HW = 0 for v in num_anchors_per_level: HW += v - HWA = head_outputs['cls_logits'].size(1) + HWA = head_outputs["cls_logits"].size(1) A = HWA // HW num_anchors_per_level = [hw * A for hw in num_anchors_per_level] @@ -559,13 +570,13 @@ def forward(self, images, targets=None): model_urls = { - 'retinanet_resnet50_fpn_coco': - 'https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth', + "retinanet_resnet50_fpn_coco": "https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth", } -def retinanet_resnet50_fpn(pretrained=False, progress=True, - num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs): +def retinanet_resnet50_fpn( + pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs +): """ Constructs a RetinaNet model with a ResNet-50-FPN backbone. @@ -613,18 +624,23 @@ def retinanet_resnet50_fpn(pretrained=False, progress=True, Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. """ trainable_backbone_layers = _validate_trainable_layers( - pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3) + pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3 + ) if pretrained: # no need to download the backbone if pretrained is set pretrained_backbone = False # skip P2 because it generates too many anchors (according to their paper) - backbone = resnet_fpn_backbone('resnet50', pretrained_backbone, returned_layers=[2, 3, 4], - extra_blocks=LastLevelP6P7(256, 256), trainable_layers=trainable_backbone_layers) + backbone = resnet_fpn_backbone( + "resnet50", + pretrained_backbone, + returned_layers=[2, 3, 4], + extra_blocks=LastLevelP6P7(256, 256), + trainable_layers=trainable_backbone_layers, + ) model = RetinaNet(backbone, num_classes, **kwargs) if pretrained: - state_dict = load_state_dict_from_url(model_urls['retinanet_resnet50_fpn_coco'], - progress=progress) + state_dict = load_state_dict_from_url(model_urls["retinanet_resnet50_fpn_coco"], progress=progress) model.load_state_dict(state_dict) overwrite_eps(model, 0.0) return model diff --git a/torchvision/models/detection/roi_heads.py b/torchvision/models/detection/roi_heads.py index 9948d5f537f..35aee4b7d54 100644 --- a/torchvision/models/detection/roi_heads.py +++ b/torchvision/models/detection/roi_heads.py @@ -1,17 +1,14 @@ -import torch -import torchvision +from typing import Optional, List, Dict, Tuple +import torch import torch.nn.functional as F +import torchvision from torch import nn, Tensor - from torchvision.ops import boxes as box_ops - from torchvision.ops import roi_align from . import _utils as det_utils -from typing import Optional, List, Dict, Tuple - def fastrcnn_loss(class_logits, box_regression, labels, regression_targets): # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor] @@ -46,7 +43,7 @@ def fastrcnn_loss(class_logits, box_regression, labels, regression_targets): box_regression[sampled_pos_inds_subset, labels_pos], regression_targets[sampled_pos_inds_subset], beta=1 / 9, - reduction='sum', + reduction="sum", ) box_loss = box_loss / labels.numel() @@ -95,7 +92,7 @@ def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M): matched_idxs = matched_idxs.to(boxes) rois = torch.cat([matched_idxs[:, None], boxes], dim=1) gt_masks = gt_masks[:, None].to(rois) - return roi_align(gt_masks, rois, (M, M), 1.)[:, 0] + return roi_align(gt_masks, rois, (M, M), 1.0)[:, 0] def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs): @@ -113,8 +110,7 @@ def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs discretization_size = mask_logits.shape[-1] labels = [gt_label[idxs] for gt_label, idxs in zip(gt_labels, mask_matched_idxs)] mask_targets = [ - project_masks_on_boxes(m, p, i, discretization_size) - for m, p, i in zip(gt_masks, proposals, mask_matched_idxs) + project_masks_on_boxes(m, p, i, discretization_size) for m, p, i in zip(gt_masks, proposals, mask_matched_idxs) ] labels = torch.cat(labels, dim=0) @@ -167,59 +163,72 @@ def keypoints_to_heatmap(keypoints, rois, heatmap_size): return heatmaps, valid -def _onnx_heatmaps_to_keypoints(maps, maps_i, roi_map_width, roi_map_height, - widths_i, heights_i, offset_x_i, offset_y_i): +def _onnx_heatmaps_to_keypoints( + maps, maps_i, roi_map_width, roi_map_height, widths_i, heights_i, offset_x_i, offset_y_i +): num_keypoints = torch.scalar_tensor(maps.size(1), dtype=torch.int64) width_correction = widths_i / roi_map_width height_correction = heights_i / roi_map_height roi_map = F.interpolate( - maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode='bicubic', align_corners=False)[:, 0] + maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode="bicubic", align_corners=False + )[:, 0] w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64) pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1) - x_int = (pos % w) - y_int = ((pos - x_int) // w) + x_int = pos % w + y_int = (pos - x_int) // w - x = (torch.tensor(0.5, dtype=torch.float32) + x_int.to(dtype=torch.float32)) * \ - width_correction.to(dtype=torch.float32) - y = (torch.tensor(0.5, dtype=torch.float32) + y_int.to(dtype=torch.float32)) * \ - height_correction.to(dtype=torch.float32) + x = (torch.tensor(0.5, dtype=torch.float32) + x_int.to(dtype=torch.float32)) * width_correction.to( + dtype=torch.float32 + ) + y = (torch.tensor(0.5, dtype=torch.float32) + y_int.to(dtype=torch.float32)) * height_correction.to( + dtype=torch.float32 + ) xy_preds_i_0 = x + offset_x_i.to(dtype=torch.float32) xy_preds_i_1 = y + offset_y_i.to(dtype=torch.float32) xy_preds_i_2 = torch.ones(xy_preds_i_1.shape, dtype=torch.float32) - xy_preds_i = torch.stack([xy_preds_i_0.to(dtype=torch.float32), - xy_preds_i_1.to(dtype=torch.float32), - xy_preds_i_2.to(dtype=torch.float32)], 0) + xy_preds_i = torch.stack( + [ + xy_preds_i_0.to(dtype=torch.float32), + xy_preds_i_1.to(dtype=torch.float32), + xy_preds_i_2.to(dtype=torch.float32), + ], + 0, + ) # TODO: simplify when indexing without rank will be supported by ONNX base = num_keypoints * num_keypoints + num_keypoints + 1 ind = torch.arange(num_keypoints) ind = ind.to(dtype=torch.int64) * base - end_scores_i = roi_map.index_select(1, y_int.to(dtype=torch.int64)) \ - .index_select(2, x_int.to(dtype=torch.int64)).view(-1).index_select(0, ind.to(dtype=torch.int64)) + end_scores_i = ( + roi_map.index_select(1, y_int.to(dtype=torch.int64)) + .index_select(2, x_int.to(dtype=torch.int64)) + .view(-1) + .index_select(0, ind.to(dtype=torch.int64)) + ) return xy_preds_i, end_scores_i @torch.jit._script_if_tracing -def _onnx_heatmaps_to_keypoints_loop(maps, rois, widths_ceil, heights_ceil, - widths, heights, offset_x, offset_y, num_keypoints): +def _onnx_heatmaps_to_keypoints_loop( + maps, rois, widths_ceil, heights_ceil, widths, heights, offset_x, offset_y, num_keypoints +): xy_preds = torch.zeros((0, 3, int(num_keypoints)), dtype=torch.float32, device=maps.device) end_scores = torch.zeros((0, int(num_keypoints)), dtype=torch.float32, device=maps.device) for i in range(int(rois.size(0))): - xy_preds_i, end_scores_i = _onnx_heatmaps_to_keypoints(maps, maps[i], - widths_ceil[i], heights_ceil[i], - widths[i], heights[i], - offset_x[i], offset_y[i]) - xy_preds = torch.cat((xy_preds.to(dtype=torch.float32), - xy_preds_i.unsqueeze(0).to(dtype=torch.float32)), 0) - end_scores = torch.cat((end_scores.to(dtype=torch.float32), - end_scores_i.to(dtype=torch.float32).unsqueeze(0)), 0) + xy_preds_i, end_scores_i = _onnx_heatmaps_to_keypoints( + maps, maps[i], widths_ceil[i], heights_ceil[i], widths[i], heights[i], offset_x[i], offset_y[i] + ) + xy_preds = torch.cat((xy_preds.to(dtype=torch.float32), xy_preds_i.unsqueeze(0).to(dtype=torch.float32)), 0) + end_scores = torch.cat( + (end_scores.to(dtype=torch.float32), end_scores_i.to(dtype=torch.float32).unsqueeze(0)), 0 + ) return xy_preds, end_scores @@ -246,10 +255,17 @@ def heatmaps_to_keypoints(maps, rois): num_keypoints = maps.shape[1] if torchvision._is_tracing(): - xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop(maps, rois, - widths_ceil, heights_ceil, widths, heights, - offset_x, offset_y, - torch.scalar_tensor(num_keypoints, dtype=torch.int64)) + xy_preds, end_scores = _onnx_heatmaps_to_keypoints_loop( + maps, + rois, + widths_ceil, + heights_ceil, + widths, + heights, + offset_x, + offset_y, + torch.scalar_tensor(num_keypoints, dtype=torch.int64), + ) return xy_preds.permute(0, 2, 1), end_scores xy_preds = torch.zeros((len(rois), 3, num_keypoints), dtype=torch.float32, device=maps.device) @@ -260,13 +276,14 @@ def heatmaps_to_keypoints(maps, rois): width_correction = widths[i] / roi_map_width height_correction = heights[i] / roi_map_height roi_map = F.interpolate( - maps[i][:, None], size=(roi_map_height, roi_map_width), mode='bicubic', align_corners=False)[:, 0] + maps[i][:, None], size=(roi_map_height, roi_map_width), mode="bicubic", align_corners=False + )[:, 0] # roi_map_probs = scores_to_probs(roi_map.copy()) w = roi_map.shape[2] pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1) x_int = pos % w - y_int = torch.div(pos - x_int, w, rounding_mode='floor') + y_int = torch.div(pos - x_int, w, rounding_mode="floor") # assert (roi_map_probs[k, y_int, x_int] == # roi_map_probs[k, :, :].max()) x = (x_int.float() + 0.5) * width_correction @@ -288,9 +305,7 @@ def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched valid = [] for proposals_per_image, gt_kp_in_image, midx in zip(proposals, gt_keypoints, keypoint_matched_idxs): kp = gt_kp_in_image[midx] - heatmaps_per_image, valid_per_image = keypoints_to_heatmap( - kp, proposals_per_image, discretization_size - ) + heatmaps_per_image, valid_per_image = keypoints_to_heatmap(kp, proposals_per_image, discretization_size) heatmaps.append(heatmaps_per_image.view(-1)) valid.append(valid_per_image.view(-1)) @@ -327,10 +342,10 @@ def keypointrcnn_inference(x, boxes): def _onnx_expand_boxes(boxes, scale): # type: (Tensor, float) -> Tensor - w_half = (boxes[:, 2] - boxes[:, 0]) * .5 - h_half = (boxes[:, 3] - boxes[:, 1]) * .5 - x_c = (boxes[:, 2] + boxes[:, 0]) * .5 - y_c = (boxes[:, 3] + boxes[:, 1]) * .5 + w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5 + h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5 + x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5 + y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5 w_half = w_half.to(dtype=torch.float32) * scale h_half = h_half.to(dtype=torch.float32) * scale @@ -350,10 +365,10 @@ def expand_boxes(boxes, scale): # type: (Tensor, float) -> Tensor if torchvision._is_tracing(): return _onnx_expand_boxes(boxes, scale) - w_half = (boxes[:, 2] - boxes[:, 0]) * .5 - h_half = (boxes[:, 3] - boxes[:, 1]) * .5 - x_c = (boxes[:, 2] + boxes[:, 0]) * .5 - y_c = (boxes[:, 3] + boxes[:, 1]) * .5 + w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5 + h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5 + x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5 + y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5 w_half *= scale h_half *= scale @@ -395,7 +410,7 @@ def paste_mask_in_image(mask, box, im_h, im_w): mask = mask.expand((1, 1, -1, -1)) # Resize mask - mask = F.interpolate(mask, size=(h, w), mode='bilinear', align_corners=False) + mask = F.interpolate(mask, size=(h, w), mode="bilinear", align_corners=False) mask = mask[0][0] im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device) @@ -404,9 +419,7 @@ def paste_mask_in_image(mask, box, im_h, im_w): y_0 = max(box[1], 0) y_1 = min(box[3] + 1, im_h) - im_mask[y_0:y_1, x_0:x_1] = mask[ - (y_0 - box[1]):(y_1 - box[1]), (x_0 - box[0]):(x_1 - box[0]) - ] + im_mask[y_0:y_1, x_0:x_1] = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])] return im_mask @@ -414,8 +427,8 @@ def _onnx_paste_mask_in_image(mask, box, im_h, im_w): one = torch.ones(1, dtype=torch.int64) zero = torch.zeros(1, dtype=torch.int64) - w = (box[2] - box[0] + one) - h = (box[3] - box[1] + one) + w = box[2] - box[0] + one + h = box[3] - box[1] + one w = torch.max(torch.cat((w, one))) h = torch.max(torch.cat((h, one))) @@ -423,7 +436,7 @@ def _onnx_paste_mask_in_image(mask, box, im_h, im_w): mask = mask.expand((1, 1, mask.size(0), mask.size(1))) # Resize mask - mask = F.interpolate(mask, size=(int(h), int(w)), mode='bilinear', align_corners=False) + mask = F.interpolate(mask, size=(int(h), int(w)), mode="bilinear", align_corners=False) mask = mask[0][0] x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero))) @@ -431,23 +444,18 @@ def _onnx_paste_mask_in_image(mask, box, im_h, im_w): y_0 = torch.max(torch.cat((box[1].unsqueeze(0), zero))) y_1 = torch.min(torch.cat((box[3].unsqueeze(0) + one, im_h.unsqueeze(0)))) - unpaded_im_mask = mask[(y_0 - box[1]):(y_1 - box[1]), - (x_0 - box[0]):(x_1 - box[0])] + unpaded_im_mask = mask[(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0])] # TODO : replace below with a dynamic padding when support is added in ONNX # pad y zeros_y0 = torch.zeros(y_0, unpaded_im_mask.size(1)) zeros_y1 = torch.zeros(im_h - y_1, unpaded_im_mask.size(1)) - concat_0 = torch.cat((zeros_y0, - unpaded_im_mask.to(dtype=torch.float32), - zeros_y1), 0)[0:im_h, :] + concat_0 = torch.cat((zeros_y0, unpaded_im_mask.to(dtype=torch.float32), zeros_y1), 0)[0:im_h, :] # pad x zeros_x0 = torch.zeros(concat_0.size(0), x_0) zeros_x1 = torch.zeros(concat_0.size(0), im_w - x_1) - im_mask = torch.cat((zeros_x0, - concat_0, - zeros_x1), 1)[:, :im_w] + im_mask = torch.cat((zeros_x0, concat_0, zeros_x1), 1)[:, :im_w] return im_mask @@ -468,13 +476,10 @@ def paste_masks_in_image(masks, boxes, img_shape, padding=1): im_h, im_w = img_shape if torchvision._is_tracing(): - return _onnx_paste_masks_in_image_loop(masks, boxes, - torch.scalar_tensor(im_h, dtype=torch.int64), - torch.scalar_tensor(im_w, dtype=torch.int64))[:, None] - res = [ - paste_mask_in_image(m[0], b, im_h, im_w) - for m, b in zip(masks, boxes) - ] + return _onnx_paste_masks_in_image_loop( + masks, boxes, torch.scalar_tensor(im_h, dtype=torch.int64), torch.scalar_tensor(im_w, dtype=torch.int64) + )[:, None] + res = [paste_mask_in_image(m[0], b, im_h, im_w) for m, b in zip(masks, boxes)] if len(res) > 0: ret = torch.stack(res, dim=0)[:, None] else: @@ -484,46 +489,44 @@ def paste_masks_in_image(masks, boxes, img_shape, padding=1): class RoIHeads(nn.Module): __annotations__ = { - 'box_coder': det_utils.BoxCoder, - 'proposal_matcher': det_utils.Matcher, - 'fg_bg_sampler': det_utils.BalancedPositiveNegativeSampler, + "box_coder": det_utils.BoxCoder, + "proposal_matcher": det_utils.Matcher, + "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler, } - def __init__(self, - box_roi_pool, - box_head, - box_predictor, - # Faster R-CNN training - fg_iou_thresh, bg_iou_thresh, - batch_size_per_image, positive_fraction, - bbox_reg_weights, - # Faster R-CNN inference - score_thresh, - nms_thresh, - detections_per_img, - # Mask - mask_roi_pool=None, - mask_head=None, - mask_predictor=None, - keypoint_roi_pool=None, - keypoint_head=None, - keypoint_predictor=None, - ): + def __init__( + self, + box_roi_pool, + box_head, + box_predictor, + # Faster R-CNN training + fg_iou_thresh, + bg_iou_thresh, + batch_size_per_image, + positive_fraction, + bbox_reg_weights, + # Faster R-CNN inference + score_thresh, + nms_thresh, + detections_per_img, + # Mask + mask_roi_pool=None, + mask_head=None, + mask_predictor=None, + keypoint_roi_pool=None, + keypoint_head=None, + keypoint_predictor=None, + ): super(RoIHeads, self).__init__() self.box_similarity = box_ops.box_iou # assign ground-truth boxes for each proposal - self.proposal_matcher = det_utils.Matcher( - fg_iou_thresh, - bg_iou_thresh, - allow_low_quality_matches=False) + self.proposal_matcher = det_utils.Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=False) - self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler( - batch_size_per_image, - positive_fraction) + self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction) if bbox_reg_weights is None: - bbox_reg_weights = (10., 10., 5., 5.) + bbox_reg_weights = (10.0, 10.0, 5.0, 5.0) self.box_coder = det_utils.BoxCoder(bbox_reg_weights) self.box_roi_pool = box_roi_pool @@ -572,9 +575,7 @@ def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels): clamped_matched_idxs_in_image = torch.zeros( (proposals_in_image.shape[0],), dtype=torch.int64, device=device ) - labels_in_image = torch.zeros( - (proposals_in_image.shape[0],), dtype=torch.int64, device=device - ) + labels_in_image = torch.zeros((proposals_in_image.shape[0],), dtype=torch.int64, device=device) else: # set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image) @@ -601,19 +602,14 @@ def subsample(self, labels): # type: (List[Tensor]) -> List[Tensor] sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels) sampled_inds = [] - for img_idx, (pos_inds_img, neg_inds_img) in enumerate( - zip(sampled_pos_inds, sampled_neg_inds) - ): + for img_idx, (pos_inds_img, neg_inds_img) in enumerate(zip(sampled_pos_inds, sampled_neg_inds)): img_sampled_inds = torch.where(pos_inds_img | neg_inds_img)[0] sampled_inds.append(img_sampled_inds) return sampled_inds def add_gt_proposals(self, proposals, gt_boxes): # type: (List[Tensor], List[Tensor]) -> List[Tensor] - proposals = [ - torch.cat((proposal, gt_box)) - for proposal, gt_box in zip(proposals, gt_boxes) - ] + proposals = [torch.cat((proposal, gt_box)) for proposal, gt_box in zip(proposals, gt_boxes)] return proposals @@ -625,10 +621,11 @@ def check_targets(self, targets): if self.has_mask(): assert all(["masks" in t for t in targets]) - def select_training_samples(self, - proposals, # type: List[Tensor] - targets # type: Optional[List[Dict[str, Tensor]]] - ): + def select_training_samples( + self, + proposals, # type: List[Tensor] + targets, # type: Optional[List[Dict[str, Tensor]]] + ): # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]] self.check_targets(targets) assert targets is not None @@ -661,12 +658,13 @@ def select_training_samples(self, regression_targets = self.box_coder.encode(matched_gt_boxes, proposals) return proposals, matched_idxs, labels, regression_targets - def postprocess_detections(self, - class_logits, # type: Tensor - box_regression, # type: Tensor - proposals, # type: List[Tensor] - image_shapes # type: List[Tuple[int, int]] - ): + def postprocess_detections( + self, + class_logits, # type: Tensor + box_regression, # type: Tensor + proposals, # type: List[Tensor] + image_shapes, # type: List[Tuple[int, int]] + ): # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]] device = class_logits.device num_classes = class_logits.shape[-1] @@ -710,7 +708,7 @@ def postprocess_detections(self, # non-maximum suppression, independently done per class keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh) # keep only topk scoring predictions - keep = keep[:self.detections_per_img] + keep = keep[: self.detections_per_img] boxes, scores, labels = boxes[keep], scores[keep], labels[keep] all_boxes.append(boxes) @@ -719,12 +717,13 @@ def postprocess_detections(self, return all_boxes, all_scores, all_labels - def forward(self, - features, # type: Dict[str, Tensor] - proposals, # type: List[Tensor] - image_shapes, # type: List[Tuple[int, int]] - targets=None # type: Optional[List[Dict[str, Tensor]]] - ): + def forward( + self, + features, # type: Dict[str, Tensor] + proposals, # type: List[Tensor] + image_shapes, # type: List[Tuple[int, int]] + targets=None, # type: Optional[List[Dict[str, Tensor]]] + ): # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]] """ Args: @@ -737,10 +736,10 @@ def forward(self, for t in targets: # TODO: https://github.com/pytorch/pytorch/issues/26731 floating_point_types = (torch.float, torch.double, torch.half) - assert t["boxes"].dtype in floating_point_types, 'target boxes must of float type' - assert t["labels"].dtype == torch.int64, 'target labels must of int64 type' + assert t["boxes"].dtype in floating_point_types, "target boxes must of float type" + assert t["labels"].dtype == torch.int64, "target labels must of int64 type" if self.has_keypoint(): - assert t["keypoints"].dtype == torch.float32, 'target keypoints must of float type' + assert t["keypoints"].dtype == torch.float32, "target keypoints must of float type" if self.training: proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets) @@ -757,12 +756,8 @@ def forward(self, losses = {} if self.training: assert labels is not None and regression_targets is not None - loss_classifier, loss_box_reg = fastrcnn_loss( - class_logits, box_regression, labels, regression_targets) - losses = { - "loss_classifier": loss_classifier, - "loss_box_reg": loss_box_reg - } + loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets) + losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg} else: boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes) num_images = len(boxes) @@ -805,12 +800,8 @@ def forward(self, gt_masks = [t["masks"] for t in targets] gt_labels = [t["labels"] for t in targets] - rcnn_loss_mask = maskrcnn_loss( - mask_logits, mask_proposals, - gt_masks, gt_labels, pos_matched_idxs) - loss_mask = { - "loss_mask": rcnn_loss_mask - } + rcnn_loss_mask = maskrcnn_loss(mask_logits, mask_proposals, gt_masks, gt_labels, pos_matched_idxs) + loss_mask = {"loss_mask": rcnn_loss_mask} else: labels = [r["labels"] for r in result] masks_probs = maskrcnn_inference(mask_logits, labels) @@ -821,8 +812,11 @@ def forward(self, # keep none checks in if conditional so torchscript will conditionally # compile each branch - if self.keypoint_roi_pool is not None and self.keypoint_head is not None \ - and self.keypoint_predictor is not None: + if ( + self.keypoint_roi_pool is not None + and self.keypoint_head is not None + and self.keypoint_predictor is not None + ): keypoint_proposals = [p["boxes"] for p in result] if self.training: # during training, only focus on positive boxes @@ -848,11 +842,9 @@ def forward(self, gt_keypoints = [t["keypoints"] for t in targets] rcnn_loss_keypoint = keypointrcnn_loss( - keypoint_logits, keypoint_proposals, - gt_keypoints, pos_matched_idxs) - loss_keypoint = { - "loss_keypoint": rcnn_loss_keypoint - } + keypoint_logits, keypoint_proposals, gt_keypoints, pos_matched_idxs + ) + loss_keypoint = {"loss_keypoint": rcnn_loss_keypoint} else: assert keypoint_logits is not None assert keypoint_proposals is not None diff --git a/torchvision/models/detection/rpn.py b/torchvision/models/detection/rpn.py index fcac856a916..c963b3c9944 100644 --- a/torchvision/models/detection/rpn.py +++ b/torchvision/models/detection/rpn.py @@ -1,27 +1,25 @@ -import torch -from torch.nn import functional as F -from torch import nn, Tensor +from typing import List, Optional, Dict, Tuple +import torch import torchvision +from torch import nn, Tensor +from torch.nn import functional as F from torchvision.ops import boxes as box_ops from . import _utils as det_utils -from .image_list import ImageList - -from typing import List, Optional, Dict, Tuple # Import AnchorGenerator to keep compatibility. from .anchor_utils import AnchorGenerator +from .image_list import ImageList @torch.jit.unused def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n): # type: (Tensor, int) -> Tuple[int, int] from torch.onnx import operators + num_anchors = operators.shape_as_tensor(ob)[1].unsqueeze(0) - pre_nms_top_n = torch.min(torch.cat( - (torch.tensor([orig_pre_nms_top_n], dtype=num_anchors.dtype), - num_anchors), 0)) + pre_nms_top_n = torch.min(torch.cat((torch.tensor([orig_pre_nms_top_n], dtype=num_anchors.dtype), num_anchors), 0)) return num_anchors, pre_nms_top_n @@ -37,13 +35,9 @@ class RPNHead(nn.Module): def __init__(self, in_channels, num_anchors): super(RPNHead, self).__init__() - self.conv = nn.Conv2d( - in_channels, in_channels, kernel_size=3, stride=1, padding=1 - ) + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1) - self.bbox_pred = nn.Conv2d( - in_channels, num_anchors * 4, kernel_size=1, stride=1 - ) + self.bbox_pred = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=1, stride=1) for layer in self.children(): torch.nn.init.normal_(layer.weight, std=0.01) @@ -76,21 +70,15 @@ def concat_box_prediction_layers(box_cls, box_regression): # same format as the labels. Note that the labels are computed for # all feature levels concatenated, so we keep the same representation # for the objectness and the box_regression - for box_cls_per_level, box_regression_per_level in zip( - box_cls, box_regression - ): + for box_cls_per_level, box_regression_per_level in zip(box_cls, box_regression): N, AxC, H, W = box_cls_per_level.shape Ax4 = box_regression_per_level.shape[1] A = Ax4 // 4 C = AxC // A - box_cls_per_level = permute_and_flatten( - box_cls_per_level, N, A, C, H, W - ) + box_cls_per_level = permute_and_flatten(box_cls_per_level, N, A, C, H, W) box_cls_flattened.append(box_cls_per_level) - box_regression_per_level = permute_and_flatten( - box_regression_per_level, N, A, 4, H, W - ) + box_regression_per_level = permute_and_flatten(box_regression_per_level, N, A, 4, H, W) box_regression_flattened.append(box_regression_per_level) # concatenate on the first dimension (representing the feature levels), to # take into account the way the labels were generated (with all feature maps @@ -125,22 +113,30 @@ class RegionProposalNetwork(torch.nn.Module): nms_thresh (float): NMS threshold used for postprocessing the RPN proposals """ + __annotations__ = { - 'box_coder': det_utils.BoxCoder, - 'proposal_matcher': det_utils.Matcher, - 'fg_bg_sampler': det_utils.BalancedPositiveNegativeSampler, - 'pre_nms_top_n': Dict[str, int], - 'post_nms_top_n': Dict[str, int], + "box_coder": det_utils.BoxCoder, + "proposal_matcher": det_utils.Matcher, + "fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler, + "pre_nms_top_n": Dict[str, int], + "post_nms_top_n": Dict[str, int], } - def __init__(self, - anchor_generator, - head, - # - fg_iou_thresh, bg_iou_thresh, - batch_size_per_image, positive_fraction, - # - pre_nms_top_n, post_nms_top_n, nms_thresh, score_thresh=0.0): + def __init__( + self, + anchor_generator, + head, + # + fg_iou_thresh, + bg_iou_thresh, + batch_size_per_image, + positive_fraction, + # + pre_nms_top_n, + post_nms_top_n, + nms_thresh, + score_thresh=0.0, + ): super(RegionProposalNetwork, self).__init__() self.anchor_generator = anchor_generator self.head = head @@ -155,9 +151,7 @@ def __init__(self, allow_low_quality_matches=True, ) - self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler( - batch_size_per_image, positive_fraction - ) + self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction) # used during testing self._pre_nms_top_n = pre_nms_top_n self._post_nms_top_n = post_nms_top_n @@ -167,13 +161,13 @@ def __init__(self, def pre_nms_top_n(self): if self.training: - return self._pre_nms_top_n['training'] - return self._pre_nms_top_n['testing'] + return self._pre_nms_top_n["training"] + return self._pre_nms_top_n["testing"] def post_nms_top_n(self): if self.training: - return self._post_nms_top_n['training'] - return self._post_nms_top_n['testing'] + return self._post_nms_top_n["training"] + return self._post_nms_top_n["testing"] def assign_targets_to_anchors(self, anchors, targets): # type: (List[Tensor], List[Dict[str, Tensor]]) -> Tuple[List[Tensor], List[Tensor]] @@ -235,8 +229,7 @@ def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_ objectness = objectness.reshape(num_images, -1) levels = [ - torch.full((n,), idx, dtype=torch.int64, device=device) - for idx, n in enumerate(num_anchors_per_level) + torch.full((n,), idx, dtype=torch.int64, device=device) for idx, n in enumerate(num_anchors_per_level) ] levels = torch.cat(levels, 0) levels = levels.reshape(1, -1).expand_as(objectness) @@ -271,7 +264,7 @@ def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_ keep = box_ops.batched_nms(boxes, scores, lvl, self.nms_thresh) # keep only topk scoring predictions - keep = keep[:self.post_nms_top_n()] + keep = keep[: self.post_nms_top_n()] boxes, scores = boxes[keep], scores[keep] final_boxes.append(boxes) @@ -303,24 +296,26 @@ def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets) labels = torch.cat(labels, dim=0) regression_targets = torch.cat(regression_targets, dim=0) - box_loss = F.smooth_l1_loss( - pred_bbox_deltas[sampled_pos_inds], - regression_targets[sampled_pos_inds], - beta=1 / 9, - reduction='sum', - ) / (sampled_inds.numel()) - - objectness_loss = F.binary_cross_entropy_with_logits( - objectness[sampled_inds], labels[sampled_inds] + box_loss = ( + F.smooth_l1_loss( + pred_bbox_deltas[sampled_pos_inds], + regression_targets[sampled_pos_inds], + beta=1 / 9, + reduction="sum", + ) + / (sampled_inds.numel()) ) + objectness_loss = F.binary_cross_entropy_with_logits(objectness[sampled_inds], labels[sampled_inds]) + return objectness_loss, box_loss - def forward(self, - images, # type: ImageList - features, # type: Dict[str, Tensor] - targets=None # type: Optional[List[Dict[str, Tensor]]] - ): + def forward( + self, + images, # type: ImageList + features, # type: Dict[str, Tensor] + targets=None, # type: Optional[List[Dict[str, Tensor]]] + ): # type: (...) -> Tuple[List[Tensor], Dict[str, Tensor]] """ Args: @@ -346,8 +341,7 @@ def forward(self, num_images = len(anchors) num_anchors_per_level_shape_tensors = [o[0].shape for o in objectness] num_anchors_per_level = [s[0] * s[1] * s[2] for s in num_anchors_per_level_shape_tensors] - objectness, pred_bbox_deltas = \ - concat_box_prediction_layers(objectness, pred_bbox_deltas) + objectness, pred_bbox_deltas = concat_box_prediction_layers(objectness, pred_bbox_deltas) # apply pred_bbox_deltas to anchors to obtain the decoded proposals # note that we detach the deltas because Faster R-CNN do not backprop through # the proposals @@ -361,7 +355,8 @@ def forward(self, labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets) regression_targets = self.box_coder.encode(matched_gt_boxes, anchors) loss_objectness, loss_rpn_box_reg = self.compute_loss( - objectness, pred_bbox_deltas, labels, regression_targets) + objectness, pred_bbox_deltas, labels, regression_targets + ) losses = { "loss_objectness": loss_objectness, "loss_rpn_box_reg": loss_rpn_box_reg, diff --git a/torchvision/models/detection/ssd.py b/torchvision/models/detection/ssd.py index e67c4930b30..37eb7522f94 100644 --- a/torchvision/models/detection/ssd.py +++ b/torchvision/models/detection/ssd.py @@ -1,29 +1,29 @@ -import torch -import torch.nn.functional as F import warnings - from collections import OrderedDict -from torch import nn, Tensor from typing import Any, Dict, List, Optional, Tuple +import torch +import torch.nn.functional as F +from torch import nn, Tensor + +from ..._internally_replaced_utils import load_state_dict_from_url +from ...ops import boxes as box_ops +from .. import vgg from . import _utils as det_utils from .anchor_utils import DefaultBoxGenerator from .backbone_utils import _validate_trainable_layers from .transform import GeneralizedRCNNTransform -from .. import vgg -from ..._internally_replaced_utils import load_state_dict_from_url -from ...ops import boxes as box_ops -__all__ = ['SSD', 'ssd300_vgg16'] +__all__ = ["SSD", "ssd300_vgg16"] model_urls = { - 'ssd300_vgg16_coco': 'https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth', + "ssd300_vgg16_coco": "https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth", } backbone_urls = { # We port the features of a VGG16 backbone trained by amdegroot because unlike the one on TorchVision, it uses the # same input standardization method as the paper. Ref: https://s3.amazonaws.com/amdegroot-models/vgg16_reducedfc.pth - 'vgg16_features': 'https://download.pytorch.org/models/vgg16_features-amdegroot.pth' + "vgg16_features": "https://download.pytorch.org/models/vgg16_features-amdegroot.pth" } @@ -43,8 +43,8 @@ def __init__(self, in_channels: List[int], num_anchors: List[int], num_classes: def forward(self, x: List[Tensor]) -> Dict[str, Tensor]: return { - 'bbox_regression': self.regression_head(x), - 'cls_logits': self.classification_head(x), + "bbox_regression": self.regression_head(x), + "cls_logits": self.classification_head(x), } @@ -159,31 +159,38 @@ class SSD(nn.Module): proposals used during the training of the classification head. It is used to estimate the negative to positive ratio. """ + __annotations__ = { - 'box_coder': det_utils.BoxCoder, - 'proposal_matcher': det_utils.Matcher, + "box_coder": det_utils.BoxCoder, + "proposal_matcher": det_utils.Matcher, } - def __init__(self, backbone: nn.Module, anchor_generator: DefaultBoxGenerator, - size: Tuple[int, int], num_classes: int, - image_mean: Optional[List[float]] = None, image_std: Optional[List[float]] = None, - head: Optional[nn.Module] = None, - score_thresh: float = 0.01, - nms_thresh: float = 0.45, - detections_per_img: int = 200, - iou_thresh: float = 0.5, - topk_candidates: int = 400, - positive_fraction: float = 0.25): + def __init__( + self, + backbone: nn.Module, + anchor_generator: DefaultBoxGenerator, + size: Tuple[int, int], + num_classes: int, + image_mean: Optional[List[float]] = None, + image_std: Optional[List[float]] = None, + head: Optional[nn.Module] = None, + score_thresh: float = 0.01, + nms_thresh: float = 0.45, + detections_per_img: int = 200, + iou_thresh: float = 0.5, + topk_candidates: int = 400, + positive_fraction: float = 0.25, + ): super().__init__() self.backbone = backbone self.anchor_generator = anchor_generator - self.box_coder = det_utils.BoxCoder(weights=(10., 10., 5., 5.)) + self.box_coder = det_utils.BoxCoder(weights=(10.0, 10.0, 5.0, 5.0)) if head is None: - if hasattr(backbone, 'out_channels'): + if hasattr(backbone, "out_channels"): out_channels = backbone.out_channels else: out_channels = det_utils.retrieve_out_channels(backbone, size) @@ -200,8 +207,9 @@ def __init__(self, backbone: nn.Module, anchor_generator: DefaultBoxGenerator, image_mean = [0.485, 0.456, 0.406] if image_std is None: image_std = [0.229, 0.224, 0.225] - self.transform = GeneralizedRCNNTransform(min(size), max(size), image_mean, image_std, - size_divisible=1, fixed_size=size) + self.transform = GeneralizedRCNNTransform( + min(size), max(size), image_mean, image_std, size_divisible=1, fixed_size=size + ) self.score_thresh = score_thresh self.nms_thresh = nms_thresh @@ -213,45 +221,58 @@ def __init__(self, backbone: nn.Module, anchor_generator: DefaultBoxGenerator, self._has_warned = False @torch.jit.unused - def eager_outputs(self, losses: Dict[str, Tensor], - detections: List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]: + def eager_outputs( + self, losses: Dict[str, Tensor], detections: List[Dict[str, Tensor]] + ) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]: if self.training: return losses return detections - def compute_loss(self, targets: List[Dict[str, Tensor]], head_outputs: Dict[str, Tensor], anchors: List[Tensor], - matched_idxs: List[Tensor]) -> Dict[str, Tensor]: - bbox_regression = head_outputs['bbox_regression'] - cls_logits = head_outputs['cls_logits'] + def compute_loss( + self, + targets: List[Dict[str, Tensor]], + head_outputs: Dict[str, Tensor], + anchors: List[Tensor], + matched_idxs: List[Tensor], + ) -> Dict[str, Tensor]: + bbox_regression = head_outputs["bbox_regression"] + cls_logits = head_outputs["cls_logits"] # Match original targets with default boxes num_foreground = 0 bbox_loss = [] cls_targets = [] - for (targets_per_image, bbox_regression_per_image, cls_logits_per_image, anchors_per_image, - matched_idxs_per_image) in zip(targets, bbox_regression, cls_logits, anchors, matched_idxs): + for ( + targets_per_image, + bbox_regression_per_image, + cls_logits_per_image, + anchors_per_image, + matched_idxs_per_image, + ) in zip(targets, bbox_regression, cls_logits, anchors, matched_idxs): # produce the matching between boxes and targets foreground_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0] foreground_matched_idxs_per_image = matched_idxs_per_image[foreground_idxs_per_image] num_foreground += foreground_matched_idxs_per_image.numel() # Calculate regression loss - matched_gt_boxes_per_image = targets_per_image['boxes'][foreground_matched_idxs_per_image] + matched_gt_boxes_per_image = targets_per_image["boxes"][foreground_matched_idxs_per_image] bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :] anchors_per_image = anchors_per_image[foreground_idxs_per_image, :] target_regression = self.box_coder.encode_single(matched_gt_boxes_per_image, anchors_per_image) - bbox_loss.append(torch.nn.functional.smooth_l1_loss( - bbox_regression_per_image, - target_regression, - reduction='sum' - )) + bbox_loss.append( + torch.nn.functional.smooth_l1_loss(bbox_regression_per_image, target_regression, reduction="sum") + ) # Estimate ground truth for class targets - gt_classes_target = torch.zeros((cls_logits_per_image.size(0), ), dtype=targets_per_image['labels'].dtype, - device=targets_per_image['labels'].device) - gt_classes_target[foreground_idxs_per_image] = \ - targets_per_image['labels'][foreground_matched_idxs_per_image] + gt_classes_target = torch.zeros( + (cls_logits_per_image.size(0),), + dtype=targets_per_image["labels"].dtype, + device=targets_per_image["labels"].device, + ) + gt_classes_target[foreground_idxs_per_image] = targets_per_image["labels"][ + foreground_matched_idxs_per_image + ] cls_targets.append(gt_classes_target) bbox_loss = torch.stack(bbox_loss) @@ -259,30 +280,29 @@ def compute_loss(self, targets: List[Dict[str, Tensor]], head_outputs: Dict[str, # Calculate classification loss num_classes = cls_logits.size(-1) - cls_loss = F.cross_entropy( - cls_logits.view(-1, num_classes), - cls_targets.view(-1), - reduction='none' - ).view(cls_targets.size()) + cls_loss = F.cross_entropy(cls_logits.view(-1, num_classes), cls_targets.view(-1), reduction="none").view( + cls_targets.size() + ) # Hard Negative Sampling foreground_idxs = cls_targets > 0 num_negative = self.neg_to_pos_ratio * foreground_idxs.sum(1, keepdim=True) # num_negative[num_negative < self.neg_to_pos_ratio] = self.neg_to_pos_ratio negative_loss = cls_loss.clone() - negative_loss[foreground_idxs] = -float('inf') # use -inf to detect positive values that creeped in the sample + negative_loss[foreground_idxs] = -float("inf") # use -inf to detect positive values that creeped in the sample values, idx = negative_loss.sort(1, descending=True) # background_idxs = torch.logical_and(idx.sort(1)[1] < num_negative, torch.isfinite(values)) background_idxs = idx.sort(1)[1] < num_negative N = max(1, num_foreground) return { - 'bbox_regression': bbox_loss.sum() / N, - 'classification': (cls_loss[foreground_idxs].sum() + cls_loss[background_idxs].sum()) / N, + "bbox_regression": bbox_loss.sum() / N, + "classification": (cls_loss[foreground_idxs].sum() + cls_loss[background_idxs].sum()) / N, } - def forward(self, images: List[Tensor], - targets: Optional[List[Dict[str, Tensor]]] = None) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]: + def forward( + self, images: List[Tensor], targets: Optional[List[Dict[str, Tensor]]] = None + ) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]: if self.training and targets is None: raise ValueError("In training mode, targets should be passed") @@ -292,12 +312,11 @@ def forward(self, images: List[Tensor], boxes = target["boxes"] if isinstance(boxes, torch.Tensor): if len(boxes.shape) != 2 or boxes.shape[-1] != 4: - raise ValueError("Expected target boxes to be a tensor" - "of shape [N, 4], got {:}.".format( - boxes.shape)) + raise ValueError( + "Expected target boxes to be a tensor" "of shape [N, 4], got {:}.".format(boxes.shape) + ) else: - raise ValueError("Expected target boxes to be of type " - "Tensor, got {:}.".format(type(boxes))) + raise ValueError("Expected target boxes to be of type " "Tensor, got {:}.".format(type(boxes))) # get the original image sizes original_image_sizes: List[Tuple[int, int]] = [] @@ -317,14 +336,15 @@ def forward(self, images: List[Tensor], if degenerate_boxes.any(): bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0] degen_bb: List[float] = boxes[bb_idx].tolist() - raise ValueError("All bounding boxes should have positive height and width." - " Found invalid box {} for target at index {}." - .format(degen_bb, target_idx)) + raise ValueError( + "All bounding boxes should have positive height and width." + " Found invalid box {} for target at index {}.".format(degen_bb, target_idx) + ) # get the features from the backbone features = self.backbone(images.tensors) if isinstance(features, torch.Tensor): - features = OrderedDict([('0', features)]) + features = OrderedDict([("0", features)]) features = list(features.values()) @@ -341,12 +361,13 @@ def forward(self, images: List[Tensor], matched_idxs = [] for anchors_per_image, targets_per_image in zip(anchors, targets): - if targets_per_image['boxes'].numel() == 0: - matched_idxs.append(torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64, - device=anchors_per_image.device)) + if targets_per_image["boxes"].numel() == 0: + matched_idxs.append( + torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64, device=anchors_per_image.device) + ) continue - match_quality_matrix = box_ops.box_iou(targets_per_image['boxes'], anchors_per_image) + match_quality_matrix = box_ops.box_iou(targets_per_image["boxes"], anchors_per_image) matched_idxs.append(self.proposal_matcher(match_quality_matrix)) losses = self.compute_loss(targets, head_outputs, anchors, matched_idxs) @@ -361,10 +382,11 @@ def forward(self, images: List[Tensor], return losses, detections return self.eager_outputs(losses, detections) - def postprocess_detections(self, head_outputs: Dict[str, Tensor], image_anchors: List[Tensor], - image_shapes: List[Tuple[int, int]]) -> List[Dict[str, Tensor]]: - bbox_regression = head_outputs['bbox_regression'] - pred_scores = F.softmax(head_outputs['cls_logits'], dim=-1) + def postprocess_detections( + self, head_outputs: Dict[str, Tensor], image_anchors: List[Tensor], image_shapes: List[Tuple[int, int]] + ) -> List[Dict[str, Tensor]]: + bbox_regression = head_outputs["bbox_regression"] + pred_scores = F.softmax(head_outputs["cls_logits"], dim=-1) num_classes = pred_scores.size(-1) device = pred_scores.device @@ -400,13 +422,15 @@ def postprocess_detections(self, head_outputs: Dict[str, Tensor], image_anchors: # non-maximum suppression keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh) - keep = keep[:self.detections_per_img] - - detections.append({ - 'boxes': image_boxes[keep], - 'scores': image_scores[keep], - 'labels': image_labels[keep], - }) + keep = keep[: self.detections_per_img] + + detections.append( + { + "boxes": image_boxes[keep], + "scores": image_scores[keep], + "labels": image_labels[keep], + } + ) return detections @@ -423,45 +447,47 @@ def __init__(self, backbone: nn.Module, highres: bool): self.scale_weight = nn.Parameter(torch.ones(512) * 20) # Multiple Feature maps - page 4, Fig 2 of SSD paper - self.features = nn.Sequential( - *backbone[:maxpool4_pos] # until conv4_3 - ) + self.features = nn.Sequential(*backbone[:maxpool4_pos]) # until conv4_3 # SSD300 case - page 4, Fig 2 of SSD paper - extra = nn.ModuleList([ - nn.Sequential( - nn.Conv2d(1024, 256, kernel_size=1), - nn.ReLU(inplace=True), - nn.Conv2d(256, 512, kernel_size=3, padding=1, stride=2), # conv8_2 - nn.ReLU(inplace=True), - ), - nn.Sequential( - nn.Conv2d(512, 128, kernel_size=1), - nn.ReLU(inplace=True), - nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2), # conv9_2 - nn.ReLU(inplace=True), - ), - nn.Sequential( - nn.Conv2d(256, 128, kernel_size=1), - nn.ReLU(inplace=True), - nn.Conv2d(128, 256, kernel_size=3), # conv10_2 - nn.ReLU(inplace=True), - ), - nn.Sequential( - nn.Conv2d(256, 128, kernel_size=1), - nn.ReLU(inplace=True), - nn.Conv2d(128, 256, kernel_size=3), # conv11_2 - nn.ReLU(inplace=True), - ) - ]) + extra = nn.ModuleList( + [ + nn.Sequential( + nn.Conv2d(1024, 256, kernel_size=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 512, kernel_size=3, padding=1, stride=2), # conv8_2 + nn.ReLU(inplace=True), + ), + nn.Sequential( + nn.Conv2d(512, 128, kernel_size=1), + nn.ReLU(inplace=True), + nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2), # conv9_2 + nn.ReLU(inplace=True), + ), + nn.Sequential( + nn.Conv2d(256, 128, kernel_size=1), + nn.ReLU(inplace=True), + nn.Conv2d(128, 256, kernel_size=3), # conv10_2 + nn.ReLU(inplace=True), + ), + nn.Sequential( + nn.Conv2d(256, 128, kernel_size=1), + nn.ReLU(inplace=True), + nn.Conv2d(128, 256, kernel_size=3), # conv11_2 + nn.ReLU(inplace=True), + ), + ] + ) if highres: # Additional layers for the SSD512 case. See page 11, footernote 5. - extra.append(nn.Sequential( - nn.Conv2d(256, 128, kernel_size=1), - nn.ReLU(inplace=True), - nn.Conv2d(128, 256, kernel_size=4), # conv12_2 - nn.ReLU(inplace=True), - )) + extra.append( + nn.Sequential( + nn.Conv2d(256, 128, kernel_size=1), + nn.ReLU(inplace=True), + nn.Conv2d(128, 256, kernel_size=4), # conv12_2 + nn.ReLU(inplace=True), + ) + ) _xavier_init(extra) fc = nn.Sequential( @@ -469,13 +495,16 @@ def __init__(self, backbone: nn.Module, highres: bool): nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, padding=6, dilation=6), # FC6 with atrous nn.ReLU(inplace=True), nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=1), # FC7 - nn.ReLU(inplace=True) + nn.ReLU(inplace=True), ) _xavier_init(fc) - extra.insert(0, nn.Sequential( - *backbone[maxpool4_pos:-1], # until conv5_3, skip maxpool5 - fc, - )) + extra.insert( + 0, + nn.Sequential( + *backbone[maxpool4_pos:-1], # until conv5_3, skip maxpool5 + fc, + ), + ) self.extra = extra def forward(self, x: Tensor) -> Dict[str, Tensor]: @@ -495,7 +524,7 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]: def _vgg_extractor(backbone_name: str, highres: bool, progress: bool, pretrained: bool, trainable_layers: int): if backbone_name in backbone_urls: # Use custom backbones more appropriate for SSD - arch = backbone_name.split('_')[0] + arch = backbone_name.split("_")[0] backbone = vgg.__dict__[arch](pretrained=False, progress=progress).features if pretrained: state_dict = load_state_dict_from_url(backbone_urls[backbone_name], progress=progress) @@ -519,8 +548,14 @@ def _vgg_extractor(backbone_name: str, highres: bool, progress: bool, pretrained return SSDFeatureExtractorVGG(backbone, highres) -def ssd300_vgg16(pretrained: bool = False, progress: bool = True, num_classes: int = 91, - pretrained_backbone: bool = True, trainable_backbone_layers: Optional[int] = None, **kwargs: Any): +def ssd300_vgg16( + pretrained: bool = False, + progress: bool = True, + num_classes: int = 91, + pretrained_backbone: bool = True, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +): """Constructs an SSD model with input size 300x300 and a VGG16 backbone. Reference: `"SSD: Single Shot MultiBox Detector" `_. @@ -569,16 +604,19 @@ def ssd300_vgg16(pretrained: bool = False, progress: bool = True, num_classes: i warnings.warn("The size of the model is already fixed; ignoring the argument.") trainable_backbone_layers = _validate_trainable_layers( - pretrained or pretrained_backbone, trainable_backbone_layers, 5, 5) + pretrained or pretrained_backbone, trainable_backbone_layers, 5, 5 + ) if pretrained: # no need to download the backbone if pretrained is set pretrained_backbone = False backbone = _vgg_extractor("vgg16_features", False, progress, pretrained_backbone, trainable_backbone_layers) - anchor_generator = DefaultBoxGenerator([[2], [2, 3], [2, 3], [2, 3], [2], [2]], - scales=[0.07, 0.15, 0.33, 0.51, 0.69, 0.87, 1.05], - steps=[8, 16, 32, 64, 100, 300]) + anchor_generator = DefaultBoxGenerator( + [[2], [2, 3], [2, 3], [2, 3], [2], [2]], + scales=[0.07, 0.15, 0.33, 0.51, 0.69, 0.87, 1.05], + steps=[8, 16, 32, 64, 100, 300], + ) defaults = { # Rescale the input in a way compatible to the backbone @@ -588,7 +626,7 @@ def ssd300_vgg16(pretrained: bool = False, progress: bool = True, num_classes: i kwargs = {**defaults, **kwargs} model = SSD(backbone, anchor_generator, (300, 300), num_classes, **kwargs) if pretrained: - weights_name = 'ssd300_vgg16_coco' + weights_name = "ssd300_vgg16_coco" if model_urls.get(weights_name, None) is None: raise ValueError("No checkpoint is available for model {}".format(weights_name)) state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress) diff --git a/torchvision/models/detection/ssdlite.py b/torchvision/models/detection/ssdlite.py index c78caad1be7..2ede62040b7 100644 --- a/torchvision/models/detection/ssdlite.py +++ b/torchvision/models/detection/ssdlite.py @@ -1,38 +1,43 @@ -import torch import warnings - from collections import OrderedDict from functools import partial -from torch import nn, Tensor from typing import Any, Callable, Dict, List, Optional, Tuple +import torch +from torch import nn, Tensor + +from ..._internally_replaced_utils import load_state_dict_from_url +from ...ops.misc import ConvNormActivation +from .. import mobilenet from . import _utils as det_utils -from .ssd import SSD, SSDScoringHead from .anchor_utils import DefaultBoxGenerator from .backbone_utils import _validate_trainable_layers -from .. import mobilenet -from ..._internally_replaced_utils import load_state_dict_from_url -from ...ops.misc import ConvNormActivation +from .ssd import SSD, SSDScoringHead -__all__ = ['ssdlite320_mobilenet_v3_large'] +__all__ = ["ssdlite320_mobilenet_v3_large"] model_urls = { - 'ssdlite320_mobilenet_v3_large_coco': - 'https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth' + "ssdlite320_mobilenet_v3_large_coco": "https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth" } # Building blocks of SSDlite as described in section 6.2 of MobileNetV2 paper -def _prediction_block(in_channels: int, out_channels: int, kernel_size: int, - norm_layer: Callable[..., nn.Module]) -> nn.Sequential: +def _prediction_block( + in_channels: int, out_channels: int, kernel_size: int, norm_layer: Callable[..., nn.Module] +) -> nn.Sequential: return nn.Sequential( # 3x3 depthwise with stride 1 and padding 1 - ConvNormActivation(in_channels, in_channels, kernel_size=kernel_size, groups=in_channels, - norm_layer=norm_layer, activation_layer=nn.ReLU6), - + ConvNormActivation( + in_channels, + in_channels, + kernel_size=kernel_size, + groups=in_channels, + norm_layer=norm_layer, + activation_layer=nn.ReLU6, + ), # 1x1 projetion to output channels - nn.Conv2d(in_channels, out_channels, 1) + nn.Conv2d(in_channels, out_channels, 1), ) @@ -41,16 +46,23 @@ def _extra_block(in_channels: int, out_channels: int, norm_layer: Callable[..., intermediate_channels = out_channels // 2 return nn.Sequential( # 1x1 projection to half output channels - ConvNormActivation(in_channels, intermediate_channels, kernel_size=1, - norm_layer=norm_layer, activation_layer=activation), - + ConvNormActivation( + in_channels, intermediate_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=activation + ), # 3x3 depthwise with stride 2 and padding 1 - ConvNormActivation(intermediate_channels, intermediate_channels, kernel_size=3, stride=2, - groups=intermediate_channels, norm_layer=norm_layer, activation_layer=activation), - + ConvNormActivation( + intermediate_channels, + intermediate_channels, + kernel_size=3, + stride=2, + groups=intermediate_channels, + norm_layer=norm_layer, + activation_layer=activation, + ), # 1x1 projetion to output channels - ConvNormActivation(intermediate_channels, out_channels, kernel_size=1, - norm_layer=norm_layer, activation_layer=activation), + ConvNormActivation( + intermediate_channels, out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=activation + ), ) @@ -63,22 +75,24 @@ def _normal_init(conv: nn.Module): class SSDLiteHead(nn.Module): - def __init__(self, in_channels: List[int], num_anchors: List[int], num_classes: int, - norm_layer: Callable[..., nn.Module]): + def __init__( + self, in_channels: List[int], num_anchors: List[int], num_classes: int, norm_layer: Callable[..., nn.Module] + ): super().__init__() self.classification_head = SSDLiteClassificationHead(in_channels, num_anchors, num_classes, norm_layer) self.regression_head = SSDLiteRegressionHead(in_channels, num_anchors, norm_layer) def forward(self, x: List[Tensor]) -> Dict[str, Tensor]: return { - 'bbox_regression': self.regression_head(x), - 'cls_logits': self.classification_head(x), + "bbox_regression": self.regression_head(x), + "cls_logits": self.classification_head(x), } class SSDLiteClassificationHead(SSDScoringHead): - def __init__(self, in_channels: List[int], num_anchors: List[int], num_classes: int, - norm_layer: Callable[..., nn.Module]): + def __init__( + self, in_channels: List[int], num_anchors: List[int], num_classes: int, norm_layer: Callable[..., nn.Module] + ): cls_logits = nn.ModuleList() for channels, anchors in zip(in_channels, num_anchors): cls_logits.append(_prediction_block(channels, num_classes * anchors, 3, norm_layer)) @@ -96,24 +110,33 @@ def __init__(self, in_channels: List[int], num_anchors: List[int], norm_layer: C class SSDLiteFeatureExtractorMobileNet(nn.Module): - def __init__(self, backbone: nn.Module, c4_pos: int, norm_layer: Callable[..., nn.Module], width_mult: float = 1.0, - min_depth: int = 16, **kwargs: Any): + def __init__( + self, + backbone: nn.Module, + c4_pos: int, + norm_layer: Callable[..., nn.Module], + width_mult: float = 1.0, + min_depth: int = 16, + **kwargs: Any, + ): super().__init__() assert not backbone[c4_pos].use_res_connect self.features = nn.Sequential( # As described in section 6.3 of MobileNetV3 paper nn.Sequential(*backbone[:c4_pos], backbone[c4_pos].block[0]), # from start until C4 expansion layer - nn.Sequential(backbone[c4_pos].block[1:], *backbone[c4_pos + 1:]), # from C4 depthwise until end + nn.Sequential(backbone[c4_pos].block[1:], *backbone[c4_pos + 1 :]), # from C4 depthwise until end ) get_depth = lambda d: max(min_depth, int(d * width_mult)) # noqa: E731 - extra = nn.ModuleList([ - _extra_block(backbone[-1].out_channels, get_depth(512), norm_layer), - _extra_block(get_depth(512), get_depth(256), norm_layer), - _extra_block(get_depth(256), get_depth(256), norm_layer), - _extra_block(get_depth(256), get_depth(128), norm_layer), - ]) + extra = nn.ModuleList( + [ + _extra_block(backbone[-1].out_channels, get_depth(512), norm_layer), + _extra_block(get_depth(512), get_depth(256), norm_layer), + _extra_block(get_depth(256), get_depth(256), norm_layer), + _extra_block(get_depth(256), get_depth(128), norm_layer), + ] + ) _normal_init(extra) self.extra = extra @@ -132,10 +155,17 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]: return OrderedDict([(str(i), v) for i, v in enumerate(output)]) -def _mobilenet_extractor(backbone_name: str, progress: bool, pretrained: bool, trainable_layers: int, - norm_layer: Callable[..., nn.Module], **kwargs: Any): - backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, progress=progress, - norm_layer=norm_layer, **kwargs).features +def _mobilenet_extractor( + backbone_name: str, + progress: bool, + pretrained: bool, + trainable_layers: int, + norm_layer: Callable[..., nn.Module], + **kwargs: Any, +): + backbone = mobilenet.__dict__[backbone_name]( + pretrained=pretrained, progress=progress, norm_layer=norm_layer, **kwargs + ).features if not pretrained: # Change the default initialization scheme if not pretrained _normal_init(backbone) @@ -156,10 +186,15 @@ def _mobilenet_extractor(backbone_name: str, progress: bool, pretrained: bool, t return SSDLiteFeatureExtractorMobileNet(backbone, stage_indices[-2], norm_layer, **kwargs) -def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = True, num_classes: int = 91, - pretrained_backbone: bool = False, trainable_backbone_layers: Optional[int] = None, - norm_layer: Optional[Callable[..., nn.Module]] = None, - **kwargs: Any): +def ssdlite320_mobilenet_v3_large( + pretrained: bool = False, + progress: bool = True, + num_classes: int = 91, + pretrained_backbone: bool = False, + trainable_backbone_layers: Optional[int] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + **kwargs: Any, +): """Constructs an SSDlite model with input size 320x320 and a MobileNetV3 Large backbone, as described at `"Searching for MobileNetV3" `_ and @@ -188,7 +223,8 @@ def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = Tru warnings.warn("The size of the model is already fixed; ignoring the argument.") trainable_backbone_layers = _validate_trainable_layers( - pretrained or pretrained_backbone, trainable_backbone_layers, 6, 6) + pretrained or pretrained_backbone, trainable_backbone_layers, 6, 6 + ) if pretrained: pretrained_backbone = False @@ -199,8 +235,15 @@ def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = Tru if norm_layer is None: norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.03) - backbone = _mobilenet_extractor("mobilenet_v3_large", progress, pretrained_backbone, trainable_backbone_layers, - norm_layer, reduced_tail=reduce_tail, **kwargs) + backbone = _mobilenet_extractor( + "mobilenet_v3_large", + progress, + pretrained_backbone, + trainable_backbone_layers, + norm_layer, + reduced_tail=reduce_tail, + **kwargs, + ) size = (320, 320) anchor_generator = DefaultBoxGenerator([[2, 3] for _ in range(6)], min_ratio=0.2, max_ratio=0.95) @@ -219,11 +262,17 @@ def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = Tru "image_std": [0.5, 0.5, 0.5], } kwargs = {**defaults, **kwargs} - model = SSD(backbone, anchor_generator, size, num_classes, - head=SSDLiteHead(out_channels, num_anchors, num_classes, norm_layer), **kwargs) + model = SSD( + backbone, + anchor_generator, + size, + num_classes, + head=SSDLiteHead(out_channels, num_anchors, num_classes, norm_layer), + **kwargs, + ) if pretrained: - weights_name = 'ssdlite320_mobilenet_v3_large_coco' + weights_name = "ssdlite320_mobilenet_v3_large_coco" if model_urls.get(weights_name, None) is None: raise ValueError("No checkpoint is available for model {}".format(weights_name)) state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress) diff --git a/torchvision/models/detection/transform.py b/torchvision/models/detection/transform.py index 60fa2e344ed..e4a1134b85c 100644 --- a/torchvision/models/detection/transform.py +++ b/torchvision/models/detection/transform.py @@ -1,9 +1,9 @@ import math +from typing import List, Tuple, Dict, Optional + import torch import torchvision - from torch import nn, Tensor -from typing import List, Tuple, Dict, Optional from .image_list import ImageList from .roi_heads import paste_masks_in_image @@ -12,6 +12,7 @@ @torch.jit.unused def _get_shape_onnx(image: Tensor) -> Tensor: from torch.onnx import operators + return operators.shape_as_tensor(image)[-2:] @@ -21,10 +22,13 @@ def _fake_cast_onnx(v: Tensor) -> float: return v -def _resize_image_and_masks(image: Tensor, self_min_size: float, self_max_size: float, - target: Optional[Dict[str, Tensor]] = None, - fixed_size: Optional[Tuple[int, int]] = None, - ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: +def _resize_image_and_masks( + image: Tensor, + self_min_size: float, + self_max_size: float, + target: Optional[Dict[str, Tensor]] = None, + fixed_size: Optional[Tuple[int, int]] = None, +) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: if torchvision._is_tracing(): im_shape = _get_shape_onnx(image) else: @@ -46,16 +50,23 @@ def _resize_image_and_masks(image: Tensor, self_min_size: float, self_max_size: scale_factor = scale.item() recompute_scale_factor = True - image = torch.nn.functional.interpolate(image[None], size=size, scale_factor=scale_factor, mode='bilinear', - recompute_scale_factor=recompute_scale_factor, align_corners=False)[0] + image = torch.nn.functional.interpolate( + image[None], + size=size, + scale_factor=scale_factor, + mode="bilinear", + recompute_scale_factor=recompute_scale_factor, + align_corners=False, + )[0] if target is None: return image, target if "masks" in target: mask = target["masks"] - mask = torch.nn.functional.interpolate(mask[:, None].float(), size=size, scale_factor=scale_factor, - recompute_scale_factor=recompute_scale_factor)[:, 0].byte() + mask = torch.nn.functional.interpolate( + mask[:, None].float(), size=size, scale_factor=scale_factor, recompute_scale_factor=recompute_scale_factor + )[:, 0].byte() target["masks"] = mask return image, target @@ -72,8 +83,15 @@ class GeneralizedRCNNTransform(nn.Module): It returns a ImageList for the inputs, and a List[Dict[Tensor]] for the targets """ - def __init__(self, min_size: int, max_size: int, image_mean: List[float], image_std: List[float], - size_divisible: int = 32, fixed_size: Optional[Tuple[int, int]] = None): + def __init__( + self, + min_size: int, + max_size: int, + image_mean: List[float], + image_std: List[float], + size_divisible: int = 32, + fixed_size: Optional[Tuple[int, int]] = None, + ): super(GeneralizedRCNNTransform, self).__init__() if not isinstance(min_size, (list, tuple)): min_size = (min_size,) @@ -84,10 +102,9 @@ def __init__(self, min_size: int, max_size: int, image_mean: List[float], image_ self.size_divisible = size_divisible self.fixed_size = fixed_size - def forward(self, - images: List[Tensor], - targets: Optional[List[Dict[str, Tensor]]] = None - ) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]]: + def forward( + self, images: List[Tensor], targets: Optional[List[Dict[str, Tensor]]] = None + ) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]]: images = [img for img in images] if targets is not None: # make a copy of targets to avoid modifying it in-place @@ -106,8 +123,9 @@ def forward(self, target_index = targets[i] if targets is not None else None if image.dim() != 3: - raise ValueError("images is expected to be a list of 3d tensors " - "of shape [C, H, W], got {}".format(image.shape)) + raise ValueError( + "images is expected to be a list of 3d tensors " "of shape [C, H, W], got {}".format(image.shape) + ) image = self.normalize(image) image, target_index = self.resize(image, target_index) images[i] = image @@ -141,13 +159,14 @@ def torch_choice(self, k: List[int]) -> int: TorchScript. Remove if https://github.com/pytorch/pytorch/issues/25803 is fixed. """ - index = int(torch.empty(1).uniform_(0., float(len(k))).item()) + index = int(torch.empty(1).uniform_(0.0, float(len(k))).item()) return k[index] - def resize(self, - image: Tensor, - target: Optional[Dict[str, Tensor]] = None, - ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: + def resize( + self, + image: Tensor, + target: Optional[Dict[str, Tensor]] = None, + ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: h, w = image.shape[-2:] if self.training: size = float(self.torch_choice(self.min_size)) @@ -220,11 +239,12 @@ def batch_images(self, images: List[Tensor], size_divisible: int = 32) -> Tensor return batched_imgs - def postprocess(self, - result: List[Dict[str, Tensor]], - image_shapes: List[Tuple[int, int]], - original_image_sizes: List[Tuple[int, int]] - ) -> List[Dict[str, Tensor]]: + def postprocess( + self, + result: List[Dict[str, Tensor]], + image_shapes: List[Tuple[int, int]], + original_image_sizes: List[Tuple[int, int]], + ) -> List[Dict[str, Tensor]]: if self.training: return result for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)): @@ -242,19 +262,20 @@ def postprocess(self, return result def __repr__(self) -> str: - format_string = self.__class__.__name__ + '(' - _indent = '\n ' + format_string = self.__class__.__name__ + "(" + _indent = "\n " format_string += "{0}Normalize(mean={1}, std={2})".format(_indent, self.image_mean, self.image_std) - format_string += "{0}Resize(min_size={1}, max_size={2}, mode='bilinear')".format(_indent, self.min_size, - self.max_size) - format_string += '\n)' + format_string += "{0}Resize(min_size={1}, max_size={2}, mode='bilinear')".format( + _indent, self.min_size, self.max_size + ) + format_string += "\n)" return format_string def resize_keypoints(keypoints: Tensor, original_size: List[int], new_size: List[int]) -> Tensor: ratios = [ - torch.tensor(s, dtype=torch.float32, device=keypoints.device) / - torch.tensor(s_orig, dtype=torch.float32, device=keypoints.device) + torch.tensor(s, dtype=torch.float32, device=keypoints.device) + / torch.tensor(s_orig, dtype=torch.float32, device=keypoints.device) for s, s_orig in zip(new_size, original_size) ] ratio_h, ratio_w = ratios @@ -271,8 +292,8 @@ def resize_keypoints(keypoints: Tensor, original_size: List[int], new_size: List def resize_boxes(boxes: Tensor, original_size: List[int], new_size: List[int]) -> Tensor: ratios = [ - torch.tensor(s, dtype=torch.float32, device=boxes.device) / - torch.tensor(s_orig, dtype=torch.float32, device=boxes.device) + torch.tensor(s, dtype=torch.float32, device=boxes.device) + / torch.tensor(s_orig, dtype=torch.float32, device=boxes.device) for s, s_orig in zip(new_size, original_size) ] ratio_height, ratio_width = ratios diff --git a/torchvision/models/efficientnet.py b/torchvision/models/efficientnet.py index 4dd23c1ea45..b9a5913ea77 100644 --- a/torchvision/models/efficientnet.py +++ b/torchvision/models/efficientnet.py @@ -1,19 +1,28 @@ import copy import math -import torch - from functools import partial -from torch import nn, Tensor from typing import Any, Callable, List, Optional, Sequence +import torch +from torch import nn, Tensor +from torchvision.ops import StochasticDepth + from .._internally_replaced_utils import load_state_dict_from_url from ..ops.misc import ConvNormActivation, SqueezeExcitation from ._utils import _make_divisible -from torchvision.ops import StochasticDepth -__all__ = ["EfficientNet", "efficientnet_b0", "efficientnet_b1", "efficientnet_b2", "efficientnet_b3", - "efficientnet_b4", "efficientnet_b5", "efficientnet_b6", "efficientnet_b7"] +__all__ = [ + "EfficientNet", + "efficientnet_b0", + "efficientnet_b1", + "efficientnet_b2", + "efficientnet_b3", + "efficientnet_b4", + "efficientnet_b5", + "efficientnet_b6", + "efficientnet_b7", +] model_urls = { @@ -32,10 +41,17 @@ class MBConvConfig: # Stores information listed at Table 1 of the EfficientNet paper - def __init__(self, - expand_ratio: float, kernel: int, stride: int, - input_channels: int, out_channels: int, num_layers: int, - width_mult: float, depth_mult: float) -> None: + def __init__( + self, + expand_ratio: float, + kernel: int, + stride: int, + input_channels: int, + out_channels: int, + num_layers: int, + width_mult: float, + depth_mult: float, + ) -> None: self.expand_ratio = expand_ratio self.kernel = kernel self.stride = stride @@ -44,14 +60,14 @@ def __init__(self, self.num_layers = self.adjust_depth(num_layers, depth_mult) def __repr__(self) -> str: - s = self.__class__.__name__ + '(' - s += 'expand_ratio={expand_ratio}' - s += ', kernel={kernel}' - s += ', stride={stride}' - s += ', input_channels={input_channels}' - s += ', out_channels={out_channels}' - s += ', num_layers={num_layers}' - s += ')' + s = self.__class__.__name__ + "(" + s += "expand_ratio={expand_ratio}" + s += ", kernel={kernel}" + s += ", stride={stride}" + s += ", input_channels={input_channels}" + s += ", out_channels={out_channels}" + s += ", num_layers={num_layers}" + s += ")" return s.format(**self.__dict__) @staticmethod @@ -64,12 +80,17 @@ def adjust_depth(num_layers: int, depth_mult: float): class MBConv(nn.Module): - def __init__(self, cnf: MBConvConfig, stochastic_depth_prob: float, norm_layer: Callable[..., nn.Module], - se_layer: Callable[..., nn.Module] = SqueezeExcitation) -> None: + def __init__( + self, + cnf: MBConvConfig, + stochastic_depth_prob: float, + norm_layer: Callable[..., nn.Module], + se_layer: Callable[..., nn.Module] = SqueezeExcitation, + ) -> None: super().__init__() if not (1 <= cnf.stride <= 2): - raise ValueError('illegal stride value') + raise ValueError("illegal stride value") self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels @@ -79,21 +100,39 @@ def __init__(self, cnf: MBConvConfig, stochastic_depth_prob: float, norm_layer: # expand expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio) if expanded_channels != cnf.input_channels: - layers.append(ConvNormActivation(cnf.input_channels, expanded_channels, kernel_size=1, - norm_layer=norm_layer, activation_layer=activation_layer)) + layers.append( + ConvNormActivation( + cnf.input_channels, + expanded_channels, + kernel_size=1, + norm_layer=norm_layer, + activation_layer=activation_layer, + ) + ) # depthwise - layers.append(ConvNormActivation(expanded_channels, expanded_channels, kernel_size=cnf.kernel, - stride=cnf.stride, groups=expanded_channels, - norm_layer=norm_layer, activation_layer=activation_layer)) + layers.append( + ConvNormActivation( + expanded_channels, + expanded_channels, + kernel_size=cnf.kernel, + stride=cnf.stride, + groups=expanded_channels, + norm_layer=norm_layer, + activation_layer=activation_layer, + ) + ) # squeeze and excitation squeeze_channels = max(1, cnf.input_channels // 4) layers.append(se_layer(expanded_channels, squeeze_channels, activation=partial(nn.SiLU, inplace=True))) # project - layers.append(ConvNormActivation(expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, - activation_layer=None)) + layers.append( + ConvNormActivation( + expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None + ) + ) self.block = nn.Sequential(*layers) self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") @@ -109,14 +148,14 @@ def forward(self, input: Tensor) -> Tensor: class EfficientNet(nn.Module): def __init__( - self, - inverted_residual_setting: List[MBConvConfig], - dropout: float, - stochastic_depth_prob: float = 0.2, - num_classes: int = 1000, - block: Optional[Callable[..., nn.Module]] = None, - norm_layer: Optional[Callable[..., nn.Module]] = None, - **kwargs: Any + self, + inverted_residual_setting: List[MBConvConfig], + dropout: float, + stochastic_depth_prob: float = 0.2, + num_classes: int = 1000, + block: Optional[Callable[..., nn.Module]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + **kwargs: Any, ) -> None: """ EfficientNet main class @@ -133,8 +172,10 @@ def __init__( if not inverted_residual_setting: raise ValueError("The inverted_residual_setting should not be empty") - elif not (isinstance(inverted_residual_setting, Sequence) and - all([isinstance(s, MBConvConfig) for s in inverted_residual_setting])): + elif not ( + isinstance(inverted_residual_setting, Sequence) + and all([isinstance(s, MBConvConfig) for s in inverted_residual_setting]) + ): raise TypeError("The inverted_residual_setting should be List[MBConvConfig]") if block is None: @@ -147,8 +188,11 @@ def __init__( # building first layer firstconv_output_channels = inverted_residual_setting[0].input_channels - layers.append(ConvNormActivation(3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, - activation_layer=nn.SiLU)) + layers.append( + ConvNormActivation( + 3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, activation_layer=nn.SiLU + ) + ) # building inverted residual blocks total_stage_blocks = sum([cnf.num_layers for cnf in inverted_residual_setting]) @@ -175,8 +219,15 @@ def __init__( # building last several layers lastconv_input_channels = inverted_residual_setting[-1].out_channels lastconv_output_channels = 4 * lastconv_input_channels - layers.append(ConvNormActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1, - norm_layer=norm_layer, activation_layer=nn.SiLU)) + layers.append( + ConvNormActivation( + lastconv_input_channels, + lastconv_output_channels, + kernel_size=1, + norm_layer=norm_layer, + activation_layer=nn.SiLU, + ) + ) self.features = nn.Sequential(*layers) self.avgpool = nn.AdaptiveAvgPool2d(1) @@ -187,7 +238,7 @@ def __init__( for m in self.modules(): if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out') + nn.init.kaiming_normal_(m.weight, mode="fan_out") if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): @@ -232,7 +283,7 @@ def _efficientnet_model( dropout: float, pretrained: bool, progress: bool, - **kwargs: Any + **kwargs: Any, ) -> EfficientNet: model = EfficientNet(inverted_residual_setting, dropout, **kwargs) if pretrained: @@ -318,8 +369,15 @@ def efficientnet_b5(pretrained: bool = False, progress: bool = True, **kwargs: A progress (bool): If True, displays a progress bar of the download to stderr """ inverted_residual_setting = _efficientnet_conf(width_mult=1.6, depth_mult=2.2, **kwargs) - return _efficientnet_model("efficientnet_b5", inverted_residual_setting, 0.4, pretrained, progress, - norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), **kwargs) + return _efficientnet_model( + "efficientnet_b5", + inverted_residual_setting, + 0.4, + pretrained, + progress, + norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), + **kwargs, + ) def efficientnet_b6(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: @@ -332,8 +390,15 @@ def efficientnet_b6(pretrained: bool = False, progress: bool = True, **kwargs: A progress (bool): If True, displays a progress bar of the download to stderr """ inverted_residual_setting = _efficientnet_conf(width_mult=1.8, depth_mult=2.6, **kwargs) - return _efficientnet_model("efficientnet_b6", inverted_residual_setting, 0.5, pretrained, progress, - norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), **kwargs) + return _efficientnet_model( + "efficientnet_b6", + inverted_residual_setting, + 0.5, + pretrained, + progress, + norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), + **kwargs, + ) def efficientnet_b7(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> EfficientNet: @@ -346,5 +411,12 @@ def efficientnet_b7(pretrained: bool = False, progress: bool = True, **kwargs: A progress (bool): If True, displays a progress bar of the download to stderr """ inverted_residual_setting = _efficientnet_conf(width_mult=2.0, depth_mult=3.1, **kwargs) - return _efficientnet_model("efficientnet_b7", inverted_residual_setting, 0.5, pretrained, progress, - norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), **kwargs) + return _efficientnet_model( + "efficientnet_b7", + inverted_residual_setting, + 0.5, + pretrained, + progress, + norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), + **kwargs, + ) diff --git a/torchvision/models/feature_extraction.py b/torchvision/models/feature_extraction.py index 9409d31172b..c2674cb83a1 100644 --- a/torchvision/models/feature_extraction.py +++ b/torchvision/models/feature_extraction.py @@ -1,17 +1,17 @@ -from typing import Dict, Callable, List, Union, Optional, Tuple -from collections import OrderedDict -import warnings import re +import warnings +from collections import OrderedDict from copy import deepcopy from itertools import chain +from typing import Dict, Callable, List, Union, Optional, Tuple import torch -from torch import nn from torch import fx +from torch import nn from torch.fx.graph_module import _copy_attr -__all__ = ['create_feature_extractor', 'get_graph_node_names'] +__all__ = ["create_feature_extractor", "get_graph_node_names"] class LeafModuleAwareTracer(fx.Tracer): @@ -20,10 +20,11 @@ class LeafModuleAwareTracer(fx.Tracer): modules that are not to be traced through. The resulting graph ends up having single nodes referencing calls to the leaf modules' forward methods. """ + def __init__(self, *args, **kwargs): self.leaf_modules = {} - if 'leaf_modules' in kwargs: - leaf_modules = kwargs.pop('leaf_modules') + if "leaf_modules" in kwargs: + leaf_modules = kwargs.pop("leaf_modules") self.leaf_modules = leaf_modules super(LeafModuleAwareTracer, self).__init__(*args, **kwargs) @@ -51,10 +52,11 @@ class NodePathTracer(LeafModuleAwareTracer): - When a duplicate node name is encountered, a suffix of the form _{int} is added. The counter starts from 1. """ + def __init__(self, *args, **kwargs): super(NodePathTracer, self).__init__(*args, **kwargs) # Track the qualified name of the Node being traced - self.current_module_qualname = '' + self.current_module_qualname = "" # A map from FX Node to the qualified name\# # NOTE: This is loosely like the "qualified name" mentioned in the # torch.fx docs https://pytorch.org/docs/stable/fx.html but adapted @@ -78,32 +80,31 @@ def call_module(self, m: torch.nn.Module, forward: Callable, args, kwargs): if not self.is_leaf_module(m, module_qualname): out = forward(*args, **kwargs) return out - return self.create_proxy('call_module', module_qualname, args, kwargs) + return self.create_proxy("call_module", module_qualname, args, kwargs) finally: self.current_module_qualname = old_qualname - def create_proxy(self, kind: str, target: fx.node.Target, args, kwargs, - name=None, type_expr=None, *_) -> fx.proxy.Proxy: + def create_proxy( + self, kind: str, target: fx.node.Target, args, kwargs, name=None, type_expr=None, *_ + ) -> fx.proxy.Proxy: """ Override of `Tracer.create_proxy`. This override intercepts the recording of every operation and stores away the current traced module's qualified name in `node_to_qualname` """ proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr) - self.node_to_qualname[proxy.node] = self._get_node_qualname( - self.current_module_qualname, proxy.node) + self.node_to_qualname[proxy.node] = self._get_node_qualname(self.current_module_qualname, proxy.node) return proxy - def _get_node_qualname( - self, module_qualname: str, node: fx.node.Node) -> str: + def _get_node_qualname(self, module_qualname: str, node: fx.node.Node) -> str: node_qualname = module_qualname - if node.op != 'call_module': + if node.op != "call_module": # In this case module_qualname from torch.fx doesn't go all the # way to the leaf function/op so we need to append it if len(node_qualname) > 0: # Only append '.' if we are deeper than the top level module - node_qualname += '.' + node_qualname += "." node_qualname += str(node) # Now we need to add an _{index} postfix on any repeated node names @@ -111,23 +112,22 @@ def _get_node_qualname( # But for anything else, torch.fx already has a globally scoped # _{index} postfix. But we want it locally (relative to direct parent) # scoped. So first we need to undo the torch.fx postfix - if re.match(r'.+_[0-9]+$', node_qualname) is not None: - node_qualname = node_qualname.rsplit('_', 1)[0] + if re.match(r".+_[0-9]+$", node_qualname) is not None: + node_qualname = node_qualname.rsplit("_", 1)[0] # ... and now we add on our own postfix for existing_qualname in reversed(self.node_to_qualname.values()): # Check to see if existing_qualname is of the form # {node_qualname} or {node_qualname}_{int} - if re.match(rf'{node_qualname}(_[0-9]+)?$', - existing_qualname) is not None: - postfix = existing_qualname.replace(node_qualname, '') + if re.match(rf"{node_qualname}(_[0-9]+)?$", existing_qualname) is not None: + postfix = existing_qualname.replace(node_qualname, "") if len(postfix): # existing_qualname is of the form {node_qualname}_{int} next_index = int(postfix[1:]) + 1 else: # existing_qualname is of the form {node_qualname} next_index = 1 - node_qualname += f'_{next_index}' + node_qualname += f"_{next_index}" break return node_qualname @@ -141,8 +141,7 @@ def _is_subseq(x, y): return all(any(x_item == y_item for x_item in iter_x) for y_item in y) -def _warn_graph_differences( - train_tracer: NodePathTracer, eval_tracer: NodePathTracer): +def _warn_graph_differences(train_tracer: NodePathTracer, eval_tracer: NodePathTracer): """ Utility function for warning the user if there are differences between the train graph nodes and the eval graph nodes. @@ -150,29 +149,32 @@ def _warn_graph_differences( train_nodes = list(train_tracer.node_to_qualname.values()) eval_nodes = list(eval_tracer.node_to_qualname.values()) - if len(train_nodes) == len(eval_nodes) and all( - t == e for t, e in zip(train_nodes, eval_nodes)): + if len(train_nodes) == len(eval_nodes) and all(t == e for t, e in zip(train_nodes, eval_nodes)): return suggestion_msg = ( "When choosing nodes for feature extraction, you may need to specify " - "output nodes for train and eval mode separately.") + "output nodes for train and eval mode separately." + ) if _is_subseq(train_nodes, eval_nodes): - msg = ("NOTE: The nodes obtained by tracing the model in eval mode " - "are a subsequence of those obtained in train mode. ") + msg = ( + "NOTE: The nodes obtained by tracing the model in eval mode " + "are a subsequence of those obtained in train mode. " + ) elif _is_subseq(eval_nodes, train_nodes): - msg = ("NOTE: The nodes obtained by tracing the model in train mode " - "are a subsequence of those obtained in eval mode. ") + msg = ( + "NOTE: The nodes obtained by tracing the model in train mode " + "are a subsequence of those obtained in eval mode. " + ) else: - msg = ("The nodes obtained by tracing the model in train mode " - "are different to those obtained in eval mode. ") + msg = "The nodes obtained by tracing the model in train mode " "are different to those obtained in eval mode. " warnings.warn(msg + suggestion_msg) def get_graph_node_names( - model: nn.Module, tracer_kwargs: Dict = {}, - suppress_diff_warning: bool = False) -> Tuple[List[str], List[str]]: + model: nn.Module, tracer_kwargs: Dict = {}, suppress_diff_warning: bool = False +) -> Tuple[List[str], List[str]]: """ Dev utility to return node names in order of execution. See note on node names under :func:`create_feature_extractor`. Useful for seeing which node @@ -230,11 +232,10 @@ class DualGraphModule(fx.GraphModule): - Copies submodules according to the nodes of both train and eval graphs. - Calling train(mode) switches between train graph and eval graph. """ - def __init__(self, - root: torch.nn.Module, - train_graph: fx.Graph, - eval_graph: fx.Graph, - class_name: str = 'GraphModule'): + + def __init__( + self, root: torch.nn.Module, train_graph: fx.Graph, eval_graph: fx.Graph, class_name: str = "GraphModule" + ): """ Args: root (nn.Module): module from which the copied module hierarchy is @@ -252,7 +253,7 @@ def __init__(self, # Copy all get_attr and call_module ops (indicated by BOTH train and # eval graphs) for node in chain(iter(train_graph.nodes), iter(eval_graph.nodes)): - if node.op in ['get_attr', 'call_module']: + if node.op in ["get_attr", "call_module"]: assert isinstance(node.target, str) _copy_attr(root, self, node.target) @@ -266,10 +267,11 @@ def __init__(self, # Locally defined Tracers are not pickleable. This is needed because torch.package will # serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer # to re-create the Graph during deserialization. - assert self.eval_graph._tracer_cls == self.train_graph._tracer_cls, \ - "Train mode and eval mode should use the same tracer class" + assert ( + self.eval_graph._tracer_cls == self.train_graph._tracer_cls + ), "Train mode and eval mode should use the same tracer class" self._tracer_cls = None - if self.graph._tracer_cls and '' not in self.graph._tracer_cls.__qualname__: + if self.graph._tracer_cls and "" not in self.graph._tracer_cls.__qualname__: self._tracer_cls = self.graph._tracer_cls def train(self, mode=True): @@ -288,12 +290,13 @@ def train(self, mode=True): def create_feature_extractor( - model: nn.Module, - return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, - train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, - eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, - tracer_kwargs: Dict = {}, - suppress_diff_warning: bool = False) -> fx.GraphModule: + model: nn.Module, + return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, + train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, + eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, + tracer_kwargs: Dict = {}, + suppress_diff_warning: bool = False, +) -> fx.GraphModule: """ Creates a new graph module that returns intermediate nodes from a given model as dictionary with user specified keys as strings, and the requested @@ -396,18 +399,17 @@ def create_feature_extractor( """ is_training = model.training - assert any(arg is not None for arg in [ - return_nodes, train_return_nodes, eval_return_nodes]), ( - "Either `return_nodes` or `train_return_nodes` and " - "`eval_return_nodes` together, should be specified") + assert any(arg is not None for arg in [return_nodes, train_return_nodes, eval_return_nodes]), ( + "Either `return_nodes` or `train_return_nodes` and " "`eval_return_nodes` together, should be specified" + ) - assert not ((train_return_nodes is None) ^ (eval_return_nodes is None)), \ - ("If any of `train_return_nodes` and `eval_return_nodes` are " - "specified, then both should be specified") + assert not ((train_return_nodes is None) ^ (eval_return_nodes is None)), ( + "If any of `train_return_nodes` and `eval_return_nodes` are " "specified, then both should be specified" + ) - assert ((return_nodes is None) ^ (train_return_nodes is None)), \ - ("If `train_return_nodes` and `eval_return_nodes` are specified, " - "then both should be specified") + assert (return_nodes is None) ^ (train_return_nodes is None), ( + "If `train_return_nodes` and `eval_return_nodes` are specified, " "then both should be specified" + ) # Put *_return_nodes into Dict[str, str] format def to_strdict(n) -> Dict[str, str]: @@ -426,45 +428,42 @@ def to_strdict(n) -> Dict[str, str]: # Repeat the tracing and graph rewriting for train and eval mode tracers = {} graphs = {} - mode_return_nodes: Dict[str, Dict[str, str]] = { - 'train': train_return_nodes, - 'eval': eval_return_nodes - } - for mode in ['train', 'eval']: - if mode == 'train': + mode_return_nodes: Dict[str, Dict[str, str]] = {"train": train_return_nodes, "eval": eval_return_nodes} + for mode in ["train", "eval"]: + if mode == "train": model.train() - elif mode == 'eval': + elif mode == "eval": model.eval() # Instantiate our NodePathTracer and use that to trace the model tracer = NodePathTracer(**tracer_kwargs) graph = tracer.trace(model) - name = model.__class__.__name__ if isinstance( - model, nn.Module) else model.__name__ + name = model.__class__.__name__ if isinstance(model, nn.Module) else model.__name__ graph_module = fx.GraphModule(tracer.root, graph, name) available_nodes = list(tracer.node_to_qualname.values()) # FIXME We don't know if we should expect this to happen - assert len(set(available_nodes)) == len(available_nodes), \ - "There are duplicate nodes! Please raise an issue https://github.com/pytorch/vision/issues" + assert len(set(available_nodes)) == len( + available_nodes + ), "There are duplicate nodes! Please raise an issue https://github.com/pytorch/vision/issues" # Check that all outputs in return_nodes are present in the model for query in mode_return_nodes[mode].keys(): # To check if a query is available we need to check that at least # one of the available names starts with it up to a . - if not any([re.match(rf'^{query}(\.|$)', n) is not None - for n in available_nodes]): + if not any([re.match(rf"^{query}(\.|$)", n) is not None for n in available_nodes]): raise ValueError( f"node: '{query}' is not present in model. Hint: use " "`get_graph_node_names` to make sure the " "`return_nodes` you specified are present. It may even " "be that you need to specify `train_return_nodes` and " - "`eval_return_nodes` separately.") + "`eval_return_nodes` separately." + ) # Remove existing output nodes (train mode) orig_output_nodes = [] for n in reversed(graph_module.graph.nodes): - if n.op == 'output': + if n.op == "output": orig_output_nodes.append(n) assert len(orig_output_nodes) for n in orig_output_nodes: @@ -482,8 +481,8 @@ def to_strdict(n) -> Dict[str, str]: # - When packing outputs into a named tuple like in InceptionV3 continue for query in mode_return_nodes[mode]: - depth = query.count('.') - if '.'.join(module_qualname.split('.')[:depth + 1]) == query: + depth = query.count(".") + if ".".join(module_qualname.split(".")[: depth + 1]) == query: output_nodes[mode_return_nodes[mode][query]] = n mode_return_nodes[mode].pop(query) break @@ -504,11 +503,10 @@ def to_strdict(n) -> Dict[str, str]: # Warn user if there are any discrepancies between the graphs of the # train and eval modes if not suppress_diff_warning: - _warn_graph_differences(tracers['train'], tracers['eval']) + _warn_graph_differences(tracers["train"], tracers["eval"]) # Build the final graph module - graph_module = DualGraphModule( - model, graphs['train'], graphs['eval'], class_name=name) + graph_module = DualGraphModule(model, graphs["train"], graphs["eval"], class_name=name) # Restore original training mode model.train(is_training) diff --git a/torchvision/models/googlenet.py b/torchvision/models/googlenet.py index 6c34f521a81..132805389a7 100644 --- a/torchvision/models/googlenet.py +++ b/torchvision/models/googlenet.py @@ -1,22 +1,23 @@ import warnings from collections import namedtuple +from typing import Optional, Tuple, List, Callable, Any + import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor + from .._internally_replaced_utils import load_state_dict_from_url -from typing import Optional, Tuple, List, Callable, Any -__all__ = ['GoogLeNet', 'googlenet', "GoogLeNetOutputs", "_GoogLeNetOutputs"] +__all__ = ["GoogLeNet", "googlenet", "GoogLeNetOutputs", "_GoogLeNetOutputs"] model_urls = { # GoogLeNet ported from TensorFlow - 'googlenet': 'https://download.pytorch.org/models/googlenet-1378be20.pth', + "googlenet": "https://download.pytorch.org/models/googlenet-1378be20.pth", } -GoogLeNetOutputs = namedtuple('GoogLeNetOutputs', ['logits', 'aux_logits2', 'aux_logits1']) -GoogLeNetOutputs.__annotations__ = {'logits': Tensor, 'aux_logits2': Optional[Tensor], - 'aux_logits1': Optional[Tensor]} +GoogLeNetOutputs = namedtuple("GoogLeNetOutputs", ["logits", "aux_logits2", "aux_logits1"]) +GoogLeNetOutputs.__annotations__ = {"logits": Tensor, "aux_logits2": Optional[Tensor], "aux_logits1": Optional[Tensor]} # Script annotations failed with _GoogleNetOutputs = namedtuple ... # _GoogLeNetOutputs set here for backwards compat @@ -37,19 +38,19 @@ def googlenet(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> was trained on ImageNet. Default: *False* """ if pretrained: - if 'transform_input' not in kwargs: - kwargs['transform_input'] = True - if 'aux_logits' not in kwargs: - kwargs['aux_logits'] = False - if kwargs['aux_logits']: - warnings.warn('auxiliary heads in the pretrained googlenet model are NOT pretrained, ' - 'so make sure to train them') - original_aux_logits = kwargs['aux_logits'] - kwargs['aux_logits'] = True - kwargs['init_weights'] = False + if "transform_input" not in kwargs: + kwargs["transform_input"] = True + if "aux_logits" not in kwargs: + kwargs["aux_logits"] = False + if kwargs["aux_logits"]: + warnings.warn( + "auxiliary heads in the pretrained googlenet model are NOT pretrained, " "so make sure to train them" + ) + original_aux_logits = kwargs["aux_logits"] + kwargs["aux_logits"] = True + kwargs["init_weights"] = False model = GoogLeNet(**kwargs) - state_dict = load_state_dict_from_url(model_urls['googlenet'], - progress=progress) + state_dict = load_state_dict_from_url(model_urls["googlenet"], progress=progress) model.load_state_dict(state_dict) if not original_aux_logits: model.aux_logits = False @@ -61,7 +62,7 @@ def googlenet(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> class GoogLeNet(nn.Module): - __constants__ = ['aux_logits', 'transform_input'] + __constants__ = ["aux_logits", "transform_input"] def __init__( self, @@ -69,15 +70,18 @@ def __init__( aux_logits: bool = True, transform_input: bool = False, init_weights: Optional[bool] = None, - blocks: Optional[List[Callable[..., nn.Module]]] = None + blocks: Optional[List[Callable[..., nn.Module]]] = None, ) -> None: super(GoogLeNet, self).__init__() if blocks is None: blocks = [BasicConv2d, Inception, InceptionAux] if init_weights is None: - warnings.warn('The default weight initialization of GoogleNet will be changed in future releases of ' - 'torchvision. If you wish to keep the old behavior (which leads to long initialization times' - ' due to scipy/scipy#11299), please set init_weights=True.', FutureWarning) + warnings.warn( + "The default weight initialization of GoogleNet will be changed in future releases of " + "torchvision. If you wish to keep the old behavior (which leads to long initialization times" + " due to scipy/scipy#11299), please set init_weights=True.", + FutureWarning, + ) init_weights = True assert len(blocks) == 3 conv_block = blocks[0] @@ -197,7 +201,7 @@ def eager_outputs(self, x: Tensor, aux2: Tensor, aux1: Optional[Tensor]) -> Goog if self.training and self.aux_logits: return _GoogLeNetOutputs(x, aux2, aux1) else: - return x # type: ignore[return-value] + return x # type: ignore[return-value] def forward(self, x: Tensor) -> GoogLeNetOutputs: x = self._transform_input(x) @@ -212,7 +216,6 @@ def forward(self, x: Tensor) -> GoogLeNetOutputs: class Inception(nn.Module): - def __init__( self, in_channels: int, @@ -222,7 +225,7 @@ def __init__( ch5x5red: int, ch5x5: int, pool_proj: int, - conv_block: Optional[Callable[..., nn.Module]] = None + conv_block: Optional[Callable[..., nn.Module]] = None, ) -> None: super(Inception, self).__init__() if conv_block is None: @@ -230,20 +233,19 @@ def __init__( self.branch1 = conv_block(in_channels, ch1x1, kernel_size=1) self.branch2 = nn.Sequential( - conv_block(in_channels, ch3x3red, kernel_size=1), - conv_block(ch3x3red, ch3x3, kernel_size=3, padding=1) + conv_block(in_channels, ch3x3red, kernel_size=1), conv_block(ch3x3red, ch3x3, kernel_size=3, padding=1) ) self.branch3 = nn.Sequential( conv_block(in_channels, ch5x5red, kernel_size=1), # Here, kernel_size=3 instead of kernel_size=5 is a known bug. # Please see https://github.com/pytorch/vision/issues/906 for details. - conv_block(ch5x5red, ch5x5, kernel_size=3, padding=1) + conv_block(ch5x5red, ch5x5, kernel_size=3, padding=1), ) self.branch4 = nn.Sequential( nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True), - conv_block(in_channels, pool_proj, kernel_size=1) + conv_block(in_channels, pool_proj, kernel_size=1), ) def _forward(self, x: Tensor) -> List[Tensor]: @@ -261,12 +263,8 @@ def forward(self, x: Tensor) -> Tensor: class InceptionAux(nn.Module): - def __init__( - self, - in_channels: int, - num_classes: int, - conv_block: Optional[Callable[..., nn.Module]] = None + self, in_channels: int, num_classes: int, conv_block: Optional[Callable[..., nn.Module]] = None ) -> None: super(InceptionAux, self).__init__() if conv_block is None: @@ -295,13 +293,7 @@ def forward(self, x: Tensor) -> Tensor: class BasicConv2d(nn.Module): - - def __init__( - self, - in_channels: int, - out_channels: int, - **kwargs: Any - ) -> None: + def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None: super(BasicConv2d, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) self.bn = nn.BatchNorm2d(out_channels, eps=0.001) diff --git a/torchvision/models/inception.py b/torchvision/models/inception.py index b5cefefa78d..2f18b8bc569 100644 --- a/torchvision/models/inception.py +++ b/torchvision/models/inception.py @@ -1,22 +1,24 @@ -from collections import namedtuple import warnings +from collections import namedtuple +from typing import Callable, Any, Optional, Tuple, List + import torch -from torch import nn, Tensor import torch.nn.functional as F +from torch import nn, Tensor + from .._internally_replaced_utils import load_state_dict_from_url -from typing import Callable, Any, Optional, Tuple, List -__all__ = ['Inception3', 'inception_v3', 'InceptionOutputs', '_InceptionOutputs'] +__all__ = ["Inception3", "inception_v3", "InceptionOutputs", "_InceptionOutputs"] model_urls = { # Inception v3 ported from TensorFlow - 'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth', + "inception_v3_google": "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth", } -InceptionOutputs = namedtuple('InceptionOutputs', ['logits', 'aux_logits']) -InceptionOutputs.__annotations__ = {'logits': Tensor, 'aux_logits': Optional[Tensor]} +InceptionOutputs = namedtuple("InceptionOutputs", ["logits", "aux_logits"]) +InceptionOutputs.__annotations__ = {"logits": Tensor, "aux_logits": Optional[Tensor]} # Script annotations failed with _GoogleNetOutputs = namedtuple ... # _InceptionOutputs set here for backwards compat @@ -41,17 +43,16 @@ def inception_v3(pretrained: bool = False, progress: bool = True, **kwargs: Any) was trained on ImageNet. Default: *False* """ if pretrained: - if 'transform_input' not in kwargs: - kwargs['transform_input'] = True - if 'aux_logits' in kwargs: - original_aux_logits = kwargs['aux_logits'] - kwargs['aux_logits'] = True + if "transform_input" not in kwargs: + kwargs["transform_input"] = True + if "aux_logits" in kwargs: + original_aux_logits = kwargs["aux_logits"] + kwargs["aux_logits"] = True else: original_aux_logits = True - kwargs['init_weights'] = False # we are loading weights from a pretrained model + kwargs["init_weights"] = False # we are loading weights from a pretrained model model = Inception3(**kwargs) - state_dict = load_state_dict_from_url(model_urls['inception_v3_google'], - progress=progress) + state_dict = load_state_dict_from_url(model_urls["inception_v3_google"], progress=progress) model.load_state_dict(state_dict) if not original_aux_logits: model.aux_logits = False @@ -62,25 +63,24 @@ def inception_v3(pretrained: bool = False, progress: bool = True, **kwargs: Any) class Inception3(nn.Module): - def __init__( self, num_classes: int = 1000, aux_logits: bool = True, transform_input: bool = False, inception_blocks: Optional[List[Callable[..., nn.Module]]] = None, - init_weights: Optional[bool] = None + init_weights: Optional[bool] = None, ) -> None: super(Inception3, self).__init__() if inception_blocks is None: - inception_blocks = [ - BasicConv2d, InceptionA, InceptionB, InceptionC, - InceptionD, InceptionE, InceptionAux - ] + inception_blocks = [BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE, InceptionAux] if init_weights is None: - warnings.warn('The default weight initialization of inception_v3 will be changed in future releases of ' - 'torchvision. If you wish to keep the old behavior (which leads to long initialization times' - ' due to scipy/scipy#11299), please set init_weights=True.', FutureWarning) + warnings.warn( + "The default weight initialization of inception_v3 will be changed in future releases of " + "torchvision. If you wish to keep the old behavior (which leads to long initialization times" + " due to scipy/scipy#11299), please set init_weights=True.", + FutureWarning, + ) init_weights = True assert len(inception_blocks) == 7 conv_block = inception_blocks[0] @@ -120,7 +120,7 @@ def __init__( if init_weights: for m in self.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): - stddev = float(m.stddev) if hasattr(m, 'stddev') else 0.1 # type: ignore + stddev = float(m.stddev) if hasattr(m, "stddev") else 0.1 # type: ignore torch.nn.init.trunc_normal_(m.weight, mean=0.0, std=stddev, a=-2, b=2) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) @@ -208,12 +208,8 @@ def forward(self, x: Tensor) -> InceptionOutputs: class InceptionA(nn.Module): - def __init__( - self, - in_channels: int, - pool_features: int, - conv_block: Optional[Callable[..., nn.Module]] = None + self, in_channels: int, pool_features: int, conv_block: Optional[Callable[..., nn.Module]] = None ) -> None: super(InceptionA, self).__init__() if conv_block is None: @@ -251,12 +247,7 @@ def forward(self, x: Tensor) -> Tensor: class InceptionB(nn.Module): - - def __init__( - self, - in_channels: int, - conv_block: Optional[Callable[..., nn.Module]] = None - ) -> None: + def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None) -> None: super(InceptionB, self).__init__() if conv_block is None: conv_block = BasicConv2d @@ -284,12 +275,8 @@ def forward(self, x: Tensor) -> Tensor: class InceptionC(nn.Module): - def __init__( - self, - in_channels: int, - channels_7x7: int, - conv_block: Optional[Callable[..., nn.Module]] = None + self, in_channels: int, channels_7x7: int, conv_block: Optional[Callable[..., nn.Module]] = None ) -> None: super(InceptionC, self).__init__() if conv_block is None: @@ -334,12 +321,7 @@ def forward(self, x: Tensor) -> Tensor: class InceptionD(nn.Module): - - def __init__( - self, - in_channels: int, - conv_block: Optional[Callable[..., nn.Module]] = None - ) -> None: + def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None) -> None: super(InceptionD, self).__init__() if conv_block is None: conv_block = BasicConv2d @@ -370,12 +352,7 @@ def forward(self, x: Tensor) -> Tensor: class InceptionE(nn.Module): - - def __init__( - self, - in_channels: int, - conv_block: Optional[Callable[..., nn.Module]] = None - ) -> None: + def __init__(self, in_channels: int, conv_block: Optional[Callable[..., nn.Module]] = None) -> None: super(InceptionE, self).__init__() if conv_block is None: conv_block = BasicConv2d @@ -422,12 +399,8 @@ def forward(self, x: Tensor) -> Tensor: class InceptionAux(nn.Module): - def __init__( - self, - in_channels: int, - num_classes: int, - conv_block: Optional[Callable[..., nn.Module]] = None + self, in_channels: int, num_classes: int, conv_block: Optional[Callable[..., nn.Module]] = None ) -> None: super(InceptionAux, self).__init__() if conv_block is None: @@ -457,13 +430,7 @@ def forward(self, x: Tensor) -> Tensor: class BasicConv2d(nn.Module): - - def __init__( - self, - in_channels: int, - out_channels: int, - **kwargs: Any - ) -> None: + def __init__(self, in_channels: int, out_channels: int, **kwargs: Any) -> None: super(BasicConv2d, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) self.bn = nn.BatchNorm2d(out_channels, eps=0.001) diff --git a/torchvision/models/mnasnet.py b/torchvision/models/mnasnet.py index ffefab77628..3f48f82c41e 100644 --- a/torchvision/models/mnasnet.py +++ b/torchvision/models/mnasnet.py @@ -1,20 +1,19 @@ import warnings +from typing import Any, Dict, List import torch -from torch import Tensor import torch.nn as nn +from torch import Tensor + from .._internally_replaced_utils import load_state_dict_from_url -from typing import Any, Dict, List -__all__ = ['MNASNet', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3'] +__all__ = ["MNASNet", "mnasnet0_5", "mnasnet0_75", "mnasnet1_0", "mnasnet1_3"] _MODEL_URLS = { - "mnasnet0_5": - "https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth", + "mnasnet0_5": "https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth", "mnasnet0_75": None, - "mnasnet1_0": - "https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth", - "mnasnet1_3": None + "mnasnet1_0": "https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth", + "mnasnet1_3": None, } # Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is @@ -23,34 +22,27 @@ class _InvertedResidual(nn.Module): - def __init__( - self, - in_ch: int, - out_ch: int, - kernel_size: int, - stride: int, - expansion_factor: int, - bn_momentum: float = 0.1 + self, in_ch: int, out_ch: int, kernel_size: int, stride: int, expansion_factor: int, bn_momentum: float = 0.1 ) -> None: super(_InvertedResidual, self).__init__() assert stride in [1, 2] assert kernel_size in [3, 5] mid_ch = in_ch * expansion_factor - self.apply_residual = (in_ch == out_ch and stride == 1) + self.apply_residual = in_ch == out_ch and stride == 1 self.layers = nn.Sequential( # Pointwise nn.Conv2d(in_ch, mid_ch, 1, bias=False), nn.BatchNorm2d(mid_ch, momentum=bn_momentum), nn.ReLU(inplace=True), # Depthwise - nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=kernel_size // 2, - stride=stride, groups=mid_ch, bias=False), + nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=kernel_size // 2, stride=stride, groups=mid_ch, bias=False), nn.BatchNorm2d(mid_ch, momentum=bn_momentum), nn.ReLU(inplace=True), # Linear pointwise. Note that there's no activation. nn.Conv2d(mid_ch, out_ch, 1, bias=False), - nn.BatchNorm2d(out_ch, momentum=bn_momentum)) + nn.BatchNorm2d(out_ch, momentum=bn_momentum), + ) def forward(self, input: Tensor) -> Tensor: if self.apply_residual: @@ -59,39 +51,37 @@ def forward(self, input: Tensor) -> Tensor: return self.layers(input) -def _stack(in_ch: int, out_ch: int, kernel_size: int, stride: int, exp_factor: int, repeats: int, - bn_momentum: float) -> nn.Sequential: - """ Creates a stack of inverted residuals. """ +def _stack( + in_ch: int, out_ch: int, kernel_size: int, stride: int, exp_factor: int, repeats: int, bn_momentum: float +) -> nn.Sequential: + """Creates a stack of inverted residuals.""" assert repeats >= 1 # First one has no skip, because feature map size changes. - first = _InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor, - bn_momentum=bn_momentum) + first = _InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor, bn_momentum=bn_momentum) remaining = [] for _ in range(1, repeats): - remaining.append( - _InvertedResidual(out_ch, out_ch, kernel_size, 1, exp_factor, - bn_momentum=bn_momentum)) + remaining.append(_InvertedResidual(out_ch, out_ch, kernel_size, 1, exp_factor, bn_momentum=bn_momentum)) return nn.Sequential(first, *remaining) def _round_to_multiple_of(val: float, divisor: int, round_up_bias: float = 0.9) -> int: - """ Asymmetric rounding to make `val` divisible by `divisor`. With default + """Asymmetric rounding to make `val` divisible by `divisor`. With default bias, will round up, unless the number is no more than 10% greater than the - smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88. """ + smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88.""" assert 0.0 < round_up_bias < 1.0 new_val = max(divisor, int(val + divisor / 2) // divisor * divisor) return new_val if new_val >= round_up_bias * val else new_val + divisor def _get_depths(alpha: float) -> List[int]: - """ Scales tensor depths as in reference MobileNet code, prefers rouding up - rather than down. """ + """Scales tensor depths as in reference MobileNet code, prefers rouding up + rather than down.""" depths = [32, 16, 24, 40, 80, 96, 192, 320] return [_round_to_multiple_of(depth * alpha, 8) for depth in depths] class MNASNet(torch.nn.Module): - """ MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. This + """MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. This implements the B1 variant of the model. >>> model = MNASNet(1.0, num_classes=1000) >>> x = torch.rand(1, 3, 224, 224) @@ -101,15 +91,11 @@ class MNASNet(torch.nn.Module): >>> y.nelement() 1000 """ + # Version 2 adds depth scaling in the initial stages of the network. _version = 2 - def __init__( - self, - alpha: float, - num_classes: int = 1000, - dropout: float = 0.2 - ) -> None: + def __init__(self, alpha: float, num_classes: int = 1000, dropout: float = 0.2) -> None: super(MNASNet, self).__init__() assert alpha > 0.0 self.alpha = alpha @@ -121,8 +107,7 @@ def __init__( nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM), nn.ReLU(inplace=True), # Depthwise separable, no skip. - nn.Conv2d(depths[0], depths[0], 3, padding=1, stride=1, - groups=depths[0], bias=False), + nn.Conv2d(depths[0], depths[0], 3, padding=1, stride=1, groups=depths[0], bias=False), nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM), nn.ReLU(inplace=True), nn.Conv2d(depths[0], depths[1], 1, padding=0, stride=1, bias=False), @@ -140,8 +125,7 @@ def __init__( nn.ReLU(inplace=True), ] self.layers = nn.Sequential(*layers) - self.classifier = nn.Sequential(nn.Dropout(p=dropout, inplace=True), - nn.Linear(1280, num_classes)) + self.classifier = nn.Sequential(nn.Dropout(p=dropout, inplace=True), nn.Linear(1280, num_classes)) self._initialize_weights() def forward(self, x: Tensor) -> Tensor: @@ -153,20 +137,26 @@ def forward(self, x: Tensor) -> Tensor: def _initialize_weights(self) -> None: for m in self.modules(): if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode="fan_out", - nonlinearity="relu") + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.BatchNorm2d): nn.init.ones_(m.weight) nn.init.zeros_(m.bias) elif isinstance(m, nn.Linear): - nn.init.kaiming_uniform_(m.weight, mode="fan_out", - nonlinearity="sigmoid") + nn.init.kaiming_uniform_(m.weight, mode="fan_out", nonlinearity="sigmoid") nn.init.zeros_(m.bias) - def _load_from_state_dict(self, state_dict: Dict, prefix: str, local_metadata: Dict, strict: bool, - missing_keys: List[str], unexpected_keys: List[str], error_msgs: List[str]) -> None: + def _load_from_state_dict( + self, + state_dict: Dict, + prefix: str, + local_metadata: Dict, + strict: bool, + missing_keys: List[str], + unexpected_keys: List[str], + error_msgs: List[str], + ) -> None: version = local_metadata.get("version", None) assert version in [1, 2] @@ -180,8 +170,7 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, local_metadata: D nn.Conv2d(3, 32, 3, padding=1, stride=2, bias=False), nn.BatchNorm2d(32, momentum=_BN_MOMENTUM), nn.ReLU(inplace=True), - nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32, - bias=False), + nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32, bias=False), nn.BatchNorm2d(32, momentum=_BN_MOMENTUM), nn.ReLU(inplace=True), nn.Conv2d(32, 16, 1, padding=0, stride=1, bias=False), @@ -199,20 +188,19 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, local_metadata: D "This checkpoint will load and work as before, but " "you may want to upgrade by training a newer model or " "transfer learning from an updated ImageNet checkpoint.", - UserWarning) + UserWarning, + ) super(MNASNet, self)._load_from_state_dict( - state_dict, prefix, local_metadata, strict, missing_keys, - unexpected_keys, error_msgs) + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) def _load_pretrained(model_name: str, model: nn.Module, progress: bool) -> None: if model_name not in _MODEL_URLS or _MODEL_URLS[model_name] is None: - raise ValueError( - "No checkpoint is available for model type {}".format(model_name)) + raise ValueError("No checkpoint is available for model type {}".format(model_name)) checkpoint_url = _MODEL_URLS[model_name] - model.load_state_dict( - load_state_dict_from_url(checkpoint_url, progress=progress)) + model.load_state_dict(load_state_dict_from_url(checkpoint_url, progress=progress)) def mnasnet0_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MNASNet: diff --git a/torchvision/models/mobilenetv2.py b/torchvision/models/mobilenetv2.py index 9e68fbfc5c7..19c7e80dd69 100644 --- a/torchvision/models/mobilenetv2.py +++ b/torchvision/models/mobilenetv2.py @@ -1,20 +1,21 @@ -import torch import warnings - from functools import partial -from torch import nn +from typing import Callable, Any, Optional, List + +import torch from torch import Tensor +from torch import nn + from .._internally_replaced_utils import load_state_dict_from_url from ..ops.misc import ConvNormActivation from ._utils import _make_divisible -from typing import Callable, Any, Optional, List -__all__ = ['MobileNetV2', 'mobilenet_v2'] +__all__ = ["MobileNetV2", "mobilenet_v2"] model_urls = { - 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', + "mobilenet_v2": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", } @@ -23,7 +24,9 @@ class _DeprecatedConvBNAct(ConvNormActivation): def __init__(self, *args, **kwargs): warnings.warn( "The ConvBNReLU/ConvBNActivation classes are deprecated and will be removed in future versions. " - "Use torchvision.ops.misc.ConvNormActivation instead.", FutureWarning) + "Use torchvision.ops.misc.ConvNormActivation instead.", + FutureWarning, + ) if kwargs.get("norm_layer", None) is None: kwargs["norm_layer"] = nn.BatchNorm2d if kwargs.get("activation_layer", None) is None: @@ -37,12 +40,7 @@ def __init__(self, *args, **kwargs): class InvertedResidual(nn.Module): def __init__( - self, - inp: int, - oup: int, - stride: int, - expand_ratio: int, - norm_layer: Optional[Callable[..., nn.Module]] = None + self, inp: int, oup: int, stride: int, expand_ratio: int, norm_layer: Optional[Callable[..., nn.Module]] = None ) -> None: super(InvertedResidual, self).__init__() self.stride = stride @@ -57,16 +55,25 @@ def __init__( layers: List[nn.Module] = [] if expand_ratio != 1: # pw - layers.append(ConvNormActivation(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer, - activation_layer=nn.ReLU6)) - layers.extend([ - # dw - ConvNormActivation(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer, - activation_layer=nn.ReLU6), - # pw-linear - nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), - norm_layer(oup), - ]) + layers.append( + ConvNormActivation(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6) + ) + layers.extend( + [ + # dw + ConvNormActivation( + hidden_dim, + hidden_dim, + stride=stride, + groups=hidden_dim, + norm_layer=norm_layer, + activation_layer=nn.ReLU6, + ), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + norm_layer(oup), + ] + ) self.conv = nn.Sequential(*layers) self.out_channels = oup self._is_cn = stride > 1 @@ -86,7 +93,7 @@ def __init__( inverted_residual_setting: Optional[List[List[int]]] = None, round_nearest: int = 8, block: Optional[Callable[..., nn.Module]] = None, - norm_layer: Optional[Callable[..., nn.Module]] = None + norm_layer: Optional[Callable[..., nn.Module]] = None, ) -> None: """ MobileNet V2 main class @@ -126,14 +133,17 @@ def __init__( # only check the first element, assuming user knows t,c,n,s are required if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: - raise ValueError("inverted_residual_setting should be non-empty " - "or a 4-element list, got {}".format(inverted_residual_setting)) + raise ValueError( + "inverted_residual_setting should be non-empty " + "or a 4-element list, got {}".format(inverted_residual_setting) + ) # building first layer input_channel = _make_divisible(input_channel * width_mult, round_nearest) self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) - features: List[nn.Module] = [ConvNormActivation(3, input_channel, stride=2, norm_layer=norm_layer, - activation_layer=nn.ReLU6)] + features: List[nn.Module] = [ + ConvNormActivation(3, input_channel, stride=2, norm_layer=norm_layer, activation_layer=nn.ReLU6) + ] # building inverted residual blocks for t, c, n, s in inverted_residual_setting: output_channel = _make_divisible(c * width_mult, round_nearest) @@ -142,8 +152,11 @@ def __init__( features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer)) input_channel = output_channel # building last several layers - features.append(ConvNormActivation(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer, - activation_layer=nn.ReLU6)) + features.append( + ConvNormActivation( + input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6 + ) + ) # make it nn.Sequential self.features = nn.Sequential(*features) @@ -156,7 +169,7 @@ def __init__( # weight initialization for m in self.modules(): if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out') + nn.init.kaiming_normal_(m.weight, mode="fan_out") if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): @@ -191,7 +204,6 @@ def mobilenet_v2(pretrained: bool = False, progress: bool = True, **kwargs: Any) """ model = MobileNetV2(**kwargs) if pretrained: - state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'], - progress=progress) + state_dict = load_state_dict_from_url(model_urls["mobilenet_v2"], progress=progress) model.load_state_dict(state_dict) return model diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index 537e2136bbb..00a6a200c70 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -1,10 +1,10 @@ import warnings -import torch - from functools import partial -from torch import nn, Tensor from typing import Any, Callable, List, Optional, Sequence +import torch +from torch import nn, Tensor + from .._internally_replaced_utils import load_state_dict_from_url from ..ops.misc import ConvNormActivation, SqueezeExcitation as SElayer from ._utils import _make_divisible @@ -20,22 +20,34 @@ class SqueezeExcitation(SElayer): - """DEPRECATED - """ + """DEPRECATED""" + def __init__(self, input_channels: int, squeeze_factor: int = 4): squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8) super().__init__(input_channels, squeeze_channels, scale_activation=nn.Hardsigmoid) self.relu = self.activation - delattr(self, 'activation') + delattr(self, "activation") warnings.warn( "This SqueezeExcitation class is deprecated and will be removed in future versions. " - "Use torchvision.ops.misc.SqueezeExcitation instead.", FutureWarning) + "Use torchvision.ops.misc.SqueezeExcitation instead.", + FutureWarning, + ) class InvertedResidualConfig: # Stores information listed at Tables 1 and 2 of the MobileNetV3 paper - def __init__(self, input_channels: int, kernel: int, expanded_channels: int, out_channels: int, use_se: bool, - activation: str, stride: int, dilation: int, width_mult: float): + def __init__( + self, + input_channels: int, + kernel: int, + expanded_channels: int, + out_channels: int, + use_se: bool, + activation: str, + stride: int, + dilation: int, + width_mult: float, + ): self.input_channels = self.adjust_channels(input_channels, width_mult) self.kernel = kernel self.expanded_channels = self.adjust_channels(expanded_channels, width_mult) @@ -52,11 +64,15 @@ def adjust_channels(channels: int, width_mult: float): class InvertedResidual(nn.Module): # Implemented as described at section 5 of MobileNetV3 paper - def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Module], - se_layer: Callable[..., nn.Module] = partial(SElayer, scale_activation=nn.Hardsigmoid)): + def __init__( + self, + cnf: InvertedResidualConfig, + norm_layer: Callable[..., nn.Module], + se_layer: Callable[..., nn.Module] = partial(SElayer, scale_activation=nn.Hardsigmoid), + ): super().__init__() if not (1 <= cnf.stride <= 2): - raise ValueError('illegal stride value') + raise ValueError("illegal stride value") self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels @@ -65,21 +81,40 @@ def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Mod # expand if cnf.expanded_channels != cnf.input_channels: - layers.append(ConvNormActivation(cnf.input_channels, cnf.expanded_channels, kernel_size=1, - norm_layer=norm_layer, activation_layer=activation_layer)) + layers.append( + ConvNormActivation( + cnf.input_channels, + cnf.expanded_channels, + kernel_size=1, + norm_layer=norm_layer, + activation_layer=activation_layer, + ) + ) # depthwise stride = 1 if cnf.dilation > 1 else cnf.stride - layers.append(ConvNormActivation(cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel, - stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels, - norm_layer=norm_layer, activation_layer=activation_layer)) + layers.append( + ConvNormActivation( + cnf.expanded_channels, + cnf.expanded_channels, + kernel_size=cnf.kernel, + stride=stride, + dilation=cnf.dilation, + groups=cnf.expanded_channels, + norm_layer=norm_layer, + activation_layer=activation_layer, + ) + ) if cnf.use_se: squeeze_channels = _make_divisible(cnf.expanded_channels // 4, 8) layers.append(se_layer(cnf.expanded_channels, squeeze_channels)) # project - layers.append(ConvNormActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, - activation_layer=None)) + layers.append( + ConvNormActivation( + cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None + ) + ) self.block = nn.Sequential(*layers) self.out_channels = cnf.out_channels @@ -93,15 +128,14 @@ def forward(self, input: Tensor) -> Tensor: class MobileNetV3(nn.Module): - def __init__( - self, - inverted_residual_setting: List[InvertedResidualConfig], - last_channel: int, - num_classes: int = 1000, - block: Optional[Callable[..., nn.Module]] = None, - norm_layer: Optional[Callable[..., nn.Module]] = None, - **kwargs: Any + self, + inverted_residual_setting: List[InvertedResidualConfig], + last_channel: int, + num_classes: int = 1000, + block: Optional[Callable[..., nn.Module]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + **kwargs: Any, ) -> None: """ MobileNet V3 main class @@ -117,8 +151,10 @@ def __init__( if not inverted_residual_setting: raise ValueError("The inverted_residual_setting should not be empty") - elif not (isinstance(inverted_residual_setting, Sequence) and - all([isinstance(s, InvertedResidualConfig) for s in inverted_residual_setting])): + elif not ( + isinstance(inverted_residual_setting, Sequence) + and all([isinstance(s, InvertedResidualConfig) for s in inverted_residual_setting]) + ): raise TypeError("The inverted_residual_setting should be List[InvertedResidualConfig]") if block is None: @@ -131,8 +167,16 @@ def __init__( # building first layer firstconv_output_channels = inverted_residual_setting[0].input_channels - layers.append(ConvNormActivation(3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, - activation_layer=nn.Hardswish)) + layers.append( + ConvNormActivation( + 3, + firstconv_output_channels, + kernel_size=3, + stride=2, + norm_layer=norm_layer, + activation_layer=nn.Hardswish, + ) + ) # building inverted residual blocks for cnf in inverted_residual_setting: @@ -141,8 +185,15 @@ def __init__( # building last several layers lastconv_input_channels = inverted_residual_setting[-1].out_channels lastconv_output_channels = 6 * lastconv_input_channels - layers.append(ConvNormActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1, - norm_layer=norm_layer, activation_layer=nn.Hardswish)) + layers.append( + ConvNormActivation( + lastconv_input_channels, + lastconv_output_channels, + kernel_size=1, + norm_layer=norm_layer, + activation_layer=nn.Hardswish, + ) + ) self.features = nn.Sequential(*layers) self.avgpool = nn.AdaptiveAvgPool2d(1) @@ -155,7 +206,7 @@ def __init__( for m in self.modules(): if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out') + nn.init.kaiming_normal_(m.weight, mode="fan_out") if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): @@ -179,8 +230,9 @@ def forward(self, x: Tensor) -> Tensor: return self._forward_impl(x) -def _mobilenet_v3_conf(arch: str, width_mult: float = 1.0, reduced_tail: bool = False, dilated: bool = False, - **kwargs: Any): +def _mobilenet_v3_conf( + arch: str, width_mult: float = 1.0, reduced_tail: bool = False, dilated: bool = False, **kwargs: Any +): reduce_divider = 2 if reduced_tail else 1 dilation = 2 if dilated else 1 @@ -233,7 +285,7 @@ def _mobilenet_v3_model( last_channel: int, pretrained: bool, progress: bool, - **kwargs: Any + **kwargs: Any, ): model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs) if pretrained: diff --git a/torchvision/models/quantization/googlenet.py b/torchvision/models/quantization/googlenet.py index 685815ac676..4b6d25e013c 100644 --- a/torchvision/models/quantization/googlenet.py +++ b/torchvision/models/quantization/googlenet.py @@ -1,22 +1,21 @@ import warnings +from typing import Any + import torch import torch.nn as nn -from torch.nn import functional as F -from typing import Any from torch import Tensor +from torch.nn import functional as F +from torchvision.models.googlenet import GoogLeNetOutputs, BasicConv2d, Inception, InceptionAux, GoogLeNet, model_urls from ..._internally_replaced_utils import load_state_dict_from_url -from torchvision.models.googlenet import ( - GoogLeNetOutputs, BasicConv2d, Inception, InceptionAux, GoogLeNet, model_urls) - from .utils import _replace_relu, quantize_model -__all__ = ['QuantizableGoogLeNet', 'googlenet'] +__all__ = ["QuantizableGoogLeNet", "googlenet"] quant_model_urls = { # fp32 GoogLeNet ported from TensorFlow, with weights quantized in PyTorch - 'googlenet_fbgemm': 'https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth', + "googlenet_fbgemm": "https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth", } @@ -44,35 +43,35 @@ def googlenet( was trained on ImageNet. Default: *False* """ if pretrained: - if 'transform_input' not in kwargs: - kwargs['transform_input'] = True - if 'aux_logits' not in kwargs: - kwargs['aux_logits'] = False - if kwargs['aux_logits']: - warnings.warn('auxiliary heads in the pretrained googlenet model are NOT pretrained, ' - 'so make sure to train them') - original_aux_logits = kwargs['aux_logits'] - kwargs['aux_logits'] = True - kwargs['init_weights'] = False + if "transform_input" not in kwargs: + kwargs["transform_input"] = True + if "aux_logits" not in kwargs: + kwargs["aux_logits"] = False + if kwargs["aux_logits"]: + warnings.warn( + "auxiliary heads in the pretrained googlenet model are NOT pretrained, " "so make sure to train them" + ) + original_aux_logits = kwargs["aux_logits"] + kwargs["aux_logits"] = True + kwargs["init_weights"] = False model = QuantizableGoogLeNet(**kwargs) _replace_relu(model) if quantize: # TODO use pretrained as a string to specify the backend - backend = 'fbgemm' + backend = "fbgemm" quantize_model(model, backend) else: assert pretrained in [True, False] if pretrained: if quantize: - model_url = quant_model_urls['googlenet' + '_' + backend] + model_url = quant_model_urls["googlenet" + "_" + backend] else: - model_url = model_urls['googlenet'] + model_url = model_urls["googlenet"] - state_dict = load_state_dict_from_url(model_url, - progress=progress) + state_dict = load_state_dict_from_url(model_url, progress=progress) model.load_state_dict(state_dict) @@ -84,7 +83,6 @@ def googlenet( class QuantizableBasicConv2d(BasicConv2d): - def __init__(self, *args: Any, **kwargs: Any) -> None: super(QuantizableBasicConv2d, self).__init__(*args, **kwargs) self.relu = nn.ReLU() @@ -100,10 +98,10 @@ def fuse_model(self) -> None: class QuantizableInception(Inception): - def __init__(self, *args: Any, **kwargs: Any) -> None: super(QuantizableInception, self).__init__( # type: ignore[misc] - conv_block=QuantizableBasicConv2d, *args, **kwargs) + conv_block=QuantizableBasicConv2d, *args, **kwargs + ) self.cat = nn.quantized.FloatFunctional() def forward(self, x: Tensor) -> Tensor: @@ -115,9 +113,7 @@ class QuantizableInceptionAux(InceptionAux): # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 def __init__(self, *args: Any, **kwargs: Any) -> None: super(QuantizableInceptionAux, self).__init__( # type: ignore[misc] - conv_block=QuantizableBasicConv2d, - *args, - **kwargs + conv_block=QuantizableBasicConv2d, *args, **kwargs ) self.relu = nn.ReLU() self.dropout = nn.Dropout(0.7) @@ -144,9 +140,7 @@ class QuantizableGoogLeNet(GoogLeNet): # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 def __init__(self, *args: Any, **kwargs: Any) -> None: super(QuantizableGoogLeNet, self).__init__( # type: ignore[misc] - blocks=[QuantizableBasicConv2d, QuantizableInception, QuantizableInceptionAux], - *args, - **kwargs + blocks=[QuantizableBasicConv2d, QuantizableInception, QuantizableInceptionAux], *args, **kwargs ) self.quant = torch.quantization.QuantStub() self.dequant = torch.quantization.DeQuantStub() diff --git a/torchvision/models/quantization/inception.py b/torchvision/models/quantization/inception.py index 6c6384c295a..acad3f6df53 100644 --- a/torchvision/models/quantization/inception.py +++ b/torchvision/models/quantization/inception.py @@ -1,13 +1,13 @@ import warnings +from typing import Any, List import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor -from typing import Any, List - from torchvision.models import inception as inception_module from torchvision.models.inception import InceptionOutputs + from ..._internally_replaced_utils import load_state_dict_from_url from .utils import _replace_relu, quantize_model @@ -20,8 +20,7 @@ quant_model_urls = { # fp32 weights ported from TensorFlow, quantized in PyTorch - "inception_v3_google_fbgemm": - "https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth" + "inception_v3_google_fbgemm": "https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth" } @@ -66,7 +65,7 @@ def inception_v3( if quantize: # TODO use pretrained as a string to specify the backend - backend = 'fbgemm' + backend = "fbgemm" quantize_model(model, backend) else: assert pretrained in [True, False] @@ -76,12 +75,11 @@ def inception_v3( if not original_aux_logits: model.aux_logits = False model.AuxLogits = None - model_url = quant_model_urls['inception_v3_google' + '_' + backend] + model_url = quant_model_urls["inception_v3_google" + "_" + backend] else: - model_url = inception_module.model_urls['inception_v3_google'] + model_url = inception_module.model_urls["inception_v3_google"] - state_dict = load_state_dict_from_url(model_url, - progress=progress) + state_dict = load_state_dict_from_url(model_url, progress=progress) model.load_state_dict(state_dict) @@ -111,9 +109,7 @@ class QuantizableInceptionA(inception_module.InceptionA): # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 def __init__(self, *args: Any, **kwargs: Any) -> None: super(QuantizableInceptionA, self).__init__( # type: ignore[misc] - conv_block=QuantizableBasicConv2d, - *args, - **kwargs + conv_block=QuantizableBasicConv2d, *args, **kwargs ) self.myop = nn.quantized.FloatFunctional() @@ -126,9 +122,7 @@ class QuantizableInceptionB(inception_module.InceptionB): # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 def __init__(self, *args: Any, **kwargs: Any) -> None: super(QuantizableInceptionB, self).__init__( # type: ignore[misc] - conv_block=QuantizableBasicConv2d, - *args, - **kwargs + conv_block=QuantizableBasicConv2d, *args, **kwargs ) self.myop = nn.quantized.FloatFunctional() @@ -141,9 +135,7 @@ class QuantizableInceptionC(inception_module.InceptionC): # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 def __init__(self, *args: Any, **kwargs: Any) -> None: super(QuantizableInceptionC, self).__init__( # type: ignore[misc] - conv_block=QuantizableBasicConv2d, - *args, - **kwargs + conv_block=QuantizableBasicConv2d, *args, **kwargs ) self.myop = nn.quantized.FloatFunctional() @@ -156,9 +148,7 @@ class QuantizableInceptionD(inception_module.InceptionD): # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 def __init__(self, *args: Any, **kwargs: Any) -> None: super(QuantizableInceptionD, self).__init__( # type: ignore[misc] - conv_block=QuantizableBasicConv2d, - *args, - **kwargs + conv_block=QuantizableBasicConv2d, *args, **kwargs ) self.myop = nn.quantized.FloatFunctional() @@ -171,9 +161,7 @@ class QuantizableInceptionE(inception_module.InceptionE): # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 def __init__(self, *args: Any, **kwargs: Any) -> None: super(QuantizableInceptionE, self).__init__( # type: ignore[misc] - conv_block=QuantizableBasicConv2d, - *args, - **kwargs + conv_block=QuantizableBasicConv2d, *args, **kwargs ) self.myop1 = nn.quantized.FloatFunctional() self.myop2 = nn.quantized.FloatFunctional() @@ -209,9 +197,7 @@ class QuantizableInceptionAux(inception_module.InceptionAux): # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 def __init__(self, *args: Any, **kwargs: Any) -> None: super(QuantizableInceptionAux, self).__init__( # type: ignore[misc] - conv_block=QuantizableBasicConv2d, - *args, - **kwargs + conv_block=QuantizableBasicConv2d, *args, **kwargs ) @@ -233,8 +219,8 @@ def __init__( QuantizableInceptionC, QuantizableInceptionD, QuantizableInceptionE, - QuantizableInceptionAux - ] + QuantizableInceptionAux, + ], ) self.quant = torch.quantization.QuantStub() self.dequant = torch.quantization.DeQuantStub() diff --git a/torchvision/models/quantization/mobilenetv2.py b/torchvision/models/quantization/mobilenetv2.py index 2349afff447..a2c88cdd388 100644 --- a/torchvision/models/quantization/mobilenetv2.py +++ b/torchvision/models/quantization/mobilenetv2.py @@ -1,21 +1,19 @@ -from torch import nn -from torch import Tensor - -from ..._internally_replaced_utils import load_state_dict_from_url - from typing import Any -from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, model_urls +from torch import Tensor +from torch import nn from torch.quantization import QuantStub, DeQuantStub, fuse_modules -from .utils import _replace_relu, quantize_model +from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, model_urls + +from ..._internally_replaced_utils import load_state_dict_from_url from ...ops.misc import ConvNormActivation +from .utils import _replace_relu, quantize_model -__all__ = ['QuantizableMobileNetV2', 'mobilenet_v2'] +__all__ = ["QuantizableMobileNetV2", "mobilenet_v2"] quant_model_urls = { - 'mobilenet_v2_qnnpack': - 'https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth' + "mobilenet_v2_qnnpack": "https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth" } @@ -57,7 +55,7 @@ def forward(self, x: Tensor) -> Tensor: def fuse_model(self) -> None: for m in self.modules(): if type(m) == ConvNormActivation: - fuse_modules(m, ['0', '1', '2'], inplace=True) + fuse_modules(m, ["0", "1", "2"], inplace=True) if type(m) == QuantizableInvertedResidual: m.fuse_model() @@ -87,19 +85,18 @@ def mobilenet_v2( if quantize: # TODO use pretrained as a string to specify the backend - backend = 'qnnpack' + backend = "qnnpack" quantize_model(model, backend) else: assert pretrained in [True, False] if pretrained: if quantize: - model_url = quant_model_urls['mobilenet_v2_' + backend] + model_url = quant_model_urls["mobilenet_v2_" + backend] else: - model_url = model_urls['mobilenet_v2'] + model_url = model_urls["mobilenet_v2"] - state_dict = load_state_dict_from_url(model_url, - progress=progress) + state_dict = load_state_dict_from_url(model_url, progress=progress) model.load_state_dict(state_dict) return model diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py index 8655a9b0a45..ad195d178c7 100644 --- a/torchvision/models/quantization/mobilenetv3.py +++ b/torchvision/models/quantization/mobilenetv3.py @@ -1,19 +1,19 @@ +from typing import Any, List, Optional + import torch from torch import nn, Tensor +from torch.quantization import QuantStub, DeQuantStub, fuse_modules + from ..._internally_replaced_utils import load_state_dict_from_url from ...ops.misc import ConvNormActivation, SqueezeExcitation -from ..mobilenetv3 import InvertedResidual, InvertedResidualConfig, MobileNetV3,\ - model_urls, _mobilenet_v3_conf -from torch.quantization import QuantStub, DeQuantStub, fuse_modules -from typing import Any, List, Optional +from ..mobilenetv3 import InvertedResidual, InvertedResidualConfig, MobileNetV3, model_urls, _mobilenet_v3_conf from .utils import _replace_relu -__all__ = ['QuantizableMobileNetV3', 'mobilenet_v3_large'] +__all__ = ["QuantizableMobileNetV3", "mobilenet_v3_large"] quant_model_urls = { - 'mobilenet_v3_large_qnnpack': - "https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth", + "mobilenet_v3_large_qnnpack": "https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth", } @@ -29,7 +29,7 @@ def forward(self, input: Tensor) -> Tensor: return self.skip_mul.mul(self._scale(input), input) def fuse_model(self) -> None: - fuse_modules(self, ['fc1', 'activation'], inplace=True) + fuse_modules(self, ["fc1", "activation"], inplace=True) def _load_from_state_dict( self, @@ -45,7 +45,7 @@ def _load_from_state_dict( if version is None or version < 2: default_state_dict = { - "scale_activation.activation_post_process.scale": torch.tensor([1.]), + "scale_activation.activation_post_process.scale": torch.tensor([1.0]), "scale_activation.activation_post_process.zero_point": torch.tensor([0], dtype=torch.int32), "scale_activation.activation_post_process.fake_quant_enabled": torch.tensor([1]), "scale_activation.activation_post_process.observer_enabled": torch.tensor([1]), @@ -69,11 +69,7 @@ def _load_from_state_dict( class QuantizableInvertedResidual(InvertedResidual): # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__( # type: ignore[misc] - se_layer=QuantizableSqueezeExcitation, - *args, - **kwargs - ) + super().__init__(se_layer=QuantizableSqueezeExcitation, *args, **kwargs) # type: ignore[misc] self.skip_add = nn.quantized.FloatFunctional() def forward(self, x: Tensor) -> Tensor: @@ -104,20 +100,15 @@ def forward(self, x: Tensor) -> Tensor: def fuse_model(self) -> None: for m in self.modules(): if type(m) == ConvNormActivation: - modules_to_fuse = ['0', '1'] + modules_to_fuse = ["0", "1"] if len(m) == 3 and type(m[2]) == nn.ReLU: - modules_to_fuse.append('2') + modules_to_fuse.append("2") fuse_modules(m, modules_to_fuse, inplace=True) elif type(m) == QuantizableSqueezeExcitation: m.fuse_model() -def _load_weights( - arch: str, - model: QuantizableMobileNetV3, - model_url: Optional[str], - progress: bool -) -> None: +def _load_weights(arch: str, model: QuantizableMobileNetV3, model_url: Optional[str], progress: bool) -> None: if model_url is None: raise ValueError("No checkpoint is available for {}".format(arch)) state_dict = load_state_dict_from_url(model_url, progress=progress) @@ -138,14 +129,14 @@ def _mobilenet_v3_model( _replace_relu(model) if quantize: - backend = 'qnnpack' + backend = "qnnpack" model.fuse_model() model.qconfig = torch.quantization.get_default_qat_qconfig(backend) torch.quantization.prepare_qat(model, inplace=True) if pretrained: - _load_weights(arch, model, quant_model_urls.get(arch + '_' + backend, None), progress) + _load_weights(arch, model, quant_model_urls.get(arch + "_" + backend, None), progress) torch.quantization.convert(model, inplace=True) model.eval() diff --git a/torchvision/models/quantization/resnet.py b/torchvision/models/quantization/resnet.py index 8f87e40ec3d..f7124798254 100644 --- a/torchvision/models/quantization/resnet.py +++ b/torchvision/models/quantization/resnet.py @@ -1,24 +1,21 @@ +from typing import Any, Type, Union, List + import torch -from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet, model_urls import torch.nn as nn from torch import Tensor -from typing import Any, Type, Union, List +from torch.quantization import fuse_modules +from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet, model_urls from ..._internally_replaced_utils import load_state_dict_from_url -from torch.quantization import fuse_modules from .utils import _replace_relu, quantize_model -__all__ = ['QuantizableResNet', 'resnet18', 'resnet50', - 'resnext101_32x8d'] +__all__ = ["QuantizableResNet", "resnet18", "resnet50", "resnext101_32x8d"] quant_model_urls = { - 'resnet18_fbgemm': - 'https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth', - 'resnet50_fbgemm': - 'https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth', - 'resnext101_32x8d_fbgemm': - 'https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth', + "resnet18_fbgemm": "https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth", + "resnet50_fbgemm": "https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth", + "resnext101_32x8d_fbgemm": "https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth", } @@ -45,10 +42,9 @@ def forward(self, x: Tensor) -> Tensor: return out def fuse_model(self) -> None: - torch.quantization.fuse_modules(self, [['conv1', 'bn1', 'relu'], - ['conv2', 'bn2']], inplace=True) + torch.quantization.fuse_modules(self, [["conv1", "bn1", "relu"], ["conv2", "bn2"]], inplace=True) if self.downsample: - torch.quantization.fuse_modules(self.downsample, ['0', '1'], inplace=True) + torch.quantization.fuse_modules(self.downsample, ["0", "1"], inplace=True) class QuantizableBottleneck(Bottleneck): @@ -77,15 +73,12 @@ def forward(self, x: Tensor) -> Tensor: return out def fuse_model(self) -> None: - fuse_modules(self, [['conv1', 'bn1', 'relu1'], - ['conv2', 'bn2', 'relu2'], - ['conv3', 'bn3']], inplace=True) + fuse_modules(self, [["conv1", "bn1", "relu1"], ["conv2", "bn2", "relu2"], ["conv3", "bn3"]], inplace=True) if self.downsample: - torch.quantization.fuse_modules(self.downsample, ['0', '1'], inplace=True) + torch.quantization.fuse_modules(self.downsample, ["0", "1"], inplace=True) class QuantizableResNet(ResNet): - def __init__(self, *args: Any, **kwargs: Any) -> None: super(QuantizableResNet, self).__init__(*args, **kwargs) @@ -109,7 +102,7 @@ def fuse_model(self) -> None: and the model after modification is in floating point """ - fuse_modules(self, ['conv1', 'bn1', 'relu'], inplace=True) + fuse_modules(self, ["conv1", "bn1", "relu"], inplace=True) for m in self.modules(): if type(m) == QuantizableBottleneck or type(m) == QuantizableBasicBlock: m.fuse_model() @@ -129,19 +122,18 @@ def _resnet( _replace_relu(model) if quantize: # TODO use pretrained as a string to specify the backend - backend = 'fbgemm' + backend = "fbgemm" quantize_model(model, backend) else: assert pretrained in [True, False] if pretrained: if quantize: - model_url = quant_model_urls[arch + '_' + backend] + model_url = quant_model_urls[arch + "_" + backend] else: model_url = model_urls[arch] - state_dict = load_state_dict_from_url(model_url, - progress=progress) + state_dict = load_state_dict_from_url(model_url, progress=progress) model.load_state_dict(state_dict) return model @@ -161,8 +153,7 @@ def resnet18( progress (bool): If True, displays a progress bar of the download to stderr quantize (bool): If True, return a quantized version of the model """ - return _resnet('resnet18', QuantizableBasicBlock, [2, 2, 2, 2], pretrained, progress, - quantize, **kwargs) + return _resnet("resnet18", QuantizableBasicBlock, [2, 2, 2, 2], pretrained, progress, quantize, **kwargs) def resnet50( @@ -180,8 +171,7 @@ def resnet50( progress (bool): If True, displays a progress bar of the download to stderr quantize (bool): If True, return a quantized version of the model """ - return _resnet('resnet50', QuantizableBottleneck, [3, 4, 6, 3], pretrained, progress, - quantize, **kwargs) + return _resnet("resnet50", QuantizableBottleneck, [3, 4, 6, 3], pretrained, progress, quantize, **kwargs) def resnext101_32x8d( @@ -198,7 +188,6 @@ def resnext101_32x8d( progress (bool): If True, displays a progress bar of the download to stderr quantize (bool): If True, return a quantized version of the model """ - kwargs['groups'] = 32 - kwargs['width_per_group'] = 8 - return _resnet('resnext101_32x8d', QuantizableBottleneck, [3, 4, 23, 3], - pretrained, progress, quantize, **kwargs) + kwargs["groups"] = 32 + kwargs["width_per_group"] = 8 + return _resnet("resnext101_32x8d", QuantizableBottleneck, [3, 4, 23, 3], pretrained, progress, quantize, **kwargs) diff --git a/torchvision/models/quantization/shufflenetv2.py b/torchvision/models/quantization/shufflenetv2.py index 4f0861dcb30..a4c4aede665 100644 --- a/torchvision/models/quantization/shufflenetv2.py +++ b/torchvision/models/quantization/shufflenetv2.py @@ -1,23 +1,26 @@ +from typing import Any + import torch import torch.nn as nn from torch import Tensor -from typing import Any +from torchvision.models import shufflenetv2 from ..._internally_replaced_utils import load_state_dict_from_url -from torchvision.models import shufflenetv2 from .utils import _replace_relu, quantize_model __all__ = [ - 'QuantizableShuffleNetV2', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', - 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0' + "QuantizableShuffleNetV2", + "shufflenet_v2_x0_5", + "shufflenet_v2_x1_0", + "shufflenet_v2_x1_5", + "shufflenet_v2_x2_0", ] quant_model_urls = { - 'shufflenetv2_x0.5_fbgemm': None, - 'shufflenetv2_x1.0_fbgemm': - 'https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth', - 'shufflenetv2_x1.5_fbgemm': None, - 'shufflenetv2_x2.0_fbgemm': None, + "shufflenetv2_x0.5_fbgemm": None, + "shufflenetv2_x1.0_fbgemm": "https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth", + "shufflenetv2_x1.5_fbgemm": None, + "shufflenetv2_x2.0_fbgemm": None, } @@ -42,9 +45,7 @@ class QuantizableShuffleNetV2(shufflenetv2.ShuffleNetV2): # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 def __init__(self, *args: Any, **kwargs: Any) -> None: super(QuantizableShuffleNetV2, self).__init__( # type: ignore[misc] - *args, - inverted_residual=QuantizableInvertedResidual, - **kwargs + *args, inverted_residual=QuantizableInvertedResidual, **kwargs ) self.quant = torch.quantization.QuantStub() self.dequant = torch.quantization.DeQuantStub() @@ -69,9 +70,7 @@ def fuse_model(self) -> None: for m in self.modules(): if type(m) == QuantizableInvertedResidual: if len(m.branch1._modules.items()) > 0: - torch.quantization.fuse_modules( - m.branch1, [["0", "1"], ["2", "3", "4"]], inplace=True - ) + torch.quantization.fuse_modules(m.branch1, [["0", "1"], ["2", "3", "4"]], inplace=True) torch.quantization.fuse_modules( m.branch2, [["0", "1", "2"], ["3", "4"], ["5", "6", "7"]], @@ -93,19 +92,18 @@ def _shufflenetv2( if quantize: # TODO use pretrained as a string to specify the backend - backend = 'fbgemm' + backend = "fbgemm" quantize_model(model, backend) else: assert pretrained in [True, False] if pretrained: if quantize: - model_url = quant_model_urls[arch + '_' + backend] + model_url = quant_model_urls[arch + "_" + backend] else: model_url = shufflenetv2.model_urls[arch] - state_dict = load_state_dict_from_url(model_url, - progress=progress) + state_dict = load_state_dict_from_url(model_url, progress=progress) model.load_state_dict(state_dict) return model @@ -127,8 +125,9 @@ def shufflenet_v2_x0_5( progress (bool): If True, displays a progress bar of the download to stderr quantize (bool): If True, return a quantized version of the model """ - return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress, quantize, - [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) + return _shufflenetv2( + "shufflenetv2_x0.5", pretrained, progress, quantize, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs + ) def shufflenet_v2_x1_0( @@ -147,8 +146,9 @@ def shufflenet_v2_x1_0( progress (bool): If True, displays a progress bar of the download to stderr quantize (bool): If True, return a quantized version of the model """ - return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress, quantize, - [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) + return _shufflenetv2( + "shufflenetv2_x1.0", pretrained, progress, quantize, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs + ) def shufflenet_v2_x1_5( @@ -167,8 +167,9 @@ def shufflenet_v2_x1_5( progress (bool): If True, displays a progress bar of the download to stderr quantize (bool): If True, return a quantized version of the model """ - return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress, quantize, - [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) + return _shufflenetv2( + "shufflenetv2_x1.5", pretrained, progress, quantize, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs + ) def shufflenet_v2_x2_0( @@ -187,5 +188,6 @@ def shufflenet_v2_x2_0( progress (bool): If True, displays a progress bar of the download to stderr quantize (bool): If True, return a quantized version of the model """ - return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress, quantize, - [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs) + return _shufflenetv2( + "shufflenetv2_x2.0", pretrained, progress, quantize, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs + ) diff --git a/torchvision/models/quantization/utils.py b/torchvision/models/quantization/utils.py index c195d162482..74a8287030b 100644 --- a/torchvision/models/quantization/utils.py +++ b/torchvision/models/quantization/utils.py @@ -23,14 +23,15 @@ def quantize_model(model: nn.Module, backend: str) -> None: torch.backends.quantized.engine = backend model.eval() # Make sure that weight qconfig matches that of the serialized models - if backend == 'fbgemm': + if backend == "fbgemm": model.qconfig = torch.quantization.QConfig( # type: ignore[assignment] activation=torch.quantization.default_observer, - weight=torch.quantization.default_per_channel_weight_observer) - elif backend == 'qnnpack': + weight=torch.quantization.default_per_channel_weight_observer, + ) + elif backend == "qnnpack": model.qconfig = torch.quantization.QConfig( # type: ignore[assignment] - activation=torch.quantization.default_observer, - weight=torch.quantization.default_weight_observer) + activation=torch.quantization.default_observer, weight=torch.quantization.default_weight_observer + ) # TODO https://github.com/pytorch/vision/pull/4232#pullrequestreview-730461659 model.fuse_model() # type: ignore[operator] diff --git a/torchvision/models/regnet.py b/torchvision/models/regnet.py index 0bd89f7799f..7736fa4cfe0 100644 --- a/torchvision/models/regnet.py +++ b/torchvision/models/regnet.py @@ -4,11 +4,11 @@ import math -import torch - from collections import OrderedDict from functools import partial from typing import Any, Callable, List, Optional, Tuple + +import torch from torch import nn, Tensor from .._internally_replaced_utils import load_state_dict_from_url @@ -16,10 +16,23 @@ from ._utils import _make_divisible -__all__ = ["RegNet", "regnet_y_400mf", "regnet_y_800mf", "regnet_y_1_6gf", - "regnet_y_3_2gf", "regnet_y_8gf", "regnet_y_16gf", "regnet_y_32gf", - "regnet_x_400mf", "regnet_x_800mf", "regnet_x_1_6gf", "regnet_x_3_2gf", - "regnet_x_8gf", "regnet_x_16gf", "regnet_x_32gf"] +__all__ = [ + "RegNet", + "regnet_y_400mf", + "regnet_y_800mf", + "regnet_y_1_6gf", + "regnet_y_3_2gf", + "regnet_y_8gf", + "regnet_y_16gf", + "regnet_y_32gf", + "regnet_x_400mf", + "regnet_x_800mf", + "regnet_x_1_6gf", + "regnet_x_3_2gf", + "regnet_x_8gf", + "regnet_x_16gf", + "regnet_x_32gf", +] model_urls = { @@ -42,8 +55,9 @@ def __init__( norm_layer: Callable[..., nn.Module], activation_layer: Callable[..., nn.Module], ) -> None: - super().__init__(width_in, width_out, kernel_size=3, stride=2, - norm_layer=norm_layer, activation_layer=activation_layer) + super().__init__( + width_in, width_out, kernel_size=3, stride=2, norm_layer=norm_layer, activation_layer=activation_layer + ) class BottleneckTransform(nn.Sequential): @@ -64,10 +78,12 @@ def __init__( w_b = int(round(width_out * bottleneck_multiplier)) g = w_b // group_width - layers["a"] = ConvNormActivation(width_in, w_b, kernel_size=1, stride=1, - norm_layer=norm_layer, activation_layer=activation_layer) - layers["b"] = ConvNormActivation(w_b, w_b, kernel_size=3, stride=stride, groups=g, - norm_layer=norm_layer, activation_layer=activation_layer) + layers["a"] = ConvNormActivation( + width_in, w_b, kernel_size=1, stride=1, norm_layer=norm_layer, activation_layer=activation_layer + ) + layers["b"] = ConvNormActivation( + w_b, w_b, kernel_size=3, stride=stride, groups=g, norm_layer=norm_layer, activation_layer=activation_layer + ) if se_ratio: # The SE reduction ratio is defined with respect to the @@ -79,8 +95,9 @@ def __init__( activation=activation_layer, ) - layers["c"] = ConvNormActivation(w_b, width_out, kernel_size=1, stride=1, - norm_layer=norm_layer, activation_layer=None) + layers["c"] = ConvNormActivation( + w_b, width_out, kernel_size=1, stride=1, norm_layer=norm_layer, activation_layer=None + ) super().__init__(layers) @@ -104,8 +121,9 @@ def __init__( self.proj = None should_proj = (width_in != width_out) or (stride != 1) if should_proj: - self.proj = ConvNormActivation(width_in, width_out, kernel_size=1, - stride=stride, norm_layer=norm_layer, activation_layer=None) + self.proj = ConvNormActivation( + width_in, width_out, kernel_size=1, stride=stride, norm_layer=norm_layer, activation_layer=None + ) self.f = BottleneckTransform( width_in, width_out, @@ -217,10 +235,7 @@ def from_init_params( # Compute the block widths. Each stage has one unique block width widths_cont = torch.arange(depth) * w_a + w_0 block_capacity = torch.round(torch.log(widths_cont / w_0) / math.log(w_m)) - block_widths = ( - torch.round(torch.divide(w_0 * torch.pow(w_m, block_capacity), QUANT)) - * QUANT - ).int().tolist() + block_widths = (torch.round(torch.divide(w_0 * torch.pow(w_m, block_capacity), QUANT)) * QUANT).int().tolist() num_stages = len(set(block_widths)) # Convert to per stage parameters @@ -254,14 +269,12 @@ def from_init_params( ) def _get_expanded_params(self): - return zip( - self.widths, self.strides, self.depths, self.group_widths, self.bottleneck_multipliers - ) + return zip(self.widths, self.strides, self.depths, self.group_widths, self.bottleneck_multipliers) @staticmethod def _adjust_widths_groups_compatibilty( - stage_widths: List[int], bottleneck_ratios: List[float], - group_widths: List[int]) -> Tuple[List[int], List[int]]: + stage_widths: List[int], bottleneck_ratios: List[float], group_widths: List[int] + ) -> Tuple[List[int], List[int]]: """ Adjusts the compatibility of widths and groups, depending on the bottleneck ratio. @@ -389,8 +402,7 @@ def regnet_y_400mf(pretrained: bool = False, progress: bool = True, **kwargs: An pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - params = BlockParams.from_init_params(depth=16, w_0=48, w_a=27.89, w_m=2.09, - group_width=8, se_ratio=0.25, **kwargs) + params = BlockParams.from_init_params(depth=16, w_0=48, w_a=27.89, w_m=2.09, group_width=8, se_ratio=0.25, **kwargs) return _regnet("regnet_y_400mf", params, pretrained, progress, **kwargs) @@ -403,8 +415,7 @@ def regnet_y_800mf(pretrained: bool = False, progress: bool = True, **kwargs: An pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - params = BlockParams.from_init_params(depth=14, w_0=56, w_a=38.84, w_m=2.4, - group_width=16, se_ratio=0.25, **kwargs) + params = BlockParams.from_init_params(depth=14, w_0=56, w_a=38.84, w_m=2.4, group_width=16, se_ratio=0.25, **kwargs) return _regnet("regnet_y_800mf", params, pretrained, progress, **kwargs) @@ -417,8 +428,9 @@ def regnet_y_1_6gf(pretrained: bool = False, progress: bool = True, **kwargs: An pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - params = BlockParams.from_init_params(depth=27, w_0=48, w_a=20.71, w_m=2.65, - group_width=24, se_ratio=0.25, **kwargs) + params = BlockParams.from_init_params( + depth=27, w_0=48, w_a=20.71, w_m=2.65, group_width=24, se_ratio=0.25, **kwargs + ) return _regnet("regnet_y_1_6gf", params, pretrained, progress, **kwargs) @@ -431,8 +443,9 @@ def regnet_y_3_2gf(pretrained: bool = False, progress: bool = True, **kwargs: An pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - params = BlockParams.from_init_params(depth=21, w_0=80, w_a=42.63, w_m=2.66, - group_width=24, se_ratio=0.25, **kwargs) + params = BlockParams.from_init_params( + depth=21, w_0=80, w_a=42.63, w_m=2.66, group_width=24, se_ratio=0.25, **kwargs + ) return _regnet("regnet_y_3_2gf", params, pretrained, progress, **kwargs) @@ -445,8 +458,9 @@ def regnet_y_8gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - params = BlockParams.from_init_params(depth=17, w_0=192, w_a=76.82, w_m=2.19, - group_width=56, se_ratio=0.25, **kwargs) + params = BlockParams.from_init_params( + depth=17, w_0=192, w_a=76.82, w_m=2.19, group_width=56, se_ratio=0.25, **kwargs + ) return _regnet("regnet_y_8gf", params, pretrained, progress, **kwargs) @@ -459,8 +473,9 @@ def regnet_y_16gf(pretrained: bool = False, progress: bool = True, **kwargs: Any pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - params = BlockParams.from_init_params(depth=18, w_0=200, w_a=106.23, w_m=2.48, - group_width=112, se_ratio=0.25, **kwargs) + params = BlockParams.from_init_params( + depth=18, w_0=200, w_a=106.23, w_m=2.48, group_width=112, se_ratio=0.25, **kwargs + ) return _regnet("regnet_y_16gf", params, pretrained, progress, **kwargs) @@ -473,8 +488,9 @@ def regnet_y_32gf(pretrained: bool = False, progress: bool = True, **kwargs: Any pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - params = BlockParams.from_init_params(depth=20, w_0=232, w_a=115.89, w_m=2.53, - group_width=232, se_ratio=0.25, **kwargs) + params = BlockParams.from_init_params( + depth=20, w_0=232, w_a=115.89, w_m=2.53, group_width=232, se_ratio=0.25, **kwargs + ) return _regnet("regnet_y_32gf", params, pretrained, progress, **kwargs) @@ -487,8 +503,7 @@ def regnet_x_400mf(pretrained: bool = False, progress: bool = True, **kwargs: An pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - params = BlockParams.from_init_params(depth=22, w_0=24, w_a=24.48, w_m=2.54, - group_width=16, **kwargs) + params = BlockParams.from_init_params(depth=22, w_0=24, w_a=24.48, w_m=2.54, group_width=16, **kwargs) return _regnet("regnet_x_400mf", params, pretrained, progress, **kwargs) @@ -501,8 +516,7 @@ def regnet_x_800mf(pretrained: bool = False, progress: bool = True, **kwargs: An pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - params = BlockParams.from_init_params(depth=16, w_0=56, w_a=35.73, w_m=2.28, - group_width=16, **kwargs) + params = BlockParams.from_init_params(depth=16, w_0=56, w_a=35.73, w_m=2.28, group_width=16, **kwargs) return _regnet("regnet_x_800mf", params, pretrained, progress, **kwargs) @@ -515,8 +529,7 @@ def regnet_x_1_6gf(pretrained: bool = False, progress: bool = True, **kwargs: An pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - params = BlockParams.from_init_params(depth=18, w_0=80, w_a=34.01, w_m=2.25, - group_width=24, **kwargs) + params = BlockParams.from_init_params(depth=18, w_0=80, w_a=34.01, w_m=2.25, group_width=24, **kwargs) return _regnet("regnet_x_1_6gf", params, pretrained, progress, **kwargs) @@ -529,8 +542,7 @@ def regnet_x_3_2gf(pretrained: bool = False, progress: bool = True, **kwargs: An pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - params = BlockParams.from_init_params(depth=25, w_0=88, w_a=26.31, w_m=2.25, - group_width=48, **kwargs) + params = BlockParams.from_init_params(depth=25, w_0=88, w_a=26.31, w_m=2.25, group_width=48, **kwargs) return _regnet("regnet_x_3_2gf", params, pretrained, progress, **kwargs) @@ -543,8 +555,7 @@ def regnet_x_8gf(pretrained: bool = False, progress: bool = True, **kwargs: Any) pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - params = BlockParams.from_init_params(depth=23, w_0=80, w_a=49.56, w_m=2.88, - group_width=120, **kwargs) + params = BlockParams.from_init_params(depth=23, w_0=80, w_a=49.56, w_m=2.88, group_width=120, **kwargs) return _regnet("regnet_x_8gf", params, pretrained, progress, **kwargs) @@ -557,8 +568,7 @@ def regnet_x_16gf(pretrained: bool = False, progress: bool = True, **kwargs: Any pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - params = BlockParams.from_init_params(depth=22, w_0=216, w_a=55.59, w_m=2.1, - group_width=128, **kwargs) + params = BlockParams.from_init_params(depth=22, w_0=216, w_a=55.59, w_m=2.1, group_width=128, **kwargs) return _regnet("regnet_x_16gf", params, pretrained, progress, **kwargs) @@ -571,8 +581,8 @@ def regnet_x_32gf(pretrained: bool = False, progress: bool = True, **kwargs: Any pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - params = BlockParams.from_init_params(depth=23, w_0=320, w_a=69.86, w_m=2.0, - group_width=168, **kwargs) + params = BlockParams.from_init_params(depth=23, w_0=320, w_a=69.86, w_m=2.0, group_width=168, **kwargs) return _regnet("regnet_x_32gf", params, pretrained, progress, **kwargs) + # TODO(kazhang): Add RegNetZ_500MF and RegNetZ_4GF diff --git a/torchvision/models/resnet.py b/torchvision/models/resnet.py index 6d708767441..7584ebb98ea 100644 --- a/torchvision/models/resnet.py +++ b/torchvision/models/resnet.py @@ -1,32 +1,51 @@ +from typing import Type, Any, Callable, Union, List, Optional + import torch -from torch import Tensor import torch.nn as nn +from torch import Tensor + from .._internally_replaced_utils import load_state_dict_from_url -from typing import Type, Any, Callable, Union, List, Optional -__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', - 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', - 'wide_resnet50_2', 'wide_resnet101_2'] +__all__ = [ + "ResNet", + "resnet18", + "resnet34", + "resnet50", + "resnet101", + "resnet152", + "resnext50_32x4d", + "resnext101_32x8d", + "wide_resnet50_2", + "wide_resnet101_2", +] model_urls = { - 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth', - 'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth', - 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth', - 'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth', - 'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth', - 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', - 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', - 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', - 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', + "resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth", + "resnet34": "https://download.pytorch.org/models/resnet34-b627a593.pth", + "resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth", + "resnet101": "https://download.pytorch.org/models/resnet101-63fe2227.pth", + "resnet152": "https://download.pytorch.org/models/resnet152-394f9c45.pth", + "resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", + "resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", + "wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth", + "wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth", } def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: """3x3 convolution with padding""" - return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, - padding=dilation, groups=groups, bias=False, dilation=dilation) + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation, + ) def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: @@ -46,13 +65,13 @@ def __init__( groups: int = 1, base_width: int = 64, dilation: int = 1, - norm_layer: Optional[Callable[..., nn.Module]] = None + norm_layer: Optional[Callable[..., nn.Module]] = None, ) -> None: super(BasicBlock, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d if groups != 1 or base_width != 64: - raise ValueError('BasicBlock only supports groups=1 and base_width=64') + raise ValueError("BasicBlock only supports groups=1 and base_width=64") if dilation > 1: raise NotImplementedError("Dilation > 1 not supported in BasicBlock") # Both self.conv1 and self.downsample layers downsample the input when stride != 1 @@ -101,12 +120,12 @@ def __init__( groups: int = 1, base_width: int = 64, dilation: int = 1, - norm_layer: Optional[Callable[..., nn.Module]] = None + norm_layer: Optional[Callable[..., nn.Module]] = None, ) -> None: super(Bottleneck, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d - width = int(planes * (base_width / 64.)) * groups + width = int(planes * (base_width / 64.0)) * groups # Both self.conv2 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv1x1(inplanes, width) self.bn1 = norm_layer(width) @@ -142,7 +161,6 @@ def forward(self, x: Tensor) -> Tensor: class ResNet(nn.Module): - def __init__( self, block: Type[Union[BasicBlock, Bottleneck]], @@ -152,7 +170,7 @@ def __init__( groups: int = 1, width_per_group: int = 64, replace_stride_with_dilation: Optional[List[bool]] = None, - norm_layer: Optional[Callable[..., nn.Module]] = None + norm_layer: Optional[Callable[..., nn.Module]] = None, ) -> None: super(ResNet, self).__init__() if norm_layer is None: @@ -166,28 +184,26 @@ def __init__( # the 2x2 stride with a dilated convolution instead replace_stride_with_dilation = [False, False, False] if len(replace_stride_with_dilation) != 3: - raise ValueError("replace_stride_with_dilation should be None " - "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + raise ValueError( + "replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation) + ) self.groups = groups self.base_width = width_per_group - self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, - bias=False) + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = norm_layer(self.inplanes) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) - self.layer2 = self._make_layer(block, 128, layers[1], stride=2, - dilate=replace_stride_with_dilation[0]) - self.layer3 = self._make_layer(block, 256, layers[2], stride=2, - dilate=replace_stride_with_dilation[1]) - self.layer4 = self._make_layer(block, 512, layers[3], stride=2, - dilate=replace_stride_with_dilation[2]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(512 * block.expansion, num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) @@ -202,8 +218,14 @@ def __init__( elif isinstance(m, BasicBlock): nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] - def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, - stride: int = 1, dilate: bool = False) -> nn.Sequential: + def _make_layer( + self, + block: Type[Union[BasicBlock, Bottleneck]], + planes: int, + blocks: int, + stride: int = 1, + dilate: bool = False, + ) -> nn.Sequential: norm_layer = self._norm_layer downsample = None previous_dilation = self.dilation @@ -217,13 +239,23 @@ def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, b ) layers = [] - layers.append(block(self.inplanes, planes, stride, downsample, self.groups, - self.base_width, previous_dilation, norm_layer)) + layers.append( + block( + self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer + ) + ) self.inplanes = planes * block.expansion for _ in range(1, blocks): - layers.append(block(self.inplanes, planes, groups=self.groups, - base_width=self.base_width, dilation=self.dilation, - norm_layer=norm_layer)) + layers.append( + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm_layer=norm_layer, + ) + ) return nn.Sequential(*layers) @@ -255,12 +287,11 @@ def _resnet( layers: List[int], pretrained: bool, progress: bool, - **kwargs: Any + **kwargs: Any, ) -> ResNet: model = ResNet(block, layers, **kwargs) if pretrained: - state_dict = load_state_dict_from_url(model_urls[arch], - progress=progress) + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) model.load_state_dict(state_dict) return model @@ -273,8 +304,7 @@ def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, - **kwargs) + return _resnet("resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs) def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: @@ -285,8 +315,7 @@ def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, - **kwargs) + return _resnet("resnet34", BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs) def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: @@ -297,8 +326,7 @@ def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, - **kwargs) + return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: @@ -309,8 +337,7 @@ def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, - **kwargs) + return _resnet("resnet101", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: @@ -321,8 +348,7 @@ def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, - **kwargs) + return _resnet("resnet152", Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs) def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: @@ -333,10 +359,9 @@ def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: A pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - kwargs['groups'] = 32 - kwargs['width_per_group'] = 4 - return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], - pretrained, progress, **kwargs) + kwargs["groups"] = 32 + kwargs["width_per_group"] = 4 + return _resnet("resnext50_32x4d", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: @@ -347,10 +372,9 @@ def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - kwargs['groups'] = 32 - kwargs['width_per_group'] = 8 - return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], - pretrained, progress, **kwargs) + kwargs["groups"] = 32 + kwargs["width_per_group"] = 8 + return _resnet("resnext101_32x8d", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: @@ -366,9 +390,8 @@ def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: A pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - kwargs['width_per_group'] = 64 * 2 - return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], - pretrained, progress, **kwargs) + kwargs["width_per_group"] = 64 * 2 + return _resnet("wide_resnet50_2", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: @@ -384,6 +407,5 @@ def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - kwargs['width_per_group'] = 64 * 2 - return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], - pretrained, progress, **kwargs) + kwargs["width_per_group"] = 64 * 2 + return _resnet("wide_resnet101_2", Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) diff --git a/torchvision/models/segmentation/_utils.py b/torchvision/models/segmentation/_utils.py index fb94b9b1528..0e9a9477838 100644 --- a/torchvision/models/segmentation/_utils.py +++ b/torchvision/models/segmentation/_utils.py @@ -6,14 +6,9 @@ class _SimpleSegmentationModel(nn.Module): - __constants__ = ['aux_classifier'] - - def __init__( - self, - backbone: nn.Module, - classifier: nn.Module, - aux_classifier: Optional[nn.Module] = None - ) -> None: + __constants__ = ["aux_classifier"] + + def __init__(self, backbone: nn.Module, classifier: nn.Module, aux_classifier: Optional[nn.Module] = None) -> None: super(_SimpleSegmentationModel, self).__init__() self.backbone = backbone self.classifier = classifier @@ -27,13 +22,13 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]: result = OrderedDict() x = features["out"] x = self.classifier(x) - x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False) + x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False) result["out"] = x if self.aux_classifier is not None: x = features["aux"] x = self.aux_classifier(x) - x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False) + x = F.interpolate(x, size=input_shape, mode="bilinear", align_corners=False) result["aux"] = x return result diff --git a/torchvision/models/segmentation/deeplabv3.py b/torchvision/models/segmentation/deeplabv3.py index 15ab8846e7d..a8f06bd89bd 100644 --- a/torchvision/models/segmentation/deeplabv3.py +++ b/torchvision/models/segmentation/deeplabv3.py @@ -1,7 +1,8 @@ +from typing import List + import torch from torch import nn from torch.nn import functional as F -from typing import List from ._utils import _SimpleSegmentationModel @@ -24,6 +25,7 @@ class DeepLabV3(_SimpleSegmentationModel): the backbone and returns a dense prediction. aux_classifier (nn.Module, optional): auxiliary classifier used during training """ + pass @@ -34,7 +36,7 @@ def __init__(self, in_channels: int, num_classes: int) -> None: nn.Conv2d(256, 256, 3, padding=1, bias=False), nn.BatchNorm2d(256), nn.ReLU(), - nn.Conv2d(256, num_classes, 1) + nn.Conv2d(256, num_classes, 1), ) @@ -43,7 +45,7 @@ def __init__(self, in_channels: int, out_channels: int, dilation: int) -> None: modules = [ nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False), nn.BatchNorm2d(out_channels), - nn.ReLU() + nn.ReLU(), ] super(ASPPConv, self).__init__(*modules) @@ -54,23 +56,23 @@ def __init__(self, in_channels: int, out_channels: int) -> None: nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), - nn.ReLU()) + nn.ReLU(), + ) def forward(self, x: torch.Tensor) -> torch.Tensor: size = x.shape[-2:] for mod in self: x = mod(x) - return F.interpolate(x, size=size, mode='bilinear', align_corners=False) + return F.interpolate(x, size=size, mode="bilinear", align_corners=False) class ASPP(nn.Module): def __init__(self, in_channels: int, atrous_rates: List[int], out_channels: int = 256) -> None: super(ASPP, self).__init__() modules = [] - modules.append(nn.Sequential( - nn.Conv2d(in_channels, out_channels, 1, bias=False), - nn.BatchNorm2d(out_channels), - nn.ReLU())) + modules.append( + nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU()) + ) rates = tuple(atrous_rates) for rate in rates: @@ -84,7 +86,8 @@ def __init__(self, in_channels: int, atrous_rates: List[int], out_channels: int nn.Conv2d(len(self.convs) * out_channels, out_channels, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(), - nn.Dropout(0.5)) + nn.Dropout(0.5), + ) def forward(self, x: torch.Tensor) -> torch.Tensor: _res = [] diff --git a/torchvision/models/segmentation/fcn.py b/torchvision/models/segmentation/fcn.py index 9c8db1e1211..6a935e9ac48 100644 --- a/torchvision/models/segmentation/fcn.py +++ b/torchvision/models/segmentation/fcn.py @@ -19,6 +19,7 @@ class FCN(_SimpleSegmentationModel): the backbone and returns a dense prediction. aux_classifier (nn.Module, optional): auxiliary classifier used during training """ + pass @@ -30,7 +31,7 @@ def __init__(self, in_channels: int, channels: int) -> None: nn.BatchNorm2d(inter_channels), nn.ReLU(), nn.Dropout(0.1), - nn.Conv2d(inter_channels, channels, 1) + nn.Conv2d(inter_channels, channels, 1), ] super(FCNHead, self).__init__(*layers) diff --git a/torchvision/models/segmentation/lraspp.py b/torchvision/models/segmentation/lraspp.py index 0e5fb5ee898..654e2811315 100644 --- a/torchvision/models/segmentation/lraspp.py +++ b/torchvision/models/segmentation/lraspp.py @@ -1,8 +1,8 @@ from collections import OrderedDict +from typing import Dict from torch import nn, Tensor from torch.nn import functional as F -from typing import Dict __all__ = ["LRASPP"] @@ -25,12 +25,7 @@ class LRASPP(nn.Module): """ def __init__( - self, - backbone: nn.Module, - low_channels: int, - high_channels: int, - num_classes: int, - inter_channels: int = 128 + self, backbone: nn.Module, low_channels: int, high_channels: int, num_classes: int, inter_channels: int = 128 ) -> None: super().__init__() self.backbone = backbone @@ -39,7 +34,7 @@ def __init__( def forward(self, input: Tensor) -> Dict[str, Tensor]: features = self.backbone(input) out = self.classifier(features) - out = F.interpolate(out, size=input.shape[-2:], mode='bilinear', align_corners=False) + out = F.interpolate(out, size=input.shape[-2:], mode="bilinear", align_corners=False) result = OrderedDict() result["out"] = out @@ -48,19 +43,12 @@ def forward(self, input: Tensor) -> Dict[str, Tensor]: class LRASPPHead(nn.Module): - - def __init__( - self, - low_channels: int, - high_channels: int, - num_classes: int, - inter_channels: int - ) -> None: + def __init__(self, low_channels: int, high_channels: int, num_classes: int, inter_channels: int) -> None: super().__init__() self.cbr = nn.Sequential( nn.Conv2d(high_channels, inter_channels, 1, bias=False), nn.BatchNorm2d(inter_channels), - nn.ReLU(inplace=True) + nn.ReLU(inplace=True), ) self.scale = nn.Sequential( nn.AdaptiveAvgPool2d(1), @@ -77,6 +65,6 @@ def forward(self, input: Dict[str, Tensor]) -> Tensor: x = self.cbr(high) s = self.scale(high) x = x * s - x = F.interpolate(x, size=low.shape[-2:], mode='bilinear', align_corners=False) + x = F.interpolate(x, size=low.shape[-2:], mode="bilinear", align_corners=False) return self.low_classifier(low) + self.high_classifier(x) diff --git a/torchvision/models/segmentation/segmentation.py b/torchvision/models/segmentation/segmentation.py index 938965e330b..d5223842010 100644 --- a/torchvision/models/segmentation/segmentation.py +++ b/torchvision/models/segmentation/segmentation.py @@ -1,45 +1,48 @@ -from torch import nn from typing import Any, Optional -from .._utils import IntermediateLayerGetter + +from torch import nn + from ..._internally_replaced_utils import load_state_dict_from_url from .. import mobilenetv3 from .. import resnet +from .._utils import IntermediateLayerGetter from .deeplabv3 import DeepLabHead, DeepLabV3 from .fcn import FCN, FCNHead from .lraspp import LRASPP -__all__ = ['fcn_resnet50', 'fcn_resnet101', 'deeplabv3_resnet50', 'deeplabv3_resnet101', - 'deeplabv3_mobilenet_v3_large', 'lraspp_mobilenet_v3_large'] +__all__ = [ + "fcn_resnet50", + "fcn_resnet101", + "deeplabv3_resnet50", + "deeplabv3_resnet101", + "deeplabv3_mobilenet_v3_large", + "lraspp_mobilenet_v3_large", +] model_urls = { - 'fcn_resnet50_coco': 'https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth', - 'fcn_resnet101_coco': 'https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth', - 'deeplabv3_resnet50_coco': 'https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth', - 'deeplabv3_resnet101_coco': 'https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth', - 'deeplabv3_mobilenet_v3_large_coco': - 'https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth', - 'lraspp_mobilenet_v3_large_coco': 'https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth', + "fcn_resnet50_coco": "https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth", + "fcn_resnet101_coco": "https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth", + "deeplabv3_resnet50_coco": "https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth", + "deeplabv3_resnet101_coco": "https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth", + "deeplabv3_mobilenet_v3_large_coco": "https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth", + "lraspp_mobilenet_v3_large_coco": "https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth", } def _segm_model( - name: str, - backbone_name: str, - num_classes: int, - aux: Optional[bool], - pretrained_backbone: bool = True + name: str, backbone_name: str, num_classes: int, aux: Optional[bool], pretrained_backbone: bool = True ) -> nn.Module: - if 'resnet' in backbone_name: + if "resnet" in backbone_name: backbone = resnet.__dict__[backbone_name]( - pretrained=pretrained_backbone, - replace_stride_with_dilation=[False, True, True]) - out_layer = 'layer4' + pretrained=pretrained_backbone, replace_stride_with_dilation=[False, True, True] + ) + out_layer = "layer4" out_inplanes = 2048 - aux_layer = 'layer3' + aux_layer = "layer3" aux_inplanes = 1024 - elif 'mobilenet_v3' in backbone_name: + elif "mobilenet_v3" in backbone_name: backbone = mobilenetv3.__dict__[backbone_name](pretrained=pretrained_backbone, dilated=True).features # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. @@ -52,11 +55,11 @@ def _segm_model( aux_layer = str(aux_pos) aux_inplanes = backbone[aux_pos].out_channels else: - raise NotImplementedError('backbone {} is not supported as of now'.format(backbone_name)) + raise NotImplementedError("backbone {} is not supported as of now".format(backbone_name)) - return_layers = {out_layer: 'out'} + return_layers = {out_layer: "out"} if aux: - return_layers[aux_layer] = 'aux' + return_layers[aux_layer] = "aux" backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) aux_classifier = None @@ -64,8 +67,8 @@ def _segm_model( aux_classifier = FCNHead(aux_inplanes, num_classes) model_map = { - 'deeplabv3': (DeepLabHead, DeepLabV3), - 'fcn': (FCNHead, FCN), + "deeplabv3": (DeepLabHead, DeepLabV3), + "fcn": (FCNHead, FCN), } classifier = model_map[name][0](out_inplanes, num_classes) base_model = model_map[name][1] @@ -81,7 +84,7 @@ def _load_model( progress: bool, num_classes: int, aux_loss: Optional[bool], - **kwargs: Any + **kwargs: Any, ) -> nn.Module: if pretrained: aux_loss = True @@ -93,10 +96,10 @@ def _load_model( def _load_weights(model: nn.Module, arch_type: str, backbone: str, progress: bool) -> None: - arch = arch_type + '_' + backbone + '_coco' + arch = arch_type + "_" + backbone + "_coco" model_url = model_urls.get(arch, None) if model_url is None: - raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) + raise NotImplementedError("pretrained {} is not supported as of now".format(arch)) else: state_dict = load_state_dict_from_url(model_url, progress=progress) model.load_state_dict(state_dict) @@ -113,7 +116,7 @@ def _segm_lraspp_mobilenetv3(backbone_name: str, num_classes: int, pretrained_ba low_channels = backbone[low_pos].out_channels high_channels = backbone[high_pos].out_channels - backbone = IntermediateLayerGetter(backbone, return_layers={str(low_pos): 'low', str(high_pos): 'high'}) + backbone = IntermediateLayerGetter(backbone, return_layers={str(low_pos): "low", str(high_pos): "high"}) model = LRASPP(backbone, low_channels, high_channels, num_classes) return model @@ -124,7 +127,7 @@ def fcn_resnet50( progress: bool = True, num_classes: int = 21, aux_loss: Optional[bool] = None, - **kwargs: Any + **kwargs: Any, ) -> nn.Module: """Constructs a Fully-Convolutional Network model with a ResNet-50 backbone. @@ -135,7 +138,7 @@ def fcn_resnet50( num_classes (int): number of output classes of the model (including the background) aux_loss (bool): If True, it uses an auxiliary loss """ - return _load_model('fcn', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs) + return _load_model("fcn", "resnet50", pretrained, progress, num_classes, aux_loss, **kwargs) def fcn_resnet101( @@ -143,7 +146,7 @@ def fcn_resnet101( progress: bool = True, num_classes: int = 21, aux_loss: Optional[bool] = None, - **kwargs: Any + **kwargs: Any, ) -> nn.Module: """Constructs a Fully-Convolutional Network model with a ResNet-101 backbone. @@ -154,7 +157,7 @@ def fcn_resnet101( num_classes (int): number of output classes of the model (including the background) aux_loss (bool): If True, it uses an auxiliary loss """ - return _load_model('fcn', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs) + return _load_model("fcn", "resnet101", pretrained, progress, num_classes, aux_loss, **kwargs) def deeplabv3_resnet50( @@ -162,7 +165,7 @@ def deeplabv3_resnet50( progress: bool = True, num_classes: int = 21, aux_loss: Optional[bool] = None, - **kwargs: Any + **kwargs: Any, ) -> nn.Module: """Constructs a DeepLabV3 model with a ResNet-50 backbone. @@ -173,7 +176,7 @@ def deeplabv3_resnet50( num_classes (int): number of output classes of the model (including the background) aux_loss (bool): If True, it uses an auxiliary loss """ - return _load_model('deeplabv3', 'resnet50', pretrained, progress, num_classes, aux_loss, **kwargs) + return _load_model("deeplabv3", "resnet50", pretrained, progress, num_classes, aux_loss, **kwargs) def deeplabv3_resnet101( @@ -181,7 +184,7 @@ def deeplabv3_resnet101( progress: bool = True, num_classes: int = 21, aux_loss: Optional[bool] = None, - **kwargs: Any + **kwargs: Any, ) -> nn.Module: """Constructs a DeepLabV3 model with a ResNet-101 backbone. @@ -192,7 +195,7 @@ def deeplabv3_resnet101( num_classes (int): The number of classes aux_loss (bool): If True, include an auxiliary classifier """ - return _load_model('deeplabv3', 'resnet101', pretrained, progress, num_classes, aux_loss, **kwargs) + return _load_model("deeplabv3", "resnet101", pretrained, progress, num_classes, aux_loss, **kwargs) def deeplabv3_mobilenet_v3_large( @@ -200,7 +203,7 @@ def deeplabv3_mobilenet_v3_large( progress: bool = True, num_classes: int = 21, aux_loss: Optional[bool] = None, - **kwargs: Any + **kwargs: Any, ) -> nn.Module: """Constructs a DeepLabV3 model with a MobileNetV3-Large backbone. @@ -211,14 +214,11 @@ def deeplabv3_mobilenet_v3_large( num_classes (int): number of output classes of the model (including the background) aux_loss (bool): If True, it uses an auxiliary loss """ - return _load_model('deeplabv3', 'mobilenet_v3_large', pretrained, progress, num_classes, aux_loss, **kwargs) + return _load_model("deeplabv3", "mobilenet_v3_large", pretrained, progress, num_classes, aux_loss, **kwargs) def lraspp_mobilenet_v3_large( - pretrained: bool = False, - progress: bool = True, - num_classes: int = 21, - **kwargs: Any + pretrained: bool = False, progress: bool = True, num_classes: int = 21, **kwargs: Any ) -> nn.Module: """Constructs a Lite R-ASPP Network model with a MobileNetV3-Large backbone. @@ -229,14 +229,14 @@ def lraspp_mobilenet_v3_large( num_classes (int): number of output classes of the model (including the background) """ if kwargs.pop("aux_loss", False): - raise NotImplementedError('This model does not use auxiliary loss') + raise NotImplementedError("This model does not use auxiliary loss") - backbone_name = 'mobilenet_v3_large' + backbone_name = "mobilenet_v3_large" if pretrained: kwargs["pretrained_backbone"] = False model = _segm_lraspp_mobilenetv3(backbone_name, num_classes, **kwargs) if pretrained: - _load_weights(model, 'lraspp', backbone_name, progress) + _load_weights(model, "lraspp", backbone_name, progress) return model diff --git a/torchvision/models/shufflenetv2.py b/torchvision/models/shufflenetv2.py index 65d60a09e6c..a9bb58fc9d1 100644 --- a/torchvision/models/shufflenetv2.py +++ b/torchvision/models/shufflenetv2.py @@ -1,20 +1,19 @@ +from typing import Callable, Any, List + import torch -from torch import Tensor import torch.nn as nn +from torch import Tensor + from .._internally_replaced_utils import load_state_dict_from_url -from typing import Callable, Any, List -__all__ = [ - 'ShuffleNetV2', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', - 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0' -] +__all__ = ["ShuffleNetV2", "shufflenet_v2_x0_5", "shufflenet_v2_x1_0", "shufflenet_v2_x1_5", "shufflenet_v2_x2_0"] model_urls = { - 'shufflenetv2_x0.5': 'https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth', - 'shufflenetv2_x1.0': 'https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth', - 'shufflenetv2_x1.5': None, - 'shufflenetv2_x2.0': None, + "shufflenetv2_x0.5": "https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth", + "shufflenetv2_x1.0": "https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth", + "shufflenetv2_x1.5": None, + "shufflenetv2_x2.0": None, } @@ -23,8 +22,7 @@ def channel_shuffle(x: Tensor, groups: int) -> Tensor: channels_per_group = num_channels // groups # reshape - x = x.view(batchsize, groups, - channels_per_group, height, width) + x = x.view(batchsize, groups, channels_per_group, height, width) x = torch.transpose(x, 1, 2).contiguous() @@ -35,16 +33,11 @@ def channel_shuffle(x: Tensor, groups: int) -> Tensor: class InvertedResidual(nn.Module): - def __init__( - self, - inp: int, - oup: int, - stride: int - ) -> None: + def __init__(self, inp: int, oup: int, stride: int) -> None: super(InvertedResidual, self).__init__() if not (1 <= stride <= 3): - raise ValueError('illegal stride value') + raise ValueError("illegal stride value") self.stride = stride branch_features = oup // 2 @@ -62,8 +55,14 @@ def __init__( self.branch1 = nn.Sequential() self.branch2 = nn.Sequential( - nn.Conv2d(inp if (self.stride > 1) else branch_features, - branch_features, kernel_size=1, stride=1, padding=0, bias=False), + nn.Conv2d( + inp if (self.stride > 1) else branch_features, + branch_features, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ), nn.BatchNorm2d(branch_features), nn.ReLU(inplace=True), self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1), @@ -75,12 +74,7 @@ def __init__( @staticmethod def depthwise_conv( - i: int, - o: int, - kernel_size: int, - stride: int = 1, - padding: int = 0, - bias: bool = False + i: int, o: int, kernel_size: int, stride: int = 1, padding: int = 0, bias: bool = False ) -> nn.Conv2d: return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i) @@ -102,14 +96,14 @@ def __init__( stages_repeats: List[int], stages_out_channels: List[int], num_classes: int = 1000, - inverted_residual: Callable[..., nn.Module] = InvertedResidual + inverted_residual: Callable[..., nn.Module] = InvertedResidual, ) -> None: super(ShuffleNetV2, self).__init__() if len(stages_repeats) != 3: - raise ValueError('expected stages_repeats as list of 3 positive ints') + raise ValueError("expected stages_repeats as list of 3 positive ints") if len(stages_out_channels) != 5: - raise ValueError('expected stages_out_channels as list of 5 positive ints') + raise ValueError("expected stages_out_channels as list of 5 positive ints") self._stage_out_channels = stages_out_channels input_channels = 3 @@ -127,9 +121,8 @@ def __init__( self.stage2: nn.Sequential self.stage3: nn.Sequential self.stage4: nn.Sequential - stage_names = ['stage{}'.format(i) for i in [2, 3, 4]] - for name, repeats, output_channels in zip( - stage_names, stages_repeats, self._stage_out_channels[1:]): + stage_names = ["stage{}".format(i) for i in [2, 3, 4]] + for name, repeats, output_channels in zip(stage_names, stages_repeats, self._stage_out_channels[1:]): seq = [inverted_residual(input_channels, output_channels, 2)] for i in range(repeats - 1): seq.append(inverted_residual(output_channels, output_channels, 1)) @@ -167,7 +160,7 @@ def _shufflenetv2(arch: str, pretrained: bool, progress: bool, *args: Any, **kwa if pretrained: model_url = model_urls[arch] if model_url is None: - raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) + raise NotImplementedError("pretrained {} is not supported as of now".format(arch)) else: state_dict = load_state_dict_from_url(model_url, progress=progress) model.load_state_dict(state_dict) @@ -185,8 +178,7 @@ def shufflenet_v2_x0_5(pretrained: bool = False, progress: bool = True, **kwargs pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress, - [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) + return _shufflenetv2("shufflenetv2_x0.5", pretrained, progress, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) def shufflenet_v2_x1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2: @@ -199,8 +191,7 @@ def shufflenet_v2_x1_0(pretrained: bool = False, progress: bool = True, **kwargs pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress, - [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) + return _shufflenetv2("shufflenetv2_x1.0", pretrained, progress, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) def shufflenet_v2_x1_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2: @@ -213,8 +204,7 @@ def shufflenet_v2_x1_5(pretrained: bool = False, progress: bool = True, **kwargs pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress, - [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) + return _shufflenetv2("shufflenetv2_x1.5", pretrained, progress, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) def shufflenet_v2_x2_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2: @@ -227,5 +217,4 @@ def shufflenet_v2_x2_0(pretrained: bool = False, progress: bool = True, **kwargs pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress, - [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs) + return _shufflenetv2("shufflenetv2_x2.0", pretrained, progress, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs) diff --git a/torchvision/models/squeezenet.py b/torchvision/models/squeezenet.py index c54e475d412..e6258502da0 100644 --- a/torchvision/models/squeezenet.py +++ b/torchvision/models/squeezenet.py @@ -1,55 +1,42 @@ +from typing import Any + import torch import torch.nn as nn import torch.nn.init as init + from .._internally_replaced_utils import load_state_dict_from_url -from typing import Any -__all__ = ['SqueezeNet', 'squeezenet1_0', 'squeezenet1_1'] +__all__ = ["SqueezeNet", "squeezenet1_0", "squeezenet1_1"] model_urls = { - 'squeezenet1_0': 'https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth', - 'squeezenet1_1': 'https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth', + "squeezenet1_0": "https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth", + "squeezenet1_1": "https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth", } class Fire(nn.Module): - - def __init__( - self, - inplanes: int, - squeeze_planes: int, - expand1x1_planes: int, - expand3x3_planes: int - ) -> None: + def __init__(self, inplanes: int, squeeze_planes: int, expand1x1_planes: int, expand3x3_planes: int) -> None: super(Fire, self).__init__() self.inplanes = inplanes self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1) self.squeeze_activation = nn.ReLU(inplace=True) - self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, - kernel_size=1) + self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, kernel_size=1) self.expand1x1_activation = nn.ReLU(inplace=True) - self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, - kernel_size=3, padding=1) + self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, kernel_size=3, padding=1) self.expand3x3_activation = nn.ReLU(inplace=True) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.squeeze_activation(self.squeeze(x)) - return torch.cat([ - self.expand1x1_activation(self.expand1x1(x)), - self.expand3x3_activation(self.expand3x3(x)) - ], 1) + return torch.cat( + [self.expand1x1_activation(self.expand1x1(x)), self.expand3x3_activation(self.expand3x3(x))], 1 + ) class SqueezeNet(nn.Module): - - def __init__( - self, - version: str = '1_0', - num_classes: int = 1000 - ) -> None: + def __init__(self, version: str = "1_0", num_classes: int = 1000) -> None: super(SqueezeNet, self).__init__() self.num_classes = num_classes - if version == '1_0': + if version == "1_0": self.features = nn.Sequential( nn.Conv2d(3, 96, kernel_size=7, stride=2), nn.ReLU(inplace=True), @@ -65,7 +52,7 @@ def __init__( nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), Fire(512, 64, 256, 256), ) - elif version == '1_1': + elif version == "1_1": self.features = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, stride=2), nn.ReLU(inplace=True), @@ -85,16 +72,12 @@ def __init__( # FIXME: Is this needed? SqueezeNet should only be called from the # FIXME: squeezenet1_x() functions # FIXME: This checking is not done for the other models - raise ValueError("Unsupported SqueezeNet version {version}:" - "1_0 or 1_1 expected".format(version=version)) + raise ValueError("Unsupported SqueezeNet version {version}:" "1_0 or 1_1 expected".format(version=version)) # Final convolution is initialized differently from the rest final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1) self.classifier = nn.Sequential( - nn.Dropout(p=0.5), - final_conv, - nn.ReLU(inplace=True), - nn.AdaptiveAvgPool2d((1, 1)) + nn.Dropout(p=0.5), final_conv, nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)) ) for m in self.modules(): @@ -115,9 +98,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def _squeezenet(version: str, pretrained: bool, progress: bool, **kwargs: Any) -> SqueezeNet: model = SqueezeNet(version, **kwargs) if pretrained: - arch = 'squeezenet' + version - state_dict = load_state_dict_from_url(model_urls[arch], - progress=progress) + arch = "squeezenet" + version + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) model.load_state_dict(state_dict) return model @@ -132,7 +114,7 @@ def squeezenet1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _squeezenet('1_0', pretrained, progress, **kwargs) + return _squeezenet("1_0", pretrained, progress, **kwargs) def squeezenet1_1(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> SqueezeNet: @@ -146,4 +128,4 @@ def squeezenet1_1(pretrained: bool = False, progress: bool = True, **kwargs: Any pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _squeezenet('1_1', pretrained, progress, **kwargs) + return _squeezenet("1_1", pretrained, progress, **kwargs) diff --git a/torchvision/models/vgg.py b/torchvision/models/vgg.py index 619bce97b2f..f5109b647d9 100644 --- a/torchvision/models/vgg.py +++ b/torchvision/models/vgg.py @@ -1,35 +1,38 @@ +from typing import Union, List, Dict, Any, cast + import torch import torch.nn as nn + from .._internally_replaced_utils import load_state_dict_from_url -from typing import Union, List, Dict, Any, cast __all__ = [ - 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', - 'vgg19_bn', 'vgg19', + "VGG", + "vgg11", + "vgg11_bn", + "vgg13", + "vgg13_bn", + "vgg16", + "vgg16_bn", + "vgg19_bn", + "vgg19", ] model_urls = { - 'vgg11': 'https://download.pytorch.org/models/vgg11-8a719046.pth', - 'vgg13': 'https://download.pytorch.org/models/vgg13-19584684.pth', - 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', - 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', - 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', - 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', - 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', - 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', + "vgg11": "https://download.pytorch.org/models/vgg11-8a719046.pth", + "vgg13": "https://download.pytorch.org/models/vgg13-19584684.pth", + "vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth", + "vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", + "vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth", + "vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth", + "vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth", + "vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth", } class VGG(nn.Module): - - def __init__( - self, - features: nn.Module, - num_classes: int = 1000, - init_weights: bool = True - ) -> None: + def __init__(self, features: nn.Module, num_classes: int = 1000, init_weights: bool = True) -> None: super(VGG, self).__init__() self.features = features self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) @@ -55,7 +58,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def _initialize_weights(self) -> None: for m in self.modules(): if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): @@ -70,7 +73,7 @@ def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequ layers: List[nn.Module] = [] in_channels = 3 for v in cfg: - if v == 'M': + if v == "M": layers += [nn.MaxPool2d(kernel_size=2, stride=2)] else: v = cast(int, v) @@ -84,20 +87,19 @@ def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequ cfgs: Dict[str, List[Union[str, int]]] = { - 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], - 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], - 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], - 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], + "A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], + "B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], + "D": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"], + "E": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"], } def _vgg(arch: str, cfg: str, batch_norm: bool, pretrained: bool, progress: bool, **kwargs: Any) -> VGG: if pretrained: - kwargs['init_weights'] = False + kwargs["init_weights"] = False model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) if pretrained: - state_dict = load_state_dict_from_url(model_urls[arch], - progress=progress) + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) model.load_state_dict(state_dict) return model @@ -111,7 +113,7 @@ def vgg11(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs) + return _vgg("vgg11", "A", False, pretrained, progress, **kwargs) def vgg11_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: @@ -123,7 +125,7 @@ def vgg11_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs) + return _vgg("vgg11_bn", "A", True, pretrained, progress, **kwargs) def vgg13(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: @@ -135,7 +137,7 @@ def vgg13(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs) + return _vgg("vgg13", "B", False, pretrained, progress, **kwargs) def vgg13_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: @@ -147,7 +149,7 @@ def vgg13_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs) + return _vgg("vgg13_bn", "B", True, pretrained, progress, **kwargs) def vgg16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: @@ -159,7 +161,7 @@ def vgg16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs) + return _vgg("vgg16", "D", False, pretrained, progress, **kwargs) def vgg16_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: @@ -171,7 +173,7 @@ def vgg16_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs) + return _vgg("vgg16_bn", "D", True, pretrained, progress, **kwargs) def vgg19(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: @@ -183,7 +185,7 @@ def vgg19(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs) + return _vgg("vgg19", "E", False, pretrained, progress, **kwargs) def vgg19_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: @@ -195,4 +197,4 @@ def vgg19_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs) + return _vgg("vgg19_bn", "E", True, pretrained, progress, **kwargs) diff --git a/torchvision/models/video/resnet.py b/torchvision/models/video/resnet.py index faf3b3bc4a8..5cfbbaeb559 100644 --- a/torchvision/models/video/resnet.py +++ b/torchvision/models/video/resnet.py @@ -1,27 +1,23 @@ -from torch import Tensor -import torch.nn as nn from typing import Tuple, Optional, Callable, List, Type, Any, Union +import torch.nn as nn +from torch import Tensor + from ..._internally_replaced_utils import load_state_dict_from_url -__all__ = ['r3d_18', 'mc3_18', 'r2plus1d_18'] +__all__ = ["r3d_18", "mc3_18", "r2plus1d_18"] model_urls = { - 'r3d_18': 'https://download.pytorch.org/models/r3d_18-b3b3357e.pth', - 'mc3_18': 'https://download.pytorch.org/models/mc3_18-a90a0ba3.pth', - 'r2plus1d_18': 'https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth', + "r3d_18": "https://download.pytorch.org/models/r3d_18-b3b3357e.pth", + "mc3_18": "https://download.pytorch.org/models/mc3_18-a90a0ba3.pth", + "r2plus1d_18": "https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth", } class Conv3DSimple(nn.Conv3d): def __init__( - self, - in_planes: int, - out_planes: int, - midplanes: Optional[int] = None, - stride: int = 1, - padding: int = 1 + self, in_planes: int, out_planes: int, midplanes: Optional[int] = None, stride: int = 1, padding: int = 1 ) -> None: super(Conv3DSimple, self).__init__( @@ -30,7 +26,8 @@ def __init__( kernel_size=(3, 3, 3), stride=stride, padding=padding, - bias=False) + bias=False, + ) @staticmethod def get_downsample_stride(stride: int) -> Tuple[int, int, int]: @@ -38,24 +35,22 @@ def get_downsample_stride(stride: int) -> Tuple[int, int, int]: class Conv2Plus1D(nn.Sequential): - - def __init__( - self, - in_planes: int, - out_planes: int, - midplanes: int, - stride: int = 1, - padding: int = 1 - ) -> None: + def __init__(self, in_planes: int, out_planes: int, midplanes: int, stride: int = 1, padding: int = 1) -> None: super(Conv2Plus1D, self).__init__( - nn.Conv3d(in_planes, midplanes, kernel_size=(1, 3, 3), - stride=(1, stride, stride), padding=(0, padding, padding), - bias=False), + nn.Conv3d( + in_planes, + midplanes, + kernel_size=(1, 3, 3), + stride=(1, stride, stride), + padding=(0, padding, padding), + bias=False, + ), nn.BatchNorm3d(midplanes), nn.ReLU(inplace=True), - nn.Conv3d(midplanes, out_planes, kernel_size=(3, 1, 1), - stride=(stride, 1, 1), padding=(padding, 0, 0), - bias=False)) + nn.Conv3d( + midplanes, out_planes, kernel_size=(3, 1, 1), stride=(stride, 1, 1), padding=(padding, 0, 0), bias=False + ), + ) @staticmethod def get_downsample_stride(stride: int) -> Tuple[int, int, int]: @@ -63,14 +58,8 @@ def get_downsample_stride(stride: int) -> Tuple[int, int, int]: class Conv3DNoTemporal(nn.Conv3d): - def __init__( - self, - in_planes: int, - out_planes: int, - midplanes: Optional[int] = None, - stride: int = 1, - padding: int = 1 + self, in_planes: int, out_planes: int, midplanes: Optional[int] = None, stride: int = 1, padding: int = 1 ) -> None: super(Conv3DNoTemporal, self).__init__( @@ -79,7 +68,8 @@ def __init__( kernel_size=(1, 3, 3), stride=(1, stride, stride), padding=(0, padding, padding), - bias=False) + bias=False, + ) @staticmethod def get_downsample_stride(stride: int) -> Tuple[int, int, int]: @@ -102,14 +92,9 @@ def __init__( super(BasicBlock, self).__init__() self.conv1 = nn.Sequential( - conv_builder(inplanes, planes, midplanes, stride), - nn.BatchNorm3d(planes), - nn.ReLU(inplace=True) - ) - self.conv2 = nn.Sequential( - conv_builder(planes, planes, midplanes), - nn.BatchNorm3d(planes) + conv_builder(inplanes, planes, midplanes, stride), nn.BatchNorm3d(planes), nn.ReLU(inplace=True) ) + self.conv2 = nn.Sequential(conv_builder(planes, planes, midplanes), nn.BatchNorm3d(planes)) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride @@ -145,21 +130,17 @@ def __init__( # 1x1x1 self.conv1 = nn.Sequential( - nn.Conv3d(inplanes, planes, kernel_size=1, bias=False), - nn.BatchNorm3d(planes), - nn.ReLU(inplace=True) + nn.Conv3d(inplanes, planes, kernel_size=1, bias=False), nn.BatchNorm3d(planes), nn.ReLU(inplace=True) ) # Second kernel self.conv2 = nn.Sequential( - conv_builder(planes, planes, midplanes, stride), - nn.BatchNorm3d(planes), - nn.ReLU(inplace=True) + conv_builder(planes, planes, midplanes, stride), nn.BatchNorm3d(planes), nn.ReLU(inplace=True) ) # 1x1x1 self.conv3 = nn.Sequential( nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False), - nn.BatchNorm3d(planes * self.expansion) + nn.BatchNorm3d(planes * self.expansion), ) self.relu = nn.ReLU(inplace=True) self.downsample = downsample @@ -182,35 +163,31 @@ def forward(self, x: Tensor) -> Tensor: class BasicStem(nn.Sequential): - """The default conv-batchnorm-relu stem - """ + """The default conv-batchnorm-relu stem""" + def __init__(self) -> None: super(BasicStem, self).__init__( - nn.Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), - padding=(1, 3, 3), bias=False), + nn.Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3), bias=False), nn.BatchNorm3d(64), - nn.ReLU(inplace=True)) + nn.ReLU(inplace=True), + ) class R2Plus1dStem(nn.Sequential): - """R(2+1)D stem is different than the default one as it uses separated 3D convolution - """ + """R(2+1)D stem is different than the default one as it uses separated 3D convolution""" + def __init__(self) -> None: super(R2Plus1dStem, self).__init__( - nn.Conv3d(3, 45, kernel_size=(1, 7, 7), - stride=(1, 2, 2), padding=(0, 3, 3), - bias=False), + nn.Conv3d(3, 45, kernel_size=(1, 7, 7), stride=(1, 2, 2), padding=(0, 3, 3), bias=False), nn.BatchNorm3d(45), nn.ReLU(inplace=True), - nn.Conv3d(45, 64, kernel_size=(3, 1, 1), - stride=(1, 1, 1), padding=(1, 0, 0), - bias=False), + nn.Conv3d(45, 64, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0), bias=False), nn.BatchNorm3d(64), - nn.ReLU(inplace=True)) + nn.ReLU(inplace=True), + ) class VideoResNet(nn.Module): - def __init__( self, block: Type[Union[BasicBlock, Bottleneck]], @@ -273,16 +250,15 @@ def _make_layer( conv_builder: Type[Union[Conv3DSimple, Conv3DNoTemporal, Conv2Plus1D]], planes: int, blocks: int, - stride: int = 1 + stride: int = 1, ) -> nn.Sequential: downsample = None if stride != 1 or self.inplanes != planes * block.expansion: ds_stride = conv_builder.get_downsample_stride(stride) downsample = nn.Sequential( - nn.Conv3d(self.inplanes, planes * block.expansion, - kernel_size=1, stride=ds_stride, bias=False), - nn.BatchNorm3d(planes * block.expansion) + nn.Conv3d(self.inplanes, planes * block.expansion, kernel_size=1, stride=ds_stride, bias=False), + nn.BatchNorm3d(planes * block.expansion), ) layers = [] layers.append(block(self.inplanes, planes, conv_builder, stride, downsample)) @@ -296,8 +272,7 @@ def _make_layer( def _initialize_weights(self) -> None: for m in self.modules(): if isinstance(m, nn.Conv3d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', - nonlinearity='relu') + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm3d): @@ -312,8 +287,7 @@ def _video_resnet(arch: str, pretrained: bool = False, progress: bool = True, ** model = VideoResNet(**kwargs) if pretrained: - state_dict = load_state_dict_from_url(model_urls[arch], - progress=progress) + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) model.load_state_dict(state_dict) return model @@ -330,12 +304,16 @@ def r3d_18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> Vi nn.Module: R3D-18 network """ - return _video_resnet('r3d_18', - pretrained, progress, - block=BasicBlock, - conv_makers=[Conv3DSimple] * 4, - layers=[2, 2, 2, 2], - stem=BasicStem, **kwargs) + return _video_resnet( + "r3d_18", + pretrained, + progress, + block=BasicBlock, + conv_makers=[Conv3DSimple] * 4, + layers=[2, 2, 2, 2], + stem=BasicStem, + **kwargs, + ) def mc3_18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VideoResNet: @@ -349,12 +327,16 @@ def mc3_18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> Vi Returns: nn.Module: MC3 Network definition """ - return _video_resnet('mc3_18', - pretrained, progress, - block=BasicBlock, - conv_makers=[Conv3DSimple] + [Conv3DNoTemporal] * 3, # type: ignore[list-item] - layers=[2, 2, 2, 2], - stem=BasicStem, **kwargs) + return _video_resnet( + "mc3_18", + pretrained, + progress, + block=BasicBlock, + conv_makers=[Conv3DSimple] + [Conv3DNoTemporal] * 3, # type: ignore[list-item] + layers=[2, 2, 2, 2], + stem=BasicStem, + **kwargs, + ) def r2plus1d_18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VideoResNet: @@ -368,9 +350,13 @@ def r2plus1d_18(pretrained: bool = False, progress: bool = True, **kwargs: Any) Returns: nn.Module: R(2+1)D-18 network """ - return _video_resnet('r2plus1d_18', - pretrained, progress, - block=BasicBlock, - conv_makers=[Conv2Plus1D] * 4, - layers=[2, 2, 2, 2], - stem=R2Plus1dStem, **kwargs) + return _video_resnet( + "r2plus1d_18", + pretrained, + progress, + block=BasicBlock, + conv_makers=[Conv2Plus1D] * 4, + layers=[2, 2, 2, 2], + stem=R2Plus1dStem, + **kwargs, + ) diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index 33b35dc93b9..a047f26c321 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -1,26 +1,50 @@ -from .boxes import nms, batched_nms, remove_small_boxes, clip_boxes_to_image, box_area, box_iou, generalized_box_iou, \ - masks_to_boxes +from ._register_onnx_ops import _register_custom_op +from .boxes import ( + nms, + batched_nms, + remove_small_boxes, + clip_boxes_to_image, + box_area, + box_iou, + generalized_box_iou, + masks_to_boxes, +) from .boxes import box_convert from .deform_conv import deform_conv2d, DeformConv2d -from .roi_align import roi_align, RoIAlign -from .roi_pool import roi_pool, RoIPool -from .ps_roi_align import ps_roi_align, PSRoIAlign -from .ps_roi_pool import ps_roi_pool, PSRoIPool -from .poolers import MultiScaleRoIAlign from .feature_pyramid_network import FeaturePyramidNetwork from .focal_loss import sigmoid_focal_loss +from .poolers import MultiScaleRoIAlign +from .ps_roi_align import ps_roi_align, PSRoIAlign +from .ps_roi_pool import ps_roi_pool, PSRoIPool +from .roi_align import roi_align, RoIAlign +from .roi_pool import roi_pool, RoIPool from .stochastic_depth import stochastic_depth, StochasticDepth -from ._register_onnx_ops import _register_custom_op - _register_custom_op() __all__ = [ - 'deform_conv2d', 'DeformConv2d', 'nms', 'batched_nms', 'remove_small_boxes', - 'clip_boxes_to_image', 'box_convert', - 'box_area', 'box_iou', 'generalized_box_iou', 'roi_align', 'RoIAlign', 'roi_pool', - 'RoIPool', 'ps_roi_align', 'PSRoIAlign', 'ps_roi_pool', - 'PSRoIPool', 'MultiScaleRoIAlign', 'FeaturePyramidNetwork', - 'sigmoid_focal_loss', 'stochastic_depth', 'StochasticDepth' + "deform_conv2d", + "DeformConv2d", + "nms", + "batched_nms", + "remove_small_boxes", + "clip_boxes_to_image", + "box_convert", + "box_area", + "box_iou", + "generalized_box_iou", + "roi_align", + "RoIAlign", + "roi_pool", + "RoIPool", + "ps_roi_align", + "PSRoIAlign", + "ps_roi_pool", + "PSRoIPool", + "MultiScaleRoIAlign", + "FeaturePyramidNetwork", + "sigmoid_focal_loss", + "stochastic_depth", + "StochasticDepth", ] diff --git a/torchvision/ops/_register_onnx_ops.py b/torchvision/ops/_register_onnx_ops.py index e8dd90b4672..76e62ae1728 100644 --- a/torchvision/ops/_register_onnx_ops.py +++ b/torchvision/ops/_register_onnx_ops.py @@ -1,51 +1,70 @@ import sys -import torch import warnings +import torch + _onnx_opset_version = 11 def _register_custom_op(): - from torch.onnx.symbolic_helper import parse_args, scalar_type_to_onnx, scalar_type_to_pytorch_type, \ - cast_pytorch_to_onnx - from torch.onnx.symbolic_opset9 import _cast_Long + from torch.onnx.symbolic_helper import ( + parse_args, + scalar_type_to_onnx, + scalar_type_to_pytorch_type, + cast_pytorch_to_onnx, + ) from torch.onnx.symbolic_opset11 import select, squeeze, unsqueeze + from torch.onnx.symbolic_opset9 import _cast_Long - @parse_args('v', 'v', 'f') + @parse_args("v", "v", "f") def symbolic_multi_label_nms(g, boxes, scores, iou_threshold): boxes = unsqueeze(g, boxes, 0) scores = unsqueeze(g, unsqueeze(g, scores, 0), 0) - max_output_per_class = g.op('Constant', value_t=torch.tensor([sys.maxsize], dtype=torch.long)) - iou_threshold = g.op('Constant', value_t=torch.tensor([iou_threshold], dtype=torch.float)) - nms_out = g.op('NonMaxSuppression', boxes, scores, max_output_per_class, iou_threshold) - return squeeze(g, select(g, nms_out, 1, g.op('Constant', value_t=torch.tensor([2], dtype=torch.long))), 1) + max_output_per_class = g.op("Constant", value_t=torch.tensor([sys.maxsize], dtype=torch.long)) + iou_threshold = g.op("Constant", value_t=torch.tensor([iou_threshold], dtype=torch.float)) + nms_out = g.op("NonMaxSuppression", boxes, scores, max_output_per_class, iou_threshold) + return squeeze(g, select(g, nms_out, 1, g.op("Constant", value_t=torch.tensor([2], dtype=torch.long))), 1) - @parse_args('v', 'v', 'f', 'i', 'i', 'i', 'i') + @parse_args("v", "v", "f", "i", "i", "i", "i") def roi_align(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned): - batch_indices = _cast_Long(g, squeeze(g, select(g, rois, 1, g.op('Constant', - value_t=torch.tensor([0], dtype=torch.long))), 1), False) - rois = select(g, rois, 1, g.op('Constant', value_t=torch.tensor([1, 2, 3, 4], dtype=torch.long))) + batch_indices = _cast_Long( + g, squeeze(g, select(g, rois, 1, g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))), 1), False + ) + rois = select(g, rois, 1, g.op("Constant", value_t=torch.tensor([1, 2, 3, 4], dtype=torch.long))) if aligned: - warnings.warn("ONNX export of ROIAlign with aligned=True does not match PyTorch when using malformed boxes," - " ONNX forces ROIs to be 1x1 or larger.") + warnings.warn( + "ONNX export of ROIAlign with aligned=True does not match PyTorch when using malformed boxes," + " ONNX forces ROIs to be 1x1 or larger." + ) scale = torch.tensor(0.5 / spatial_scale).to(dtype=torch.float) rois = g.op("Sub", rois, scale) # ONNX doesn't support negative sampling_ratio if sampling_ratio < 0: - warnings.warn("ONNX doesn't support negative sampling ratio," - "therefore is is set to 0 in order to be exported.") + warnings.warn( + "ONNX doesn't support negative sampling ratio," "therefore is is set to 0 in order to be exported." + ) sampling_ratio = 0 - return g.op('RoiAlign', input, rois, batch_indices, spatial_scale_f=spatial_scale, - output_height_i=pooled_height, output_width_i=pooled_width, sampling_ratio_i=sampling_ratio) + return g.op( + "RoiAlign", + input, + rois, + batch_indices, + spatial_scale_f=spatial_scale, + output_height_i=pooled_height, + output_width_i=pooled_width, + sampling_ratio_i=sampling_ratio, + ) - @parse_args('v', 'v', 'f', 'i', 'i') + @parse_args("v", "v", "f", "i", "i") def roi_pool(g, input, rois, spatial_scale, pooled_height, pooled_width): - roi_pool = g.op('MaxRoiPool', input, rois, - pooled_shape_i=(pooled_height, pooled_width), spatial_scale_f=spatial_scale) + roi_pool = g.op( + "MaxRoiPool", input, rois, pooled_shape_i=(pooled_height, pooled_width), spatial_scale_f=spatial_scale + ) return roi_pool, None from torch.onnx import register_custom_op_symbolic - register_custom_op_symbolic('torchvision::nms', symbolic_multi_label_nms, _onnx_opset_version) - register_custom_op_symbolic('torchvision::roi_align', roi_align, _onnx_opset_version) - register_custom_op_symbolic('torchvision::roi_pool', roi_pool, _onnx_opset_version) + + register_custom_op_symbolic("torchvision::nms", symbolic_multi_label_nms, _onnx_opset_version) + register_custom_op_symbolic("torchvision::roi_align", roi_align, _onnx_opset_version) + register_custom_op_symbolic("torchvision::roi_pool", roi_pool, _onnx_opset_version) diff --git a/torchvision/ops/_utils.py b/torchvision/ops/_utils.py index 7cc6367a7a4..86dfce46509 100644 --- a/torchvision/ops/_utils.py +++ b/torchvision/ops/_utils.py @@ -1,6 +1,7 @@ +from typing import List, Union + import torch from torch import Tensor -from typing import List, Union def _cat(tensors: List[Tensor], dim: int = 0) -> Tensor: @@ -27,10 +28,11 @@ def convert_boxes_to_roi_format(boxes: List[Tensor]) -> Tensor: def check_roi_boxes_shape(boxes: Union[Tensor, List[Tensor]]): if isinstance(boxes, (list, tuple)): for _tensor in boxes: - assert _tensor.size(1) == 4, \ - 'The shape of the tensor in the boxes list is not correct as List[Tensor[L, 4]]' + assert ( + _tensor.size(1) == 4 + ), "The shape of the tensor in the boxes list is not correct as List[Tensor[L, 4]]" elif isinstance(boxes, torch.Tensor): - assert boxes.size(1) == 5, 'The boxes tensor shape is not correct as Tensor[K, 5]' + assert boxes.size(1) == 5, "The boxes tensor shape is not correct as Tensor[K, 5]" else: - assert False, 'boxes is expected to be a Tensor[L, 5] or a List[Tensor[K, 4]]' + assert False, "boxes is expected to be a Tensor[L, 5] or a List[Tensor[K, 4]]" return diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index 2c0cfd84eac..f8d9c596606 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -1,10 +1,12 @@ -import torch -from torch import Tensor from typing import Tuple -from ._box_convert import _box_cxcywh_to_xyxy, _box_xyxy_to_cxcywh, _box_xywh_to_xyxy, _box_xyxy_to_xywh + +import torch import torchvision +from torch import Tensor from torchvision.extension import _assert_has_ops +from ._box_convert import _box_cxcywh_to_xyxy, _box_xyxy_to_cxcywh, _box_xywh_to_xyxy, _box_xyxy_to_xywh + def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor: """ @@ -183,13 +185,13 @@ def box_convert(boxes: Tensor, in_fmt: str, out_fmt: str) -> Tensor: if in_fmt == out_fmt: return boxes.clone() - if in_fmt != 'xyxy' and out_fmt != 'xyxy': + if in_fmt != "xyxy" and out_fmt != "xyxy": # convert to xyxy and change in_fmt xyxy if in_fmt == "xywh": boxes = _box_xywh_to_xyxy(boxes) elif in_fmt == "cxcywh": boxes = _box_cxcywh_to_xyxy(boxes) - in_fmt = 'xyxy' + in_fmt = "xyxy" if in_fmt == "xyxy": if out_fmt == "xywh": diff --git a/torchvision/ops/deform_conv.py b/torchvision/ops/deform_conv.py index 4399d441843..550f659b07d 100644 --- a/torchvision/ops/deform_conv.py +++ b/torchvision/ops/deform_conv.py @@ -1,11 +1,11 @@ import math +from typing import Optional, Tuple import torch from torch import nn, Tensor from torch.nn import init -from torch.nn.parameter import Parameter from torch.nn.modules.utils import _pair -from typing import Optional, Tuple +from torch.nn.parameter import Parameter from torchvision.extension import _assert_has_ops @@ -84,7 +84,9 @@ def deform_conv2d( "the shape of the offset tensor at dimension 1 is not valid. It should " "be a multiple of 2 * weight.size[2] * weight.size[3].\n" "Got offset.shape[1]={}, while 2 * weight.size[2] * weight.size[3]={}".format( - offset.shape[1], 2 * weights_h * weights_w)) + offset.shape[1], 2 * weights_h * weights_w + ) + ) return torch.ops.torchvision.deform_conv2d( input, @@ -92,12 +94,16 @@ def deform_conv2d( offset, mask, bias, - stride_h, stride_w, - pad_h, pad_w, - dil_h, dil_w, + stride_h, + stride_w, + pad_h, + pad_w, + dil_h, + dil_w, n_weight_grps, n_offset_grps, - use_mask,) + use_mask, + ) class DeformConv2d(nn.Module): @@ -119,9 +125,9 @@ def __init__( super(DeformConv2d, self).__init__() if in_channels % groups != 0: - raise ValueError('in_channels must be divisible by groups') + raise ValueError("in_channels must be divisible by groups") if out_channels % groups != 0: - raise ValueError('out_channels must be divisible by groups') + raise ValueError("out_channels must be divisible by groups") self.in_channels = in_channels self.out_channels = out_channels @@ -131,13 +137,14 @@ def __init__( self.dilation = _pair(dilation) self.groups = groups - self.weight = Parameter(torch.empty(out_channels, in_channels // groups, - self.kernel_size[0], self.kernel_size[1])) + self.weight = Parameter( + torch.empty(out_channels, in_channels // groups, self.kernel_size[0], self.kernel_size[1]) + ) if bias: self.bias = Parameter(torch.empty(out_channels)) else: - self.register_parameter('bias', None) + self.register_parameter("bias", None) self.reset_parameters() @@ -160,18 +167,26 @@ def forward(self, input: Tensor, offset: Tensor, mask: Optional[Tensor] = None) out_height, out_width]): masks to be applied for each position in the convolution kernel. """ - return deform_conv2d(input, offset, self.weight, self.bias, stride=self.stride, - padding=self.padding, dilation=self.dilation, mask=mask) + return deform_conv2d( + input, + offset, + self.weight, + self.bias, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + mask=mask, + ) def __repr__(self) -> str: - s = self.__class__.__name__ + '(' - s += '{in_channels}' - s += ', {out_channels}' - s += ', kernel_size={kernel_size}' - s += ', stride={stride}' - s += ', padding={padding}' if self.padding != (0, 0) else '' - s += ', dilation={dilation}' if self.dilation != (1, 1) else '' - s += ', groups={groups}' if self.groups != 1 else '' - s += ', bias=False' if self.bias is None else '' - s += ')' + s = self.__class__.__name__ + "(" + s += "{in_channels}" + s += ", {out_channels}" + s += ", kernel_size={kernel_size}" + s += ", stride={stride}" + s += ", padding={padding}" if self.padding != (0, 0) else "" + s += ", dilation={dilation}" if self.dilation != (1, 1) else "" + s += ", groups={groups}" if self.groups != 1 else "" + s += ", bias=False" if self.bias is None else "" + s += ")" return s.format(**self.__dict__) diff --git a/torchvision/ops/feature_pyramid_network.py b/torchvision/ops/feature_pyramid_network.py index 7d72769ab07..dd40e7bd6d5 100644 --- a/torchvision/ops/feature_pyramid_network.py +++ b/torchvision/ops/feature_pyramid_network.py @@ -1,10 +1,9 @@ from collections import OrderedDict +from typing import Tuple, List, Dict, Optional import torch.nn.functional as F from torch import nn, Tensor -from typing import Tuple, List, Dict, Optional - class ExtraFPNBlock(nn.Module): """ @@ -21,6 +20,7 @@ class ExtraFPNBlock(nn.Module): of the FPN names (List[str]): the extended set of names for the results """ + def forward( self, results: List[Tensor], @@ -67,6 +67,7 @@ class FeaturePyramidNetwork(nn.Module): >>> ('feat3', torch.Size([1, 5, 8, 8]))] """ + def __init__( self, in_channels_list: List[int], @@ -165,6 +166,7 @@ class LastLevelMaxPool(ExtraFPNBlock): """ Applies a max_pool2d on top of the last feature map """ + def forward( self, x: List[Tensor], @@ -180,6 +182,7 @@ class LastLevelP6P7(ExtraFPNBlock): """ This module is used in RetinaNet to generate extra layers, P6 and P7. """ + def __init__(self, in_channels: int, out_channels: int): super(LastLevelP6P7, self).__init__() self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1) diff --git a/torchvision/ops/focal_loss.py b/torchvision/ops/focal_loss.py index de18f30c83a..3f72273c39c 100644 --- a/torchvision/ops/focal_loss.py +++ b/torchvision/ops/focal_loss.py @@ -31,9 +31,7 @@ def sigmoid_focal_loss( Loss tensor with the reduction option applied. """ p = torch.sigmoid(inputs) - ce_loss = F.binary_cross_entropy_with_logits( - inputs, targets, reduction="none" - ) + ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") p_t = p * targets + (1 - p) * (1 - targets) loss = ce_loss * ((1 - p_t) ** gamma) diff --git a/torchvision/ops/misc.py b/torchvision/ops/misc.py index 7ee8df3371e..3df290bc6c5 100644 --- a/torchvision/ops/misc.py +++ b/torchvision/ops/misc.py @@ -9,9 +9,10 @@ """ import warnings +from typing import Callable, List, Optional + import torch from torch import Tensor -from typing import Callable, List, Optional class Conv2d(torch.nn.Conv2d): @@ -19,7 +20,9 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) warnings.warn( "torchvision.ops.misc.Conv2d is deprecated and will be " - "removed in future versions, use torch.nn.Conv2d instead.", FutureWarning) + "removed in future versions, use torch.nn.Conv2d instead.", + FutureWarning, + ) class ConvTranspose2d(torch.nn.ConvTranspose2d): @@ -27,7 +30,9 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) warnings.warn( "torchvision.ops.misc.ConvTranspose2d is deprecated and will be " - "removed in future versions, use torch.nn.ConvTranspose2d instead.", FutureWarning) + "removed in future versions, use torch.nn.ConvTranspose2d instead.", + FutureWarning, + ) class BatchNorm2d(torch.nn.BatchNorm2d): @@ -35,7 +40,9 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) warnings.warn( "torchvision.ops.misc.BatchNorm2d is deprecated and will be " - "removed in future versions, use torch.nn.BatchNorm2d instead.", FutureWarning) + "removed in future versions, use torch.nn.BatchNorm2d instead.", + FutureWarning, + ) interpolate = torch.nn.functional.interpolate @@ -56,8 +63,7 @@ def __init__( ): # n=None for backward-compatibility if n is not None: - warnings.warn("`n` argument is deprecated and has been renamed `num_features`", - DeprecationWarning) + warnings.warn("`n` argument is deprecated and has been renamed `num_features`", DeprecationWarning) num_features = n super(FrozenBatchNorm2d, self).__init__() self.eps = eps @@ -76,13 +82,13 @@ def _load_from_state_dict( unexpected_keys: List[str], error_msgs: List[str], ): - num_batches_tracked_key = prefix + 'num_batches_tracked' + num_batches_tracked_key = prefix + "num_batches_tracked" if num_batches_tracked_key in state_dict: del state_dict[num_batches_tracked_key] super(FrozenBatchNorm2d, self)._load_from_state_dict( - state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs) + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) def forward(self, x: Tensor) -> Tensor: # move reshapes to the beginning @@ -115,8 +121,18 @@ def __init__( ) -> None: if padding is None: padding = (kernel_size - 1) // 2 * dilation - layers = [torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, - dilation=dilation, groups=groups, bias=norm_layer is None)] + layers = [ + torch.nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation=dilation, + groups=groups, + bias=norm_layer is None, + ) + ] if norm_layer is not None: layers.append(norm_layer(out_channels)) if activation_layer is not None: diff --git a/torchvision/ops/poolers.py b/torchvision/ops/poolers.py index f4ff289299b..6d2388c74b3 100644 --- a/torchvision/ops/poolers.py +++ b/torchvision/ops/poolers.py @@ -1,11 +1,11 @@ -import torch -from torch import nn, Tensor +from typing import Optional, List, Dict, Tuple, Union +import torch import torchvision -from torchvision.ops import roi_align +from torch import nn, Tensor from torchvision.ops.boxes import box_area -from typing import Optional, List, Dict, Tuple, Union +from .roi_align import roi_align # copying result_idx_in_level to a specific index in result[] @@ -16,15 +16,17 @@ def _onnx_merge_levels(levels: Tensor, unmerged_results: List[Tensor]) -> Tensor: first_result = unmerged_results[0] dtype, device = first_result.dtype, first_result.device - res = torch.zeros((levels.size(0), first_result.size(1), - first_result.size(2), first_result.size(3)), - dtype=dtype, device=device) + res = torch.zeros( + (levels.size(0), first_result.size(1), first_result.size(2), first_result.size(3)), dtype=dtype, device=device + ) for level in range(len(unmerged_results)): index = torch.where(levels == level)[0].view(-1, 1, 1, 1) - index = index.expand(index.size(0), - unmerged_results[level].size(1), - unmerged_results[level].size(2), - unmerged_results[level].size(3)) + index = index.expand( + index.size(0), + unmerged_results[level].size(1), + unmerged_results[level].size(2), + unmerged_results[level].size(3), + ) res = res.scatter(0, index, unmerged_results[level]) return res @@ -116,10 +118,7 @@ class MultiScaleRoIAlign(nn.Module): """ - __annotations__ = { - 'scales': Optional[List[float]], - 'map_levels': Optional[LevelMapper] - } + __annotations__ = {"scales": Optional[List[float]], "map_levels": Optional[LevelMapper]} def __init__( self, @@ -224,10 +223,11 @@ def forward( if num_levels == 1: return roi_align( - x_filtered[0], rois, + x_filtered[0], + rois, output_size=self.output_size, spatial_scale=scales[0], - sampling_ratio=self.sampling_ratio + sampling_ratio=self.sampling_ratio, ) mapper = self.map_levels @@ -240,7 +240,11 @@ def forward( dtype, device = x_filtered[0].dtype, x_filtered[0].device result = torch.zeros( - (num_rois, num_channels,) + self.output_size, + ( + num_rois, + num_channels, + ) + + self.output_size, dtype=dtype, device=device, ) @@ -251,9 +255,12 @@ def forward( rois_per_level = rois[idx_in_level] result_idx_in_level = roi_align( - per_level_feature, rois_per_level, + per_level_feature, + rois_per_level, output_size=self.output_size, - spatial_scale=scale, sampling_ratio=self.sampling_ratio) + spatial_scale=scale, + sampling_ratio=self.sampling_ratio, + ) if torchvision._is_tracing(): tracing_results.append(result_idx_in_level.to(dtype)) @@ -273,5 +280,7 @@ def forward( return result def __repr__(self) -> str: - return (f"{self.__class__.__name__}(featmap_names={self.featmap_names}, " - f"output_size={self.output_size}, sampling_ratio={self.sampling_ratio})") + return ( + f"{self.__class__.__name__}(featmap_names={self.featmap_names}, " + f"output_size={self.output_size}, sampling_ratio={self.sampling_ratio})" + ) diff --git a/torchvision/ops/ps_roi_align.py b/torchvision/ops/ps_roi_align.py index be64bdddd87..0bfeefc867b 100644 --- a/torchvision/ops/ps_roi_align.py +++ b/torchvision/ops/ps_roi_align.py @@ -1,9 +1,8 @@ import torch from torch import nn, Tensor - from torch.nn.modules.utils import _pair - from torchvision.extension import _assert_has_ops + from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape @@ -49,10 +48,9 @@ def ps_roi_align( output_size = _pair(output_size) if not isinstance(rois, torch.Tensor): rois = convert_boxes_to_roi_format(rois) - output, _ = torch.ops.torchvision.ps_roi_align(input, rois, spatial_scale, - output_size[0], - output_size[1], - sampling_ratio) + output, _ = torch.ops.torchvision.ps_roi_align( + input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio + ) return output @@ -60,6 +58,7 @@ class PSRoIAlign(nn.Module): """ See :func:`ps_roi_align`. """ + def __init__( self, output_size: int, @@ -72,13 +71,12 @@ def __init__( self.sampling_ratio = sampling_ratio def forward(self, input: Tensor, rois: Tensor) -> Tensor: - return ps_roi_align(input, rois, self.output_size, self.spatial_scale, - self.sampling_ratio) + return ps_roi_align(input, rois, self.output_size, self.spatial_scale, self.sampling_ratio) def __repr__(self) -> str: - tmpstr = self.__class__.__name__ + '(' - tmpstr += 'output_size=' + str(self.output_size) - tmpstr += ', spatial_scale=' + str(self.spatial_scale) - tmpstr += ', sampling_ratio=' + str(self.sampling_ratio) - tmpstr += ')' + tmpstr = self.__class__.__name__ + "(" + tmpstr += "output_size=" + str(self.output_size) + tmpstr += ", spatial_scale=" + str(self.spatial_scale) + tmpstr += ", sampling_ratio=" + str(self.sampling_ratio) + tmpstr += ")" return tmpstr diff --git a/torchvision/ops/ps_roi_pool.py b/torchvision/ops/ps_roi_pool.py index 0084c0e0c74..cde3543300c 100644 --- a/torchvision/ops/ps_roi_pool.py +++ b/torchvision/ops/ps_roi_pool.py @@ -1,9 +1,8 @@ import torch from torch import nn, Tensor - from torch.nn.modules.utils import _pair - from torchvision.extension import _assert_has_ops + from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape @@ -43,9 +42,7 @@ def ps_roi_pool( output_size = _pair(output_size) if not isinstance(rois, torch.Tensor): rois = convert_boxes_to_roi_format(rois) - output, _ = torch.ops.torchvision.ps_roi_pool(input, rois, spatial_scale, - output_size[0], - output_size[1]) + output, _ = torch.ops.torchvision.ps_roi_pool(input, rois, spatial_scale, output_size[0], output_size[1]) return output @@ -53,6 +50,7 @@ class PSRoIPool(nn.Module): """ See :func:`ps_roi_pool`. """ + def __init__(self, output_size: int, spatial_scale: float): super(PSRoIPool, self).__init__() self.output_size = output_size @@ -62,8 +60,8 @@ def forward(self, input: Tensor, rois: Tensor) -> Tensor: return ps_roi_pool(input, rois, self.output_size, self.spatial_scale) def __repr__(self) -> str: - tmpstr = self.__class__.__name__ + '(' - tmpstr += 'output_size=' + str(self.output_size) - tmpstr += ', spatial_scale=' + str(self.spatial_scale) - tmpstr += ')' + tmpstr = self.__class__.__name__ + "(" + tmpstr += "output_size=" + str(self.output_size) + tmpstr += ", spatial_scale=" + str(self.spatial_scale) + tmpstr += ")" return tmpstr diff --git a/torchvision/ops/roi_align.py b/torchvision/ops/roi_align.py index 57494cff1e1..1178e8cd52c 100644 --- a/torchvision/ops/roi_align.py +++ b/torchvision/ops/roi_align.py @@ -2,8 +2,8 @@ import torch from torch import nn, Tensor -from torch.nn.modules.utils import _pair from torch.jit.annotations import BroadcastingList2 +from torch.nn.modules.utils import _pair from torchvision.extension import _assert_has_ops from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape @@ -55,15 +55,16 @@ def roi_align( output_size = _pair(output_size) if not isinstance(rois, torch.Tensor): rois = convert_boxes_to_roi_format(rois) - return torch.ops.torchvision.roi_align(input, rois, spatial_scale, - output_size[0], output_size[1], - sampling_ratio, aligned) + return torch.ops.torchvision.roi_align( + input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned + ) class RoIAlign(nn.Module): """ See :func:`roi_align`. """ + def __init__( self, output_size: BroadcastingList2[int], @@ -81,10 +82,10 @@ def forward(self, input: Tensor, rois: Tensor) -> Tensor: return roi_align(input, rois, self.output_size, self.spatial_scale, self.sampling_ratio, self.aligned) def __repr__(self) -> str: - tmpstr = self.__class__.__name__ + '(' - tmpstr += 'output_size=' + str(self.output_size) - tmpstr += ', spatial_scale=' + str(self.spatial_scale) - tmpstr += ', sampling_ratio=' + str(self.sampling_ratio) - tmpstr += ', aligned=' + str(self.aligned) - tmpstr += ')' + tmpstr = self.__class__.__name__ + "(" + tmpstr += "output_size=" + str(self.output_size) + tmpstr += ", spatial_scale=" + str(self.spatial_scale) + tmpstr += ", sampling_ratio=" + str(self.sampling_ratio) + tmpstr += ", aligned=" + str(self.aligned) + tmpstr += ")" return tmpstr diff --git a/torchvision/ops/roi_pool.py b/torchvision/ops/roi_pool.py index 6ad787af725..5eb19154054 100644 --- a/torchvision/ops/roi_pool.py +++ b/torchvision/ops/roi_pool.py @@ -2,8 +2,8 @@ import torch from torch import nn, Tensor -from torch.nn.modules.utils import _pair from torch.jit.annotations import BroadcastingList2 +from torch.nn.modules.utils import _pair from torchvision.extension import _assert_has_ops from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape @@ -44,8 +44,7 @@ def roi_pool( output_size = _pair(output_size) if not isinstance(rois, torch.Tensor): rois = convert_boxes_to_roi_format(rois) - output, _ = torch.ops.torchvision.roi_pool(input, rois, spatial_scale, - output_size[0], output_size[1]) + output, _ = torch.ops.torchvision.roi_pool(input, rois, spatial_scale, output_size[0], output_size[1]) return output @@ -53,6 +52,7 @@ class RoIPool(nn.Module): """ See :func:`roi_pool`. """ + def __init__(self, output_size: BroadcastingList2[int], spatial_scale: float): super(RoIPool, self).__init__() self.output_size = output_size @@ -62,8 +62,8 @@ def forward(self, input: Tensor, rois: Tensor) -> Tensor: return roi_pool(input, rois, self.output_size, self.spatial_scale) def __repr__(self) -> str: - tmpstr = self.__class__.__name__ + '(' - tmpstr += 'output_size=' + str(self.output_size) - tmpstr += ', spatial_scale=' + str(self.spatial_scale) - tmpstr += ')' + tmpstr = self.__class__.__name__ + "(" + tmpstr += "output_size=" + str(self.output_size) + tmpstr += ", spatial_scale=" + str(self.spatial_scale) + tmpstr += ")" return tmpstr diff --git a/torchvision/ops/stochastic_depth.py b/torchvision/ops/stochastic_depth.py index f15d6c53933..de120862941 100644 --- a/torchvision/ops/stochastic_depth.py +++ b/torchvision/ops/stochastic_depth.py @@ -38,13 +38,14 @@ def stochastic_depth(input: Tensor, p: float, mode: str, training: bool = True) return input * noise -torch.fx.wrap('stochastic_depth') +torch.fx.wrap("stochastic_depth") class StochasticDepth(nn.Module): """ See :func:`stochastic_depth`. """ + def __init__(self, p: float, mode: str) -> None: super().__init__() self.p = p @@ -54,8 +55,8 @@ def forward(self, input: Tensor) -> Tensor: return stochastic_depth(input, self.p, self.mode, self.training) def __repr__(self) -> str: - tmpstr = self.__class__.__name__ + '(' - tmpstr += 'p=' + str(self.p) - tmpstr += ', mode=' + str(self.mode) - tmpstr += ')' + tmpstr = self.__class__.__name__ + "(" + tmpstr += "p=" + str(self.p) + tmpstr += ", mode=" + str(self.mode) + tmpstr += ")" return tmpstr diff --git a/torchvision/prototype/datasets/__init__.py b/torchvision/prototype/datasets/__init__.py index 6fa4d8dbc8f..c677cff0878 100644 --- a/torchvision/prototype/datasets/__init__.py +++ b/torchvision/prototype/datasets/__init__.py @@ -8,9 +8,9 @@ ) from error -from ._home import home from . import decoder, utils # Load this last, since some parts depend on the above being loaded first from ._api import register, _list as list, info, load from ._folder import from_data_folder, from_image_folder +from ._home import home diff --git a/torchvision/prototype/datasets/_api.py b/torchvision/prototype/datasets/_api.py index 97d95eef6c1..5c613035e2b 100644 --- a/torchvision/prototype/datasets/_api.py +++ b/torchvision/prototype/datasets/_api.py @@ -3,11 +3,11 @@ import torch from torch.utils.data import IterDataPipe - from torchvision.prototype.datasets import home from torchvision.prototype.datasets.decoder import pil from torchvision.prototype.datasets.utils import Dataset, DatasetInfo from torchvision.prototype.datasets.utils._internal import add_suggestion + from . import _builtin DATASETS: Dict[str, Dataset] = {} @@ -18,12 +18,7 @@ def register(dataset: Dataset) -> None: for name, obj in _builtin.__dict__.items(): - if ( - not name.startswith("_") - and isinstance(obj, type) - and issubclass(obj, Dataset) - and obj is not Dataset - ): + if not name.startswith("_") and isinstance(obj, type) and issubclass(obj, Dataset) and obj is not Dataset: register(obj()) diff --git a/torchvision/prototype/datasets/_builtin/caltech.py b/torchvision/prototype/datasets/_builtin/caltech.py index a97d80fa8cb..d2ce41c0d0f 100644 --- a/torchvision/prototype/datasets/_builtin/caltech.py +++ b/torchvision/prototype/datasets/_builtin/caltech.py @@ -1,10 +1,9 @@ import io import pathlib -from typing import Any, Callable, Dict, List, Optional, Tuple, Union import re +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np - import torch from torch.utils.data import IterDataPipe from torch.utils.data.datapipes.iter import ( @@ -13,7 +12,6 @@ Shuffler, Filter, ) - from torchdata.datapipes.iter import KeyZipper from torchvision.prototype.datasets.utils import ( Dataset, diff --git a/torchvision/prototype/datasets/_folder.py b/torchvision/prototype/datasets/_folder.py index 5626f68650f..55e48387d6a 100644 --- a/torchvision/prototype/datasets/_folder.py +++ b/torchvision/prototype/datasets/_folder.py @@ -8,7 +8,6 @@ import torch from torch.utils.data import IterDataPipe from torch.utils.data.datapipes.iter import FileLister, FileLoader, Mapper, Shuffler, Filter - from torchvision.prototype.datasets.decoder import pil from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE diff --git a/torchvision/prototype/datasets/decoder.py b/torchvision/prototype/datasets/decoder.py index d4897bebc91..64cea43e5f0 100644 --- a/torchvision/prototype/datasets/decoder.py +++ b/torchvision/prototype/datasets/decoder.py @@ -2,7 +2,6 @@ import PIL.Image import torch - from torchvision.transforms.functional import pil_to_tensor __all__ = ["pil"] diff --git a/torchvision/prototype/datasets/utils/_dataset.py b/torchvision/prototype/datasets/utils/_dataset.py index 19fb3b1d596..b43dc3fc4c4 100644 --- a/torchvision/prototype/datasets/utils/_dataset.py +++ b/torchvision/prototype/datasets/utils/_dataset.py @@ -19,11 +19,11 @@ import torch from torch.utils.data import IterDataPipe - from torchvision.prototype.datasets.utils._internal import ( add_suggestion, sequence_to_str, ) + from ._resource import OnlineResource @@ -64,9 +64,7 @@ def __getattr__(self, name: str) -> Any: try: return self[name] except KeyError as error: - raise AttributeError( - f"'{type(self).__name__}' object has no attribute '{name}'" - ) from error + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") from error def __setitem__(self, key: Any, value: Any) -> NoReturn: raise RuntimeError(f"'{type(self).__name__}' object is immutable") @@ -133,9 +131,7 @@ def __init__( @property def default_config(self) -> DatasetConfig: - return DatasetConfig( - {name: valid_args[0] for name, valid_args in self._valid_options.items()} - ) + return DatasetConfig({name: valid_args[0] for name, valid_args in self._valid_options.items()}) def make_config(self, **options: Any) -> DatasetConfig: for name, arg in options.items(): @@ -167,12 +163,7 @@ def __repr__(self) -> str: value = getattr(self, key) if value is not None: items.append((key, value)) - items.extend( - sorted( - (key, sequence_to_str(value)) - for key, value in self._valid_options.items() - ) - ) + items.extend(sorted((key, sequence_to_str(value)) for key, value in self._valid_options.items())) return make_repr(type(self).__name__, items) @@ -214,7 +205,5 @@ def to_datapipe( if not config: config = self.info.default_config - resource_dps = [ - resource.to_datapipe(root) for resource in self.resources(config) - ] + resource_dps = [resource.to_datapipe(root) for resource in self.resources(config)] return self._make_datapipe(resource_dps, config=config, decoder=decoder) diff --git a/torchvision/prototype/datasets/utils/_internal.py b/torchvision/prototype/datasets/utils/_internal.py index 56c9a2d8c07..7a1d34ffa0e 100644 --- a/torchvision/prototype/datasets/utils/_internal.py +++ b/torchvision/prototype/datasets/utils/_internal.py @@ -5,13 +5,7 @@ from typing import Collection, Sequence, Callable, Union, Any -__all__ = [ - "INFINITE_BUFFER_SIZE", - "sequence_to_str", - "add_suggestion", - "create_categories_file", - "read_mat" -] +__all__ = ["INFINITE_BUFFER_SIZE", "sequence_to_str", "add_suggestion", "create_categories_file", "read_mat"] # pseudo-infinite until a true infinite buffer is supported by all datapipes INFINITE_BUFFER_SIZE = 1_000_000_000 @@ -21,10 +15,7 @@ def sequence_to_str(seq: Sequence, separate_last: str = "") -> str: if len(seq) == 1: return f"'{seq[0]}'" - return ( - f"""'{"', '".join([str(item) for item in seq[:-1]])}', """ - f"""{separate_last}'{seq[-1]}'.""" - ) + return f"""'{"', '".join([str(item) for item in seq[:-1]])}', """ f"""{separate_last}'{seq[-1]}'.""" def add_suggestion( @@ -32,9 +23,7 @@ def add_suggestion( *, word: str, possibilities: Collection[str], - close_match_hint: Callable[ - [str], str - ] = lambda close_match: f"Did you mean '{close_match}'?", + close_match_hint: Callable[[str], str] = lambda close_match: f"Did you mean '{close_match}'?", alternative_hint: Callable[ [Sequence[str]], str ] = lambda possibilities: f"Can be {sequence_to_str(possibilities, separate_last='or ')}.", @@ -42,17 +31,11 @@ def add_suggestion( if not isinstance(possibilities, collections.abc.Sequence): possibilities = sorted(possibilities) suggestions = difflib.get_close_matches(word, possibilities, 1) - hint = ( - close_match_hint(suggestions[0]) - if suggestions - else alternative_hint(possibilities) - ) + hint = close_match_hint(suggestions[0]) if suggestions else alternative_hint(possibilities) return f"{msg.strip()} {hint}" -def create_categories_file( - root: Union[str, pathlib.Path], name: str, categories: Sequence[str] -) -> None: +def create_categories_file(root: Union[str, pathlib.Path], name: str, categories: Sequence[str]) -> None: with open(pathlib.Path(root) / f"{name}.categories", "w") as fh: fh.write("\n".join(categories) + "\n") @@ -61,8 +44,6 @@ def read_mat(buffer: io.IOBase, **kwargs: Any) -> Any: try: import scipy.io as sio except ImportError as error: - raise ModuleNotFoundError( - "Package `scipy` is required to be installed to read .mat files." - ) from error + raise ModuleNotFoundError("Package `scipy` is required to be installed to read .mat files.") from error return sio.loadmat(buffer, **kwargs) diff --git a/torchvision/prototype/datasets/utils/_resource.py b/torchvision/prototype/datasets/utils/_resource.py index 3f372d0f5b7..e91fdfa2e8f 100644 --- a/torchvision/prototype/datasets/utils/_resource.py +++ b/torchvision/prototype/datasets/utils/_resource.py @@ -13,9 +13,7 @@ def compute_sha256(_) -> str: class LocalResource: - def __init__( - self, path: Union[str, pathlib.Path], *, sha256: Optional[str] = None - ) -> None: + def __init__(self, path: Union[str, pathlib.Path], *, sha256: Optional[str] = None) -> None: self.path = pathlib.Path(path).expanduser().resolve() self.file_name = self.path.name self.sha256 = sha256 or compute_sha256(self.path) @@ -39,9 +37,7 @@ def to_datapipe(self, root: Union[str, pathlib.Path]) -> IterDataPipe: # TODO: add support for mirrors # TODO: add support for http -> https class HttpResource(OnlineResource): - def __init__( - self, url: str, *, sha256: str, file_name: Optional[str] = None - ) -> None: + def __init__(self, url: str, *, sha256: str, file_name: Optional[str] = None) -> None: if not file_name: file_name = os.path.basename(urlparse(url).path) super().__init__(url, sha256=sha256, file_name=file_name) diff --git a/torchvision/transforms/_functional_video.py b/torchvision/transforms/_functional_video.py index 9eba0463a4f..2b4fe371b98 100644 --- a/torchvision/transforms/_functional_video.py +++ b/torchvision/transforms/_functional_video.py @@ -1,10 +1,9 @@ -import torch import warnings +import torch + -warnings.warn( - "The _functional_video module is deprecated. Please use the functional module instead." -) +warnings.warn("The _functional_video module is deprecated. Please use the functional module instead.") def _is_tensor_video_clip(clip): @@ -23,14 +22,12 @@ def crop(clip, i, j, h, w): clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) """ assert len(clip.size()) == 4, "clip should be a 4D tensor" - return clip[..., i:i + h, j:j + w] + return clip[..., i : i + h, j : j + w] def resize(clip, target_size, interpolation_mode): assert len(target_size) == 2, "target size should be tuple (height, width)" - return torch.nn.functional.interpolate( - clip, size=target_size, mode=interpolation_mode, align_corners=False - ) + return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False) def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"): diff --git a/torchvision/transforms/_transforms_video.py b/torchvision/transforms/_transforms_video.py index bfef1b440d1..f5c7836543a 100644 --- a/torchvision/transforms/_transforms_video.py +++ b/torchvision/transforms/_transforms_video.py @@ -22,9 +22,7 @@ ] -warnings.warn( - "The _transforms_video module is deprecated. Please use the transforms module instead." -) +warnings.warn("The _transforms_video module is deprecated. Please use the transforms module instead.") class RandomCropVideo(RandomCrop): @@ -46,7 +44,7 @@ def __call__(self, clip): return F.crop(clip, i, j, h, w) def __repr__(self): - return self.__class__.__name__ + '(size={0})'.format(self.size) + return self.__class__.__name__ + "(size={0})".format(self.size) class RandomResizedCropVideo(RandomResizedCrop): @@ -79,10 +77,9 @@ def __call__(self, clip): return F.resized_crop(clip, i, j, h, w, self.size, self.interpolation_mode) def __repr__(self): - return self.__class__.__name__ + \ - '(size={0}, interpolation_mode={1}, scale={2}, ratio={3})'.format( - self.size, self.interpolation_mode, self.scale, self.ratio - ) + return self.__class__.__name__ + "(size={0}, interpolation_mode={1}, scale={2}, ratio={3})".format( + self.size, self.interpolation_mode, self.scale, self.ratio + ) class CenterCropVideo(object): @@ -103,7 +100,7 @@ def __call__(self, clip): return F.center_crop(clip, self.crop_size) def __repr__(self): - return self.__class__.__name__ + '(crop_size={0})'.format(self.crop_size) + return self.__class__.__name__ + "(crop_size={0})".format(self.crop_size) class NormalizeVideo(object): @@ -128,8 +125,7 @@ def __call__(self, clip): return F.normalize(clip, self.mean, self.std, self.inplace) def __repr__(self): - return self.__class__.__name__ + '(mean={0}, std={1}, inplace={2})'.format( - self.mean, self.std, self.inplace) + return self.__class__.__name__ + "(mean={0}, std={1}, inplace={2})".format(self.mean, self.std, self.inplace) class ToTensorVideo(object): diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index bffc4a24f67..f99e0aa2950 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -1,29 +1,58 @@ import math -import torch - from enum import Enum -from torch import Tensor from typing import List, Tuple, Optional, Dict +import torch +from torch import Tensor + from . import functional as F, InterpolationMode __all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide"] -def _apply_op(img: Tensor, op_name: str, magnitude: float, - interpolation: InterpolationMode, fill: Optional[List[float]]): +def _apply_op( + img: Tensor, op_name: str, magnitude: float, interpolation: InterpolationMode, fill: Optional[List[float]] +): if op_name == "ShearX": - img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[math.degrees(magnitude), 0.0], - interpolation=interpolation, fill=fill) + img = F.affine( + img, + angle=0.0, + translate=[0, 0], + scale=1.0, + shear=[math.degrees(magnitude), 0.0], + interpolation=interpolation, + fill=fill, + ) elif op_name == "ShearY": - img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[0.0, math.degrees(magnitude)], - interpolation=interpolation, fill=fill) + img = F.affine( + img, + angle=0.0, + translate=[0, 0], + scale=1.0, + shear=[0.0, math.degrees(magnitude)], + interpolation=interpolation, + fill=fill, + ) elif op_name == "TranslateX": - img = F.affine(img, angle=0.0, translate=[int(magnitude), 0], scale=1.0, - interpolation=interpolation, shear=[0.0, 0.0], fill=fill) + img = F.affine( + img, + angle=0.0, + translate=[int(magnitude), 0], + scale=1.0, + interpolation=interpolation, + shear=[0.0, 0.0], + fill=fill, + ) elif op_name == "TranslateY": - img = F.affine(img, angle=0.0, translate=[0, int(magnitude)], scale=1.0, - interpolation=interpolation, shear=[0.0, 0.0], fill=fill) + img = F.affine( + img, + angle=0.0, + translate=[0, int(magnitude)], + scale=1.0, + interpolation=interpolation, + shear=[0.0, 0.0], + fill=fill, + ) elif op_name == "Rotate": img = F.rotate(img, magnitude, interpolation=interpolation, fill=fill) elif op_name == "Brightness": @@ -55,6 +84,7 @@ class AutoAugmentPolicy(Enum): """AutoAugment policies learned on different datasets. Available policies are IMAGENET, CIFAR10 and SVHN. """ + IMAGENET = "imagenet" CIFAR10 = "cifar10" SVHN = "svhn" @@ -82,7 +112,7 @@ def __init__( self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Optional[List[float]] = None + fill: Optional[List[float]] = None, ) -> None: super().__init__() self.policy = policy @@ -91,8 +121,7 @@ def __init__( self.policies = self._get_policies(policy) def _get_policies( - self, - policy: AutoAugmentPolicy + self, policy: AutoAugmentPolicy ) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]: if policy == AutoAugmentPolicy.IMAGENET: return [ @@ -241,7 +270,7 @@ def forward(self, img: Tensor) -> Tensor: return img def __repr__(self) -> str: - return self.__class__.__name__ + '(policy={}, fill={})'.format(self.policy, self.fill) + return self.__class__.__name__ + "(policy={}, fill={})".format(self.policy, self.fill) class RandAugment(torch.nn.Module): @@ -261,11 +290,16 @@ class RandAugment(torch.nn.Module): If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. fill (sequence or number, optional): Pixel fill value for the area outside the transformed image. If given a number, the value is used for all bands respectively. - """ + """ - def __init__(self, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: int = 31, - interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Optional[List[float]] = None) -> None: + def __init__( + self, + num_ops: int = 2, + magnitude: int = 9, + num_magnitude_bins: int = 31, + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[List[float]] = None, + ) -> None: super().__init__() self.num_ops = num_ops self.magnitude = magnitude @@ -319,13 +353,13 @@ def forward(self, img: Tensor) -> Tensor: return img def __repr__(self) -> str: - s = self.__class__.__name__ + '(' - s += 'num_ops={num_ops}' - s += ', magnitude={magnitude}' - s += ', num_magnitude_bins={num_magnitude_bins}' - s += ', interpolation={interpolation}' - s += ', fill={fill}' - s += ')' + s = self.__class__.__name__ + "(" + s += "num_ops={num_ops}" + s += ", magnitude={magnitude}" + s += ", num_magnitude_bins={num_magnitude_bins}" + s += ", interpolation={interpolation}" + s += ", fill={fill}" + s += ")" return s.format(**self.__dict__) @@ -343,10 +377,14 @@ class TrivialAugmentWide(torch.nn.Module): If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. fill (sequence or number, optional): Pixel fill value for the area outside the transformed image. If given a number, the value is used for all bands respectively. - """ + """ - def __init__(self, num_magnitude_bins: int = 31, interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Optional[List[float]] = None) -> None: + def __init__( + self, + num_magnitude_bins: int = 31, + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[List[float]] = None, + ) -> None: super().__init__() self.num_magnitude_bins = num_magnitude_bins self.interpolation = interpolation @@ -389,17 +427,20 @@ def forward(self, img: Tensor) -> Tensor: op_index = int(torch.randint(len(op_meta), (1,)).item()) op_name = list(op_meta.keys())[op_index] magnitudes, signed = op_meta[op_name] - magnitude = float(magnitudes[torch.randint(len(magnitudes), (1,), dtype=torch.long)].item()) \ - if magnitudes.ndim > 0 else 0.0 + magnitude = ( + float(magnitudes[torch.randint(len(magnitudes), (1,), dtype=torch.long)].item()) + if magnitudes.ndim > 0 + else 0.0 + ) if signed and torch.randint(2, (1,)): magnitude *= -1.0 return _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill) def __repr__(self) -> str: - s = self.__class__.__name__ + '(' - s += 'num_magnitude_bins={num_magnitude_bins}' - s += ', interpolation={interpolation}' - s += ', fill={fill}' - s += ')' + s = self.__class__.__name__ + "(" + s += "num_magnitude_bins={num_magnitude_bins}" + s += ", interpolation={interpolation}" + s += ", fill={fill}" + s += ")" return s.format(**self.__dict__) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 8c375cc9273..578f90ef62e 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -2,13 +2,12 @@ import numbers import warnings from enum import Enum +from typing import List, Tuple, Any, Optional import numpy as np -from PIL import Image - import torch +from PIL import Image from torch import Tensor -from typing import List, Tuple, Any, Optional try: import accimage @@ -23,6 +22,7 @@ class InterpolationMode(Enum): """Interpolation modes Available interpolation methods are ``nearest``, ``bilinear``, ``bicubic``, ``box``, ``hamming``, and ``lanczos``. """ + NEAREST = "nearest" BILINEAR = "bilinear" BICUBIC = "bicubic" @@ -110,11 +110,11 @@ def to_tensor(pic): Returns: Tensor: Converted image. """ - if not(F_pil._is_pil_image(pic) or _is_numpy(pic)): - raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic))) + if not (F_pil._is_pil_image(pic) or _is_numpy(pic)): + raise TypeError("pic should be PIL Image or ndarray. Got {}".format(type(pic))) if _is_numpy(pic) and not _is_numpy_image(pic): - raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim)) + raise ValueError("pic should be 2/3 dimensional. Got {} dimensions.".format(pic.ndim)) default_float_dtype = torch.get_default_dtype() @@ -136,12 +136,10 @@ def to_tensor(pic): return torch.from_numpy(nppic).to(dtype=default_float_dtype) # handle PIL Image - mode_to_nptype = {'I': np.int32, 'I;16': np.int16, 'F': np.float32} - img = torch.from_numpy( - np.array(pic, mode_to_nptype.get(pic.mode, np.uint8), copy=True) - ) + mode_to_nptype = {"I": np.int32, "I;16": np.int16, "F": np.float32} + img = torch.from_numpy(np.array(pic, mode_to_nptype.get(pic.mode, np.uint8), copy=True)) - if pic.mode == '1': + if pic.mode == "1": img = 255 * img img = img.view(pic.size[1], pic.size[0], len(pic.getbands())) # put it from HWC to CHW format @@ -165,7 +163,7 @@ def pil_to_tensor(pic): Tensor: Converted image. """ if not F_pil._is_pil_image(pic): - raise TypeError('pic should be PIL Image. Got {}'.format(type(pic))) + raise TypeError("pic should be PIL Image. Got {}".format(type(pic))) if accimage is not None and isinstance(pic, accimage.Image): # accimage format is always uint8 internally, so always return uint8 here @@ -204,7 +202,7 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) - of the integer ``dtype``. """ if not isinstance(image, torch.Tensor): - raise TypeError('Input img should be Tensor Image') + raise TypeError("Input img should be Tensor Image") return F_t.convert_image_dtype(image, dtype) @@ -223,12 +221,12 @@ def to_pil_image(pic, mode=None): Returns: PIL Image: Image converted to PIL Image. """ - if not(isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)): - raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic))) + if not (isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)): + raise TypeError("pic should be Tensor or ndarray. Got {}.".format(type(pic))) elif isinstance(pic, torch.Tensor): if pic.ndimension() not in {2, 3}: - raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndimension())) + raise ValueError("pic should be 2/3 dimensional. Got {} dimensions.".format(pic.ndimension())) elif pic.ndimension() == 2: # if 2D image, add channel dimension (CHW) @@ -236,11 +234,11 @@ def to_pil_image(pic, mode=None): # check number of channels if pic.shape[-3] > 4: - raise ValueError('pic should not have > 4 channels. Got {} channels.'.format(pic.shape[-3])) + raise ValueError("pic should not have > 4 channels. Got {} channels.".format(pic.shape[-3])) elif isinstance(pic, np.ndarray): if pic.ndim not in {2, 3}: - raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim)) + raise ValueError("pic should be 2/3 dimensional. Got {} dimensions.".format(pic.ndim)) elif pic.ndim == 2: # if 2D image, add channel dimension (HWC) @@ -248,58 +246,58 @@ def to_pil_image(pic, mode=None): # check number of channels if pic.shape[-1] > 4: - raise ValueError('pic should not have > 4 channels. Got {} channels.'.format(pic.shape[-1])) + raise ValueError("pic should not have > 4 channels. Got {} channels.".format(pic.shape[-1])) npimg = pic if isinstance(pic, torch.Tensor): - if pic.is_floating_point() and mode != 'F': + if pic.is_floating_point() and mode != "F": pic = pic.mul(255).byte() npimg = np.transpose(pic.cpu().numpy(), (1, 2, 0)) if not isinstance(npimg, np.ndarray): - raise TypeError('Input pic must be a torch.Tensor or NumPy ndarray, ' + - 'not {}'.format(type(npimg))) + raise TypeError("Input pic must be a torch.Tensor or NumPy ndarray, " + "not {}".format(type(npimg))) if npimg.shape[2] == 1: expected_mode = None npimg = npimg[:, :, 0] if npimg.dtype == np.uint8: - expected_mode = 'L' + expected_mode = "L" elif npimg.dtype == np.int16: - expected_mode = 'I;16' + expected_mode = "I;16" elif npimg.dtype == np.int32: - expected_mode = 'I' + expected_mode = "I" elif npimg.dtype == np.float32: - expected_mode = 'F' + expected_mode = "F" if mode is not None and mode != expected_mode: - raise ValueError("Incorrect mode ({}) supplied for input type {}. Should be {}" - .format(mode, np.dtype, expected_mode)) + raise ValueError( + "Incorrect mode ({}) supplied for input type {}. Should be {}".format(mode, np.dtype, expected_mode) + ) mode = expected_mode elif npimg.shape[2] == 2: - permitted_2_channel_modes = ['LA'] + permitted_2_channel_modes = ["LA"] if mode is not None and mode not in permitted_2_channel_modes: raise ValueError("Only modes {} are supported for 2D inputs".format(permitted_2_channel_modes)) if mode is None and npimg.dtype == np.uint8: - mode = 'LA' + mode = "LA" elif npimg.shape[2] == 4: - permitted_4_channel_modes = ['RGBA', 'CMYK', 'RGBX'] + permitted_4_channel_modes = ["RGBA", "CMYK", "RGBX"] if mode is not None and mode not in permitted_4_channel_modes: raise ValueError("Only modes {} are supported for 4D inputs".format(permitted_4_channel_modes)) if mode is None and npimg.dtype == np.uint8: - mode = 'RGBA' + mode = "RGBA" else: - permitted_3_channel_modes = ['RGB', 'YCbCr', 'HSV'] + permitted_3_channel_modes = ["RGB", "YCbCr", "HSV"] if mode is not None and mode not in permitted_3_channel_modes: raise ValueError("Only modes {} are supported for 3D inputs".format(permitted_3_channel_modes)) if mode is None and npimg.dtype == np.uint8: - mode = 'RGB' + mode = "RGB" if mode is None: - raise TypeError('Input type {} is not supported'.format(npimg.dtype)) + raise TypeError("Input type {} is not supported".format(npimg.dtype)) return Image.fromarray(npimg, mode=mode) @@ -323,14 +321,16 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool Tensor: Normalized Tensor image. """ if not isinstance(tensor, torch.Tensor): - raise TypeError('Input tensor should be a torch tensor. Got {}.'.format(type(tensor))) + raise TypeError("Input tensor should be a torch tensor. Got {}.".format(type(tensor))) if not tensor.is_floating_point(): - raise TypeError('Input tensor should be a float tensor. Got {}.'.format(tensor.dtype)) + raise TypeError("Input tensor should be a float tensor. Got {}.".format(tensor.dtype)) if tensor.ndim < 3: - raise ValueError('Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = ' - '{}.'.format(tensor.size())) + raise ValueError( + "Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = " + "{}.".format(tensor.size()) + ) if not inplace: tensor = tensor.clone() @@ -339,7 +339,7 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device) std = torch.as_tensor(std, dtype=dtype, device=tensor.device) if (std == 0).any(): - raise ValueError('std evaluated to zero after conversion to {}, leading to division by zero.'.format(dtype)) + raise ValueError("std evaluated to zero after conversion to {}, leading to division by zero.".format(dtype)) if mean.ndim == 1: mean = mean.view(-1, 1, 1) if std.ndim == 1: @@ -348,8 +348,13 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool return tensor -def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, - max_size: Optional[int] = None, antialias: Optional[bool] = None) -> Tensor: +def resize( + img: Tensor, + size: List[int], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, + antialias: Optional[bool] = None, +) -> Tensor: r"""Resize the input image to the given size. If the image is torch Tensor, it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions @@ -408,9 +413,7 @@ def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = Inte if not isinstance(img, torch.Tensor): if antialias is not None and not antialias: - warnings.warn( - "Anti-alias option is always applied for PIL Image input. Argument antialias is ignored." - ) + warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.") pil_interpolation = pil_modes_mapping[interpolation] return F_pil.resize(img, size=size, interpolation=pil_interpolation, max_size=max_size) @@ -418,8 +421,7 @@ def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = Inte def scale(*args, **kwargs): - warnings.warn("The use of the transforms.Scale transform is deprecated, " + - "please use transforms.Resize instead.") + warnings.warn("The use of the transforms.Scale transform is deprecated, " + "please use transforms.Resize instead.") return resize(*args, **kwargs) @@ -527,14 +529,19 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor: if crop_width == image_width and crop_height == image_height: return img - crop_top = int(round((image_height - crop_height) / 2.)) - crop_left = int(round((image_width - crop_width) / 2.)) + crop_top = int(round((image_height - crop_height) / 2.0)) + crop_left = int(round((image_width - crop_width) / 2.0)) return crop(img, crop_top, crop_left, crop_height, crop_width) def resized_crop( - img: Tensor, top: int, left: int, height: int, width: int, size: List[int], - interpolation: InterpolationMode = InterpolationMode.BILINEAR + img: Tensor, + top: int, + left: int, + height: int, + width: int, + size: List[int], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, ) -> Tensor: """Crop the given image and resize it to desired size. If the image is torch Tensor, it is expected @@ -581,9 +588,7 @@ def hflip(img: Tensor) -> Tensor: return F_t.hflip(img) -def _get_perspective_coeffs( - startpoints: List[List[int]], endpoints: List[List[int]] -) -> List[float]: +def _get_perspective_coeffs(startpoints: List[List[int]], endpoints: List[List[int]]) -> List[float]: """Helper function to get the coefficients (a, b, c, d, e, f, g, h) for the perspective transforms. In Perspective Transform each pixel (x, y) in the original image gets transformed as, @@ -605,18 +610,18 @@ def _get_perspective_coeffs( a_matrix[2 * i + 1, :] = torch.tensor([0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]]) b_matrix = torch.tensor(startpoints, dtype=torch.float).view(8) - res = torch.linalg.lstsq(a_matrix, b_matrix, driver='gels').solution + res = torch.linalg.lstsq(a_matrix, b_matrix, driver="gels").solution output: List[float] = res.tolist() return output def perspective( - img: Tensor, - startpoints: List[List[int]], - endpoints: List[List[int]], - interpolation: InterpolationMode = InterpolationMode.BILINEAR, - fill: Optional[List[float]] = None + img: Tensor, + startpoints: List[List[int]], + endpoints: List[List[int]], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + fill: Optional[List[float]] = None, ) -> Tensor: """Perform perspective transform of the given image. If the image is torch Tensor, it is expected @@ -892,7 +897,7 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: def _get_inverse_affine_matrix( - center: List[float], angle: float, translate: List[float], scale: float, shear: List[float] + center: List[float], angle: float, translate: List[float], scale: float, shear: List[float] ) -> List[float]: # Helper method to compute inverse matrix for affine transformation @@ -942,9 +947,13 @@ def _get_inverse_affine_matrix( def rotate( - img: Tensor, angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST, - expand: bool = False, center: Optional[List[int]] = None, - fill: Optional[List[float]] = None, resample: Optional[int] = None + img: Tensor, + angle: float, + interpolation: InterpolationMode = InterpolationMode.NEAREST, + expand: bool = False, + center: Optional[List[int]] = None, + fill: Optional[List[float]] = None, + resample: Optional[int] = None, ) -> Tensor: """Rotate the image by angle. If the image is torch Tensor, it is expected @@ -1016,9 +1025,15 @@ def rotate( def affine( - img: Tensor, angle: float, translate: List[int], scale: float, shear: List[float], - interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[List[float]] = None, - resample: Optional[int] = None, fillcolor: Optional[List[float]] = None + img: Tensor, + angle: float, + translate: List[int], + scale: float, + shear: List[float], + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[List[float]] = None, + resample: Optional[int] = None, + fillcolor: Optional[List[float]] = None, ) -> Tensor: """Apply affine transformation on the image keeping image center invariant. If the image is torch Tensor, it is expected @@ -1065,9 +1080,7 @@ def affine( interpolation = _interpolation_modes_from_int(interpolation) if fillcolor is not None: - warnings.warn( - "Argument fillcolor is deprecated and will be removed since v0.10.0. Please, use fill instead" - ) + warnings.warn("Argument fillcolor is deprecated and will be removed since v0.10.0. Please, use fill instead") fill = fillcolor if not isinstance(angle, (int, float)): @@ -1168,7 +1181,7 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor: def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor: - """ Erase the input Tensor Image with given value. + """Erase the input Tensor Image with given value. This transform does not support PIL Image. Args: @@ -1184,12 +1197,12 @@ def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool Tensor Image: Erased image. """ if not isinstance(img, torch.Tensor): - raise TypeError('img should be Tensor Image. Got {}'.format(type(img))) + raise TypeError("img should be Tensor Image. Got {}".format(type(img))) if not inplace: img = img.clone() - img[..., i:i + h, j:j + w] = v + img[..., i : i + h, j : j + w] = v return img @@ -1220,34 +1233,34 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[floa PIL Image or Tensor: Gaussian Blurred version of the image. """ if not isinstance(kernel_size, (int, list, tuple)): - raise TypeError('kernel_size should be int or a sequence of integers. Got {}'.format(type(kernel_size))) + raise TypeError("kernel_size should be int or a sequence of integers. Got {}".format(type(kernel_size))) if isinstance(kernel_size, int): kernel_size = [kernel_size, kernel_size] if len(kernel_size) != 2: - raise ValueError('If kernel_size is a sequence its length should be 2. Got {}'.format(len(kernel_size))) + raise ValueError("If kernel_size is a sequence its length should be 2. Got {}".format(len(kernel_size))) for ksize in kernel_size: if ksize % 2 == 0 or ksize < 0: - raise ValueError('kernel_size should have odd and positive integers. Got {}'.format(kernel_size)) + raise ValueError("kernel_size should have odd and positive integers. Got {}".format(kernel_size)) if sigma is None: sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size] if sigma is not None and not isinstance(sigma, (int, float, list, tuple)): - raise TypeError('sigma should be either float or sequence of floats. Got {}'.format(type(sigma))) + raise TypeError("sigma should be either float or sequence of floats. Got {}".format(type(sigma))) if isinstance(sigma, (int, float)): sigma = [float(sigma), float(sigma)] if isinstance(sigma, (list, tuple)) and len(sigma) == 1: sigma = [sigma[0], sigma[0]] if len(sigma) != 2: - raise ValueError('If sigma is a sequence, its length should be 2. Got {}'.format(len(sigma))) + raise ValueError("If sigma is a sequence, its length should be 2. Got {}".format(len(sigma))) for s in sigma: - if s <= 0.: - raise ValueError('sigma should have positive values. Got {}'.format(sigma)) + if s <= 0.0: + raise ValueError("sigma should have positive values. Got {}".format(sigma)) t_img = img if not isinstance(img, torch.Tensor): if not F_pil._is_pil_image(img): - raise TypeError('img should be PIL Image or Tensor. Got {}'.format(type(img))) + raise TypeError("img should be PIL Image or Tensor. Got {}".format(type(img))) t_img = to_tensor(img) @@ -1290,7 +1303,7 @@ def posterize(img: Tensor, bits: int) -> Tensor: PIL Image or Tensor: Posterized image. """ if not (0 <= bits <= 8): - raise ValueError('The number if bits should be between 0 and 8. Got {}'.format(bits)) + raise ValueError("The number if bits should be between 0 and 8. Got {}".format(bits)) if not isinstance(img, torch.Tensor): return F_pil.posterize(img, bits) diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 67f4ff4bb33..eb2ab31a4a9 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -29,14 +29,14 @@ def get_image_size(img: Any) -> List[int]: @torch.jit.unused def get_image_num_channels(img: Any) -> int: if _is_pil_image(img): - return 1 if img.mode == 'L' else 3 + return 1 if img.mode == "L" else 3 raise TypeError("Unexpected type {}".format(type(img))) @torch.jit.unused def hflip(img: Image.Image) -> Image.Image: if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + raise TypeError("img should be PIL Image. Got {}".format(type(img))) return img.transpose(Image.FLIP_LEFT_RIGHT) @@ -44,7 +44,7 @@ def hflip(img: Image.Image) -> Image.Image: @torch.jit.unused def vflip(img: Image.Image) -> Image.Image: if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + raise TypeError("img should be PIL Image. Got {}".format(type(img))) return img.transpose(Image.FLIP_TOP_BOTTOM) @@ -52,7 +52,7 @@ def vflip(img: Image.Image) -> Image.Image: @torch.jit.unused def adjust_brightness(img: Image.Image, brightness_factor: float) -> Image.Image: if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + raise TypeError("img should be PIL Image. Got {}".format(type(img))) enhancer = ImageEnhance.Brightness(img) img = enhancer.enhance(brightness_factor) @@ -62,7 +62,7 @@ def adjust_brightness(img: Image.Image, brightness_factor: float) -> Image.Image @torch.jit.unused def adjust_contrast(img: Image.Image, contrast_factor: float) -> Image.Image: if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + raise TypeError("img should be PIL Image. Got {}".format(type(img))) enhancer = ImageEnhance.Contrast(img) img = enhancer.enhance(contrast_factor) @@ -72,7 +72,7 @@ def adjust_contrast(img: Image.Image, contrast_factor: float) -> Image.Image: @torch.jit.unused def adjust_saturation(img: Image.Image, saturation_factor: float) -> Image.Image: if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + raise TypeError("img should be PIL Image. Got {}".format(type(img))) enhancer = ImageEnhance.Color(img) img = enhancer.enhance(saturation_factor) @@ -81,25 +81,25 @@ def adjust_saturation(img: Image.Image, saturation_factor: float) -> Image.Image @torch.jit.unused def adjust_hue(img: Image.Image, hue_factor: float) -> Image.Image: - if not(-0.5 <= hue_factor <= 0.5): - raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor)) + if not (-0.5 <= hue_factor <= 0.5): + raise ValueError("hue_factor ({}) is not in [-0.5, 0.5].".format(hue_factor)) if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + raise TypeError("img should be PIL Image. Got {}".format(type(img))) input_mode = img.mode - if input_mode in {'L', '1', 'I', 'F'}: + if input_mode in {"L", "1", "I", "F"}: return img - h, s, v = img.convert('HSV').split() + h, s, v = img.convert("HSV").split() np_h = np.array(h, dtype=np.uint8) # uint8 addition take cares of rotation across boundaries - with np.errstate(over='ignore'): + with np.errstate(over="ignore"): np_h += np.uint8(hue_factor * 255) - h = Image.fromarray(np_h, 'L') + h = Image.fromarray(np_h, "L") - img = Image.merge('HSV', (h, s, v)).convert(input_mode) + img = Image.merge("HSV", (h, s, v)).convert(input_mode) return img @@ -111,14 +111,14 @@ def adjust_gamma( ) -> Image.Image: if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + raise TypeError("img should be PIL Image. Got {}".format(type(img))) if gamma < 0: - raise ValueError('Gamma should be a non-negative real number') + raise ValueError("Gamma should be a non-negative real number") input_mode = img.mode - img = img.convert('RGB') - gamma_map = [(255 + 1 - 1e-3) * gain * pow(ele / 255., gamma) for ele in range(256)] * 3 + img = img.convert("RGB") + gamma_map = [(255 + 1 - 1e-3) * gain * pow(ele / 255.0, gamma) for ele in range(256)] * 3 img = img.point(gamma_map) # use PIL's point-function to accelerate this part img = img.convert(input_mode) @@ -147,8 +147,9 @@ def pad( padding = tuple(padding) if isinstance(padding, tuple) and len(padding) not in [1, 2, 4]: - raise ValueError("Padding must be an int or a 1, 2, or 4 element tuple, not a " + - "{} element tuple".format(len(padding))) + raise ValueError( + "Padding must be an int or a 1, 2, or 4 element tuple, not a " + "{} element tuple".format(len(padding)) + ) if isinstance(padding, tuple) and len(padding) == 1: # Compatibility with `functional_tensor.pad` @@ -187,7 +188,7 @@ def pad( pad_left, pad_top, pad_right, pad_bottom = np.maximum(p, 0) - if img.mode == 'P': + if img.mode == "P": palette = img.getpalette() img = np.asarray(img) img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode) @@ -216,7 +217,7 @@ def crop( ) -> Image.Image: if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + raise TypeError("img should be PIL Image. Got {}".format(type(img))) return img.crop((left, top, left + width, top + height)) @@ -230,9 +231,9 @@ def resize( ) -> Image.Image: if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + raise TypeError("img should be PIL Image. Got {}".format(type(img))) if not (isinstance(size, int) or (isinstance(size, Sequence) and len(size) in (1, 2))): - raise TypeError('Got inappropriate size arg: {}'.format(size)) + raise TypeError("Got inappropriate size arg: {}".format(size)) if isinstance(size, Sequence) and len(size) == 1: size = size[0] @@ -280,8 +281,7 @@ def _parse_fill( fill = tuple([fill] * num_bands) if isinstance(fill, (list, tuple)): if len(fill) != num_bands: - msg = ("The number of elements in 'fill' does not match the number of " - "bands of the image ({} != {})") + msg = "The number of elements in 'fill' does not match the number of " "bands of the image ({} != {})" raise ValueError(msg.format(len(fill), num_bands)) fill = tuple(fill) @@ -298,7 +298,7 @@ def affine( ) -> Image.Image: if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + raise TypeError("img should be PIL Image. Got {}".format(type(img))) output_size = img.size opts = _parse_fill(fill, img) @@ -331,7 +331,7 @@ def perspective( ) -> Image.Image: if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + raise TypeError("img should be PIL Image. Got {}".format(type(img))) opts = _parse_fill(fill, img) @@ -341,17 +341,17 @@ def perspective( @torch.jit.unused def to_grayscale(img: Image.Image, num_output_channels: int) -> Image.Image: if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + raise TypeError("img should be PIL Image. Got {}".format(type(img))) if num_output_channels == 1: - img = img.convert('L') + img = img.convert("L") elif num_output_channels == 3: - img = img.convert('L') + img = img.convert("L") np_img = np.array(img, dtype=np.uint8) np_img = np.dstack([np_img, np_img, np_img]) - img = Image.fromarray(np_img, 'RGB') + img = Image.fromarray(np_img, "RGB") else: - raise ValueError('num_output_channels should be either 1 or 3') + raise ValueError("num_output_channels should be either 1 or 3") return img @@ -359,28 +359,28 @@ def to_grayscale(img: Image.Image, num_output_channels: int) -> Image.Image: @torch.jit.unused def invert(img: Image.Image) -> Image.Image: if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + raise TypeError("img should be PIL Image. Got {}".format(type(img))) return ImageOps.invert(img) @torch.jit.unused def posterize(img: Image.Image, bits: int) -> Image.Image: if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + raise TypeError("img should be PIL Image. Got {}".format(type(img))) return ImageOps.posterize(img, bits) @torch.jit.unused def solarize(img: Image.Image, threshold: int) -> Image.Image: if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + raise TypeError("img should be PIL Image. Got {}".format(type(img))) return ImageOps.solarize(img, threshold) @torch.jit.unused def adjust_sharpness(img: Image.Image, sharpness_factor: float) -> Image.Image: if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + raise TypeError("img should be PIL Image. Got {}".format(type(img))) enhancer = ImageEnhance.Sharpness(img) img = enhancer.enhance(sharpness_factor) @@ -390,12 +390,12 @@ def adjust_sharpness(img: Image.Image, sharpness_factor: float) -> Image.Image: @torch.jit.unused def autocontrast(img: Image.Image) -> Image.Image: if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + raise TypeError("img should be PIL Image. Got {}".format(type(img))) return ImageOps.autocontrast(img) @torch.jit.unused def equalize(img: Image.Image) -> Image.Image: if not _is_pil_image(img): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + raise TypeError("img should be PIL Image. Got {}".format(type(img))) return ImageOps.equalize(img) diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index a95e29fda0d..d0fd78346b6 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -1,10 +1,10 @@ import warnings +from typing import Optional, Tuple, List import torch from torch import Tensor -from torch.nn.functional import grid_sample, conv2d, interpolate, pad as torch_pad from torch.jit.annotations import BroadcastingList2 -from typing import Optional, Tuple, List +from torch.nn.functional import grid_sample, conv2d, interpolate, pad as torch_pad def _is_tensor_a_torch_image(x: Tensor) -> bool: @@ -97,7 +97,7 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) - # factor should be forced to int for torch jit script # otherwise factor is a float and image // factor can produce different results factor = int((input_max + 1) // (output_max + 1)) - image = torch.div(image, factor, rounding_mode='floor') + image = torch.div(image, factor, rounding_mode="floor") return image.to(dtype) else: # factor should be forced to int for torch jit script @@ -128,7 +128,7 @@ def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor: if left < 0 or top < 0 or right > w or bottom > h: padding_ltrb = [max(-left, 0), max(-top, 0), max(right - w, 0), max(bottom - h, 0)] - return pad(img[..., max(top, 0):bottom, max(left, 0):right], padding_ltrb, fill=0) + return pad(img[..., max(top, 0) : bottom, max(left, 0) : right], padding_ltrb, fill=0) return img[..., top:bottom, left:right] @@ -138,7 +138,7 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor: _assert_channels(img, [3]) if num_output_channels not in (1, 3): - raise ValueError('num_output_channels should be either 1 or 3') + raise ValueError("num_output_channels should be either 1 or 3") r, g, b = img.unbind(dim=-3) # This implementation closely follows the TF one: @@ -154,7 +154,7 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor: def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor: if brightness_factor < 0: - raise ValueError('brightness_factor ({}) is not non-negative.'.format(brightness_factor)) + raise ValueError("brightness_factor ({}) is not non-negative.".format(brightness_factor)) _assert_image_tensor(img) @@ -165,7 +165,7 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor: def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: if contrast_factor < 0: - raise ValueError('contrast_factor ({}) is not non-negative.'.format(contrast_factor)) + raise ValueError("contrast_factor ({}) is not non-negative.".format(contrast_factor)) _assert_image_tensor(img) @@ -182,10 +182,10 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: if not (-0.5 <= hue_factor <= 0.5): - raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor)) + raise ValueError("hue_factor ({}) is not in [-0.5, 0.5].".format(hue_factor)) if not (isinstance(img, torch.Tensor)): - raise TypeError('Input img should be Tensor image') + raise TypeError("Input img should be Tensor image") _assert_image_tensor(img) @@ -211,7 +211,7 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: if saturation_factor < 0: - raise ValueError('saturation_factor ({}) is not non-negative.'.format(saturation_factor)) + raise ValueError("saturation_factor ({}) is not non-negative.".format(saturation_factor)) _assert_image_tensor(img) @@ -225,12 +225,12 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: if not isinstance(img, torch.Tensor): - raise TypeError('Input img should be a Tensor.') + raise TypeError("Input img should be a Tensor.") _assert_channels(img, [1, 3]) if gamma < 0: - raise ValueError('Gamma should be a non-negative real number') + raise ValueError("Gamma should be a non-negative real number") result = img dtype = img.dtype @@ -244,11 +244,9 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: def center_crop(img: Tensor, output_size: BroadcastingList2[int]) -> Tensor: - """DEPRECATED - """ + """DEPRECATED""" warnings.warn( - "This method is deprecated and will be removed in future releases. " - "Please, use ``F.center_crop`` instead." + "This method is deprecated and will be removed in future releases. " "Please, use ``F.center_crop`` instead." ) _assert_image_tensor(img) @@ -268,11 +266,9 @@ def center_crop(img: Tensor, output_size: BroadcastingList2[int]) -> Tensor: def five_crop(img: Tensor, size: BroadcastingList2[int]) -> List[Tensor]: - """DEPRECATED - """ + """DEPRECATED""" warnings.warn( - "This method is deprecated and will be removed in future releases. " - "Please, use ``F.five_crop`` instead." + "This method is deprecated and will be removed in future releases. " "Please, use ``F.five_crop`` instead." ) _assert_image_tensor(img) @@ -295,11 +291,9 @@ def five_crop(img: Tensor, size: BroadcastingList2[int]) -> List[Tensor]: def ten_crop(img: Tensor, size: BroadcastingList2[int], vertical_flip: bool = False) -> List[Tensor]: - """DEPRECATED - """ + """DEPRECATED""" warnings.warn( - "This method is deprecated and will be removed in future releases. " - "Please, use ``F.ten_crop`` instead." + "This method is deprecated and will be removed in future releases. " "Please, use ``F.ten_crop`` instead." ) _assert_image_tensor(img) @@ -357,7 +351,7 @@ def _rgb2hsv(img: Tensor) -> Tensor: hr = (maxc == r) * (bc - gc) hg = ((maxc == g) & (maxc != r)) * (2.0 + rc - bc) hb = ((maxc != g) & (maxc != r)) * (4.0 + gc - rc) - h = (hr + hg + hb) + h = hr + hg + hb h = torch.fmod((h / 6.0 + 1.0), 1.0) return torch.stack((h, s, maxc), dim=-3) @@ -389,7 +383,7 @@ def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor: # crop if needed if padding[0] < 0 or padding[1] < 0 or padding[2] < 0 or padding[3] < 0: crop_left, crop_right, crop_top, crop_bottom = [-min(x, 0) for x in padding] - img = img[..., crop_top:img.shape[-2] - crop_bottom, crop_left:img.shape[-1] - crop_right] + img = img[..., crop_top : img.shape[-2] - crop_bottom, crop_left : img.shape[-1] - crop_right] padding = [max(x, 0) for x in padding] in_sizes = img.size() @@ -427,8 +421,9 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con padding = list(padding) if isinstance(padding, list) and len(padding) not in [1, 2, 4]: - raise ValueError("Padding must be an int or a 1, 2, or 4 element tuple, not a " + - "{} element tuple".format(len(padding))) + raise ValueError( + "Padding must be an int or a 1, 2, or 4 element tuple, not a " + "{} element tuple".format(len(padding)) + ) if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") @@ -488,7 +483,7 @@ def resize( size: List[int], interpolation: str = "bilinear", max_size: Optional[int] = None, - antialias: Optional[bool] = None + antialias: Optional[bool] = None, ) -> Tensor: _assert_image_tensor(img) @@ -505,8 +500,9 @@ def resize( if isinstance(size, list): if len(size) not in [1, 2]: - raise ValueError("Size must be an int or a 1 or 2 element tuple/list, not a " - "{} element tuple/list".format(len(size))) + raise ValueError( + "Size must be an int or a 1 or 2 element tuple/list, not a " "{} element tuple/list".format(len(size)) + ) if max_size is not None and len(size) != 1: raise ValueError( "max_size should only be passed if size specifies the length of the smaller edge, " @@ -594,8 +590,10 @@ def _assert_grid_transform_inputs( # Check fill num_channels = get_image_num_channels(img) if isinstance(fill, (tuple, list)) and (len(fill) > 1 and len(fill) != num_channels): - msg = ("The number of elements in 'fill' cannot broadcast to match the number of " - "channels of the image ({} != {})") + msg = ( + "The number of elements in 'fill' cannot broadcast to match the number of " + "channels of the image ({} != {})" + ) raise ValueError(msg.format(len(fill), num_channels)) if interpolation not in supported_interpolation_modes: @@ -633,7 +631,12 @@ def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtyp def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str, fill: Optional[List[float]]) -> Tensor: - img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [grid.dtype, ]) + img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in( + img, + [ + grid.dtype, + ], + ) if img.shape[0] > 1: # Apply same grid to a batch of images @@ -653,7 +656,7 @@ def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str, fill: Optional[L mask = mask.expand_as(img) len_fill = len(fill) if isinstance(fill, (tuple, list)) else 1 fill_img = torch.tensor(fill, dtype=img.dtype, device=img.device).view(1, len_fill, 1, 1).expand_as(img) - if mode == 'nearest': + if mode == "nearest": mask = mask < 0.5 img[mask] = fill_img[mask] else: # 'bilinear' @@ -664,7 +667,11 @@ def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str, fill: Optional[L def _gen_affine_grid( - theta: Tensor, w: int, h: int, ow: int, oh: int, + theta: Tensor, + w: int, + h: int, + ow: int, + oh: int, ) -> Tensor: # https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/ # AffineGridGenerator.cpp#L18 @@ -686,7 +693,7 @@ def _gen_affine_grid( def affine( - img: Tensor, matrix: List[float], interpolation: str = "nearest", fill: Optional[List[float]] = None + img: Tensor, matrix: List[float], interpolation: str = "nearest", fill: Optional[List[float]] = None ) -> Tensor: _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"]) @@ -704,12 +711,14 @@ def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int] # https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054 # pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points. - pts = torch.tensor([ - [-0.5 * w, -0.5 * h, 1.0], - [-0.5 * w, 0.5 * h, 1.0], - [0.5 * w, 0.5 * h, 1.0], - [0.5 * w, -0.5 * h, 1.0], - ]) + pts = torch.tensor( + [ + [-0.5 * w, -0.5 * h, 1.0], + [-0.5 * w, 0.5 * h, 1.0], + [0.5 * w, 0.5 * h, 1.0], + [0.5 * w, -0.5 * h, 1.0], + ] + ) theta = torch.tensor(matrix, dtype=torch.float).reshape(1, 2, 3) new_pts = pts.view(1, 4, 3).bmm(theta.transpose(1, 2)).view(4, 2) min_vals, _ = new_pts.min(dim=0) @@ -724,8 +733,11 @@ def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int] def rotate( - img: Tensor, matrix: List[float], interpolation: str = "nearest", - expand: bool = False, fill: Optional[List[float]] = None + img: Tensor, + matrix: List[float], + interpolation: str = "nearest", + expand: bool = False, + fill: Optional[List[float]] = None, ) -> Tensor: _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"]) w, h = img.shape[-1], img.shape[-2] @@ -746,14 +758,10 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1) # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1) # - theta1 = torch.tensor([[ - [coeffs[0], coeffs[1], coeffs[2]], - [coeffs[3], coeffs[4], coeffs[5]] - ]], dtype=dtype, device=device) - theta2 = torch.tensor([[ - [coeffs[6], coeffs[7], 1.0], - [coeffs[6], coeffs[7], 1.0] - ]], dtype=dtype, device=device) + theta1 = torch.tensor( + [[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device + ) + theta2 = torch.tensor([[[coeffs[6], coeffs[7], 1.0], [coeffs[6], coeffs[7], 1.0]]], dtype=dtype, device=device) d = 0.5 base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device) @@ -775,7 +783,7 @@ def perspective( img: Tensor, perspective_coeffs: List[float], interpolation: str = "bilinear", fill: Optional[List[float]] = None ) -> Tensor: if not (isinstance(img, torch.Tensor)): - raise TypeError('Input img should be Tensor.') + raise TypeError("Input img should be Tensor.") _assert_image_tensor(img) @@ -785,7 +793,7 @@ def perspective( interpolation=interpolation, fill=fill, supported_interpolation_modes=["nearest", "bilinear"], - coeffs=perspective_coeffs + coeffs=perspective_coeffs, ) ow, oh = img.shape[-1], img.shape[-2] @@ -805,7 +813,7 @@ def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor: def _get_gaussian_kernel2d( - kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device + kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device ) -> Tensor: kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype) kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype) @@ -815,7 +823,7 @@ def _get_gaussian_kernel2d( def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Tensor: if not (isinstance(img, torch.Tensor)): - raise TypeError('img should be Tensor. Got {}'.format(type(img))) + raise TypeError("img should be Tensor. Got {}".format(type(img))) _assert_image_tensor(img) @@ -823,7 +831,12 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Te kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device) kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1]) - img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype, ]) + img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in( + img, + [ + kernel.dtype, + ], + ) # padding = (left, right, top, bottom) padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2] @@ -857,7 +870,7 @@ def posterize(img: Tensor, bits: int) -> Tensor: raise TypeError("Only torch.uint8 image tensors are supported, but found {}".format(img.dtype)) _assert_channels(img, [1, 3]) - mask = -int(2**(8 - bits)) # JIT-friendly for: ~(2 ** (8 - bits) - 1) + mask = -int(2 ** (8 - bits)) # JIT-friendly for: ~(2 ** (8 - bits) - 1) return img & mask @@ -882,7 +895,12 @@ def _blurred_degenerate_image(img: Tensor) -> Tensor: kernel /= kernel.sum() kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1]) - result_tmp, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype, ]) + result_tmp, need_cast, need_squeeze, out_dtype = _cast_squeeze_in( + img, + [ + kernel.dtype, + ], + ) result_tmp = conv2d(result_tmp, kernel, groups=result_tmp.shape[-3]) result_tmp = _cast_squeeze_out(result_tmp, need_cast, need_squeeze, out_dtype) @@ -894,7 +912,7 @@ def _blurred_degenerate_image(img: Tensor) -> Tensor: def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor: if sharpness_factor < 0: - raise ValueError('sharpness_factor ({}) is not non-negative.'.format(sharpness_factor)) + raise ValueError("sharpness_factor ({}) is not non-negative.".format(sharpness_factor)) _assert_image_tensor(img) @@ -939,13 +957,11 @@ def _scale_channel(img_chan: Tensor) -> Tensor: hist = torch.bincount(img_chan.view(-1), minlength=256) nonzero_hist = hist[hist != 0] - step = torch.div(nonzero_hist[:-1].sum(), 255, rounding_mode='floor') + step = torch.div(nonzero_hist[:-1].sum(), 255, rounding_mode="floor") if step == 0: return img_chan - lut = torch.div( - torch.cumsum(hist, 0) + torch.div(step, 2, rounding_mode='floor'), - step, rounding_mode='floor') + lut = torch.div(torch.cumsum(hist, 0) + torch.div(step, 2, rounding_mode="floor"), step, rounding_mode="floor") lut = torch.nn.functional.pad(lut, [1, 0])[:-1].clamp(0, 255) return lut[img_chan.to(torch.int64)].to(torch.uint8) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 9310f48450a..e920f92c91a 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -17,12 +17,45 @@ from .functional import InterpolationMode, _interpolation_modes_from_int -__all__ = ["Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImage", "Normalize", "Resize", "Scale", - "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", - "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", - "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", - "RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize", - "RandomSolarize", "RandomAdjustSharpness", "RandomAutocontrast", "RandomEqualize"] +__all__ = [ + "Compose", + "ToTensor", + "PILToTensor", + "ConvertImageDtype", + "ToPILImage", + "Normalize", + "Resize", + "Scale", + "CenterCrop", + "Pad", + "Lambda", + "RandomApply", + "RandomChoice", + "RandomOrder", + "RandomCrop", + "RandomHorizontalFlip", + "RandomVerticalFlip", + "RandomResizedCrop", + "RandomSizedCrop", + "FiveCrop", + "TenCrop", + "LinearTransformation", + "ColorJitter", + "RandomRotation", + "RandomAffine", + "Grayscale", + "RandomGrayscale", + "RandomPerspective", + "RandomErasing", + "GaussianBlur", + "InterpolationMode", + "RandomInvert", + "RandomPosterize", + "RandomSolarize", + "RandomAdjustSharpness", + "RandomAutocontrast", + "RandomEqualize", +] class Compose: @@ -62,11 +95,11 @@ def __call__(self, img): return img def __repr__(self): - format_string = self.__class__.__name__ + '(' + format_string = self.__class__.__name__ + "(" for t in self.transforms: - format_string += '\n' - format_string += ' {0}'.format(t) - format_string += '\n)' + format_string += "\n" + format_string += " {0}".format(t) + format_string += "\n)" return format_string @@ -98,7 +131,7 @@ def __call__(self, pic): return F.to_tensor(pic) def __repr__(self): - return self.__class__.__name__ + '()' + return self.__class__.__name__ + "()" class PILToTensor: @@ -118,7 +151,7 @@ def __call__(self, pic): return F.pil_to_tensor(pic) def __repr__(self): - return self.__class__.__name__ + '()' + return self.__class__.__name__ + "()" class ConvertImageDtype(torch.nn.Module): @@ -165,6 +198,7 @@ class ToPILImage: .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes """ + def __init__(self, mode=None): self.mode = mode @@ -180,10 +214,10 @@ def __call__(self, pic): return F.to_pil_image(pic, self.mode) def __repr__(self): - format_string = self.__class__.__name__ + '(' + format_string = self.__class__.__name__ + "(" if self.mode is not None: - format_string += 'mode={0}'.format(self.mode) - format_string += ')' + format_string += "mode={0}".format(self.mode) + format_string += ")" return format_string @@ -222,7 +256,7 @@ def forward(self, tensor: Tensor) -> Tensor: return F.normalize(tensor, self.mean, self.std, self.inplace) def __repr__(self): - return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) + return self.__class__.__name__ + "(mean={0}, std={1})".format(self.mean, self.std) class Resize(torch.nn.Module): @@ -301,17 +335,20 @@ def forward(self, img): def __repr__(self): interpolate_str = self.interpolation.value - return self.__class__.__name__ + '(size={0}, interpolation={1}, max_size={2}, antialias={3})'.format( - self.size, interpolate_str, self.max_size, self.antialias) + return self.__class__.__name__ + "(size={0}, interpolation={1}, max_size={2}, antialias={3})".format( + self.size, interpolate_str, self.max_size, self.antialias + ) class Scale(Resize): """ Note: This transform is deprecated in favor of Resize. """ + def __init__(self, *args, **kwargs): - warnings.warn("The use of the transforms.Scale transform is deprecated, " + - "please use transforms.Resize instead.") + warnings.warn( + "The use of the transforms.Scale transform is deprecated, " + "please use transforms.Resize instead." + ) super(Scale, self).__init__(*args, **kwargs) @@ -342,7 +379,7 @@ def forward(self, img): return F.center_crop(img, self.size) def __repr__(self): - return self.__class__.__name__ + '(size={0})'.format(self.size) + return self.__class__.__name__ + "(size={0})".format(self.size) class Pad(torch.nn.Module): @@ -395,8 +432,9 @@ def __init__(self, padding, fill=0, padding_mode="constant"): raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") if isinstance(padding, Sequence) and len(padding) not in [1, 2, 4]: - raise ValueError("Padding must be an int or a 1, 2, or 4 element tuple, not a " + - "{} element tuple".format(len(padding))) + raise ValueError( + "Padding must be an int or a 1, 2, or 4 element tuple, not a " + "{} element tuple".format(len(padding)) + ) self.padding = padding self.fill = fill @@ -413,8 +451,9 @@ def forward(self, img): return F.pad(img, self.padding, self.fill, self.padding_mode) def __repr__(self): - return self.__class__.__name__ + '(padding={0}, fill={1}, padding_mode={2})'.\ - format(self.padding, self.fill, self.padding_mode) + return self.__class__.__name__ + "(padding={0}, fill={1}, padding_mode={2})".format( + self.padding, self.fill, self.padding_mode + ) class Lambda: @@ -433,7 +472,7 @@ def __call__(self, img): return self.lambd(img) def __repr__(self): - return self.__class__.__name__ + '()' + return self.__class__.__name__ + "()" class RandomTransforms: @@ -452,11 +491,11 @@ def __call__(self, *args, **kwargs): raise NotImplementedError() def __repr__(self): - format_string = self.__class__.__name__ + '(' + format_string = self.__class__.__name__ + "(" for t in self.transforms: - format_string += '\n' - format_string += ' {0}'.format(t) - format_string += '\n)' + format_string += "\n" + format_string += " {0}".format(t) + format_string += "\n)" return format_string @@ -493,18 +532,18 @@ def forward(self, img): return img def __repr__(self): - format_string = self.__class__.__name__ + '(' - format_string += '\n p={}'.format(self.p) + format_string = self.__class__.__name__ + "(" + format_string += "\n p={}".format(self.p) for t in self.transforms: - format_string += '\n' - format_string += ' {0}'.format(t) - format_string += '\n)' + format_string += "\n" + format_string += " {0}".format(t) + format_string += "\n)" return format_string class RandomOrder(RandomTransforms): - """Apply a list of transformations in a random order. This transform does not support torchscript. - """ + """Apply a list of transformations in a random order. This transform does not support torchscript.""" + def __call__(self, img): order = list(range(len(self.transforms))) random.shuffle(order) @@ -514,8 +553,8 @@ def __call__(self, img): class RandomChoice(RandomTransforms): - """Apply single transformation randomly picked from a list. This transform does not support torchscript. - """ + """Apply single transformation randomly picked from a list. This transform does not support torchscript.""" + def __init__(self, transforms, p=None): super().__init__(transforms) if p is not None and not isinstance(p, Sequence): @@ -528,7 +567,7 @@ def __call__(self, *args): def __repr__(self): format_string = super().__repr__() - format_string += '(p={0})'.format(self.p) + format_string += "(p={0})".format(self.p) return format_string @@ -591,23 +630,19 @@ def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int th, tw = output_size if h + 1 < th or w + 1 < tw: - raise ValueError( - "Required crop size {} is larger then input image size {}".format((th, tw), (h, w)) - ) + raise ValueError("Required crop size {} is larger then input image size {}".format((th, tw), (h, w))) if w == tw and h == th: return 0, 0, h, w - i = torch.randint(0, h - th + 1, size=(1, )).item() - j = torch.randint(0, w - tw + 1, size=(1, )).item() + i = torch.randint(0, h - th + 1, size=(1,)).item() + j = torch.randint(0, w - tw + 1, size=(1,)).item() return i, j, th, tw def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode="constant"): super().__init__() - self.size = tuple(_setup_size( - size, error_msg="Please provide only two dimensions (h, w) for size." - )) + self.size = tuple(_setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")) self.padding = padding self.pad_if_needed = pad_if_needed @@ -670,7 +705,7 @@ def forward(self, img): return img def __repr__(self): - return self.__class__.__name__ + '(p={})'.format(self.p) + return self.__class__.__name__ + "(p={})".format(self.p) class RandomVerticalFlip(torch.nn.Module): @@ -700,7 +735,7 @@ def forward(self, img): return img def __repr__(self): - return self.__class__.__name__ + '(p={})'.format(self.p) + return self.__class__.__name__ + "(p={})".format(self.p) class RandomPerspective(torch.nn.Module): @@ -780,27 +815,27 @@ def get_params(width: int, height: int, distortion_scale: float) -> Tuple[List[L half_height = height // 2 half_width = width // 2 topleft = [ - int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1, )).item()), - int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1, )).item()) + int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()), + int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()), ] topright = [ - int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1, )).item()), - int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1, )).item()) + int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()), + int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()), ] botright = [ - int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1, )).item()), - int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1, )).item()) + int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()), + int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()), ] botleft = [ - int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1, )).item()), - int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1, )).item()) + int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()), + int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()), ] startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]] endpoints = [topleft, topright, botright, botleft] return startpoints, endpoints def __repr__(self): - return self.__class__.__name__ + '(p={})'.format(self.p) + return self.__class__.__name__ + "(p={})".format(self.p) class RandomResizedCrop(torch.nn.Module): @@ -832,7 +867,7 @@ class RandomResizedCrop(torch.nn.Module): """ - def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=InterpolationMode.BILINEAR): + def __init__(self, size, scale=(0.08, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0), interpolation=InterpolationMode.BILINEAR): super().__init__() self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") @@ -856,9 +891,7 @@ def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolat self.ratio = ratio @staticmethod - def get_params( - img: Tensor, scale: List[float], ratio: List[float] - ) -> Tuple[int, int, int, int]: + def get_params(img: Tensor, scale: List[float], ratio: List[float]) -> Tuple[int, int, int, int]: """Get parameters for ``crop`` for a random sized crop. Args: @@ -876,9 +909,7 @@ def get_params( log_ratio = torch.log(torch.tensor(ratio)) for _ in range(10): target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() - aspect_ratio = torch.exp( - torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) - ).item() + aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item() w = int(round(math.sqrt(target_area * aspect_ratio))) h = int(round(math.sqrt(target_area / aspect_ratio))) @@ -916,10 +947,10 @@ def forward(self, img): def __repr__(self): interpolate_str = self.interpolation.value - format_string = self.__class__.__name__ + '(size={0}'.format(self.size) - format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale)) - format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio)) - format_string += ', interpolation={0})'.format(interpolate_str) + format_string = self.__class__.__name__ + "(size={0}".format(self.size) + format_string += ", scale={0}".format(tuple(round(s, 4) for s in self.scale)) + format_string += ", ratio={0}".format(tuple(round(r, 4) for r in self.ratio)) + format_string += ", interpolation={0})".format(interpolate_str) return format_string @@ -927,9 +958,12 @@ class RandomSizedCrop(RandomResizedCrop): """ Note: This transform is deprecated in favor of RandomResizedCrop. """ + def __init__(self, *args, **kwargs): - warnings.warn("The use of the transforms.RandomSizedCrop transform is deprecated, " + - "please use transforms.RandomResizedCrop instead.") + warnings.warn( + "The use of the transforms.RandomSizedCrop transform is deprecated, " + + "please use transforms.RandomResizedCrop instead." + ) super(RandomSizedCrop, self).__init__(*args, **kwargs) @@ -976,7 +1010,7 @@ def forward(self, img): return F.five_crop(img, self.size) def __repr__(self): - return self.__class__.__name__ + '(size={0})'.format(self.size) + return self.__class__.__name__ + "(size={0})".format(self.size) class TenCrop(torch.nn.Module): @@ -1025,7 +1059,7 @@ def forward(self, img): return F.ten_crop(img, self.size, self.vertical_flip) def __repr__(self): - return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip) + return self.__class__.__name__ + "(size={0}, vertical_flip={1})".format(self.size, self.vertical_flip) class LinearTransformation(torch.nn.Module): @@ -1050,17 +1084,25 @@ class LinearTransformation(torch.nn.Module): def __init__(self, transformation_matrix, mean_vector): super().__init__() if transformation_matrix.size(0) != transformation_matrix.size(1): - raise ValueError("transformation_matrix should be square. Got " + - "[{} x {}] rectangular matrix.".format(*transformation_matrix.size())) + raise ValueError( + "transformation_matrix should be square. Got " + + "[{} x {}] rectangular matrix.".format(*transformation_matrix.size()) + ) if mean_vector.size(0) != transformation_matrix.size(0): - raise ValueError("mean_vector should have the same length {}".format(mean_vector.size(0)) + - " as any one of the dimensions of the transformation_matrix [{}]" - .format(tuple(transformation_matrix.size()))) + raise ValueError( + "mean_vector should have the same length {}".format(mean_vector.size(0)) + + " as any one of the dimensions of the transformation_matrix [{}]".format( + tuple(transformation_matrix.size()) + ) + ) if transformation_matrix.device != mean_vector.device: - raise ValueError("Input tensors should be on the same device. Got {} and {}" - .format(transformation_matrix.device, mean_vector.device)) + raise ValueError( + "Input tensors should be on the same device. Got {} and {}".format( + transformation_matrix.device, mean_vector.device + ) + ) self.transformation_matrix = transformation_matrix self.mean_vector = mean_vector @@ -1076,13 +1118,17 @@ def forward(self, tensor: Tensor) -> Tensor: shape = tensor.shape n = shape[-3] * shape[-2] * shape[-1] if n != self.transformation_matrix.shape[0]: - raise ValueError("Input tensor and transformation matrix have incompatible shape." + - "[{} x {} x {}] != ".format(shape[-3], shape[-2], shape[-1]) + - "{}".format(self.transformation_matrix.shape[0])) + raise ValueError( + "Input tensor and transformation matrix have incompatible shape." + + "[{} x {} x {}] != ".format(shape[-3], shape[-2], shape[-1]) + + "{}".format(self.transformation_matrix.shape[0]) + ) if tensor.device.type != self.mean_vector.device.type: - raise ValueError("Input tensor should be on the same device as transformation matrix and mean vector. " - "Got {} vs {}".format(tensor.device, self.mean_vector.device)) + raise ValueError( + "Input tensor should be on the same device as transformation matrix and mean vector. " + "Got {} vs {}".format(tensor.device, self.mean_vector.device) + ) flat_tensor = tensor.view(-1, n) - self.mean_vector transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix) @@ -1090,9 +1136,9 @@ def forward(self, tensor: Tensor) -> Tensor: return tensor def __repr__(self): - format_string = self.__class__.__name__ + '(transformation_matrix=' - format_string += (str(self.transformation_matrix.tolist()) + ')') - format_string += (", (mean_vector=" + str(self.mean_vector.tolist()) + ')') + format_string = self.__class__.__name__ + "(transformation_matrix=" + format_string += str(self.transformation_matrix.tolist()) + ")" + format_string += ", (mean_vector=" + str(self.mean_vector.tolist()) + ")" return format_string @@ -1119,14 +1165,13 @@ class ColorJitter(torch.nn.Module): def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): super().__init__() - self.brightness = self._check_input(brightness, 'brightness') - self.contrast = self._check_input(contrast, 'contrast') - self.saturation = self._check_input(saturation, 'saturation') - self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), - clip_first_on_zero=False) + self.brightness = self._check_input(brightness, "brightness") + self.contrast = self._check_input(contrast, "contrast") + self.saturation = self._check_input(saturation, "saturation") + self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False) @torch.jit.unused - def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): + def _check_input(self, value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True): if isinstance(value, numbers.Number): if value < 0: raise ValueError("If {} is a single number, it must be non negative.".format(name)) @@ -1146,11 +1191,12 @@ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_firs return value @staticmethod - def get_params(brightness: Optional[List[float]], - contrast: Optional[List[float]], - saturation: Optional[List[float]], - hue: Optional[List[float]] - ) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]: + def get_params( + brightness: Optional[List[float]], + contrast: Optional[List[float]], + saturation: Optional[List[float]], + hue: Optional[List[float]], + ) -> Tuple[Tensor, Optional[float], Optional[float], Optional[float], Optional[float]]: """Get the parameters for the randomized transform to be applied on image. Args: @@ -1184,8 +1230,9 @@ def forward(self, img): Returns: PIL Image or Tensor: Color jittered image. """ - fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = \ - self.get_params(self.brightness, self.contrast, self.saturation, self.hue) + fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params( + self.brightness, self.contrast, self.saturation, self.hue + ) for fn_id in fn_idx: if fn_id == 0 and brightness_factor is not None: @@ -1200,11 +1247,11 @@ def forward(self, img): return img def __repr__(self): - format_string = self.__class__.__name__ + '(' - format_string += 'brightness={0}'.format(self.brightness) - format_string += ', contrast={0}'.format(self.contrast) - format_string += ', saturation={0}'.format(self.saturation) - format_string += ', hue={0})'.format(self.hue) + format_string = self.__class__.__name__ + "(" + format_string += "brightness={0}".format(self.brightness) + format_string += ", contrast={0}".format(self.contrast) + format_string += ", saturation={0}".format(self.saturation) + format_string += ", hue={0})".format(self.hue) return format_string @@ -1254,10 +1301,10 @@ def __init__( ) interpolation = _interpolation_modes_from_int(interpolation) - self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, )) + self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) if center is not None: - _check_sequence_input(center, "center", req_sizes=(2, )) + _check_sequence_input(center, "center", req_sizes=(2,)) self.center = center @@ -1301,14 +1348,14 @@ def forward(self, img): def __repr__(self): interpolate_str = self.interpolation.value - format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees) - format_string += ', interpolation={0}'.format(interpolate_str) - format_string += ', expand={0}'.format(self.expand) + format_string = self.__class__.__name__ + "(degrees={0}".format(self.degrees) + format_string += ", interpolation={0}".format(interpolate_str) + format_string += ", expand={0}".format(self.expand) if self.center is not None: - format_string += ', center={0}'.format(self.center) + format_string += ", center={0}".format(self.center) if self.fill is not None: - format_string += ', fill={0}'.format(self.fill) - format_string += ')' + format_string += ", fill={0}".format(self.fill) + format_string += ")" return format_string @@ -1349,8 +1396,15 @@ class RandomAffine(torch.nn.Module): """ def __init__( - self, degrees, translate=None, scale=None, shear=None, interpolation=InterpolationMode.NEAREST, fill=0, - fillcolor=None, resample=None + self, + degrees, + translate=None, + scale=None, + shear=None, + interpolation=InterpolationMode.NEAREST, + fill=0, + fillcolor=None, + resample=None, ): super().__init__() if resample is not None: @@ -1373,17 +1427,17 @@ def __init__( ) fill = fillcolor - self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, )) + self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) if translate is not None: - _check_sequence_input(translate, "translate", req_sizes=(2, )) + _check_sequence_input(translate, "translate", req_sizes=(2,)) for t in translate: if not (0.0 <= t <= 1.0): raise ValueError("translation values should be between 0 and 1") self.translate = translate if scale is not None: - _check_sequence_input(scale, "scale", req_sizes=(2, )) + _check_sequence_input(scale, "scale", req_sizes=(2,)) for s in scale: if s <= 0: raise ValueError("scale values should be positive") @@ -1405,11 +1459,11 @@ def __init__( @staticmethod def get_params( - degrees: List[float], - translate: Optional[List[float]], - scale_ranges: Optional[List[float]], - shears: Optional[List[float]], - img_size: List[int] + degrees: List[float], + translate: Optional[List[float]], + scale_ranges: Optional[List[float]], + shears: Optional[List[float]], + img_size: List[int], ) -> Tuple[float, Tuple[int, int], float, Tuple[float, float]]: """Get parameters for affine transformation @@ -1462,20 +1516,20 @@ def forward(self, img): return F.affine(img, *ret, interpolation=self.interpolation, fill=fill) def __repr__(self): - s = '{name}(degrees={degrees}' + s = "{name}(degrees={degrees}" if self.translate is not None: - s += ', translate={translate}' + s += ", translate={translate}" if self.scale is not None: - s += ', scale={scale}' + s += ", scale={scale}" if self.shear is not None: - s += ', shear={shear}' + s += ", shear={shear}" if self.interpolation != InterpolationMode.NEAREST: - s += ', interpolation={interpolation}' + s += ", interpolation={interpolation}" if self.fill != 0: - s += ', fill={fill}' - s += ')' + s += ", fill={fill}" + s += ")" d = dict(self.__dict__) - d['interpolation'] = self.interpolation.value + d["interpolation"] = self.interpolation.value return s.format(name=self.__class__.__name__, **d) @@ -1510,7 +1564,7 @@ def forward(self, img): return F.rgb_to_grayscale(img, num_output_channels=self.num_output_channels) def __repr__(self): - return self.__class__.__name__ + '(num_output_channels={0})'.format(self.num_output_channels) + return self.__class__.__name__ + "(num_output_channels={0})".format(self.num_output_channels) class RandomGrayscale(torch.nn.Module): @@ -1547,11 +1601,11 @@ def forward(self, img): return img def __repr__(self): - return self.__class__.__name__ + '(p={0})'.format(self.p) + return self.__class__.__name__ + "(p={0})".format(self.p) class RandomErasing(torch.nn.Module): - """ Randomly selects a rectangle region in an torch Tensor image and erases its pixels. + """Randomly selects a rectangle region in an torch Tensor image and erases its pixels. This transform does not support PIL Image. 'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896 @@ -1603,7 +1657,7 @@ def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace @staticmethod def get_params( - img: Tensor, scale: Tuple[float, float], ratio: Tuple[float, float], value: Optional[List[float]] = None + img: Tensor, scale: Tuple[float, float], ratio: Tuple[float, float], value: Optional[List[float]] = None ) -> Tuple[int, int, int, int, Tensor]: """Get parameters for ``erase`` for a random erasing. @@ -1624,9 +1678,7 @@ def get_params( log_ratio = torch.log(torch.tensor(ratio)) for _ in range(10): erase_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() - aspect_ratio = torch.exp( - torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) - ).item() + aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item() h = int(round(math.sqrt(erase_area * aspect_ratio))) w = int(round(math.sqrt(erase_area / aspect_ratio))) @@ -1638,8 +1690,8 @@ def get_params( else: v = torch.tensor(value)[:, None, None] - i = torch.randint(0, img_h - h + 1, size=(1, )).item() - j = torch.randint(0, img_w - w + 1, size=(1, )).item() + i = torch.randint(0, img_h - h + 1, size=(1,)).item() + j = torch.randint(0, img_w - w + 1, size=(1,)).item() return i, j, h, w, v # Return original image @@ -1657,7 +1709,9 @@ def forward(self, img): # cast self.value to script acceptable type if isinstance(self.value, (int, float)): - value = [self.value, ] + value = [ + self.value, + ] elif isinstance(self.value, str): value = None elif isinstance(self.value, tuple): @@ -1676,11 +1730,11 @@ def forward(self, img): return img def __repr__(self): - s = '(p={}, '.format(self.p) - s += 'scale={}, '.format(self.scale) - s += 'ratio={}, '.format(self.ratio) - s += 'value={}, '.format(self.value) - s += 'inplace={})'.format(self.inplace) + s = "(p={}, ".format(self.p) + s += "scale={}, ".format(self.scale) + s += "ratio={}, ".format(self.ratio) + s += "value={}, ".format(self.value) + s += "inplace={})".format(self.inplace) return self.__class__.__name__ + s @@ -1713,7 +1767,7 @@ def __init__(self, kernel_size, sigma=(0.1, 2.0)): raise ValueError("If sigma is a single number, it must be positive.") sigma = (sigma, sigma) elif isinstance(sigma, Sequence) and len(sigma) == 2: - if not 0. < sigma[0] <= sigma[1]: + if not 0.0 < sigma[0] <= sigma[1]: raise ValueError("sigma values should be positive and of the form (min, max).") else: raise ValueError("sigma should be a single number or a list/tuple with length 2.") @@ -1745,8 +1799,8 @@ def forward(self, img: Tensor) -> Tensor: return F.gaussian_blur(img, self.kernel_size, [sigma, sigma]) def __repr__(self): - s = '(kernel_size={}, '.format(self.kernel_size) - s += 'sigma={})'.format(self.sigma) + s = "(kernel_size={}, ".format(self.kernel_size) + s += "sigma={})".format(self.sigma) return self.__class__.__name__ + s @@ -1771,7 +1825,7 @@ def _check_sequence_input(x, name, req_sizes): raise ValueError("{} should be sequence of length {}.".format(name, msg)) -def _setup_angle(x, name, req_sizes=(2, )): +def _setup_angle(x, name, req_sizes=(2,)): if isinstance(x, numbers.Number): if x < 0: raise ValueError("If {} is a single number, it must be positive.".format(name)) @@ -1809,7 +1863,7 @@ def forward(self, img): return img def __repr__(self): - return self.__class__.__name__ + '(p={})'.format(self.p) + return self.__class__.__name__ + "(p={})".format(self.p) class RandomPosterize(torch.nn.Module): @@ -1841,7 +1895,7 @@ def forward(self, img): return img def __repr__(self): - return self.__class__.__name__ + '(bits={},p={})'.format(self.bits, self.p) + return self.__class__.__name__ + "(bits={},p={})".format(self.bits, self.p) class RandomSolarize(torch.nn.Module): @@ -1873,7 +1927,7 @@ def forward(self, img): return img def __repr__(self): - return self.__class__.__name__ + '(threshold={},p={})'.format(self.threshold, self.p) + return self.__class__.__name__ + "(threshold={},p={})".format(self.threshold, self.p) class RandomAdjustSharpness(torch.nn.Module): @@ -1905,7 +1959,7 @@ def forward(self, img): return img def __repr__(self): - return self.__class__.__name__ + '(sharpness_factor={},p={})'.format(self.sharpness_factor, self.p) + return self.__class__.__name__ + "(sharpness_factor={},p={})".format(self.sharpness_factor, self.p) class RandomAutocontrast(torch.nn.Module): @@ -1935,7 +1989,7 @@ def forward(self, img): return img def __repr__(self): - return self.__class__.__name__ + '(p={})'.format(self.p) + return self.__class__.__name__ + "(p={})".format(self.p) class RandomEqualize(torch.nn.Module): @@ -1965,4 +2019,4 @@ def forward(self, img): return img def __repr__(self): - return self.__class__.__name__ + '(p={})'.format(self.p) + return self.__class__.__name__ + "(p={})".format(self.p) diff --git a/torchvision/utils.py b/torchvision/utils.py index 494661e6ad8..6e8d63b1d7e 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -1,9 +1,10 @@ -from typing import Union, Optional, List, Tuple, Text, BinaryIO -import pathlib -import torch import math +import pathlib import warnings +from typing import Union, Optional, List, Tuple, Text, BinaryIO + import numpy as np +import torch from PIL import Image, ImageDraw, ImageFont, ImageColor __all__ = ["make_grid", "save_image", "draw_bounding_boxes", "draw_segmentation_masks"] @@ -18,7 +19,7 @@ def make_grid( value_range: Optional[Tuple[int, int]] = None, scale_each: bool = False, pad_value: int = 0, - **kwargs + **kwargs, ) -> torch.Tensor: """ Make a grid of images. @@ -41,9 +42,8 @@ def make_grid( Returns: grid (Tensor): the tensor containing grid of images. """ - if not (torch.is_tensor(tensor) or - (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): - raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}') + if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): + raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}") if "range" in kwargs.keys(): warning = "range will be deprecated, please use value_range instead." @@ -67,8 +67,9 @@ def make_grid( if normalize is True: tensor = tensor.clone() # avoid modifying tensor in-place if value_range is not None: - assert isinstance(value_range, tuple), \ - "value_range has to be a tuple (min, max) if specified. min and max are numbers" + assert isinstance( + value_range, tuple + ), "value_range has to be a tuple (min, max) if specified. min and max are numbers" def norm_ip(img, low, high): img.clamp_(min=low, max=high) @@ -115,7 +116,7 @@ def save_image( tensor: Union[torch.Tensor, List[torch.Tensor]], fp: Union[Text, pathlib.Path, BinaryIO], format: Optional[str] = None, - **kwargs + **kwargs, ) -> None: """ Save a given Tensor into an image file. @@ -131,7 +132,7 @@ def save_image( grid = make_grid(tensor, **kwargs) # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer - ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() + ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() im = Image.fromarray(ndarr) im.save(fp, format=format) @@ -145,7 +146,7 @@ def draw_bounding_boxes( fill: Optional[bool] = False, width: int = 1, font: Optional[str] = None, - font_size: int = 10 + font_size: int = 10, ) -> torch.Tensor: """