Skip to content

Commit

Permalink
Implement CenteredClip in averager
Browse files Browse the repository at this point in the history
  • Loading branch information
borzunov committed Sep 8, 2021
1 parent b84f62b commit 6621bb5
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 11 deletions.
107 changes: 107 additions & 0 deletions hivemind/averaging/accumulators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import dataclasses
from abc import ABC
from typing import Callable, Optional

import torch


class AccumulatorBase(ABC):
def accumulate_part(self, tensor: torch.Tensor, weight: float) -> None:
...

def reduce(self) -> torch.Tensor:
...


AccumulatorFactory = Callable[[torch.Size, int], AccumulatorBase]


class MeanAccumulator(AccumulatorBase):
def __init__(self, part_shape: torch.Size, _n_peers: int):
self._accumulator = torch.zeros(part_shape)
self._denominator = 0.0

def accumulate_part(self, tensor_part: torch.Tensor, weight: float) -> None:
self._accumulator.add_(tensor_part, alpha=weight)
self._denominator += weight

def reduce(self) -> torch.Tensor:
return self._accumulator.div_(self._denominator)


class CenteredClipAccumulator(AccumulatorBase):
def __init__(self, part_shape: torch.Size, n_peers: int, **kwargs):
self._kwargs = kwargs

self._tensors = torch.empty([n_peers] + part_shape)
self._weights = torch.empty(n_peers)
self._index = 0

def accumulate_part(self, tensor_part: torch.Tensor, weight: float) -> None:
self._tensors[self._index] = tensor_part
self._weights[self._index] = weight
self._index += 1

def reduce(self) -> torch.Tensor:
clipped = centered_clip(self._tensors, self._weights, **self._kwargs)
return clipped.result


@dataclasses.dataclass(frozen=True)
class CenteredClipResult:
result: torch.Tensor
n_clipped: torch.Tensor
last_step_delta: torch.Tensor


def centered_clip(input_tensors: torch.Tensor, weights: torch.Tensor,
tau: float = 1.0, n_iters: int = 20, stop_delta: Optional[float] = None) -> CenteredClipResult:
"""
Optimized implementation of CenteredClip from [Karimireddy, 2021].
Intended to be used in a decentralized fashion as in [Gorbunov, 2021].
:stop_delta: Stop iterations early if the ``L_inf`` norm of the last step is less than ``stop_delta``.
Note: if this option is used, the step norm calculations may increase the time per iteration by ~25%.
References:
[Karimireddy, 2021] Karimireddy, Sai Praneeth, Lie He, and Martin Jaggi. "Learning from history for byzantine
robust optimization." International Conference on Machine Learning. PMLR, 2021.
[Gorbunov, 2021] Gorbunov, Eduard, Alexander Borzunov, Michael Diskin, and Max Ryabinin.
"Secure Distributed Training at Scale." arXiv preprint arXiv:2106.11257 (2021).
"""

with torch.no_grad():
n_peers = input_tensors.shape[0]
result_shape = input_tensors.shape[1:]

input_tensors = input_tensors.flatten(start_dim=1)
weights /= weights.sum()

