-
Notifications
You must be signed in to change notification settings - Fork 914
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow registering of custom resolvers to OmegaConfigLoader
#2869
Changes from 6 commits
6a14f7a
951957a
89ebb12
3163c5e
b03b1e4
7420b28
499491a
29da634
557325e
5b0b863
c609179
e67c805
2ccf458
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -649,3 +649,43 @@ def test_variable_interpolation_in_catalog_with_separate_templates_file( | |
conf = OmegaConfigLoader(str(tmp_path)) | ||
conf.default_run_env = "" | ||
assert conf["catalog"]["companies"]["type"] == "pandas.CSVDataSet" | ||
|
||
def test_custom_resolvers(self, tmp_path): | ||
base_params = tmp_path / _BASE_ENV / "parameters.yml" | ||
param_config = { | ||
"model_options": { | ||
"test_size": "${add: 3, 4}", | ||
"random_state": "${plus_2: 1}", | ||
} | ||
} | ||
_write_yaml(base_params, param_config) | ||
custom_resolvers = { | ||
"add": lambda *x: sum(x), | ||
"plus_2": lambda x: x + 2, | ||
} | ||
conf = OmegaConfigLoader(str(tmp_path), custom_resolvers=custom_resolvers) | ||
conf.default_run_env = "" | ||
assert conf["parameters"]["model_options"]["test_size"] == 7 | ||
assert conf["parameters"]["model_options"]["random_state"] == 3 | ||
noklam marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def test_overwrite_resolvers(self, tmp_path): | ||
base_params = tmp_path / _BASE_ENV / "parameters.yml" | ||
# OmegaConf is a singleton, register a resolver to be overwritten | ||
OmegaConf.register_new_resolver("custom", lambda x: x + 10) | ||
|
||
param_config = { | ||
"model_options": { | ||
"test_size": "${custom: 10}", | ||
} | ||
} | ||
_write_yaml(base_params, param_config) | ||
conf_original = OmegaConf.load(base_params) | ||
# test_size should be calculated using custom resolver (x + 10) | ||
assert conf_original["model_options"]["test_size"] == 20 | ||
custom_resolvers = { | ||
"custom": lambda x: x + 20, | ||
} | ||
conf = OmegaConfigLoader(str(tmp_path), custom_resolvers=custom_resolvers) | ||
conf.default_run_env = "" | ||
# test_size should be calculated using overwritten custom resolver (x + 20) | ||
assert conf["parameters"]["model_options"]["test_size"] == 30 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible to add a check/assert first to show that "test_size" is set to 20 and then after the overwriting it will be 30? Maybe just by calling omegaconf directly on the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated the test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is
str(tmp_path)
strictly needed?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so,
conf_source
is required argumentThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I mean if
![image](https://private-user-images.githubusercontent.com/18221871/257535717-140dbd86-cca9-4a5f-8dd9-3520b868c619.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzg5ODAxOTgsIm5iZiI6MTczODk3OTg5OCwicGF0aCI6Ii8xODIyMTg3MS8yNTc1MzU3MTctMTQwZGJkODYtY2NhOS00YTVmLThkZDktMzUyMGI4NjhjNjE5LnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMDglMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjA4VDAxNTgxOFomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPThhOWM5MjVhZmVlZjhhYjgzNzY5N2Q3MjQ2OTA2NDRiOWI4ZDk1OTNkMzkxYzc1NGRiZTFjOTAwMjA1YmRkYjImWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.D-fx_Vtzr9HRryvH_Cv39bTBe2xNnfmyR1-2IY0LdOY)
tmp_path
alone works? I check the definition ofAbstractConfigLoader
have a signature ofconf_source: str
, but I expect it should beconf_source : str | Path
.A quick experiment suggests that it should work
cc @merelcht