Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

State reduce #3

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
Most recent change on the bottom.

## [Unreleased]
### Added
- Added `get_state` and `set_state` methods
- Added `accumulate_state` method

## [0.2.0] - 2021-11-22

Expand Down
82 changes: 82 additions & 0 deletions tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,88 @@ def test_batching(do_accumulate_by, nan_attrs, allclose):
runstats.reset(reset_n_bins=True)


@pytest.mark.parametrize(
"dim,reduce_dims",
[
(1, tuple()),
(1, (0,)),
(3, tuple()),
(3, (0,)),
((2, 3), tuple()),
(torch.Size((1, 2, 1)), tuple()),
(torch.Size((1, 2, 1)), (1,)),
(torch.Size((3, 2, 4)), (0, 2)),
(torch.Size((3, 2, 4)), (0, 1, 2)),
],
)
@pytest.mark.parametrize("do_accumulate_by", [True, False])
@pytest.mark.parametrize("reduction", [Reduction.MEAN, Reduction.RMS])
def test_state(dim, reduce_dims, do_accumulate_by, reduction, allclose):
runstats1, runstats2 = [
RunningStats(dim=dim, reduction=reduction, reduce_dims=reduce_dims)
for _ in range(2)
]
batch1, batch2 = [
torch.randn((random.randint(1, 10),) + runstats1.dim) for _ in range(2)
]
if do_accumulate_by:
acc_by1, acc_by2 = [
torch.randint(0, random.randint(1, 5), size=(batch.shape[0],))
for batch in (batch1, batch2)
]
else:
acc_by1, acc_by2 = None, None
runstats1.accumulate_batch(batch1, accumulate_by=acc_by1)
runstats2.accumulate_batch(batch2, accumulate_by=acc_by2)
_, res2 = runstats1.current_result(), runstats2.current_result()
# now, load the state of 2 -> 1
runstats1.set_state(runstats2.get_state())
# should be the same since moved the state
assert allclose(runstats1.current_result(), res2)


@pytest.mark.parametrize(
"dim,reduce_dims",
[
(1, tuple()),
(1, (0,)),
(3, tuple()),
(3, (0,)),
((2, 3), tuple()),
(torch.Size((1, 2, 1)), tuple()),
(torch.Size((1, 2, 1)), (1,)),
(torch.Size((3, 2, 4)), (0, 2)),
(torch.Size((3, 2, 4)), (0, 1, 2)),
],
)
@pytest.mark.parametrize("do_accumulate_by", [True, False])
@pytest.mark.parametrize("reduction", [Reduction.MEAN, Reduction.RMS])
def test_accumulate_state(dim, reduce_dims, do_accumulate_by, reduction, allclose):
runstats1, runstats2, runstats3 = [
RunningStats(dim=dim, reduction=reduction, reduce_dims=reduce_dims)
for _ in range(3)
]
batch1, batch2 = [
torch.randn((random.randint(1, 10),) + runstats1.dim) for _ in range(2)
]
if do_accumulate_by:
acc_by1, acc_by2 = [
torch.randint(0, random.randint(1, 5), size=(batch.shape[0],))
for batch in (batch1, batch2)
]
else:
acc_by1, acc_by2 = None, None
runstats1.accumulate_batch(batch1, accumulate_by=acc_by1)
runstats2.accumulate_batch(batch2, accumulate_by=acc_by2)
# now accumulate batch2 into runstats1 through the state
runstats1.accumulate_state(runstats2.get_state())
# and make a truth baseline
runstats3.accumulate_batch(batch1, accumulate_by=acc_by1)
runstats3.accumulate_batch(batch2, accumulate_by=acc_by2)
# and check:
assert allclose(runstats1.current_result(), runstats3.current_result())


@pytest.mark.parametrize("reduction", [Reduction.MEAN, Reduction.RMS])
def test_zeros(reduction, allclose):
dim = (4,)
Expand Down
18 changes: 18 additions & 0 deletions tests/test_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pytest
import random

import torch

from torch_runstats._util import _pad_dim0


@pytest.mark.parametrize("ndim", [1, 2, 4])
def test_pad_dim0(ndim):
orig_shape = tuple(random.randint(1, 5) for _ in range(ndim))
x = torch.ones(orig_shape)
to_add = 3
padded = _pad_dim0(x, to_add)
assert padded.shape[1:] == orig_shape[1:]
assert padded.shape[0] == orig_shape[0] + to_add
assert torch.equal(x, padded[:-to_add])
assert padded[-to_add:].abs().max() == 0
87 changes: 66 additions & 21 deletions torch_runstats/_runstats.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import torch
from .scatter import scatter
from ._util import _pad_dim0


