Skip to content

Commit

Permalink
fix incomplete RunningMean (Lightning-AI#1309)
Browse files Browse the repository at this point in the history
* fix RunningMean

* changelog

* fix none

* Update supporters.py

just needed to multiply by zero for init

* Revert "Update supporters.py"

This reverts commit 7e0da6c

* fix NaN

* formatting

Co-authored-by: William Falcon <waf2107@columbia.edu>
  • Loading branch information
2 people authored and akarnachev committed Apr 3, 2020
1 parent 09e5a87 commit 48b4b62
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 44 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed a bug to ensure lightning checkpoints to be backward compatible ([#1132](https://github.com/PyTorchLightning/pytorch-lightning/pull/1132))
- Fixed all warnings and errors in the docs build process ([#1191](https://github.com/PyTorchLightning/pytorch-lightning/pull/1191))
- Fixed an issue where `val_percent_check=0` would not disable validation ([#1251](https://github.com/PyTorchLightning/pytorch-lightning/pull/1251))
- Fixed average of incomplete `TensorRunningMean` ([#1309](https://github.com/PyTorchLightning/pytorch-lightning/pull/1309))

## [0.7.1] - 2020-03-07

Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1525,9 +1525,10 @@ def get_tqdm_dict(self) -> Dict[str, Union[int, str]]:
Dictionary with the items to be displayed in the progress bar.
"""
# call .item() only once but store elements without graphs
running_training_loss = self.trainer.running_loss.mean().cpu().item()
running_train_loss = self.trainer.running_loss.mean()
avg_training_loss = running_train_loss.cpu().item() if running_train_loss is not None else float('NaN')
tqdm_dict = {
'loss': '{:.3f}'.format(running_training_loss)
'loss': '{:.3f}'.format(avg_training_loss)
}

if self.trainer.truncated_bptt_steps is not None:
Expand Down
58 changes: 58 additions & 0 deletions pytorch_lightning/trainer/supporters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import torch


class TensorRunningMean(object):
"""
Tracks a running mean without graph references.
Round robbin for the mean
Examples:
>>> accum = TensorRunningMean(5)
>>> accum.last(), accum.mean()
(None, None)
>>> accum.append(torch.tensor(1.5))
>>> accum.last(), accum.mean()
(tensor(1.5000), tensor(1.5000))
>>> accum.append(torch.tensor(2.5))
>>> accum.last(), accum.mean()
(tensor(2.5000), tensor(2.))
>>> accum.reset()
>>> _= [accum.append(torch.tensor(i)) for i in range(13)]
>>> accum.last(), accum.mean()
(tensor(12.), tensor(10.))
"""
def __init__(self, window_length: int):
self.window_length = window_length
self.memory = torch.Tensor(self.window_length)
self.current_idx: int = 0
self.last_idx: int = None
self.rotated: bool = False

def reset(self) -> None:
self = TensorRunningMean(self.window_length)

def last(self):
if self.last_idx is not None:
return self.memory[self.last_idx]

def append(self, x):
# map proper type for memory if they don't match
if self.memory.type() != x.type():
self.memory.type_as(x)

# store without grads
with torch.no_grad():
self.memory[self.current_idx] = x
self.last_idx = self.current_idx

# increase index
self.current_idx += 1

# reset index when hit limit of tensor
self.current_idx = self.current_idx % self.window_length
if self.current_idx == 0:
self.rotated = True

def mean(self):
if self.last_idx is not None:
return self.memory.mean() if self.rotated else self.memory[:self.current_idx].mean()
39 changes: 0 additions & 39 deletions pytorch_lightning/trainer/supporting_classes.py

This file was deleted.

2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.trainer.supporting_classes import TensorRunningMean
from pytorch_lightning.trainer.supporters import TensorRunningMean

try:
from apex import amp
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def training_step(self, batch, batch_idx):
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.trainer.supporting_classes import TensorRunningMean
from pytorch_lightning.trainer.supporters import TensorRunningMean

try:
from apex import amp
Expand Down
2 changes: 1 addition & 1 deletion tests/collect_env_details.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def info_system():

def info_cuda():
return {
'GPU': set([torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]),
'GPU': [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())],
# 'nvidia_driver': get_nvidia_driver_version(run_lambda),
'available': torch.cuda.is_available(),
'version': torch.version.cuda,
Expand Down

0 comments on commit 48b4b62

Please sign in to comment.