Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add knob (IGNITE_DISABLE_DISTRIBUTED_METRICS=1) to disable distributed metrics reduction #2895

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ignite/contrib/metrics/precision_recall_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def compute(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # type: i
_prediction_tensor = torch.cat(self._predictions, dim=0)
_target_tensor = torch.cat(self._targets, dim=0)

ws = idist.get_world_size()
ws = idist.get_metrics_computation_world_size()
if ws > 1:
# All gather across all processes
_prediction_tensor = cast(torch.Tensor, idist.all_gather(_prediction_tensor))
Expand Down
10 changes: 10 additions & 0 deletions ignite/distributed/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import socket
from functools import wraps
from typing import Any, Callable, List, Mapping, Optional, Tuple, Union
Expand All @@ -20,6 +21,7 @@
"available_backends",
"model_name",
"get_world_size",
"get_metrics_computation_world_size",
"get_rank",
"get_local_rank",
"get_nproc_per_node",
Expand Down Expand Up @@ -141,6 +143,14 @@ def get_world_size() -> int:
return _model.get_world_size()


def get_metrics_computation_world_size() -> int:
"""Returns world size of current distributed configuration for metrics computation. Returns 1 if no distributed configuration."""
if os.environ.get("IGNITE_DISABLE_DISTRIBUTED_METRICS") == "1":
return 1

return get_world_size()


def get_rank() -> int:
"""Returns process rank within current distributed configuration. Returns 0 if no distributed configuration."""
if _need_to_sync and isinstance(_model, _SerialModel):
Expand Down
2 changes: 1 addition & 1 deletion ignite/metrics/epoch_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def compute(self) -> float:
_prediction_tensor = torch.cat(self._predictions, dim=0)
_target_tensor = torch.cat(self._targets, dim=0)

ws = idist.get_world_size()
ws = idist.get_metrics_computation_world_size()
if ws > 1:
# All gather across all processes
_prediction_tensor = cast(torch.Tensor, idist.all_gather(_prediction_tensor))
Expand Down
4 changes: 2 additions & 2 deletions ignite/metrics/frequency.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def update(self, output: int) -> None:
def compute(self) -> float:
time_divisor = 1.0

if idist.get_world_size() > 1:
time_divisor *= idist.get_world_size()
if idist.get_metrics_computation_world_size() > 1:
time_divisor *= idist.get_metrics_computation_world_size()

# Returns the average processed objects per second across all workers
return self._n / self._elapsed * time_divisor
Expand Down
2 changes: 1 addition & 1 deletion ignite/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ def another_wrapper(self: Metric, *args: Any, **kwargs: Any) -> Callable:
raise RuntimeError(
"Decorator sync_all_reduce should be used on ignite.metric.Metric class methods only"
)
ws = idist.get_world_size()
ws = idist.get_metrics_computation_world_size()
unreduced_attrs = {}
if len(attrs) > 0 and ws > 1:
for attr in attrs:
Expand Down
2 changes: 1 addition & 1 deletion ignite/metrics/running_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def _get_metric_value(self) -> Union[torch.Tensor, float]:
@sync_all_reduce("src")
def _get_output_value(self) -> Union[torch.Tensor, float]:
# we need to compute average instead of sum produced by @sync_all_reduce("src")
output = cast(Union[torch.Tensor, float], self.src) / idist.get_world_size()
output = cast(Union[torch.Tensor, float], self.src) / idist.get_metrics_computation_world_size()
return output

def _metric_iteration_completed(self, engine: Engine) -> None:
Expand Down