def _prod(x):
Expand Down Expand Up @@ -236,28 +237,10 @@ def accumulate_batch(
# do we need new bins?
N_to_add = new_sum.shape[0] - self._n_bins
if N_to_add > 0:

# time to expand
self._state = torch.cat(
(
self._state,
self._state.new_zeros((N_to_add,) + self._state.shape[1:]),
),
dim=0,
)
self._n = torch.cat(
(self._n, self._n.new_zeros((N_to_add,) + self._n.shape[1:])), dim=0
)

# assert self._state.shape == (self._n_bins + N_to_add,) + self._dim
self._n_bins += N_to_add

self._expand_state(N_to_add)
elif N_to_add < 0:

new_sum = torch.cat(
(new_sum, new_sum.new_zeros((-N_to_add,) + new_sum.shape[1:])), dim=0
)
N = torch.cat((N, N.new_zeros((-N_to_add,) + N.shape[1:])), dim=0)
new_sum = _pad_dim0(new_sum, -N_to_add)
N = _pad_dim0(N, -N_to_add)

self._state += (new_sum - N * self._state) / (self._n + N)
self._n += N
Expand Down Expand Up @@ -305,6 +288,50 @@ def current_result(self) -> torch.Tensor:
elif self._reduction == Reduction.RMS:
return self._state.sqrt()

def get_state(self) -> Tuple[torch.Tensor, ...]:
"""Get the current internal state of the object for later use.

The contents of this tuple of tensors has no gueranteed format and should
only be used within a program and with ``RunningStats`` objects that were
constructed with exactly identical parameters. The format of the result
is NOT gueranteed to be consistant across versions and should not be
serialized.

The returned tensors are copies of the internal state and are safe to
mutate.

Returns:
a tuple of tensors.
"""
return tuple(t.clone() for t in (self._n, self._state))

def set_state(self, state: Tuple[torch.Tensor, ...]) -> None:
"""Set the internal state of this object to ``state`` from an earlier call to ``get_state``.

Args:
state: an internal state of a ``RunningStats`` object of identical parameters retreived
with calling its ``get_state`` method.
"""
n, state = state
self._n_bins = len(n)
self._n = n.to(dtype=self._n.dtype, device=self._n.device)
self._state = state.to(dtype=self._state.dtype, device=self._state.device)

def accumulate_state(self, state: Tuple[torch.Tensor, ...]) -> None:
""" """
n, state = state
N_to_add = len(n) - self.n_bins
if N_to_add > 0:
self._expand_state(N_to_add)
elif N_to_add < 0:
# need to expand the parameter state
n = _pad_dim0(n, -N_to_add)
state = _pad_dim0(state, -N_to_add)
self._state += n * (state - self._state) / (self._n + n)
self._n += n
# Make div by zero 0
self._state = torch.nan_to_num_(self._state, nan=0.0)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

almost the same as accumulate_batch. So is it for the horovod interface?

@property
def n(self) -> torch.Tensor:
"""The number of samples processed so far in each bin.
Expand Down Expand Up @@ -338,3 +365,21 @@ def reduce_dims(self) -> Tuple[int, ...]:
def reduction(self) -> Reduction:
"""The reduction computed by this object."""
return self._reduction

def _expand_state(self, N_to_add: int) -> None:
if N_to_add == 0:
return
elif N_to_add < 0:
raise ValueError
# time to expand
self._state = torch.cat(
(
self._state,
self._state.new_zeros((N_to_add,) + self._state.shape[1:]),
),
dim=0,
)
self._n = torch.cat(
(self._n, self._n.new_zeros((N_to_add,) + self._n.shape[1:])), dim=0
)
self._n_bins += N_to_add
9 changes: 9 additions & 0 deletions torch_runstats/_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import torch


def _pad_dim0(x: torch.Tensor, n: int) -> torch.Tensor:
if n == 0:
return
elif n < 0:
raise ValueError
return torch.nn.functional.pad(x, (0,) * ((x.ndim - 1) * 2) + (0, n))
2 changes: 1 addition & 1 deletion torch_runstats/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.2.0"
__version__ = "0.2.1"