From 3864472c3f2371604fb063cc4a6c836f2dfaa1ab Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Wed, 26 Jul 2023 13:35:02 -0400 Subject: [PATCH] Move env util --- .../common_utils/case_utils.py | 26 +++---------------- torchaudio/_internal/module_utils.py | 20 ++++++++++++++ 2 files changed, 24 insertions(+), 22 deletions(-) diff --git a/test/torchaudio_unittest/common_utils/case_utils.py b/test/torchaudio_unittest/common_utils/case_utils.py index ad760afa53..1eca025a9c 100644 --- a/test/torchaudio_unittest/common_utils/case_utils.py +++ b/test/torchaudio_unittest/common_utils/case_utils.py @@ -11,7 +11,7 @@ import torch import torchaudio from torch.testing._internal.common_utils import TestCase as PytorchTestCase -from torchaudio._internal.module_utils import is_module_available +from torchaudio._internal.module_utils import eval_env, is_module_available from torchaudio.utils.ffmpeg_utils import get_video_decoders, get_video_encoders from .backend_utils import set_audio_backend @@ -143,24 +143,6 @@ def is_cuda_ctc_decoder_available(): return _IS_CUDA_CTC_DECODER_AVAILABLE -def _eval_env(var, default): - if var not in os.environ: - return default - - val = os.environ.get(var, "0") - trues = ["1", "true", "TRUE", "on", "ON", "yes", "YES"] - falses = ["0", "false", "FALSE", "off", "OFF", "no", "NO"] - if val in trues: - return True - if val not in falses: - # fmt: off - raise RuntimeError( - f"Unexpected environment variable value `{var}={val}`. " - f"Expected one of {trues + falses}") - # fmt: on - return False - - def _fail(reason): def deco(test_item): if isinstance(test_item, type): @@ -185,7 +167,7 @@ def _pass(test_item): return test_item -_IN_CI = _eval_env("CI", default=False) +_IN_CI = eval_env("CI", default=False) def _skipIf(condition, reason, key): @@ -195,7 +177,7 @@ def _skipIf(condition, reason, key): # In CI, default to fail, so as to prevent accidental skip. # In other env, default to skip var = f"TORCHAUDIO_TEST_ALLOW_SKIP_IF_{key}" - skip_allowed = _eval_env(var, default=not _IN_CI) + skip_allowed = eval_env(var, default=not _IN_CI) if skip_allowed: return unittest.skip(reason) return _fail(f"{reason} But the test cannot be skipped. (CI={_IN_CI}, {var}={skip_allowed}.)") @@ -268,7 +250,7 @@ def skipIfNoSoxEncoder(ext): key="NO_CUCTC_DECODER", ) skipIfRocm = _skipIf( - _eval_env("TORCHAUDIO_TEST_WITH_ROCM", default=False), + eval_env("TORCHAUDIO_TEST_WITH_ROCM", default=False), reason="The test doesn't currently work on the ROCm stack.", key="ON_ROCM", ) diff --git a/torchaudio/_internal/module_utils.py b/torchaudio/_internal/module_utils.py index d5ab186b0c..e0fb541fc1 100644 --- a/torchaudio/_internal/module_utils.py +++ b/torchaudio/_internal/module_utils.py @@ -1,9 +1,29 @@ import importlib.util +import os import warnings from functools import wraps from typing import Optional +def eval_env(var, default): + """Check if environment varable has True-y value""" + if var not in os.environ: + return default + + val = os.environ.get(var, "0") + trues = ["1", "true", "TRUE", "on", "ON", "yes", "YES"] + falses = ["0", "false", "FALSE", "off", "OFF", "no", "NO"] + if val in trues: + return True + if val not in falses: + # fmt: off + raise RuntimeError( + f"Unexpected environment variable value `{var}={val}`. " + f"Expected one of {trues + falses}") + # fmt: on + return False + + def is_module_available(*modules: str) -> bool: r"""Returns if a top-level module with :attr:`name` exists *without** importing it. This is generally safer than try-catch block around a