Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix mypy strict issues #2 #3497

Merged
merged 21 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion kedro/config/abstract_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(
conf_source: str,
env: None | str = None,
runtime_params: None | dict[str, Any] = None,
**kwargs,
**kwargs: Any,
):
super().__init__()
self.conf_source = conf_source
Expand Down
40 changes: 21 additions & 19 deletions kedro/config/omegaconf_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,13 +145,13 @@ def __init__( # noqa: PLR0913
except MissingConfigException:
self._globals = {}

def __setitem__(self, key, value):
def __setitem__(self, key: str, value: Any) -> None:
if key == "globals":
# Update the cached value at self._globals since it is used by the globals resolver
self._globals = value
super().__setitem__(key, value)

def __getitem__(self, key) -> dict[str, Any]: # noqa: PLR0912
def __getitem__(self, key: str) -> dict[str, Any]: # noqa: PLR0912
"""Get configuration files by key, load and merge them, and
return them in the form of a config dictionary.

Expand All @@ -175,7 +175,7 @@ def __getitem__(self, key) -> dict[str, Any]: # noqa: PLR0912
self._register_runtime_params_resolver()

if key in self:
return super().__getitem__(key)
return super().__getitem__(key) # type: ignore[no-any-return]

if key not in self.config_patterns:
raise KeyError(
Expand All @@ -196,7 +196,7 @@ def __getitem__(self, key) -> dict[str, Any]: # noqa: PLR0912
else:
base_path = str(Path(self._fs.ls("", detail=False)[-1]) / self.base_env)
try:
base_config = self.load_and_merge_dir_config(
base_config = self.load_and_merge_dir_config( # type: ignore[no-untyped-call]
base_path, patterns, key, processed_files, read_environment_variables
)
except UnsupportedInterpolationType as exc:
Expand All @@ -216,7 +216,7 @@ def __getitem__(self, key) -> dict[str, Any]: # noqa: PLR0912
else:
env_path = str(Path(self._fs.ls("", detail=False)[-1]) / run_env)
try:
env_config = self.load_and_merge_dir_config(
env_config = self.load_and_merge_dir_config( # type: ignore[no-untyped-call]
env_path, patterns, key, processed_files, read_environment_variables
)
except UnsupportedInterpolationType as exc:
Expand Down Expand Up @@ -244,9 +244,9 @@ def __getitem__(self, key) -> dict[str, Any]: # noqa: PLR0912
f" the glob pattern(s): {[*self.config_patterns[key]]}"
)

return resulting_config
return resulting_config # type: ignore[no-any-return]

def __repr__(self): # pragma: no cover
def __repr__(self) -> str: # pragma: no cover
return (
f"OmegaConfigLoader(conf_source={self.conf_source}, env={self.env}, "
f"config_patterns={self.config_patterns})"
Expand Down Expand Up @@ -312,8 +312,8 @@ def load_and_merge_dir_config( # noqa: PLR0913
self._resolve_environment_variables(config)
config_per_file[config_filepath] = config
except (ParserError, ScannerError) as exc:
line = exc.problem_mark.line # type: ignore
cursor = exc.problem_mark.column # type: ignore
line = exc.problem_mark.line
cursor = exc.problem_mark.column
raise ParserError(
f"Invalid YAML or JSON file {Path(conf_path, config_filepath.name).as_posix()},"
f" unable to read line {line}, position {cursor}."
Expand Down Expand Up @@ -342,7 +342,7 @@ def load_and_merge_dir_config( # noqa: PLR0913
if not k.startswith("_")
}

def _is_valid_config_path(self, path):
def _is_valid_config_path(self, path: Path) -> bool:
"""Check if given path is a file path and file type is yaml or json."""
posix_path = path.as_posix()
return self._fs.isfile(str(posix_path)) and path.suffix in [
Expand All @@ -351,22 +351,22 @@ def _is_valid_config_path(self, path):
".json",
]

def _register_globals_resolver(self):
def _register_globals_resolver(self) -> None:
"""Register the globals resolver"""
OmegaConf.register_new_resolver(
"globals",
self._get_globals_value,
replace=True,
)

def _register_runtime_params_resolver(self):
def _register_runtime_params_resolver(self) -> None:
OmegaConf.register_new_resolver(
"runtime_params",
self._get_runtime_value,
replace=True,
)

def _get_globals_value(self, variable, default_value=_NO_VALUE):
def _get_globals_value(self, variable: str, default_value: Any = _NO_VALUE) -> Any:
"""Return the globals values to the resolver"""
if variable.startswith("_"):
raise InterpolationResolutionError(
Expand All @@ -383,7 +383,7 @@ def _get_globals_value(self, variable, default_value=_NO_VALUE):
f"Globals key '{variable}' not found and no default value provided."
)

def _get_runtime_value(self, variable, default_value=_NO_VALUE):
def _get_runtime_value(self, variable: str, default_value: Any = _NO_VALUE) -> Any:
"""Return the runtime params values to the resolver"""
runtime_oc = OmegaConf.create(self.runtime_params)
interpolated_value = OmegaConf.select(
Expand All @@ -397,7 +397,7 @@ def _get_runtime_value(self, variable, default_value=_NO_VALUE):
)

@staticmethod
def _register_new_resolvers(resolvers: dict[str, Callable]):
def _register_new_resolvers(resolvers: dict[str, Callable]) -> None:
"""Register custom resolvers"""
for name, resolver in resolvers.items():
if not OmegaConf.has_resolver(name):
Expand All @@ -406,7 +406,7 @@ def _register_new_resolvers(resolvers: dict[str, Callable]):
OmegaConf.register_new_resolver(name=name, resolver=resolver)

@staticmethod
def _check_duplicates(seen_files_to_keys: dict[Path, set[Any]]):
def _check_duplicates(seen_files_to_keys: dict[Path, set[Any]]) -> None:
duplicates = []

filepaths = list(seen_files_to_keys.keys())
Expand Down Expand Up @@ -449,7 +449,9 @@ def _resolve_environment_variables(config: DictConfig | ListConfig) -> None:
OmegaConf.resolve(config)

@staticmethod
def _destructive_merge(config, env_config, env_path):
def _destructive_merge(
config: dict[str, Any], env_config: dict[str, Any], env_path: str
) -> dict[str, Any]:
# Destructively merge the two env dirs. The chosen env will override base.
common_keys = config.keys() & env_config.keys()
if common_keys:
Expand All @@ -464,11 +466,11 @@ def _destructive_merge(config, env_config, env_path):
return config

@staticmethod
def _soft_merge(config, env_config):
def _soft_merge(config: dict[str, Any], env_config: dict[str, Any]) -> Any:
# Soft merge the two env dirs. The chosen env will override base if keys clash.
return OmegaConf.to_container(OmegaConf.merge(config, env_config))

def _is_hidden(self, path_str: str):
def _is_hidden(self, path_str: str) -> bool:
"""Check if path contains any hidden directory or is a hidden file"""
path = Path(path_str)
conf_path = Path(self.conf_source).resolve().as_posix()
Expand Down
4 changes: 2 additions & 2 deletions kedro/framework/cli/hooks/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def before_command_run(
self,
project_metadata: ProjectMetadata,
command_args: list[str],
):
) -> None:
"""Hooks to be invoked before a CLI command runs.
It receives the ``project_metadata`` as well as
all command line arguments that were used, including the command
Expand All @@ -32,7 +32,7 @@ def before_command_run(
@cli_hook_spec
def after_command_run(
self, project_metadata: ProjectMetadata, command_args: list[str], exit_code: int
):
) -> None:
"""Hooks to be invoked after a CLI command runs.
It receives the ``project_metadata`` as well as
all command line arguments that were used, including the command
Expand Down
10 changes: 5 additions & 5 deletions kedro/framework/context/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def _convert_paths_to_absolute_posix(
return conf_dictionary


def _validate_transcoded_datasets(catalog: DataCatalog):
def _validate_transcoded_datasets(catalog: DataCatalog) -> None:
"""Validates transcoded datasets are correctly named

Args:
Expand Down Expand Up @@ -205,7 +205,7 @@ def params(self) -> dict[str, Any]:
# Merge nested structures
params = OmegaConf.merge(params, self._extra_params)

return OmegaConf.to_container(params) if OmegaConf.is_config(params) else params
return OmegaConf.to_container(params) if OmegaConf.is_config(params) else params # type: ignore[no-any-return]

def _get_catalog(
self,
Expand All @@ -229,7 +229,7 @@ def _get_catalog(
)
conf_creds = self._get_config_credentials()

catalog = settings.DATA_CATALOG_CLASS.from_config(
catalog: DataCatalog = settings.DATA_CATALOG_CLASS.from_config(
catalog=conf_catalog,
credentials=conf_creds,
load_versions=load_versions,
Expand All @@ -254,7 +254,7 @@ def _get_feed_dict(self) -> dict[str, Any]:
params = self.params
feed_dict = {"parameters": params}

def _add_param_to_feed_dict(param_name, param_value):
def _add_param_to_feed_dict(param_name: str, param_value: Any) -> None:
"""This recursively adds parameter paths to the `feed_dict`,
whenever `param_value` is a dictionary itself, so that users can
specify specific nested parameters in their node inputs.
Expand All @@ -281,7 +281,7 @@ def _add_param_to_feed_dict(param_name, param_value):
def _get_config_credentials(self) -> dict[str, Any]:
"""Getter for credentials specified in credentials directory."""
try:
conf_creds = self.config_loader["credentials"]
conf_creds: dict[str, Any] = self.config_loader["credentials"]
except MissingConfigException as exc:
logging.getLogger(__name__).debug(
"Credentials not found in your Kedro project config.\n %s", str(exc)
Expand Down
6 changes: 3 additions & 3 deletions kedro/framework/hooks/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,11 @@ class _NullPluginManager:
"""This class creates an empty ``hook_manager`` that will ignore all calls to hooks,
allowing the runner to function if no ``hook_manager`` has been instantiated."""

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
pass

def __getattr__(self, name):
def __getattr__(self, name: str) -> Any:
return self

def __call__(self, *args, **kwargs):
def __call__(self, *args: Any, **kwargs: Any) -> None:
pass
4 changes: 2 additions & 2 deletions kedro/framework/hooks/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def on_node_error( # noqa: PLR0913
inputs: dict[str, Any],
is_async: bool,
session_id: str,
):
) -> None:
"""Hook to be invoked if a node run throws an uncaught error.
The signature of this error hook should match the signature of ``before_node_run``
along with the error that was raised.
Expand Down Expand Up @@ -211,7 +211,7 @@ def on_pipeline_error(
run_params: dict[str, Any],
pipeline: Pipeline,
catalog: DataCatalog,
):
) -> None:
"""Hook to be invoked if a pipeline run throws an uncaught Exception.
The signature of this error hook should match the signature of ``before_pipeline_run``
along with the error that was raised.
Expand Down
23 changes: 14 additions & 9 deletions kedro/framework/project/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pathlib import Path
from typing import Any

import dynaconf
import importlib_resources
import yaml
from dynaconf import LazySettings
Expand All @@ -28,10 +29,10 @@
)


def _get_default_class(class_import_path):
def _get_default_class(class_import_path: str) -> Any:
module, _, class_name = class_import_path.rpartition(".")

def validator_func(settings, validators):
def validator_func(settings: dynaconf.base.Settings, validators: Any) -> Any:
return getattr(importlib.import_module(module), class_name)

return validator_func
Expand All @@ -40,7 +41,9 @@ def validator_func(settings, validators):
class _IsSubclassValidator(Validator):
"""A validator to check if the supplied setting value is a subclass of the default class"""

def validate(self, settings, *args, **kwargs):
def validate(
self, settings: dynaconf.base.Settings, *args: Any, **kwargs: Any
) -> None:
super().validate(settings, *args, **kwargs)

default_class = self.default(settings, self)
Expand All @@ -58,7 +61,9 @@ class _HasSharedParentClassValidator(Validator):
"""A validator to check that the parent of the default class is an ancestor of
the settings value."""

def validate(self, settings, *args, **kwargs):
def validate(
self, settings: dynaconf.base.Settings, *args: Any, **kwargs: Any
) -> None:
super().validate(settings, *args, **kwargs)

default_class = self.default(settings, self)
Expand Down Expand Up @@ -112,7 +117,7 @@ class _ProjectSettings(LazySettings):
"DATA_CATALOG_CLASS", default=_get_default_class("kedro.io.DataCatalog")
)

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
kwargs.update(
validators=[
self._CONF_SOURCE,
Expand All @@ -135,7 +140,7 @@ def _load_data_wrapper(func: Any) -> Any:
"""

# noqa: protected-access
def inner(self, *args, **kwargs):
def inner(self: Any, *args: Any, **kwargs: Any) -> Any:
self._load_data()
return func(self._content, *args, **kwargs)

Expand Down Expand Up @@ -165,12 +170,12 @@ def __init__(self) -> None:
self._content: dict[str, Pipeline] = {}

@staticmethod
def _get_pipelines_registry_callable(pipelines_module: str):
def _get_pipelines_registry_callable(pipelines_module: str) -> Any:
module_obj = importlib.import_module(pipelines_module)
register_pipelines = getattr(module_obj, "register_pipelines")
return register_pipelines

def _load_data(self):
def _load_data(self) -> None:
"""Lazily read pipelines defined in the pipelines registry module."""

# If the pipelines dictionary has not been configured with a pipelines module
Expand Down Expand Up @@ -212,7 +217,7 @@ def configure(self, pipelines_module: str | None = None) -> None:

class _ProjectLogging(UserDict):
# noqa: super-init-not-called
def __init__(self):
def __init__(self) -> None:
"""Initialise project logging. The path to logging configuration is given in
environment variable KEDRO_LOGGING_CONFIG (defaults to default_logging.yml)."""
path = os.environ.get(
Expand Down
14 changes: 7 additions & 7 deletions kedro/framework/session/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def _init_store(self) -> BaseSessionStore:
store_args["session_id"] = self.session_id

try:
return store_class(**store_args)
return store_class(**store_args) # type: ignore[no-any-return]
except TypeError as err:
raise ValueError(
f"\n{err}.\nStore config must only contain arguments valid "
Expand All @@ -204,7 +204,7 @@ def _init_store(self) -> BaseSessionStore:
f"\n{err}.\nFailed to instantiate session store of type '{classpath}'."
) from err

def _log_exception(self, exc_type, exc_value, exc_tb):
def _log_exception(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
type_ = [] if exc_type.__module__ == "builtins" else [exc_type.__module__]
type_.append(exc_type.__qualname__)

Expand Down Expand Up @@ -240,32 +240,32 @@ def load_context(self) -> KedroContext:
)
self._hook_manager.hook.after_context_created(context=context)

return context
return context # type: ignore[no-any-return]

def _get_config_loader(self) -> AbstractConfigLoader:
"""An instance of the config loader."""
env = self.store.get("env")
extra_params = self.store.get("extra_params")

config_loader_class = settings.CONFIG_LOADER_CLASS
return config_loader_class(
return config_loader_class( # type: ignore[no-any-return]
conf_source=self._conf_source,
env=env,
runtime_params=extra_params,
**settings.CONFIG_LOADER_ARGS,
)

def close(self):
def close(self) -> None:
"""Close the current session and save its store to disk
if `save_on_close` attribute is True.
"""
if self.save_on_close:
self._store.save()

def __enter__(self):
def __enter__(self) -> KedroSession:
return self

def __exit__(self, exc_type, exc_value, tb_):
def __exit__(self, exc_type: Any, exc_value: Any, tb_: Any) -> None:
if exc_type:
self._log_exception(exc_type, exc_value, tb_)
self.close()
Expand Down
Loading