Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove logger_connector legacy code #6733

Merged
merged 3 commits into from
Mar 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -244,10 +244,6 @@ def add_progress_bar_metrics(self, metrics):

self.trainer.dev_debugger.track_pbar_metrics_history(metrics)

def track_metrics_deprecated(self, deprecated_eval_results):
self._track_callback_metrics(deprecated_eval_results)
self.__process_eval_epoch_end_results_and_log_legacy(deprecated_eval_results)

def evaluation_epoch_end(self):
# reset dataloader idx
model_ref = self.trainer.lightning_module
Expand Down Expand Up @@ -331,32 +327,6 @@ def _track_callback_metrics(self, eval_results):
if self.trainer.state in (TrainerState.TESTING, TrainerState.VALIDATING):
self.trainer.logger_connector.evaluation_callback_metrics.update(flat)

def __process_eval_epoch_end_results_and_log_legacy(self, eval_results):
if self.trainer.sanity_checking:
return

if eval_results is not None and len(eval_results) > 0:

# in eval, the user may return something at every validation step without final reduction
if not isinstance(eval_results, list):
eval_results = [eval_results]

for result_idx, result in enumerate(eval_results):
_, prog_bar_metrics, log_metrics, _ = self.trainer.process_dict_result(result)

# eval loop returns all metrics
dataloader_result_metrics = {**prog_bar_metrics, **log_metrics}

# add metrics to prog bar
self.trainer.logger_connector.add_progress_bar_metrics(prog_bar_metrics)

# log metrics
if len(log_metrics) > 0:
self.trainer.logger_connector.log_metrics(log_metrics, {})

if len(dataloader_result_metrics) > 0:
self.eval_loop_results.append(dataloader_result_metrics)

def on_train_epoch_end(self):
# inform cached logger connector epoch finished
self.cached_results.has_batch_loop_finished = True
Expand All @@ -368,36 +338,11 @@ def log_train_epoch_end_metrics(self, epoch_output, num_optimizers):

model = self.trainer.lightning_module

# ------------------------
# determine if using a result obj
# ------------------------
# [optimizer_idx][training_step_idx][tbptt_index]
opt_idx_outputs = epoch_output[0]

# TODO: deprecate 1.0
try:
sample_obj = opt_idx_outputs[0][0] if isinstance(opt_idx_outputs[0], list) else opt_idx_outputs[0]
is_result_obj = len(epoch_output) > 0 and isinstance(sample_obj, Result)
is_1_0_result = is_result_obj and 'extra' in sample_obj
except IndexError:
is_result_obj = False
is_1_0_result = False

# ------------------
# NEW 1.0.0 PATH
# ------------------
if is_1_0_result:
# lightning module hook
self.training_epoch_end(model, epoch_output, num_optimizers)

# log/aggregate metrics automatically
epoch_log_metrics, epoch_progress_bar_metrics = self.__auto_reduce_results_on_epoch_end(epoch_output)

# TODO: deprecate 1.0
else:
epoch_log_metrics, epoch_progress_bar_metrics = self.__run_legacy_training_epoch_end(
num_optimizers, epoch_output, model, is_result_obj
)
# lightning module hook
self.training_epoch_end(model, epoch_output, num_optimizers)

# log/aggregate metrics automatically
epoch_log_metrics, epoch_progress_bar_metrics = self.__auto_reduce_results_on_epoch_end(epoch_output)

# it will perform reduction over epoch and return log metrics
cached_epoch_log_metrics = self.cached_results.get_epoch_log_metrics()
Expand Down Expand Up @@ -446,46 +391,6 @@ def training_epoch_end(self, model, epoch_output, num_optimizers):
# capture logging
self.trainer.logger_connector.cache_logged_metrics()

def __run_legacy_training_epoch_end(self, num_optimizers, epoch_output, model, is_result_obj):

epoch_log_metrics = {}
epoch_progress_bar_metrics = {}

# --------------------------
# EPOCH END STEP IF DEFINED
# --------------------------
if is_overridden('training_epoch_end', model=model):
if is_result_obj:
# with result object gather across time and training steps so each opt idx has a single result obj
epoch_output = self.__gather_result_across_time_and_optimizers(epoch_output)

if num_optimizers == 1:
epoch_output = epoch_output[0]

