diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 109b8fd8104b5..b324582fe7602 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -25,7 +25,7 @@ from argparse import Namespace from functools import partial from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union import torch from torch import ScriptModule, Tensor @@ -347,7 +347,7 @@ def log( def log_dict( self, - dictionary: dict, + dictionary: Mapping[str, Any], prog_bar: bool = False, logger: bool = True, on_step: Optional[bool] = None, diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 1cbab2fb8dee9..61739cd25d1d2 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -13,6 +13,7 @@ # limitations under the License. import operator from abc import ABC +from collections import OrderedDict from collections.abc import Mapping, Sequence from copy import copy from functools import partial @@ -85,10 +86,12 @@ def apply_to_collection( # Recursively apply to collection items if isinstance(data, Mapping): - return elem_type({ - k: apply_to_collection(v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) - for k, v in data.items() - }) + return elem_type( + OrderedDict({ + k: apply_to_collection(v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) + for k, v in data.items() + }) + ) if isinstance(data, tuple) and hasattr(data, '_fields'): # named tuple return elem_type( diff --git a/tests/utilities/test_apply_func.py b/tests/utilities/test_apply_func.py index a7eea3a749f26..7454ce01d3bee 100644 --- a/tests/utilities/test_apply_func.py +++ b/tests/utilities/test_apply_func.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import numbers -from collections import namedtuple +from collections import namedtuple, OrderedDict import numpy as np import torch @@ -76,3 +76,19 @@ def test_recursive_application_to_collection(): assert isinstance(reduced['g'], numbers.Number), 'Reduction of a number should result in a tensor' assert reduced['g'] == expected_result['g'], 'Reduction of a number did not yield the desired result' + + # mapping support + reduced = apply_to_collection({'a': 1, 'b': 2}, int, lambda x: str(x)) + assert reduced == {'a': '1', 'b': '2'} + reduced = apply_to_collection(OrderedDict([('b', 2), ('a', 1)]), int, lambda x: str(x)) + assert reduced == OrderedDict([('b', '2'), ('a', '1')]) + + # custom mappings + class _CustomCollection(dict): + + def __init__(self, initial_dict): + super().__init__(initial_dict) + + to_reduce = _CustomCollection({'a': 1, 'b': 2, 'c': 3}) + reduced = apply_to_collection(to_reduce, int, lambda x: str(x)) + assert reduced == _CustomCollection({'a': '1', 'b': '2', 'c': '3'})