Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Horovod: fixed early stopping and added metrics aggregation #3775

Merged
merged 42 commits into from
Nov 5, 2020
Merged
Show file tree
Hide file tree
Changes from 41 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
1c1a83e
Fixed early stopping for Horovod
tgaddair Oct 1, 2020
f808d33
Refactored to sync_dist_if_available
tgaddair Oct 1, 2020
7774093
Bump min Horovod version to support hvd.is_initialized
tgaddair Oct 1, 2020
fd6fb80
Changelog
tgaddair Oct 1, 2020
239fcd6
Added back change for Horovod
tgaddair Oct 6, 2020
660651c
Removed redundant checks for initialization
tgaddair Oct 6, 2020
f06d7e4
Implement metrics gathering for Horovod
tgaddair Oct 6, 2020
49f583d
Added test for EvalResult
tgaddair Oct 6, 2020
3da46e7
Renamed ddp_sync_on_step -> dist_sync_on_step
tgaddair Oct 6, 2020
4f770d4
Added metric test for Horovod
tgaddair Oct 6, 2020
905bdd1
Added option pass callable allgather function to metric base class
tgaddair Oct 8, 2020
39cedfc
Added dist_sync_fn
tgaddair Oct 8, 2020
5818466
Fixed calls to private _sync_dist
tgaddair Oct 8, 2020
857ab82
Fixed Horovod test
tgaddair Oct 8, 2020
98d5325
Added sync_tensor to the distributed backend
tgaddair Oct 8, 2020
1596f93
Skip Windows
tgaddair Oct 8, 2020
2350ba0
Insert test path
tgaddair Oct 8, 2020
5f08689
Removed redundant import
tgaddair Oct 9, 2020
4e3ce48
Updated drone
tgaddair Oct 9, 2020
1c88684
Unset HOROVOD_GPU_ALLREDUCE
tgaddair Oct 9, 2020
825278c
Unset
tgaddair Oct 9, 2020
ae45024
No cache dir
tgaddair Oct 9, 2020
a4ef8d0
No uninstall
tgaddair Oct 9, 2020
4aef1a1
Unset variables
tgaddair Oct 9, 2020
c42666f
Uninstall Horovod during initialization
tgaddair Oct 9, 2020
46d0732
Replaced more references to ddp_sync_on_step
tgaddair Oct 9, 2020
b416730
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
teddykoker Oct 22, 2020
0326b2a
Resolved merge conflicts
tgaddair Nov 3, 2020
ab8dd34
Fixed imports
tgaddair Nov 3, 2020
3590051
Fixed attribute
tgaddair Nov 3, 2020
80532eb
Added back default
tgaddair Nov 3, 2020
d7409ea
Merge branch 'master' into hvd_early_stop
tgaddair Nov 4, 2020
2a00e1a
Lint
tgaddair Nov 4, 2020
0fd1e9b
Merge branch 'master' into hvd_early_stop
tgaddair Nov 4, 2020
603724a
Added back docstring
tgaddair Nov 4, 2020
bafdb9d
Made gather_all_tensors default
tgaddair Nov 4, 2020
1177a96
Added whitespace
tgaddair Nov 4, 2020
63d8bdb
Merge branch 'master' into hvd_early_stop
SeanNaren Nov 4, 2020
2b19746
Update tests/models/test_horovod.py
tgaddair Nov 4, 2020
77ee13f
Update pytorch_lightning/metrics/metric.py
tgaddair Nov 4, 2020
855fab7
Update CHANGELOG.md
Borda Nov 4, 2020
8b1ee77
Merge branch 'master' into hvd_early_stop
SeanNaren Nov 4, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `fsspec` to tuner ([#4458](https://github.com/PyTorchLightning/pytorch-lightning/pull/4458))


- Added metrics aggregation in Horovod and fixed early stopping ([#3775](https://github.com/PyTorchLightning/pytorch-lightning/pull/3775))


### Changed


Expand Down
24 changes: 23 additions & 1 deletion pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import os
import math
from enum import Enum
from typing import Any, Optional
from typing import Any, Optional, Union

import torch

Expand All @@ -30,6 +30,12 @@
except ImportError:
amp = None

if torch.distributed.is_available():
from torch.distributed import ReduceOp
else:
class ReduceOp:
SUM = None

EPSILON = 1e-6
EPSILON_FP16 = 1e-5

Expand Down Expand Up @@ -209,6 +215,22 @@ def init_ddp_connection(
torch_backend, rank=global_rank, world_size=world_size
)

def sync_tensor(self,
tensor: Union[torch.Tensor],
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
"""
Function to reduce a tensor from several distributed processes to one aggregated tensor.
Args:
tensor: the tensor to sync and reduce
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to sum.
Can also be a string of 'avg', 'mean' to calculate the mean during reduction.
Return:
reduced value
"""
raise NotImplementedError()
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved

def __getstate__(self):
return {
'trainer': self.trainer,
Expand Down
11 changes: 9 additions & 2 deletions pytorch_lightning/accelerators/ddp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,18 @@
import sys
from os.path import abspath
from time import sleep
from typing import Optional, List
from typing import Any, Optional, List, Union

import numpy as np

from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.distributed import find_free_network_port
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities.distributed import sync_ddp_if_available
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.seed import seed_everything
from torch.nn.parallel import DistributedDataParallel
Expand Down Expand Up @@ -298,3 +299,9 @@ def configure_sync_batchnorm(self, model: LightningModule) -> LightningModule:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None)

return model

def sync_tensor(self,
tensor: Union[torch.Tensor],
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
return sync_ddp_if_available(tensor, group, reduce_op)
11 changes: 9 additions & 2 deletions pytorch_lightning/accelerators/ddp_cpu_slurm_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License
import os
from typing import List, Optional
from typing import Any, List, Optional, Union

import torch
import torch.distributed as torch_distrib
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel

from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities.distributed import sync_ddp_if_available
from pytorch_lightning.distributed.dist import LightningDistributed


Expand Down Expand Up @@ -199,3 +200,9 @@ def configure_sync_batchnorm(self, model: LightningModule) -> LightningModule:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None)

return model

def sync_tensor(self,
tensor: Union[torch.Tensor],
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
return sync_ddp_if_available(tensor, group, reduce_op)
12 changes: 9 additions & 3 deletions pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License
import os
from typing import List, Optional
from typing import Any, List, Optional, Union

import torch
import torch.distributed as torch_distrib
Expand All @@ -21,11 +21,11 @@
from torch.nn.parallel import DistributedDataParallel

from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.distributed import find_free_network_port
from pytorch_lightning.utilities.distributed import find_free_network_port, sync_ddp_if_available
from pytorch_lightning.distributed.dist import LightningDistributed

try:
Expand Down Expand Up @@ -229,3 +229,9 @@ def configure_sync_batchnorm(self, model: LightningModule) -> LightningModule:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None)

return model

def sync_tensor(self,
tensor: Union[torch.Tensor],
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
return sync_ddp_if_available(tensor, group, reduce_op)
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License
import os
from typing import List, Optional
from typing import Any, List, Optional, Union

import torch
import torch.distributed as torch_distrib
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel

from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities.distributed import sync_ddp_if_available

try:
from hydra.utils import to_absolute_path, get_original_cwd
Expand Down Expand Up @@ -198,3 +199,9 @@ def configure_sync_batchnorm(self, model: LightningModule) -> LightningModule:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None)

return model

def sync_tensor(self,
tensor: Union[torch.Tensor],
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
return sync_ddp_if_available(tensor, group, reduce_op)
12 changes: 9 additions & 3 deletions pytorch_lightning/accelerators/ddp_slurm_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License
import os
from typing import List
from typing import Any, List, Optional, Union

import torch
import torch.distributed as torch_distrib
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel

from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available
from pytorch_lightning.utilities.seed import seed_everything

try:
Expand Down Expand Up @@ -205,3 +205,9 @@ def configure_sync_batchnorm(self, model: LightningModule) -> LightningModule:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None)

return model

def sync_tensor(self,
tensor: Union[torch.Tensor],
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
return sync_ddp_if_available(tensor, group, reduce_op)
11 changes: 9 additions & 2 deletions pytorch_lightning/accelerators/ddp_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License
import os
import re
from typing import List, Optional
from typing import Any, List, Optional, Union

import torch
import torch.multiprocessing as mp
Expand All @@ -22,11 +22,12 @@
from torch.nn.parallel import DistributedDataParallel

from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.cloud_io import atomic_save, load as pl_load
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, find_free_network_port
from pytorch_lightning.utilities.distributed import sync_ddp_if_available
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.distributed.dist import LightningDistributed

Expand Down Expand Up @@ -254,3 +255,9 @@ def configure_sync_batchnorm(self, model: LightningModule) -> LightningModule:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None)

return model

def sync_tensor(self,
tensor: Union[torch.Tensor],
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
return sync_ddp_if_available(tensor, group, reduce_op)
11 changes: 9 additions & 2 deletions pytorch_lightning/accelerators/ddp_torchelastic_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License
import os
from typing import List, Optional
from typing import Any, List, Optional, Union

import torch
import torch.distributed as torch_distrib
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel

from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities.distributed import sync_ddp_if_available


try:
Expand Down Expand Up @@ -201,3 +202,9 @@ def configure_sync_batchnorm(self, model: LightningModule) -> LightningModule:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None)

return model

def sync_tensor(self,
tensor: Union[torch.Tensor],
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
return sync_ddp_if_available(tensor, group, reduce_op)
42 changes: 40 additions & 2 deletions pytorch_lightning/accelerators/horovod_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import ExitStack
from typing import Optional
from typing import Any, Optional, Union

import torch
from torch.optim.lr_scheduler import _LRScheduler

from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only

Expand Down Expand Up @@ -161,3 +161,41 @@ def barrier(self, name: Optional[str] = None):
def broadcast(self, obj, src=0):
obj = hvd.broadcast_object(obj, src)
return obj

def gather_all_tensors(self, result: Union[torch.Tensor], group: Optional[Any] = None):
if group is not None:
raise ValueError(
"Horovod does not support allgather using a subcommunicator at this time. "
"Unset `group`."
)

if len(result.shape) == 0:
# Convert scalars to single dimension tensors
result = result.reshape(1)

# sync and gather all
hvd.join()
gathered = hvd.allgather(result)
gathered_result = list(gathered.split(1, dim=0))
return gathered_result

def sync_tensor(self,
tensor: Union[torch.Tensor],
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
if group is not None:
raise ValueError(
"Horovod does not support allreduce using a subcommunicator at this time. "
"Unset `group`."
)

if reduce_op is None or reduce_op == "sum":
reduce_op = hvd.Sum
elif isinstance(reduce_op, str) and reduce_op in ("avg", "mean"):
reduce_op = hvd.Average
else:
raise ValueError(f"unrecognized `reduce_op`: {reduce_op}")

# sync all processes before reduction
hvd.join()
return hvd.allreduce(tensor, op=reduce_op)
3 changes: 3 additions & 0 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,8 @@ def log(
raise MisconfigurationException(
f"Logged key: {name} should not contain information about dataloader_idx.")

accelerator = self.trainer.accelerator_backend

self._results.log(
name,
value,
Expand All @@ -272,6 +274,7 @@ def log(
sync_dist,
sync_dist_op,
sync_dist_group,
accelerator.sync_tensor,
self._current_dataloader_idx,
)

Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,15 +124,17 @@ def log(
sync_dist: bool = False,
sync_dist_op: Union[Any, str] = 'mean',
sync_dist_group: Optional[Any] = None,
sync_fn: Callable = None,
dataloader_idx: Optional[int] = None,
):
# no metrics should be logged with graphs
if not enable_graph and isinstance(value, torch.Tensor):
value = value.detach()

# sync across ddp
# sync across workers when using distributed training
sync_fn = sync_fn or sync_ddp_if_available
if sync_dist and isinstance(value, (torch.Tensor, numbers.Number)):
value = sync_ddp_if_available(value, group=sync_dist_group, reduce_op=sync_dist_op)
value = sync_fn(value, group=sync_dist_group, reduce_op=sync_dist_op)

if 'meta' not in self:
self.__setitem__('meta', {})
Expand Down
Loading