# run training_epoch_end
# a list with a result per optimizer index
model._current_fx_name = 'training_epoch_end'
epoch_output = model.training_epoch_end(epoch_output)

# capture logging
self.trainer.logger_connector.cache_logged_metrics()

if isinstance(epoch_output, Result):
epoch_log_metrics = epoch_output.epoch_log_metrics
epoch_progress_bar_metrics = epoch_output.epoch_pbar_metrics
else:
_processed_outputs = self.trainer.process_dict_result(epoch_output)
epoch_progress_bar_metrics = _processed_outputs[1]
epoch_log_metrics = _processed_outputs[2]

# --------------------------
# Structured Result (auto epoch end)
# --------------------------
elif is_result_obj:
epoch_log_metrics, epoch_progress_bar_metrics = self.__auto_reduce_results_on_epoch_end(epoch_output)

return epoch_log_metrics, epoch_progress_bar_metrics

def __auto_reduce_results_on_epoch_end(self, epoch_output):
epoch_log_metrics = {}
epoch_progress_bar_metrics = {}
Expand Down Expand Up @@ -538,29 +443,6 @@ def __prepare_epoch_end_inputs(self, epoch_output):

return gathered_epoch_outputs

def __gather_result_across_time_and_optimizers(self, epoch_output):
"""
Gather results into a single padded tensor per metric where each tensor is gathered across
time and across time steps.

Returns:
a list where each element is a Result with the tensors gathered
"""
gathered_epoch_outputs = []
for opt_outputs in epoch_output:
# gather across time first
time_gathered_outputs = []
for tbptt_outs in opt_outputs:
tbptt_outs = tbptt_outs[0].__class__.gather(tbptt_outs)
time_gathered_outputs.append(tbptt_outs)

# gather across training steps
# each metric has dimensions (training_steps, seq_len) (seq_len=1 when no tbptt is used)
gathered_opt_output = time_gathered_outputs[0].__class__.padded_gather(time_gathered_outputs)
gathered_epoch_outputs.append(gathered_opt_output)

return gathered_epoch_outputs

def log_train_step_metrics(self, batch_output):
if self.trainer.train_loop.should_accumulate() and self.trainer.train_loop.automatic_optimization:
return
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,7 @@ def __run_eval_epoch_end(self, num_dataloaders):
if not isinstance(eval_results, list):
eval_results = [eval_results]

# track depreceated metrics
self.trainer.logger_connector.track_metrics_deprecated(eval_results)
self.trainer.logger_connector._track_callback_metrics(eval_results)

return eval_results

Expand Down
35 changes: 0 additions & 35 deletions tests/helpers/deterministic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,19 +116,6 @@ def training_step__dict_return(self, batch, batch_idx):
self.training_step_called = True
return {'loss': acc, 'log': logs, 'progress_bar': pbar, 'train_step_test': torch.tensor(549).type_as(acc)}

def training_step__for_step_end_dict(self, batch, batch_idx):
"""sends outputs to training_batch_end"""
acc = self.step(batch, batch_idx)

logs = {'log_acc1': torch.tensor(12).type_as(acc), 'log_acc2': torch.tensor(7).type_as(acc)}
pbar = {'pbar_acc1': torch.tensor(17).type_as(acc), 'pbar_acc2': torch.tensor(19).type_as(acc)}

self.training_step_called = True
result = {'loss': acc}
result.update(logs)
result.update(pbar)
return result

def training_step_end__dict(self, output):
self.training_step_end_called = True

Expand All @@ -151,28 +138,6 @@ def training_step_end__dict(self, output):
acc = output['loss']
return {'loss': acc, 'log': logs, 'progress_bar': pbar, 'train_step_end': acc}

def training_epoch_end__dict(self, outputs):
self.training_epoch_end_called = True

if self._distrib_type in (DistributedType.DP, DistributedType.DDP2):
pass
else:
# only saw 4 batches
assert len(outputs) == 4
for batch_out in outputs:
assert len(batch_out.keys()) == 4
assert self.count_num_graphs(batch_out) == 0
last_key = 'train_step_end' if self.training_step_end_called else 'train_step_test'
keys = ['loss', 'log', 'progress_bar', last_key]
for key in keys:
assert key in batch_out

