Skip to content

Commit

Permalink
[bugfix] Minor improvements to apply_to_collection and type signatu…
Browse files Browse the repository at this point in the history
…re of `log_dict` (#7851)

* minor fixeS

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
(cherry picked from commit 009e05d)
  • Loading branch information
SeanNaren authored and lexierule committed Jun 9, 2021
1 parent 8a5a56b commit c292788
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 7 deletions.
4 changes: 2 additions & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from argparse import Namespace
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union

import torch
from torch import ScriptModule, Tensor
Expand Down Expand Up @@ -347,7 +347,7 @@ def log(

def log_dict(
self,
dictionary: dict,
dictionary: Mapping[str, Any],
prog_bar: bool = False,
logger: bool = True,
on_step: Optional[bool] = None,
Expand Down
11 changes: 7 additions & 4 deletions pytorch_lightning/utilities/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import operator
from abc import ABC
from collections import OrderedDict
from collections.abc import Mapping, Sequence
from copy import copy
from functools import partial
Expand Down Expand Up @@ -85,10 +86,12 @@ def apply_to_collection(

# Recursively apply to collection items
if isinstance(data, Mapping):
return elem_type({
k: apply_to_collection(v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
for k, v in data.items()
})
return elem_type(
OrderedDict({
k: apply_to_collection(v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
for k, v in data.items()
})
)

if isinstance(data, tuple) and hasattr(data, '_fields'): # named tuple
return elem_type(
Expand Down
18 changes: 17 additions & 1 deletion tests/utilities/test_apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numbers
from collections import namedtuple
from collections import namedtuple, OrderedDict

import numpy as np
import torch
Expand Down Expand Up @@ -76,3 +76,19 @@ def test_recursive_application_to_collection():

assert isinstance(reduced['g'], numbers.Number), 'Reduction of a number should result in a tensor'
assert reduced['g'] == expected_result['g'], 'Reduction of a number did not yield the desired result'

# mapping support
reduced = apply_to_collection({'a': 1, 'b': 2}, int, lambda x: str(x))
assert reduced == {'a': '1', 'b': '2'}
reduced = apply_to_collection(OrderedDict([('b', 2), ('a', 1)]), int, lambda x: str(x))
assert reduced == OrderedDict([('b', '2'), ('a', '1')])

# custom mappings
class _CustomCollection(dict):

def __init__(self, initial_dict):
super().__init__(initial_dict)

to_reduce = _CustomCollection({'a': 1, 'b': 2, 'c': 3})
reduced = apply_to_collection(to_reduce, int, lambda x: str(x))
assert reduced == _CustomCollection({'a': '1', 'b': '2', 'c': '3'})

0 comments on commit c292788

Please sign in to comment.