Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metric aggregation #3321

Merged
merged 23 commits into from
Sep 14, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 48 additions & 27 deletions pytorch_lightning/metrics/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
sync tensors between different processes in a DDP scenario, when needed.
"""

from functools import reduce
import numbers
from typing import Any, Callable, Optional, Union

import numpy as np
import torch
from torch.distributed.distributed_c10d import reduce_op
from torch.utils.data._utils.collate import np_str_obj_array_pattern

from pytorch_lightning.utilities import rank_zero_warn
Expand All @@ -31,10 +33,20 @@
try:
from torch.distributed import ReduceOp
except ImportError:

class ReduceOp:
SUM = None

rank_zero_warn('Unsupported `ReduceOp` for distributed computing')
rank_zero_warn("Unsupported `ReduceOp` for distributed computing")

try:
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.core.functions as xf
except ImportError:
XLA_AVAILABLE = False
else:
XLA_AVAILABLE = True
justusschock marked this conversation as resolved.
Show resolved Hide resolved


def _apply_to_inputs(func_to_apply: Callable, *dec_args, **dec_kwargs) -> Callable:
Expand Down Expand Up @@ -138,8 +150,9 @@ def _numpy_metric_input_conversion(func_to_decorate: Callable) -> Callable:
Return:
Callable: the decorated function
"""
return _apply_to_inputs(
apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy)(func_to_decorate)
return _apply_to_inputs(apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy)(
func_to_decorate
)


def _tensor_metric_output_conversion(func_to_decorate: Callable) -> Callable:
Expand Down Expand Up @@ -185,8 +198,9 @@ def _tensor_metric_input_conversion(func_to_decorate: Callable) -> Callable:
Return:
Callable: the decorated function
"""
return _apply_to_inputs(
apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor)(func_to_decorate)
return _apply_to_inputs(apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor)(
func_to_decorate
)


def _tensor_collection_metric_output_conversion(func_to_decorate: Callable) -> Callable:
Expand All @@ -199,8 +213,9 @@ def _tensor_collection_metric_output_conversion(func_to_decorate: Callable) -> C
Return:
Callable: the decorated function
"""
return _apply_to_outputs(apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number),
convert_to_tensor)(func_to_decorate)
return _apply_to_outputs(apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor)(
func_to_decorate
)


def _tensor_metric_conversion(func_to_decorate: Callable) -> Callable:
Expand Down Expand Up @@ -240,10 +255,9 @@ def _tensor_collection_metric_conversion(func_to_decorate: Callable) -> Callable
return _tensor_collection_metric_output_conversion(func_convert_inputs)


def sync_ddp_if_available(result: Union[torch.Tensor],
group: Optional[Any] = None,
reduce_op: Optional[ReduceOp] = None
) -> torch.Tensor:
def sync_ddp_if_available(
result: Union[torch.Tensor], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
) -> torch.Tensor:
"""
Function to reduce the tensors from several ddp processes to one master process

Expand All @@ -265,23 +279,34 @@ def sync_ddp_if_available(result: Union[torch.Tensor],

if reduce_op is None:
reduce_op = torch.distributed.ReduceOp.SUM
elif isinstance(reduce_op, str) and reduce_op in ('avg', 'mean'):
elif isinstance(reduce_op, str) and reduce_op in ("avg", "mean"):
reduce_op = torch.distributed.ReduceOp.SUM
divide_by_world_size = True

# sync all processes before reduction
torch.distributed.barrier(group=group)
torch.distributed.all_reduce(result, op=reduce_op, group=group,
async_op=False)
torch.distributed.all_reduce(result, op=reduce_op, group=group, async_op=False)

if divide_by_world_size:
result = result / torch.distributed.get_world_size(group)

return result

def at_least_1d(tensor: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
"""Makes sure the tensor is at least of 1d shape

Args:
tensor: the tensor or array to check the shape for

Returns:
the optionally reshaped tensor
"""
if tensor.shape == ():
tensor = tensor.reshape(1,)
return tensor


def gather_all_tensors_if_available(result: Union[torch.Tensor],
group: Optional[Any] = None):
def gather_all_tensors_if_available(result: Union[torch.Tensor], group: Optional[Any] = None):
justusschock marked this conversation as resolved.
Show resolved Hide resolved
"""
Function to gather all tensors from several ddp processes onto a list that
is broadcasted to all processes
Expand Down Expand Up @@ -312,8 +337,7 @@ def gather_all_tensors_if_available(result: Union[torch.Tensor],
return result


def sync_ddp(group: Optional[Any] = None,
reduce_op: Optional[ReduceOp] = None) -> Callable:
def sync_ddp(group: Optional[Any] = None, reduce_op: Optional[ReduceOp] = None) -> Callable:
"""
This decorator syncs a functions outputs across different processes for DDP.

Expand All @@ -327,15 +351,14 @@ def sync_ddp(group: Optional[Any] = None,
"""

def decorator_fn(func_to_decorate):
return _apply_to_outputs(apply_to_collection, torch.Tensor,
sync_ddp_if_available, group=group,
reduce_op=reduce_op)(func_to_decorate)
return _apply_to_outputs(
apply_to_collection, torch.Tensor, sync_ddp_if_available, group=group, reduce_op=reduce_op
)(func_to_decorate)

return decorator_fn


def numpy_metric(group: Optional[Any] = None,
reduce_op: Optional[ReduceOp] = None) -> Callable:
def numpy_metric(group: Optional[Any] = None, reduce_op: Optional[ReduceOp] = None) -> Callable:
"""
This decorator shall be used on all function metrics working on numpy arrays.
It handles the argument conversion and DDP reduction for metrics working on numpy.
Expand All @@ -357,8 +380,7 @@ def decorator_fn(func_to_decorate):
return decorator_fn


def tensor_metric(group: Optional[Any] = None,
reduce_op: Optional[ReduceOp] = None) -> Callable:
def tensor_metric(group: Optional[Any] = None, reduce_op: Optional[ReduceOp] = None) -> Callable:
"""
This decorator shall be used on all function metrics working on tensors.
It handles the argument conversion and DDP reduction for metrics working on tensors.
Expand All @@ -379,8 +401,7 @@ def decorator_fn(func_to_decorate):
return decorator_fn


def tensor_collection_metric(group: Optional[Any] = None,
reduce_op: Optional[ReduceOp] = None) -> Callable:
def tensor_collection_metric(group: Optional[Any] = None, reduce_op: Optional[ReduceOp] = None) -> Callable:
"""
This decorator shall be used on all function metrics working on tensors and returning collections
that cannot be converted to tensors.
Expand Down
Loading