prototype_loss = outputs[0]['loss']
logs = {'epoch_end_log_1': torch.tensor(178).type_as(prototype_loss)}
pbar = {'epoch_end_pbar_1': torch.tensor(234).type_as(prototype_loss)}

return {'log': logs, 'progress_bar': pbar}

def validation_step__no_return(self, batch, batch_idx):
self.validation_step_called = True
self.step(batch, batch_idx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
"""
Tests to ensure that the training loop works with a dict
"""
import os
from unittest import mock

from pytorch_lightning import Trainer
from tests.helpers.deterministic_model import DeterministicModel
Expand Down Expand Up @@ -104,112 +102,3 @@ def training_step_with_step_end(tmpdir):
assert 'train_step_end' in train_step_end_out
assert pbar_metrics['pbar_acc1'] == 19.0
assert pbar_metrics['pbar_acc2'] == 21.0


def test_full_training_loop_dict(tmpdir):
"""
Checks train_step + training_step_end + training_epoch_end
"""
model = DeterministicModel()
model.training_step = model.training_step__for_step_end_dict
model.training_step_end = model.training_step_end__dict
model.training_epoch_end = model.training_epoch_end__dict
model.val_dataloader = None

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
weights_summary=None,
)
trainer.fit(model)

# make sure correct steps were called
assert model.training_step_called
assert model.training_step_end_called
assert model.training_epoch_end_called

# assert epoch end metrics were added
assert trainer.logger_connector.callback_metrics['epoch_end_log_1'] == 178
assert trainer.logger_connector.progress_bar_metrics['epoch_end_pbar_1'] == 234

# make sure training outputs what is expected
batch_idx, batch = 0, next(iter(model.train_dataloader()))

out = trainer.train_loop.run_training_batch(batch, batch_idx, 0)
assert out.signal == 0
assert trainer.logger_connector.logged_metrics['log_acc1'] == 14.0
assert trainer.logger_connector.logged_metrics['log_acc2'] == 9.0

# get the output of the first optimizer
train_step_end_out = out.training_step_output_for_epoch_end
assert len(train_step_end_out) == 1
train_step_end_out = train_step_end_out[0][0]
pbar_metrics = train_step_end_out['progress_bar']
assert pbar_metrics['pbar_acc1'] == 19.0
assert pbar_metrics['pbar_acc2'] == 21.0


@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
def test_result_obj_lr_scheduler_epoch(tmpdir):
"""
test that the LR scheduler was called at the correct time with the correct metrics
"""
model = DeterministicModel()
model.training_step = model.training_step__for_step_end_dict
model.training_step_end = model.training_step_end__dict
model.training_epoch_end = model.training_epoch_end__dict
model.val_dataloader = None
model.configure_optimizers = model.configure_optimizers__lr_on_plateau_epoch

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=3,
weights_summary=None,
)
trainer.fit(model)

assert len(trainer.dev_debugger.saved_lr_scheduler_updates) == 3


def test_train_step_epoch_end(tmpdir):
"""
Checks train_step + training_epoch_end (NO training_step_end)
"""
model = DeterministicModel()
model.training_step = model.training_step__dict_return
model.training_step_end = None
model.training_epoch_end = model.training_epoch_end__dict
model.val_dataloader = None

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
weights_summary=None,
)
trainer.fit(model)

# make sure correct steps were called
assert model.training_step_called
assert not model.training_step_end_called
assert model.training_epoch_end_called

# assert epoch end metrics were added
assert trainer.logger_connector.callback_metrics['epoch_end_log_1'] == 178
assert trainer.logger_connector.progress_bar_metrics['epoch_end_pbar_1'] == 234

# make sure training outputs what is expected
batch_idx, batch = 0, next(iter(model.train_dataloader()))

out = trainer.train_loop.run_training_batch(batch, batch_idx, 0)
assert out.signal == 0
assert trainer.logger_connector.logged_metrics['log_acc1'] == 12.0
assert trainer.logger_connector.logged_metrics['log_acc2'] == 7.0

# outputs are for 1 optimizer and no tbptt
train_step_end_out = out.training_step_output_for_epoch_end
assert len(train_step_end_out) == 1
train_step_end_out = train_step_end_out[0][0]

pbar_metrics = train_step_end_out['progress_bar']
assert pbar_metrics['pbar_acc1'] == 17.0
assert pbar_metrics['pbar_acc2'] == 19.0