# This finds medians faster than torch.median() and torch.quantile(q=0.5),
# see https://github.com/pytorch/pytorch/issues/51450
sorted_tensors = input_tensors.sort(dim=0).values
result = sorted_tensors[n_peers // 2].clone()
delta = None

diff = torch.sub(input_tensors, result, out=sorted_tensors) # Reuse memory from `sorted_tensors`
for _ in range(n_iters):
norms = diff.norm(dim=1)
coeffs = weights * torch.minimum(torch.tensor(1.0), tau / norms)

if stop_delta is not None:
prev_diff = result[...] = diff[0] # Reuse memory from `result`

# We only need to update `diff` (not `result`) between iterations
diff.addmm_(-coeffs.repeat(n_peers, 1), diff)

if stop_delta is not None:
delta = prev_diff.sub_(diff[0]).max()
if delta < stop_delta:
break
torch.sub(input_tensors[0], diff[0], out=result)

return CenteredClipResult(result=result.reshape(result_shape),
n_clipped=(tau < norms).sum(),
last_step_delta=delta)
5 changes: 4 additions & 1 deletion hivemind/averaging/allreduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch

from hivemind.averaging.accumulators import AccumulatorFactory
from hivemind.averaging.partition import AllreduceException, TensorPartContainer, TensorPartReducer
from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase, StubBase
Expand Down Expand Up @@ -58,6 +59,7 @@ def __init__(
tensors: Sequence[torch.Tensor],
ordered_peer_ids: Sequence[PeerID],
peer_fractions: Tuple[float, ...],
accumulator_factory: AccumulatorFactory,
weights: Optional[Sequence[float]] = None,
modes: Optional[Sequence[AveragingMode]] = None,
gathered: Optional[Dict[PeerID, Any]] = None,
Expand Down Expand Up @@ -97,7 +99,8 @@ def __init__(
self.tensor_part_reducer = TensorPartReducer(
tuple(part.shape for part in self.parts_for_local_averaging),
len(self.sender_peer_ids),
self.sender_weights,
weights=self.sender_weights,
accumulator_factory=accumulator_factory,
)

def __repr__(self):
Expand Down
3 changes: 3 additions & 0 deletions hivemind/averaging/averager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import numpy as np
import torch

from hivemind.averaging.accumulators import AccumulatorFactory, MeanAccumulator
from hivemind.averaging.allreduce import AllreduceException, AllReduceRunner, AveragingMode, GroupID
from hivemind.averaging.group_info import GroupInfo
from hivemind.averaging.load_balancing import load_balance_peers
Expand Down Expand Up @@ -112,6 +113,7 @@ def __init__(
compression: CompressionBase = NoCompression(),
state_compression: CompressionBase = NoCompression(),
tensor_infos: Optional[Sequence[CompressionInfo]] = None,
accumulator_factory: AccumulatorFactory = MeanAccumulator,
bandwidth: Optional[float] = None,
min_vector_size: int = 0,
auxiliary: bool = False,
Expand Down Expand Up @@ -170,6 +172,7 @@ def __init__(
compression=compression,
part_size_bytes=part_size_bytes,
min_vector_size=min_vector_size,
accumulator_factory=accumulator_factory,
)
self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
self._running_groups: Dict[GroupID, AllReduceRunner] = {} # one or more assembled groups that run all-reduce
Expand Down
18 changes: 9 additions & 9 deletions hivemind/averaging/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
import torch

from hivemind.averaging.accumulators import AccumulatorFactory
from hivemind.compression import CompressionBase, CompressionInfo, NoCompression
from hivemind.proto import runtime_pb2
from hivemind.utils.asyncio import amap_in_executor
Expand Down Expand Up @@ -171,16 +172,17 @@ class TensorPartReducer:
:note: even if local peer is not sending data, local parts will be used for shape information
"""

def __init__(self, part_shapes: Sequence[torch.Size], num_senders: int, weights: Optional[Sequence[float]] = None):
def __init__(self, part_shapes: Sequence[torch.Size], num_senders: int,
*, weights: Optional[Sequence[float]], accumulator_factory: AccumulatorFactory):
self.part_shapes, self.num_senders, self.num_parts = part_shapes, num_senders, len(part_shapes)
self.weights = tuple(weights or (1 for _ in range(num_senders)))
assert len(self.weights) == self.num_senders, "The number of weights is inconsistent with num_senders"
assert all(isinstance(weight, (int, float)) for weight in self.weights)
self.current_part_index = -1 # index in local_parts of the part that should be loaded next
self.current_part_accumulated_from = 0 # number of peers from which the current part was accumulated
self.accumulator = None # this will contain the sum of current tensor part from group peers
self.denominator = 0.0 # total weight accumulated from all peers for current part
self.current_part_future = asyncio.Future()
self.accumulator_factory = accumulator_factory
self.accumulator = None
self.finished = asyncio.Event()
self.reset_accumulators()

Expand All @@ -194,8 +196,7 @@ def reset_accumulators(self):
self.current_part_index += 1
self.current_part_accumulated_from = 0
self.current_part_future = asyncio.Future()
self.accumulator = torch.zeros(self.part_shapes[self.current_part_index])
self.denominator = 0.0
self.accumulator = self.accumulator_factory(self.part_shapes[self.current_part_index], self.num_senders)

async def accumulate_part(self, sender_index: int, part_index: int, tensor_part: torch.Tensor) -> torch.Tensor:
"""Add vector part to accumulator, wait for all other vectors to be added, then return the average part"""
Expand All @@ -211,21 +212,20 @@ async def accumulate_part(self, sender_index: int, part_index: int, tensor_part:

current_part_future = self.current_part_future

self.accumulator.add_(tensor_part, alpha=self.weights[sender_index])
self.denominator += self.weights[sender_index]
self.accumulator.accumulate_part(tensor_part, self.weights[sender_index])
self.current_part_accumulated_from += 1

assert self.current_part_accumulated_from <= self.num_senders
if self.current_part_accumulated_from == self.num_senders:
current_part_future.set_result(self.accumulator.div_(self.denominator))
current_part_future.set_result(self.accumulator.reduce())
self.reset_accumulators()
return await current_part_future

def finalize(self):
if not self.finished.is_set():
if hasattr(self, "current_part_future"):
self.current_part_future.cancel()
del self.accumulator
self.accumulator = None
self.finished.set()

def __del__(self):
Expand Down
4 changes: 3 additions & 1 deletion tests/test_allreduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch

from hivemind import Quantile8BitQuantization, aenumerate
from hivemind.averaging.accumulators import MeanAccumulator
from hivemind.averaging.allreduce import AllReduceRunner, AveragingMode
from hivemind.averaging.partition import TensorPartContainer, TensorPartReducer
from hivemind.compression import deserialize_torch_tensor
Expand Down Expand Up @@ -119,7 +120,7 @@ async def wait_synchronously():
@pytest.mark.asyncio
async def test_reducer(num_senders: int, num_parts: int, synchronize_prob: float):
tensor_part_shapes = [torch.Size([i]) for i in range(num_parts)]
reducer = TensorPartReducer(tensor_part_shapes, num_senders)
reducer = TensorPartReducer(tensor_part_shapes, num_senders, weights=None, accumulator_factory=MeanAccumulator)

local_tensors_by_sender = [[torch.randn(i) for i in range(num_parts)] for j in range(num_senders)]

Expand Down Expand Up @@ -196,6 +197,7 @@ async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions,
tensors=[x.clone() for x in tensors_by_peer[p2p.peer_id]],
ordered_peer_ids=peers,
peer_fractions=peer_fractions,
accumulator_factory=MeanAccumulator,
modes=peer_modes,
weights=averaging_weights,
part_size_bytes=part_size_bytes,
Expand Down

0 comments on commit 6621bb5

Please sign in to comment.