Skip to content

Commit

Permalink
feat: support of recursive get_from_params added
Browse files Browse the repository at this point in the history
  • Loading branch information
bagxi committed Jun 21, 2021
1 parent d337280 commit 32568dc
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 287 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- `dataset_from_params` support in config API ([#1231](https://github.com/catalyst-team/catalyst/pull/1231))
- transform from params support for config API added ([#1236](https://github.com/catalyst-team/catalyst/pull/1236))
- samplers from params support for config API added ([#1240](https://github.com/catalyst-team/catalyst/pull/1240))
- recursive registry.get_from_params added ([#1241](https://github.com/catalyst-team/catalyst/pull/1241))

### Changed

Expand Down
78 changes: 10 additions & 68 deletions catalyst/runners/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Dict, List
from typing import Any, Dict, List
from collections import OrderedDict
from copy import deepcopy
import logging
Expand Down Expand Up @@ -216,46 +216,6 @@ def get_loggers(self) -> Dict[str, ILogger]:

return loggers

def _get_transform_from_params(self, **params) -> Callable:
"""Creates transformation from ``**params`` parameters."""
recursion_keys = params.pop("_transforms_", ("transforms",))
for key in recursion_keys:
if key in params:
params[key] = [
self._get_transform_from_params(**transform_params)
for transform_params in params[key]
]

transform = REGISTRY.get_from_params(**params)
return transform

def get_transform(self, **params) -> Callable:
"""
Returns the data transforms for a given dataset.
Args:
**params: parameters of the transformation
Returns:
Data transformation to use
"""
# make a copy of params since we don't want to modify config
params = deepcopy(params)

transform = self._get_transform_from_params(**params)
return transform

def _get_dataset_from_params(self, **params) -> "Dataset":
"""Creates dataset from ``**params`` parameters."""
params = deepcopy(params)

transform_params: dict = params.pop("transform", None)
if transform_params is not None:
params["transform"] = self.get_transform(**transform_params)

dataset = REGISTRY.get_from_params(**params)
return dataset

def get_datasets(self, stage: str) -> "OrderedDict[str, Dataset]":
"""
Returns datasets for a given stage.
Expand All @@ -266,24 +226,14 @@ def get_datasets(self, stage: str) -> "OrderedDict[str, Dataset]":
Returns:
Dict: datasets objects
"""
params = deepcopy(self._stage_config[stage]["loaders"]["datasets"])
params = self._stage_config[stage]["loaders"]["datasets"]

datasets = [
(key, self._get_dataset_from_params(**dataset_params))
(key, REGISTRY.get_from_params(**dataset_params))
for key, dataset_params in params.items()
]
return OrderedDict(datasets)

def _get_sampler_from_params(self, **params) -> Sampler:
"""Creates sampler from ``**params`` parameters."""
recursion_keys = params.pop("_samplers_", ("sampler", "base_sampler"))
for key in recursion_keys:
if key in params:
params[key] = self._get_sampler_from_params(**params[key])

sampler = REGISTRY.get_from_params(**params)
return sampler

def get_samplers(self, stage: str) -> "OrderedDict[str, Sampler]":
"""
Returns samplers for a given stage.
Expand All @@ -294,10 +244,10 @@ def get_samplers(self, stage: str) -> "OrderedDict[str, Sampler]":
Returns:
Dict of samplers
"""
params = deepcopy(self._stage_config[stage]["loaders"].get("samplers", {}))
params = get_by_keys(self._stage_config, stage, "loaders", "samplers", default={})

samplers = [
(key, self._get_sampler_from_params(**sampler_params))
(key, REGISTRY.get_from_params(**sampler_params))
for key, sampler_params in params.items()
]
return OrderedDict(samplers)
Expand All @@ -314,7 +264,7 @@ def get_loaders(self, stage: str) -> "OrderedDict[str, DataLoader]":
"""
loaders_params = deepcopy(self._stage_config[stage]["loaders"])

# config parsed manyally in `get_datasets` and `get_samplers` methods
# config is parsed manyally in `get_datasets` and `get_samplers` methods
loaders_params.pop("datasets", None)
loaders_params.pop("samplers", None)

Expand Down Expand Up @@ -399,8 +349,10 @@ def _get_optimizer_from_params(
no_bias_weight_decay=no_bias_weight_decay,
lr_scaling=lr_scaling,
)

# instantiate optimizer
optimizer = REGISTRY.get_from_params(**params, params=model_params)
# use `shared_params` to pass model params to the nested optimizers
optimizer = REGISTRY.get_from_params(**params, shared_params={"params": model_params})
return optimizer

def get_optimizer(self, model: RunnerModel, stage: str) -> RunnerOptimizer:
Expand Down Expand Up @@ -462,23 +414,13 @@ def get_scheduler(self, optimizer: RunnerOptimizer, stage: str) -> RunnerSchedul
scheduler = self._get_scheduler_from_params(optimizer=optimizer, **scheduler_params)
return scheduler

@staticmethod
def _get_callback_from_params(**params):
params = deepcopy(params)
wrapper_params = params.pop("_wrapper", None)
callback = REGISTRY.get_from_params(**params)
if wrapper_params is not None:
wrapper_params["base_callback"] = callback
callback = ConfigRunner._get_callback_from_params(**wrapper_params) # noqa: WPS437
return callback

def get_callbacks(self, stage: str) -> "OrderedDict[str, Callback]":
"""Returns the callbacks for a given stage."""
callbacks_params = get_by_keys(self._stage_config, stage, "callbacks", default={})

callbacks = OrderedDict(
[
(key, self._get_callback_from_params(**callback_params))
(key, REGISTRY.get_from_params(**callback_params))
for key, callback_params in callbacks_params.items()
]
)
Expand Down
50 changes: 45 additions & 5 deletions catalyst/tools/registry.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Tuple, Type, Union
import collections
import copy
import functools
import inspect
import warnings
Expand Down Expand Up @@ -37,11 +38,16 @@ class Registry(collections.MutableMapping):
that calls factory. Optional. Default just calls factory.
"""

def __init__(self, default_meta_factory: MetaFactory = _default_meta_factory):
def __init__(
self,
default_meta_factory: MetaFactory = _default_meta_factory,
name_key: str = "_target_",
):
"""Init."""
self.meta_factory = default_meta_factory
self._factories: Dict[str, Factory] = {}
self._late_add_callbacks: List[LateAddCallbak] = []
self.name_key = name_key

@staticmethod
def _get_factory_name(f, provided_name=None) -> str:
Expand Down Expand Up @@ -221,8 +227,39 @@ def get_instance(self, name: str, *args, meta_factory=None, **kwargs):
f"Factory '{name}' call failed: args={args} kwargs={kwargs}"
) from e

def _recursive_get_from_params(
self,
params: Union[Dict[str, Any], Any],
shared_params: Optional[Dict[str, Any]] = None,
meta_factory=None,
) -> Union[Any, Tuple[Any, Mapping[str, Any]]]:
if not (isinstance(params, dict) and self.name_key in params):
return params

shared_params = shared_params or {}

# make a copy of params since we don't want to modify them directly
params = copy.deepcopy(params)
for key, param in params.items():
if isinstance(param, dict):
params[key] = self._recursive_get_from_params(
params=param, meta_factory=meta_factory, shared_params=shared_params
)
# limit to `list` and `tuple` only to avoid processing of generators
elif isinstance(param, (list, tuple)):
params[key] = [
self._recursive_get_from_params(
params=value, meta_factory=meta_factory, shared_params=shared_params
)
for value in param
]

name = params.pop(self.name_key)
instance = self.get_instance(name, meta_factory=meta_factory, **shared_params, **params)
return instance

def get_from_params(
self, *, meta_factory=None, **kwargs
self, *, meta_factory=None, shared_params: Optional[Dict[str, Any]] = None, **kwargs,
) -> Union[Any, Tuple[Any, Mapping[str, Any]]]:
"""
Creates instance based in configuration dict with ``instantiation_fn``.
Expand All @@ -231,14 +268,17 @@ def get_from_params(
Args:
meta_factory: Function that calls factory the right way.
If not provided, default is used.
shared_params: params to pass on all levels in case of recursive creation
**kwargs: additional kwargs for factory
Returns:
result of calling ``instantiate_fn(factory, **config)``
"""
name = kwargs.pop("_target_", None)
if name:
return self.get_instance(name, meta_factory=meta_factory, **kwargs)
if self.name_key in kwargs:
instance = self._recursive_get_from_params(
params=kwargs, shared_params=shared_params, meta_factory=meta_factory
)
return instance

def all(self) -> List[str]:
"""
Expand Down
1 change: 0 additions & 1 deletion examples/mnist_stages/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ stages:
transform:
_target_: transform.ToTensor
download: True
num_samples_per_class: 320
valid:
_target_: MNIST
root: *dataset_root
Expand Down
7 changes: 1 addition & 6 deletions examples/mnist_stages/config_tune.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ stages:
train: False
transform: *transform
download: True
#

# samplers:
# train:
# _target_: MiniEpochSampler
Expand All @@ -88,11 +88,6 @@ stages:
# drop_last: True
# shuffle: per_epoch
#
# transforms:
# _target_: albumentations.Compose
# transforms:
# - _target_: albumentations.Normalize
# - _target_: catalyst.ImageToTensor

criterion:
_target_: CrossEntropyLoss
Expand Down
39 changes: 17 additions & 22 deletions examples/mnist_stages/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,24 @@ def get_model(self, stage: str):
utils.set_requires_grad(layer, requires_grad=False)
return model

def get_datasets(
self, stage: str, num_samples_per_class: int = None
) -> "OrderedDict[str, Dataset]":
"""Provides train/validation datasets from MNIST dataset."""
num_samples_per_class = num_samples_per_class or 320

datasets = super().get_datasets(stage=stage)
datasets["train"] = {
"dataset": datasets["train"],
"sampler": BalanceClassSampler(
labels=datasets["train"].targets, mode=num_samples_per_class
),
}
return datasets

class CustomSupervisedConfigRunner(IRunnerMixin, SupervisedConfigRunner):
def _get_dataset_from_params(self, num_samples_per_class=320, **kwargs):
dataset = super()._get_dataset_from_params(**kwargs)
if kwargs.get("train", True):
dataset = {
"dataset": dataset,
"sampler": BalanceClassSampler(labels=dataset.targets, mode=num_samples_per_class),
}

return dataset
class CustomSupervisedConfigRunner(IRunnerMixin, SupervisedConfigRunner):
pass


if SETTINGS.hydra_required:
Expand All @@ -44,19 +51,7 @@ def _get_dataset_from_params(self, num_samples_per_class=320, **kwargs):
from catalyst.dl import SupervisedHydraRunner

class CustomSupervisedHydraRunner(IRunnerMixin, SupervisedHydraRunner):
def _get_dataset_from_params(self, params):
num_samples_per_class = params.pop("num_samples_per_class", 320)

dataset = super()._get_dataset_from_params(params)
if params["train"]:
dataset = {
"dataset": dataset,
"sampler": BalanceClassSampler(
labels=dataset.targets, mode=num_samples_per_class
),
}

return dataset
pass

__all__ = ["CustomSupervisedConfigRunner", "CustomSupervisedHydraRunner"]
else:
Expand Down
31 changes: 31 additions & 0 deletions tests/catalyst/tools/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,37 @@ def test_from_config():
assert res is None


def test_name_key():
"""@TODO: Docs. Contribution is welcome."""
r = Registry(name_key="_key_")

r.add(foo)

res = r.get_from_params(**{"_key_": "foo", "a": 1, "b": 2})()
assert res == {"a": 1, "b": 2}

res = r.get_from_params(**{"_target_": "foo", "a": 1, "b": 2})
assert res is None


def test_recursive_get_from_config():
def meta_factory(factory, args, kwargs):
return factory(*args, **kwargs)

r = Registry(default_meta_factory=meta_factory)

r.add(foo)

res = r.get_from_params(
**{
"_target_": "foo",
"a": {"_target_": "foo", "a": {"_target_": "foo", "a": 1, "b": 2}, "b": 2},
"b": [{"_target_": "foo", "a": 1, "b": 2}, {"_target_": "foo", "a": 1, "b": 2}],
}
)
assert res == {"a": {"a": {"a": 1, "b": 2}, "b": 2}, "b": [{"a": 1, "b": 2}, {"a": 1, "b": 2}]}


def test_meta_factory():
"""@TODO: Docs. Contribution is welcome.""" # noqa: D202

Expand Down
Loading

0 comments on commit 32568dc

Please sign in to comment.