Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow registering of custom resolvers to OmegaConfigLoader #2869

Merged
merged 13 commits into from
Aug 2, 2023
14 changes: 13 additions & 1 deletion kedro/config/omegaconf_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import logging
import mimetypes
from pathlib import Path
from typing import Any, Iterable
from typing import Any, Callable, Iterable

import fsspec
from omegaconf import OmegaConf
Expand Down Expand Up @@ -82,6 +82,7 @@ def __init__( # noqa: too-many-arguments
config_patterns: dict[str, list[str]] = None,
base_env: str = "base",
default_run_env: str = "local",
custom_resolvers: dict[str, Callable] = None,
):
"""Instantiates a ``OmegaConfigLoader``.

Expand All @@ -97,6 +98,8 @@ def __init__( # noqa: too-many-arguments
the configuration paths.
default_run_env: Name of the default run environment. Defaults to `"local"`.
Can be overridden by supplying the `env` argument.
custom_resolvers: A dictionary of custom resolvers to be registered. For more information,
see here: https://omegaconf.readthedocs.io/en/2.3_branch/custom_resolvers.html#custom-resolvers
"""
self.base_env = base_env
self.default_run_env = default_run_env
Expand All @@ -111,6 +114,9 @@ def __init__( # noqa: too-many-arguments

# Deactivate oc.env built-in resolver for OmegaConf
OmegaConf.clear_resolver("oc.env")
# Register user provided custom resolvers
if custom_resolvers:
self._register_new_resolvers(custom_resolvers)

file_mimetype, _ = mimetypes.guess_type(conf_source)
if file_mimetype == "application/x-tar":
Expand Down Expand Up @@ -302,6 +308,12 @@ def _is_valid_config_path(self, path):
".json",
]

@staticmethod
def _register_new_resolvers(resolvers: dict[str, Callable]):
"""Register custom resolvers"""
for name, resolver in resolvers.items():
OmegaConf.register_new_resolver(name=name, resolver=resolver, replace=True)

@staticmethod
def _check_duplicates(seen_files_to_keys: dict[Path, set[Any]]):
duplicates = []
Expand Down
40 changes: 40 additions & 0 deletions tests/config/test_omegaconf_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,3 +649,43 @@ def test_variable_interpolation_in_catalog_with_separate_templates_file(
conf = OmegaConfigLoader(str(tmp_path))
conf.default_run_env = ""
assert conf["catalog"]["companies"]["type"] == "pandas.CSVDataSet"

def test_custom_resolvers(self, tmp_path):
base_params = tmp_path / _BASE_ENV / "parameters.yml"
param_config = {
"model_options": {
"test_size": "${add: 3, 4}",
"random_state": "${plus_2: 1}",
}
}
_write_yaml(base_params, param_config)
custom_resolvers = {
"add": lambda *x: sum(x),
"plus_2": lambda x: x + 2,
}
conf = OmegaConfigLoader(str(tmp_path), custom_resolvers=custom_resolvers)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is str(tmp_path) strictly needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so, conf_source is required argument

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I mean if tmp_path alone works? I check the definition of AbstractConfigLoader have a signature of conf_source: str, but I expect it should be conf_source : str | Path.
image

A quick experiment suggests that it should work

image

cc @merelcht

conf.default_run_env = ""
assert conf["parameters"]["model_options"]["test_size"] == 7
assert conf["parameters"]["model_options"]["random_state"] == 3
noklam marked this conversation as resolved.
Show resolved Hide resolved

def test_overwrite_resolvers(self, tmp_path):
base_params = tmp_path / _BASE_ENV / "parameters.yml"
# OmegaConf is a singleton, register a resolver to be overwritten
OmegaConf.register_new_resolver("custom", lambda x: x + 10)

param_config = {
"model_options": {
"test_size": "${custom: 10}",
}
}
_write_yaml(base_params, param_config)
conf_original = OmegaConf.load(base_params)
# test_size should be calculated using custom resolver (x + 10)
assert conf_original["model_options"]["test_size"] == 20
custom_resolvers = {
"custom": lambda x: x + 20,
}
conf = OmegaConfigLoader(str(tmp_path), custom_resolvers=custom_resolvers)
conf.default_run_env = ""
# test_size should be calculated using overwritten custom resolver (x + 20)
assert conf["parameters"]["model_options"]["test_size"] == 30
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to add a check/assert first to show that "test_size" is set to 20 and then after the overwriting it will be 30? Maybe just by calling omegaconf directly on the param_config?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the test