diff --git a/.github/workflows/codestyle-check.yml b/.github/workflows/codestyle-check.yml index 849a132..142a8e1 100644 --- a/.github/workflows/codestyle-check.yml +++ b/.github/workflows/codestyle-check.yml @@ -22,9 +22,9 @@ jobs: - name: Install dependencies run: | pip install -r requirements.txt - pip install -r codestyle_requirements.txt + pip install pre-commit poetry install - name: Run codestyle-check run: | - chmod +x codestyle-format.sh - ./codestyle-format.sh all + pre-commit install + pre-commit run diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..fd32058 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,9 @@ +repos: +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.1.4 + hooks: + - id: ruff + types_or: [python, pyi, jupyter] + args: [--fix] + - id: ruff-format + types_or: [python, pyi, jupyter] diff --git a/.pylintrc b/.pylintrc deleted file mode 100644 index cd3dfe0..0000000 --- a/.pylintrc +++ /dev/null @@ -1,522 +0,0 @@ -[MASTER] - -# A comma-separated list of package or module names from where C extensions may -# be loaded. Extensions are loading into the active Python interpreter and may -# run arbitrary code. -extension-pkg-whitelist=['cv2'] - -# Add files or directories to the blacklist. They should be base names, not -# paths. -ignore=CVS,_internal - -# Add files or directories matching the regex patterns to the blacklist. The -# regex matches against base names, not paths. -ignore-patterns= - -# Python code to execute, usually for sys.path manipulation such as -# pygtk.require(). -#init-hook= - -# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the -# number of processors available to use. -jobs=1 - -# Control the amount of potential inferred values when inferring a single -# object. This can help the performance when dealing with large functions or -# complex, nested conditions. -limit-inference-results=100 - -# List of plugins (as comma separated values of python module names) to load, -# usually to register additional checkers. -load-plugins= - -# Pickle collected data for later comparisons. -persistent=yes - -# Specify a configuration file. -#rcfile= - -# When enabled, pylint would attempt to guess common misconfiguration and emit -# user-friendly hints instead of false-positive error messages. -suggestion-mode=yes - -# Allow loading of arbitrary C extensions. Extensions are imported into the -# active Python interpreter and may run arbitrary code. -unsafe-load-any-extension=no - - -[MESSAGES CONTROL] - -# Only show warnings with the listed confidence levels. Leave empty to show -# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED. -confidence= - -# Disable the message, report, category or checker with the given id(s). You -# can either give multiple identifiers separated by comma (,) or put this -# option multiple times (only on the command line, not in the configuration -# file where it should appear only once). You can also use "--disable=all" to -# disable everything first and then reenable specific checks. For example, if -# you want to run only the similarities checker, you can use "--disable=all -# --enable=similarities". If you want to run only the classes checker, but have -# no Warning level messages displayed, use "--disable=all --enable=classes -# --disable=W". -disable=no-else-return, - len-as-condition, - redefined-builtin, # TBD - arguments-differ, # TBD - invalid-name, # TBD - too-many-locals, # TBD - too-many-arguments, # TBD - too-many-instance-attributes, # TBD - too-few-public-methods, # TBD - protected-access, # TBD - cyclic-import, # TBD - no-else-raise, # TBD - line-too-long, # Although We ensure line length by YAPF, there are some cases YAPF won't split lines (e.g. comments). We just ignore these cases. - fixme, # Will be enable after a while - duplicate-code, # Will be enable after a while - missing-function-docstring, # Will be enable after a while - missing-module-docstring, # Will be enable after a while - missing-class-docstring, # Will be enable after a while - singleton-comparison, - raw-checker-failed, - bad-inline-option, - locally-disabled, - file-ignored, - suppressed-message, - useless-suppression, - deprecated-pragma, - use-symbolic-message-instead, - attribute-defined-outside-init, - logging-format-interpolation, - unsubscriptable-object, # pylint bug for numpy array - abstract-class-instantiated, # FIXME: Abstract class 'Tensor' with abstract methods instantiated - arguments-renamed, - consider-using-f-string, - use-dict-literal - -# Enable the message, report, category or checker with the given id(s). You can -# either give multiple identifier separated by comma (,) or put this option -# multiple time (only on the command line, not in the configuration file where -# it should appear only once). See also the "--disable" option for examples. -enable=c-extension-no-member - - -[REPORTS] - -# Python expression which should return a score less than or equal to 10. You -# have access to the variables 'error', 'warning', 'refactor', and 'convention' -# which contain the number of messages in each category, as well as 'statement' -# which is the total number of statements analyzed. This score is used by the -# global evaluation report (RP0004). -evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) - -# Template used to display messages. This is a python new-style format string -# used to format the message information. See doc for all details. -#msg-template= - -# Set the output format. Available formats are text, parseable, colorized, json -# and msvs (visual studio). You can also give a reporter class, e.g. -# mypackage.mymodule.MyReporterClass. -output-format=text - -# Tells whether to display a full report or only the messages. -reports=no - -# Activate the evaluation score. -score=yes - - -[REFACTORING] - -# Maximum number of nested blocks for function / method body -max-nested-blocks=5 - -# Complete name of functions that never returns. When checking for -# inconsistent-return-statements if a never returning function is called then -# it will be considered as an explicit return statement and no message will be -# printed. -never-returning-functions=sys.exit - - -[VARIABLES] - -# List of additional names supposed to be defined in builtins. Remember that -# you should avoid defining new builtins when possible. -additional-builtins= - -# Tells whether unused global variables should be treated as a violation. -allow-global-unused-variables=yes - -# List of strings which can identify a callback function by name. A callback -# name must start or end with one of those strings. -callbacks=cb_, - _cb - -# A regular expression matching the name of dummy variables (i.e. expected to -# not be used). -dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ - -# Argument names that match this expression will be ignored. Default to name -# with leading underscore. -ignored-argument-names=_.*|^ignored_|^unused_ - -# Tells whether we should check for unused import in __init__ files. -init-import=no - -# List of qualified module names which can have objects that can redefine -# builtins. -redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io - - -[TYPECHECK] - -# List of decorators that produce context managers, such as -# contextlib.contextmanager. Add to this list to register other decorators that -# produce valid context managers. -contextmanager-decorators=contextlib.contextmanager - -# List of members which are set dynamically and missed by pylint inference -# system, and so shouldn't trigger E1101 when accessed. Python regular -# expressions are accepted. -generated-members=numpy.*, torch.*, cv2.* - -# Tells whether missing members accessed in mixin class should be ignored. A -# mixin class is detected if its name ends with "mixin" (case insensitive). -ignore-mixin-members=yes - -# Tells whether to warn about missing members when the owner of the attribute -# is inferred to be None. -ignore-none=yes - -# This flag controls whether pylint should warn about no-member and similar -# checks whenever an opaque object is returned when inferring. The inference -# can return multiple potential results while evaluating a Python object, but -# some branches might not be evaluated, which results in partial inference. In -# that case, it might be useful to still emit no-member and other checks for -# the rest of the inferred objects. -ignore-on-opaque-inference=yes - -# List of class names for which member attributes should not be checked (useful -# for classes with dynamically set attributes). This supports the use of -# qualified names. -ignored-classes=optparse.Values,thread._local,_thread._local - -# List of module names for which member attributes should not be checked -# (useful for modules/projects where namespaces are manipulated during runtime -# and thus existing member attributes cannot be deduced by static analysis). It -# supports qualified module names, as well as Unix pattern matching. -ignored-modules=numpy.random - -# Show a hint with possible names when a member name was not found. The aspect -# of finding the hint is based on edit distance. -missing-member-hint=yes - -# The minimum edit distance a name should have in order to be considered a -# similar match for a missing member name. -missing-member-hint-distance=1 - -# The total number of similar names that should be taken in consideration when -# showing a hint for a missing member. -missing-member-max-choices=1 - -# List of decorators that change the signature of a decorated function. -signature-mutators= - - -[STRING] - -# This flag controls whether the implicit-str-concat-in-sequence should -# generate a warning on implicit string concatenation in sequences defined over -# several lines. -check-str-concat-over-line-jumps=no - - -[SPELLING] - -# Limits count of emitted suggestions for spelling mistakes. -max-spelling-suggestions=4 - -# Spelling dictionary name. Available dictionaries: none. To make it work, -# install the python-enchant package. -spelling-dict= - -# List of comma separated words that should not be checked. -spelling-ignore-words= - -# A path to a file that contains the private dictionary; one word per line. -spelling-private-dict-file= - -# Tells whether to store unknown words to the private dictionary (see the -# --spelling-private-dict-file option) instead of raising a message. -spelling-store-unknown-words=no - - -[SIMILARITIES] - -# Ignore comments when computing similarities. -ignore-comments=yes - -# Ignore docstrings when computing similarities. -ignore-docstrings=yes - -# Ignore imports when computing similarities. -ignore-imports=no - -# Minimum lines number of a similarity. -min-similarity-lines=4 - - -[MISCELLANEOUS] - -# List of note tags to take in consideration, separated by a comma. -notes=FIXME, - XXX, - TODO - - -[LOGGING] - -# Format style used to check logging format string. `old` means using % -# formatting, `new` is for `{}` formatting,and `fstr` is for f-strings. -logging-format-style=old - -# Logging modules to check that the string format arguments are in logging -# function parameter format. -logging-modules=logging - - -[FORMAT] - -# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. -expected-line-ending-format= - -# Regexp for a line that is allowed to be longer than the limit. -ignore-long-lines=^\s*(# )??$ - -# Number of spaces of indent required inside a hanging or continued line. -indent-after-paren=4 - -# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 -# tab). -indent-string=' ' - -# Maximum number of characters on a single line. -max-line-length=88 - -# Maximum number of lines in a module. -max-module-lines=1000 - -# Allow the body of a class to be on the same line as the declaration if body -# contains single statement. -single-line-class-stmt=no - -# Allow the body of an if to be on the same line as the test if there is no -# else. -single-line-if-stmt=no - - -[BASIC] - -# Naming style matching correct argument names. -argument-naming-style=snake_case - -# Regular expression matching correct argument names. Overrides argument- -# naming-style. -#argument-rgx= - -# Naming style matching correct attribute names. -attr-naming-style=snake_case - -# Regular expression matching correct attribute names. Overrides attr-naming- -# style. -#attr-rgx= - -# Bad variable names which should always be refused, separated by a comma. -bad-names=foo, - baz, - toto, - tutu, - tata - -# Naming style matching correct class attribute names. -class-attribute-naming-style=any - -# Regular expression matching correct class attribute names. Overrides class- -# attribute-naming-style. -#class-attribute-rgx= - -# Naming style matching correct class names. -class-naming-style=PascalCase - -# Regular expression matching correct class names. Overrides class-naming- -# style. -#class-rgx= - -# Naming style matching correct constant names. -const-naming-style=UPPER_CASE - -# Regular expression matching correct constant names. Overrides const-naming- -# style. -#const-rgx= - -# Minimum line length for functions/classes that require docstrings, shorter -# ones are exempt. -docstring-min-length=-1 - -# Naming style matching correct function names. -function-naming-style=snake_case - -# Regular expression matching correct function names. Overrides function- -# naming-style. -#function-rgx= - -# Good variable names which should always be accepted, separated by a comma. -good-names=i, - j, - k, - f, - ex, - Run, - _ - -# Include a hint for the correct naming format with invalid-name. -include-naming-hint=no - -# Naming style matching correct inline iteration names. -inlinevar-naming-style=any - -# Regular expression matching correct inline iteration names. Overrides -# inlinevar-naming-style. -#inlinevar-rgx= - -# Naming style matching correct method names. -method-naming-style=snake_case - -# Regular expression matching correct method names. Overrides method-naming- -# style. -#method-rgx= - -# Naming style matching correct module names. -module-naming-style=snake_case - -# Regular expression matching correct module names. Overrides module-naming- -# style. -#module-rgx= - -# Colon-delimited sets of names that determine each other's naming style when -# the name regexes allow several styles. -name-group= - -# Regular expression which should only match function or class names that do -# not require a docstring. -no-docstring-rgx=^_ - -# List of decorators that produce properties, such as abc.abstractproperty. Add -# to this list to register other decorators that produce valid properties. -# These decorators are taken in consideration only for invalid-name. -property-classes=abc.abstractproperty - -# Naming style matching correct variable names. -variable-naming-style=snake_case - -# Regular expression matching correct variable names. Overrides variable- -# naming-style. -#variable-rgx= - - -[IMPORTS] - -# List of modules that can be imported at any level, not just the top level -# one. -allow-any-import-level= - -# Allow wildcard imports from modules that define __all__. -allow-wildcard-with-all=no - -# Analyse import fallback blocks. This can be used to support both Python 2 and -# 3 compatible code, which means that the block might have code that exists -# only in one or another interpreter, leading to false positives when analysed. -analyse-fallback-blocks=no - -# Deprecated modules which should not be used, separated by a comma. -deprecated-modules=optparse,tkinter.tix - -# Create a graph of external dependencies in the given file (report RP0402 must -# not be disabled). -ext-import-graph= - -# Create a graph of every (i.e. internal and external) dependencies in the -# given file (report RP0402 must not be disabled). -import-graph= - -# Create a graph of internal dependencies in the given file (report RP0402 must -# not be disabled). -int-import-graph= - -# Force import order to recognize a module as part of the standard -# compatibility libraries. -known-standard-library= - -# Force import order to recognize a module as part of a third party library. -known-third-party=enchant - -# Couples of modules and preferred modules, separated by a comma. -preferred-modules= - - -[DESIGN] - -# Maximum number of arguments for function / method. -max-args=5 - -# Maximum number of attributes for a class (see R0902). -max-attributes=7 - -# Maximum number of boolean expressions in an if statement (see R0916). -max-bool-expr=5 - -# Maximum number of branch for function / method body. -max-branches=12 - -# Maximum number of locals for function / method body. -max-locals=15 - -# Maximum number of parents for a class (see R0901). -max-parents=7 - -# Maximum number of public methods for a class (see R0904). -max-public-methods=20 - -# Maximum number of return / yield for function / method body. -max-returns=6 - -# Maximum number of statements in function / method body. -max-statements=50 - -# Minimum number of public methods for a class (see R0903). -min-public-methods=2 - - -[CLASSES] - -# List of method names used to declare (i.e. assign) instance attributes. -defining-attr-methods=__init__, - __new__, - setUp, - __post_init__ - -# List of member names, which should be excluded from the protected access -# warning. -exclude-protected=_asdict, - _fields, - _replace, - _source, - _make - -# List of valid names for the first argument in a class method. -valid-classmethod-first-arg=cls - -# List of valid names for the first argument in a metaclass class method. -valid-metaclass-classmethod-first-arg=cls - diff --git a/codestyle-format.sh b/codestyle-format.sh deleted file mode 100644 index af812c3..0000000 --- a/codestyle-format.sh +++ /dev/null @@ -1,40 +0,0 @@ -#!/usr/bin/env bash - -check=$1 - -if [ "$check" = "all" ]; then - git_files=$(git ls-files) - py_files=$(echo "$git_files" | grep '.py$') -else - py_files=$(git diff --name-only --cached | grep '.py$') -fi - -if [ "$py_files" ]; then - target="__init__.py" - echo "use 'black' to format code" - black $py_files -q - echo "use 'isort' to sort imports" - isort $py_files -q - echo "use 'autoflake' to remove-unused-variables" - autoflake -r --in-place --remove-unused-variables $py_files - - echo "use 'pylint' to check codestyle" - pylint $py_files --rcfile=.pylintrc || pylint_ret=$? - if [ "$pylint_ret" ]; then - echo "'pylint' check failed" - exit $pylint_ret - fi - echo "use 'flake8' to check codestyle" - for file in $py_files - do - if ! [[ $file =~ $target ]]; then - flake8 --max-line-length 100 --max-doc-length 120 $file || flake8_ret=$? - fi - if [ "$flake8_ret" ]; then - echo "'flake8' check failed" - exit $flake8_ret - fi - done -else - echo "No files to format" -fi diff --git a/codestyle_requirements.txt b/codestyle_requirements.txt deleted file mode 100644 index 948ccca..0000000 --- a/codestyle_requirements.txt +++ /dev/null @@ -1,6 +0,0 @@ -flake8==6.0.0 -black==22.1.0 -isort==5.12.0 -autoflake==2.0.2 -click==8.0.4 -pylint==3.0.2 diff --git a/example/run.py b/example/run.py index a36d302..59aa4de 100644 --- a/example/run.py +++ b/example/run.py @@ -11,10 +11,7 @@ def _check_func(values): - for v in values: - if v: - return True - return False + return any(v for v in values) add_logger("file_log", "./run.log") @@ -25,14 +22,10 @@ def _check_func(values): logger.info(MODELS.module_table(select_info=["is_backbone"])) # 打印 is_pretrained 为 true 的 module 和 is_backbone 的信息 filtered_module_name = MODELS.filter("is_pretrained") -logger.info( - MODELS.module_table(select_info=["is_backbone"], module_list=filtered_module_name) -) +logger.info(MODELS.module_table(select_info=["is_backbone"], module_list=filtered_module_name)) # 打印 is_pretrained 为 true 的 module 和 is_pretrained 的信息 filtered_module_name = MODELS.filter("is_pretrained", _check_func) -logger.info( - MODELS.module_table(select_info=["is_pretrained"], module_list=filtered_module_name) -) +logger.info(MODELS.module_table(select_info=["is_pretrained"], module_list=filtered_module_name)) logger.info(Registry.registry_table()) config.set_target_fields(config.AttrNode.target_fields + ["Backbone"]) diff --git a/example/src/hook/hooks.py b/example/src/hook/hooks.py index 43ebe37..a05829b 100644 --- a/example/src/hook/hooks.py +++ b/example/src/hook/hooks.py @@ -1,7 +1,6 @@ -from src import HOOK - from excore import ConfigArgumentHookProtocol from excore.logger import logger +from src import HOOK @HOOK.register() diff --git a/example/src/model/backbone/resnet.py b/example/src/model/backbone/resnet.py index 07ae3f1..5048a61 100644 --- a/example/src/model/backbone/resnet.py +++ b/example/src/model/backbone/resnet.py @@ -11,9 +11,7 @@ def __init__(self, in_chan, out_chan): @MODEL.register(is_pretrained=True, is_backbone=True) class ResNet: - def __init__( - self, in_channel: int, depth: int, block: BasicBlock, layers: List[int] - ): + def __init__(self, in_channel: int, depth: int, block: BasicBlock, layers: List[int]): assert block == BasicBlock self.in_channel = in_channel self.depth = depth diff --git a/example/src/optim/op.py b/example/src/optim/op.py index 490b4f6..54e7f79 100644 --- a/example/src/optim/op.py +++ b/example/src/optim/op.py @@ -1,6 +1,5 @@ -from src import OPTIM - from excore.logger import logger +from src import OPTIM @OPTIM.register() diff --git a/excore/__init__.py b/excore/__init__.py index 7cc6d96..5fb7a09 100644 --- a/excore/__init__.py +++ b/excore/__init__.py @@ -1,14 +1,40 @@ from . import hub -from ._constants import (__author__, __version__, _load_workspace_config, - _workspace_cfg) -from .config import (ConfigArgumentHookProtocol, build_all, load_config, - set_target_fields) +from ._constants import ( + __author__, + __version__, + _load_workspace_config, + _workspace_cfg, +) +from .config import ( + ConfigArgumentHookProtocol, + build_all, + load, + set_target_fields, +) from .hook import Hook, HookManager -from .logger import (add_logger, enable_excore_debug, init_logger, logger, - remove_logger) +from .logger import _enable_excore_debug, add_logger, init_logger, remove_logger from .registry import Registry, load_registries +__all__ = [ + "__author__", + "__version__", + "hub", + "config", + "hook", + "logger", + "registry", + "ConfigArgumentHookProtocol", + "build_all", + "load", + "Hook", + "HookManager", + "add_logger", + "remove_logger", + "Registry", + "load_registries", +] + init_logger() _load_workspace_config() set_target_fields(_workspace_cfg["target_fields"]) -enable_excore_debug() +_enable_excore_debug() diff --git a/excore/_json_schema.py b/excore/_json_schema.py index d7d8809..fe2ad96 100644 --- a/excore/_json_schema.py +++ b/excore/_json_schema.py @@ -86,7 +86,7 @@ def _check(bases): for b in bases: if b is object: return False - if hasattr(b, "__call__"): + if callable(b): return True return False @@ -131,8 +131,7 @@ def _clean(anno): return anno if anno.__origin__ == type or ( # Optional - anno.__origin__ == Union - and anno.__args__[1] == NoneType + anno.__origin__ == Union and anno.__args__[1] == NoneType ): return _clean(anno.__args__[0]) return anno diff --git a/excore/cli.py b/excore/cli.py index 61cb0fc..b8f3221 100644 --- a/excore/cli.py +++ b/excore/cli.py @@ -12,8 +12,14 @@ from typer import Option as COp from typing_extensions import Annotated -from ._constants import (LOGO, _base_name, _cache_base_dir, _cache_dir, - _workspace_cfg, _workspace_config_file) +from ._constants import ( + LOGO, + _base_name, + _cache_base_dir, + _cache_dir, + _workspace_cfg, + _workspace_config_file, +) from ._json_schema import _generate_json_shcema, _generate_taplo_config from .logger import logger from .registry import Registry @@ -63,7 +69,7 @@ def _generate_registries(entry="__init__"): with open(target_file, "w", encoding="UTF-8") as f: f.write("") - with open(target_file, "r", encoding="UTF-8") as f: + with open(target_file, encoding="UTF-8") as f: source_code = ast.parse(f.read()) flag = _has_import_excore(source_code) if flag == 1: @@ -79,25 +85,27 @@ def _generate_registries(entry="__init__"): source_code = astor.to_source(source_code) with open(target_file, "w", encoding="UTF-8") as f: f.write(source_code) - logger.success( - "Generate Registry definition in {} according to `target_fields`", target_file - ) + logger.success("Generate Registry definition in {} according to `target_fields`", target_file) def _detect_assign(node, definition): if isinstance(node, ast.Module): for child in node.body: _detect_assign(child, definition) - elif isinstance(node, ast.Assign) and isinstance(node.value, ast.Call): - if hasattr(node.value.func, "id") and node.value.func.id == "Registry": - definition.append(node.value.args[0].value) + elif ( + isinstance(node, ast.Assign) + and isinstance(node.value, ast.Call) + and hasattr(node.value.func, "id") + and node.value.func.id == "Registry" + ): + definition.append(node.value.args[0].value) def _detect_registy_difinition() -> bool: target_file = osp.join(_workspace_cfg["src_dir"], "__init__.py") logger.info("Detect Registry definition in {}", target_file) definition = [] - with open(target_file, "r", encoding="UTF-8") as f: + with open(target_file, encoding="UTF-8") as f: source_code = ast.parse(f.read()) _detect_assign(source_code, definition) if len(definition) > 0: @@ -148,9 +156,7 @@ def _update(is_init=True, entry="__init__"): _workspace_cfg["registries"] = regs else: logger.imp("You can define fields later.") - _workspace_cfg["target_fields"] = _get_target_fields( - _workspace_cfg["registries"] - ) + _workspace_cfg["target_fields"] = _get_target_fields(_workspace_cfg["registries"]) _generate_registries(entry) else: logger.imp( @@ -207,9 +213,7 @@ def init( _update(True, entry) - logger.success( - "Welcome to ExCore. You can modify the `.excore.toml` file mannully." - ) + logger.success("Welcome to ExCore. You can modify the `.excore.toml` file mannully.") def _clear_cache(cache_dir): @@ -324,7 +328,7 @@ def registries(): def generate_registries( entry: Annotated[ str, CArg(help="Used for detect or generate Registry definition code") - ] = "__init__" + ] = "__init__", ): """ Generate registries definition code according to workspace config. diff --git a/excore/config.py b/excore/config.py index d4a510f..963c344 100644 --- a/excore/config.py +++ b/excore/config.py @@ -9,8 +9,12 @@ import toml -from ._exceptions import (CoreConfigBuildError, CoreConfigParseError, - CoreConfigSupportError, ModuleBuildError) +from ._exceptions import ( + CoreConfigBuildError, + CoreConfigParseError, + CoreConfigSupportError, + ModuleBuildError, +) from .hook import ConfigHookManager from .logger import logger from .registry import Registry, load_registries @@ -275,9 +279,7 @@ def __init__(self, modules: Optional[Union[List[ModuleNode], ModuleNode]] = None if isinstance(modules, ModuleNode): modules = [modules] elif not isinstance(modules, list): - raise TypeError( - f"Expect modules to be `list` or `ModuleNode`, but got {type(modules)}" - ) + raise TypeError(f"Expect modules to be `list` or `ModuleNode`, but got {type(modules)}") for m in modules: self[m.name] = m @@ -318,9 +320,7 @@ def __init__( self._is_initialized = True def hook(self): - raise NotImplementedError( - f"`{self.__class__.__name__}` do not implement `hook` method." - ) + raise NotImplementedError(f"`{self.__class__.__name__}` do not implement `hook` method.") @final def __call__(self): @@ -502,8 +502,7 @@ def _parse_module_node(self, node): if not isinstance(target_module_names, list): target_module_names = [target_module_names] converted_modules = [ - self._parse_single_param(name, params) - for name in target_module_names + self._parse_single_param(name, params) for name in target_module_names ] to_pop.extend(target_module_names) node[param_name] = ModuleWrapper(converted_modules) @@ -622,14 +621,10 @@ def load_config(filename: str, base_key: str = "__base__") -> AttrNode: path = os.path.dirname(filename) if ext != ".toml": - raise CoreConfigSupportError( - "Only support `toml` files for now, but got {}".format(filename) - ) + raise CoreConfigSupportError(f"Only support `toml` files for now, but got {filename}") config = toml.load(filename, AttrNode) - base_cfgs = [ - load_config(os.path.join(path, i), base_key) for i in config.pop(base_key, []) - ] + base_cfgs = [load_config(os.path.join(path, i), base_key) for i in config.pop(base_key, [])] base_cfg = AttrNode() for c in base_cfgs: base_cfg.update(c) diff --git a/excore/cuda_helper.py b/excore/cuda_helper.py index dad3d51..5863759 100644 --- a/excore/cuda_helper.py +++ b/excore/cuda_helper.py @@ -1,13 +1,16 @@ from enum import Enum from typing import List -from pynvml import nvmlDeviceGetCount # noqa +from pynvml import ( + nvmlDeviceGetCount, # noqa + nvmlInit, + nvmlShutdown, +) from pynvml import nvmlDeviceGetHandleByIndex as get_device_handle from pynvml import nvmlDeviceGetMemoryInfo as _get_memory_info from pynvml import nvmlDeviceGetName as get_device_name from pynvml import nvmlDeviceGetPowerState as _get_device_powerstate from pynvml import nvmlDeviceGetTemperature as _get_device_temperature -from pynvml import nvmlInit, nvmlShutdown __all__ = [ "get_device_handle", diff --git a/excore/hook.py b/excore/hook.py index f8a6983..40b9f28 100644 --- a/excore/hook.py +++ b/excore/hook.py @@ -68,7 +68,7 @@ def __new__(cls, name, bases, attrs): stages = inst.stages if inst.__name__ != "HookManager" and stages is None: raise HookManagerBuildError( - "The hook manager `{}` must have valid stages".format(inst.__name__) + f"The hook manager `{inst.__name__}` must have valid stages" ) return inst @@ -116,21 +116,15 @@ def __init__(self, hooks: Sequence[Hook]): for h in hooks: if not hasattr(h, "__HookType__") or h.__HookType__ not in self.stages: raise HookBuildError( - __error_msg.format( - h.__class__.__name__, "__HookType__", h.__HookType__ - ) + __error_msg.format(h.__class__.__name__, "__HookType__", h.__HookType__) ) if not hasattr(h, "__LifeSpan__") or h.__LifeSpan__ <= 0: raise HookBuildError( - __error_msg.format( - h.__class__.__name__, "__LifeSpan__", h.__LifeSpan__ - ) + __error_msg.format(h.__class__.__name__, "__LifeSpan__", h.__LifeSpan__) ) if not hasattr(h, "__CallInter__") or h.__CallInter__ <= 0: raise HookBuildError( - __error_msg.format( - h.__class__.__name__, "__CallInter__", h.__CallInter__ - ) + __error_msg.format(h.__class__.__name__, "__CallInter__", h.__CallInter__) ) self.hooks = defaultdict(list) self.calls = defaultdict(int) diff --git a/excore/hub.py b/excore/hub.py index cfa0635..f2ba88a 100644 --- a/excore/hub.py +++ b/excore/hub.py @@ -21,8 +21,14 @@ from tqdm import tqdm from ._constants import __version__, _cache_dir -from ._exceptions import (GitCheckoutError, GitPullError, HTTPDownloadError, - InvalidGitHost, InvalidProtocol, InvalidRepo) +from ._exceptions import ( + GitCheckoutError, + GitPullError, + HTTPDownloadError, + InvalidGitHost, + InvalidProtocol, + InvalidRepo, +) from .logger import logger __all__ = [ @@ -84,7 +90,7 @@ def _parse_repo_info(cls, repo_info: str) -> Tuple[str, str, str]: repo_owner, repo_name = prefix_info.split("/") return repo_owner, repo_name, branch_info except ValueError as exc: - raise InvalidRepo("repo_info: '{}' is invalid.".format(repo_info)) from exc + raise InvalidRepo(f"repo_info: '{repo_info}' is invalid.") from exc @classmethod def _check_git_host(cls, git_host): @@ -120,17 +126,15 @@ def fetch( silent: bool = True, ) -> str: if not cls._check_git_host(git_host): - raise InvalidGitHost("git_host: '{}' is malformed.".format(git_host)) + raise InvalidGitHost(f"git_host: '{git_host}' is malformed.") repo_owner, repo_name, branch_info = cls._parse_repo_info(repo_info) normalized_branch_info = branch_info.replace("/", "_") - repo_dir_raw = "{}_{}_{}".format( - repo_owner, repo_name, normalized_branch_info - ) + ("_{}".format(commit) if commit else "") - repo_dir = ( - "_".join(__version__.split(".")) + "_" + cls._gen_repo_dir(repo_dir_raw) - ) - git_url = "git@{}:{}/{}.git".format(git_host, repo_owner, repo_name) + repo_dir_raw = f"{repo_owner}_{repo_name}_{normalized_branch_info}" + if commit: + repo_dir_raw += f"_{commit}" + repo_dir = "_".join(__version__.split(".")) + "_" + cls._gen_repo_dir(repo_dir_raw) + git_url = f"git@{git_host}:{repo_owner}/{repo_name}.git" if use_cache and os.path.exists(repo_dir): # use cache logger.debug("Cache Found in {}", repo_dir) @@ -145,9 +149,7 @@ def fetch( repo_dir, ) - kwargs = ( - {"stderr": subprocess.PIPE, "stdout": subprocess.PIPE} if silent else {} - ) + kwargs = {"stderr": subprocess.PIPE, "stdout": subprocess.PIPE} if silent else {} if commit is None: # shallow clone repo by branch/tag p = subprocess.Popen( @@ -179,8 +181,7 @@ def fetch( if p.returncode: shutil.rmtree(repo_dir, ignore_errors=True) raise GitCheckoutError( - "Git checkout error, please check the commit id.\n" - + err.decode() + "Git checkout error, please check the commit id.\n" + err.decode() ) with cd(repo_dir): shutil.rmtree(".git") @@ -191,9 +192,7 @@ def fetch( def _check_clone_pipe(cls, p): _, err = p.communicate() if p.returncode: - raise GitPullError( - "Repo pull error, please check repo info.\n" + err.decode() - ) + raise GitPullError("Repo pull error, please check repo info.\n" + err.decode()) class GitHTTPSFetcher(RepoFetcherBase): @@ -209,16 +208,14 @@ def fetch( silent: bool = True, ) -> str: if not cls._check_git_host(git_host): - raise InvalidGitHost("git_host: '{}' is malformed.".format(git_host)) + raise InvalidGitHost(f"git_host: '{git_host}' is malformed.") repo_owner, repo_name, branch_info = cls._parse_repo_info(repo_info) normalized_branch_info = branch_info.replace("/", "_") - repo_dir_raw = "{}_{}_{}".format( - repo_owner, repo_name, normalized_branch_info - ) + ("_{}".format(commit) if commit else "") - repo_dir = ( - "_".join(__version__.split(".")) + "_" + cls._gen_repo_dir(repo_dir_raw) + repo_dir_raw = f"{repo_owner}_{repo_name}_{normalized_branch_info}" + ( + f"_{commit}" if commit else "" ) + repo_dir = "_".join(__version__.split(".")) + "_" + cls._gen_repo_dir(repo_dir_raw) archive_url = "https://{}/{}/{}/archive/{}.zip".format( git_host, repo_owner, repo_name, commit or branch_info ) @@ -229,7 +226,7 @@ def fetch( shutil.rmtree(repo_dir, ignore_errors=True) # ignore and clear cache - logger.debug("Downloading from {} to {}".format(archive_url, repo_dir)) + logger.debug(f"Downloading from {archive_url} to {repo_dir}") cls._download_zip_and_extract(archive_url, repo_dir) return repo_dir @@ -238,9 +235,7 @@ def fetch( def _download_zip_and_extract(cls, url, target_dir): resp = requests.get(url, timeout=cls.HTTP_TIMEOUT, stream=True) if resp.status_code != 200: - raise HTTPDownloadError( - "An error occurred when downloading from {}".format(url) - ) + raise HTTPDownloadError(f"An error occurred when downloading from {url}") total_size = int(resp.headers.get("Content-Length", 0)) _bar = tqdm(total=total_size, unit="iB", unit_scale=True) @@ -268,9 +263,7 @@ def _download_zip_and_extract(cls, url, target_dir): def download_from_url(url: str, dst: str): resp = requests.get(url, timeout=120, stream=True) if resp.status_code != 200: - raise HTTPDownloadError( - "An error occurred when downloading from {}".format(url) - ) + raise HTTPDownloadError(f"An error occurred when downloading from {url}") total_size = int(resp.headers.get("Content-Length", 0)) _bar = tqdm(total=total_size, unit="iB", unit_scale=True) @@ -293,9 +286,7 @@ def _get_repo( ) -> str: if protocol not in PROTOCOLS: raise InvalidProtocol( - "Invalid protocol, the value should be one of {}.".format( - ", ".join(PROTOCOLS.keys()) - ) + "Invalid protocol, the value should be one of {}.".format(", ".join(PROTOCOLS.keys())) ) cache_dir = os.path.expanduser(os.path.join(_cache_dir, "hub")) with cd(cache_dir): @@ -337,9 +328,7 @@ def _init_hub( git_host, repo_info, use_cache=use_cache, commit=commit, protocol=protocol ) sys.path.insert(0, absolute_repo_dir) - hubmodule = load_module( - hubconf_entry, os.path.join(absolute_repo_dir, hubconf_entry) - ) + hubmodule = load_module(hubconf_entry, os.path.join(absolute_repo_dir, hubconf_entry)) sys.path.remove(absolute_repo_dir) return hubmodule @@ -359,11 +348,7 @@ def list( protocol: str = DEFAULT_PROTOCOL, ) -> List[str]: hubmodule = _init_hub(repo_info, git_host, entry, use_cache, commit, protocol) - return [ - _ - for _ in dir(hubmodule) - if not _.startswith("__") and callable(getattr(hubmodule, _)) - ] + return [_ for _ in dir(hubmodule) if not _.startswith("__") and callable(getattr(hubmodule, _))] def load( @@ -375,14 +360,12 @@ def load( use_cache: bool = True, commit: Optional[str] = None, protocol: str = DEFAULT_PROTOCOL, - **kwargs + **kwargs, ) -> Any: - hubmodule = _init_hub( - repo_info, git_host, hubconf_entry, use_cache, commit, protocol - ) + hubmodule = _init_hub(repo_info, git_host, hubconf_entry, use_cache, commit, protocol) if not hasattr(hubmodule, entry) or not callable(getattr(hubmodule, entry)): - raise RuntimeError("Cannot find callable {} in {}".format(entry, hubconf_entry)) + raise RuntimeError(f"Cannot find callable {entry} in {hubconf_entry}") _check_dependencies(hubmodule) @@ -399,12 +382,10 @@ def help( commit: Optional[str] = None, protocol: str = DEFAULT_PROTOCOL, ) -> str: - hubmodule = _init_hub( - repo_info, git_host, hubconf_entry, use_cache, commit, protocol - ) + hubmodule = _init_hub(repo_info, git_host, hubconf_entry, use_cache, commit, protocol) if not hasattr(hubmodule, entry) or not callable(getattr(hubmodule, entry)): - raise RuntimeError("Cannot find callable {} in hubconf.py".format(entry)) + raise RuntimeError(f"Cannot find callable {entry} in hubconf.py") doc = getattr(hubmodule, entry).__doc__ return doc diff --git a/excore/logger.py b/excore/logger.py index 4f459a7..71d4ebc 100644 --- a/excore/logger.py +++ b/excore/logger.py @@ -49,15 +49,15 @@ def remove_logger(name: str) -> None: id = LOGGERS.pop(name, None) if id: logger.remove(id) - logger.success("Remove logger whose name is {}".format(name)) + logger.success(f"Remove logger whose name is {name}") else: - logger.warning("Cannot find logger with name {}".format(name)) + logger.warning(f"Cannot find logger with name {name}") def log_to_file_only(file_name: str, *args, **kwargs) -> None: logger.remove(None) logger.add(file_name, *args, **kwargs) - logger.success("Log to file {} only".format(file_name)) + logger.success(f"Log to file {file_name} only") def debug_only(*args, **kwargs) -> None: @@ -80,7 +80,7 @@ def _excore_debug(__message: str, *args, **kwargs): logger.log("EXCORE", __message, *args, **kwargs) -def enable_excore_debug(): +def _enable_excore_debug(): if os.getenv("EXCORE_DEBUG"): logger.remove() logger.add(sys.stdout, format=FORMAT, level="EXCORE") diff --git a/excore/registry.py b/excore/registry.py index 9a1ca87..6d06de8 100644 --- a/excore/registry.py +++ b/excore/registry.py @@ -10,8 +10,7 @@ from tabulate import tabulate -from ._constants import (_cache_dir, _registry_cache_file, - _workspace_config_file) +from ._constants import _cache_dir, _registry_cache_file, _workspace_config_file from .logger import logger from .utils import FileLock @@ -27,10 +26,8 @@ def _is_pure_ascii(name: str): if not _name_re.match(name): raise ValueError( - """Unexpected name, only support ASCII letters, ASCII digits, - underscores, and dashes, but got {}""".format( - name - ) + f"""Unexpected name, only support ASCII letters, ASCII digits, + underscores, and dashes, but got {name}""" ) @@ -39,10 +36,7 @@ def _is_function_or_class(module): def _default_filter_func(values: Sequence[Any]) -> bool: - for v in values: - if not v: - return False - return True + return all(v for v in values) def _default_match_func(m, base_module): @@ -81,11 +75,7 @@ def __call__(cls, name, **kwargs) -> "Registry": if name in cls._registry_pool: extra_field = [extra_field] if isinstance(extra_field, str) else extra_field target = cls._registry_pool[name] - if ( - extra_field - and hasattr(target, "extra_field") - and extra_field != target.extra_field - ): + if extra_field and hasattr(target, "extra_field") and extra_field != target.extra_field: logger.warning( f"{cls.__name__}: `{name}` has already existed," " different arguments will be ignored" @@ -123,9 +113,7 @@ def __init__( super().__init__() self.name = name if extra_field: - self.extra_field = ( - [extra_field] if isinstance(extra_field, str) else extra_field - ) + self.extra_field = [extra_field] if isinstance(extra_field, str) else extra_field self.extra_info = dict() @classmethod @@ -134,7 +122,7 @@ def dump(cls): os.makedirs(os.path.join(_cache_dir, cls._registry_dir), exist_ok=True) import pickle # pylint: disable=import-outside-toplevel - with FileLock(file_path): + with FileLock(file_path): # noqa: SIM117 with open(file_path, "wb") as f: pickle.dump(cls._registry_pool, f) @@ -153,7 +141,7 @@ def load(cls): sys.exit(0) import pickle # pylint: disable=import-outside-toplevel - with FileLock(file_path): + with FileLock(file_path): # noqa: SIM117 with open(file_path, "rb") as f: data = pickle.load(f) cls._registry_pool.update(data) @@ -205,9 +193,7 @@ def __setitem__(self, k, v) -> None: super().__setitem__(k, v) def __repr__(self) -> str: - s = json.dumps( - self, indent=4, ensure_ascii=False, sort_keys=False, separators=(",", ":") - ) + s = json.dumps(self, indent=4, ensure_ascii=False, sort_keys=False, separators=(",", ":")) return "\n" + s __str__ = __repr__ @@ -223,25 +209,19 @@ def _register( logger.ex("Registry has been locked!!!") return module if not (_is_function_or_class(module) or isinstance(module, ModuleType)): - raise TypeError( - "Only support function or class, but got {}".format(type(module)) - ) + raise TypeError(f"Only support function or class, but got {type(module)}") true_name = _get_module_name(module) name = name or true_name if not force and name in self and not self[name] == module: - raise ValueError("The name {} exists".format(name)) + raise ValueError(f"The name {name} exists") if extra_info: if not hasattr(self, "extra_field"): - raise ValueError( - "Registry `{}` does not have `extra_field`.".format(self.name) - ) + raise ValueError(f"Registry `{self.name}` does not have `extra_field`.") for k in extra_info: if k not in self.extra_field: raise ValueError( - "Registry `{}`: 'extra_info' does not has expected key {}.".format( - self.name, k - ) + f"Registry `{self.name}`: 'extra_info' does not has expected key {k}." ) self.extra_info[name] = [extra_info.get(k, None) for k in self.extra_field] elif hasattr(self, "extra_field"): @@ -259,9 +239,7 @@ def _register( return module - def register( - self, force: bool = False, name: Optional[str] = None, **extra_info - ) -> Callable: + def register(self, force: bool = False, name: Optional[str] = None, **extra_info) -> Callable: """ Decorator that registers a function or class with the current `Registry`. Any keyword arguments provided are added to the `extra_info` list for the @@ -300,9 +278,7 @@ def merge( if not isinstance(others, list): others = [others] if not isinstance(others[0], Registry): - raise TypeError( - "Expect `Registry` type, but got {}".format(type(others[0])) - ) + raise TypeError(f"Expect `Registry` type, but got {type(others[0])}") for other in others: modules = list(other.values()) names = list(other.keys()) @@ -319,9 +295,7 @@ def filter( """ filter_field = [filter_field] if isinstance(filter_field, str) else filter_field - filter_idx = [ - i for i, name in enumerate(self.extra_field) if name in filter_field - ] + filter_idx = [i for i, name in enumerate(self.extra_field) if name in filter_field] out = [] for name in self.keys(): info = self.extra_info[name] @@ -338,7 +312,7 @@ def match(self, base_module, match_func=_default_match_func): """ matched_modules = [ getattr(base_module, name) - for name in base_module.__dict__.keys() + for name in base_module.__dict__ if match_func(name, base_module) ] matched_modules = list(filter(_is_function_or_class, matched_modules)) @@ -362,7 +336,7 @@ def module_table( select_info = [select_info] if isinstance(select_info, str) else select_info for info_key in select_info: if info_key not in self.extra_field: - raise ValueError("Got unexpected info key {}".format(info_key)) + raise ValueError(f"Got unexpected info key {info_key}") else: select_info = [] @@ -383,9 +357,7 @@ def module_table( table_headers = [f"{item}" for item in [self.name, *select_info]] if select_info: - select_idx = [ - idx for idx, name in enumerate(self.extra_field) if name in select_info - ] + select_idx = [idx for idx, name in enumerate(self.extra_field) if name in select_info] else: select_idx = [] @@ -415,9 +387,7 @@ def registry_table(cls, **table_kwargs) -> Any: def load_registries(): - if not os.path.exists( - os.path.join(_cache_dir, Registry._registry_dir, _registry_cache_file) - ): + if not os.path.exists(os.path.join(_cache_dir, Registry._registry_dir, _registry_cache_file)): logger.warning("Please run `excore auto-register` in your command line first!") return Registry.load() diff --git a/excore/utils.py b/excore/utils.py index 8b648cc..df812fc 100644 --- a/excore/utils.py +++ b/excore/utils.py @@ -8,8 +8,8 @@ def __call__(self, func): @functools.wraps(func) def _cache(self): if not hasattr(self, "cached_elem"): - setattr(self, "cached_elem", func(self)) - return getattr(self, "cached_elem") + self.cached_elem = func(self) + return self.cached_elem return _cache diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 0000000..2cff3b5 --- /dev/null +++ b/ruff.toml @@ -0,0 +1,21 @@ +line-length = 100 + +[lint] +extend-select = ["E501"] +select = [ + # pycodestyle + "E", + # Pyflakes + "F", + # pyupgrade + "UP", + # flake8-bugbear + "B", + # flake8-simplify + "SIM", + # isort + "I", +] + +[lint.pydocstyle] +convention = "google"