Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Move log value to cpu. #4592

Merged
merged 19 commits into from
Nov 10, 2020
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,12 @@ def detach(self):
if isinstance(v, torch.Tensor):
self.__setitem__(k, v.detach())

def cpu(self):
"""Move all self attributes to CPU."""
for k, v in self.items():
tchaton marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(v, torch.Tensor):
self.__setitem__(k, v.cpu())

def __repr__(self):
self_copy = self.copy()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,10 @@ def cache_result(self) -> None:
# attach capture batch_size
Result.attach_batch_size(self._batch_size, hook_result)

hook_result.detach()
if self.trainer.move_metrics_to_cpu:
hook_result.cpu()

self._internals[fx_name].append(
hook_result,
dataloader_idx=dataloader_idx,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,16 @@ def cache_logged_metrics(self) -> Union[EpochResultStore, None]:
if self._current_stage is not None:
self._cached_results[self._current_stage].cache_result()

def on_trainer_init(self, logger, flush_logs_every_n_steps, log_every_n_steps):
def on_trainer_init(self, logger, flush_logs_every_n_steps: int, log_every_n_steps: int, move_metrics_to_cpu: bool):
# logging
self.configure_logger(logger)
# todo: IDE is complaining, these shall be initialized in the Trainer init at leas as placeholders
# and assign here the desired value
self.trainer.flush_logs_every_n_steps = flush_logs_every_n_steps
self.trainer.log_every_n_steps = log_every_n_steps

self.trainer.move_metrics_to_cpu = move_metrics_to_cpu

@property
def should_flush_logs(self):
should_flush = (self.trainer.global_step + 1) % self.trainer.flush_logs_every_n_steps == 0
Expand Down
32 changes: 27 additions & 5 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from pytorch_lightning.plugins.plugin_connector import PluginConnector
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.accelerators.cpu_accelerator import CPUAccelerator
from pytorch_lightning.utilities.memory import recursive_detach

# warnings to ignore in trainer
warnings.filterwarnings(
Expand Down Expand Up @@ -135,6 +136,7 @@ def __init__(
amp_level: str = 'O2',
distributed_backend: Optional[str] = None,
automatic_optimization: bool = True,
move_metrics_to_cpu: bool = False,
):
r"""
Customize every aspect of training via flags
Expand Down Expand Up @@ -272,6 +274,9 @@ def __init__(
stored in a different place than the logs written in `default_root_dir`.
Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/'
Defaults to `default_root_dir`.

move_metrics_to_cpu: Whether to force internal logged metrics to be moved to cpu.
This can save some gpu memory, but can make training slower. Use with attention.
"""
super().__init__()

Expand Down Expand Up @@ -363,7 +368,12 @@ def __init__(
self.profile_connector.on_trainer_init(profiler)

# init logger flags
self.logger_connector.on_trainer_init(logger, flush_logs_every_n_steps, log_every_n_steps)
self.logger_connector.on_trainer_init(
logger,
flush_logs_every_n_steps,
log_every_n_steps,
move_metrics_to_cpu
)

# init debugging flags
self.debugging_connector.on_init_start(
Expand Down Expand Up @@ -603,12 +613,11 @@ def run_evaluation(self, test_mode: bool = False, max_batches=None):
# log step metrics
step_metrics = self.evaluation_loop.log_evaluation_step_metrics(batch, batch_idx)

if step_metrics is not None:
dl_step_metrics.append(step_metrics)
# track epoch level outputs
dl_step_metrics = self.track_output_for_epoch_end(dl_step_metrics, step_metrics)

# track epoch level outputs
if output is not None:
dl_outputs.append(output)
dl_outputs = self.track_output_for_epoch_end(dl_outputs, output)

self.evaluation_loop.outputs.append(dl_outputs)
self.evaluation_loop.step_metrics.append(dl_step_metrics)
Expand All @@ -634,6 +643,19 @@ def run_evaluation(self, test_mode: bool = False, max_batches=None):

return eval_loop_results, deprecated_eval_results

def track_output_for_epoch_end(self, outputs, output):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
if output is not None:
if isinstance(output, Result):
output.detach()
if self.move_metrics_to_cpu:
output.cpu()
elif isinstance(output, dict):
output = recursive_detach(output, to_cpu=self.move_metrics_to_cpu)
elif isinstance(output, torch.Tensor) and output.is_cuda and self.move_metrics_to_cpu:
output = output.cpu()
outputs.append(output)
return outputs

def run_test(self):
# only load test dataloader for testing
# self.reset_test_dataloader(ref_model)
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,8 @@ def _process_training_step_output_1_0(self, training_step_output, split_batch):
# track metrics without grads for epoch reduction
training_step_output_for_epoch_end = copy(result)
training_step_output_for_epoch_end.detach()
if self.trainer.move_metrics_to_cpu:
training_step_output_for_epoch_end.cpu()

# what flows back into the system
training_step_output = result
Expand Down
9 changes: 7 additions & 2 deletions pytorch_lightning/utilities/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch


def recursive_detach(in_dict: dict) -> dict:
def recursive_detach(in_dict: dict, to_cpu: bool = False) -> dict:
"""Detach all tensors in `in_dict`.

May operate recursively if some of the values in `in_dict` are dictionaries
Expand All @@ -26,6 +26,7 @@ def recursive_detach(in_dict: dict) -> dict:

Args:
in_dict:
to_cpu: Wheter to move tensor to cpu

Return:
out_dict:
Expand All @@ -35,7 +36,11 @@ def recursive_detach(in_dict: dict) -> dict:
if isinstance(v, dict):
out_dict.update({k: recursive_detach(v)})
elif callable(getattr(v, 'detach', None)):
out_dict.update({k: v.detach()})
# detach
v = v.detach()
if to_cpu:
v = v.cpu()
out_dict.update({k: v})
else:
out_dict.update({k: v})
return out_dict
Expand Down