Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify variables: step, epoch, max_epochs, min_epochs #589

Merged
merged 1 commit into from
Dec 7, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,16 +154,16 @@ use something other than tensorboard).
Here are more advanced examples
```python
# train on cpu using only 10% of the data (for demo purposes)
trainer = Trainer(max_num_epochs=1, train_percent_check=0.1)
trainer = Trainer(max_epochs=1, train_percent_check=0.1)

# train on 4 gpus (lightning chooses GPUs for you)
# trainer = Trainer(max_num_epochs=1, gpus=4, distributed_backend='ddp')
# trainer = Trainer(max_epochs=1, gpus=4, distributed_backend='ddp')

# train on 4 gpus (you choose GPUs)
# trainer = Trainer(max_num_epochs=1, gpus=[0, 1, 3, 7], distributed_backend='ddp')
# trainer = Trainer(max_epochs=1, gpus=[0, 1, 3, 7], distributed_backend='ddp')

# train on 32 gpus across 4 nodes (make sure to submit appropriate SLURM job)
# trainer = Trainer(max_num_epochs=1, gpus=8, num_gpu_nodes=4, distributed_backend='ddp')
# trainer = Trainer(max_epochs=1, gpus=8, num_gpu_nodes=4, distributed_backend='ddp')

# train (1 epoch only here for demo)
trainer.fit(model)
Expand Down
2 changes: 1 addition & 1 deletion pl_examples/full_examples/imagenet/imagenet_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def main(hparams):
trainer = pl.Trainer(
default_save_path=hparams.save_path,
gpus=hparams.gpus,
max_num_epochs=hparams.epochs,
max_epochs=hparams.epochs,
distributed_backend=hparams.distributed_backend,
use_amp=hparams.use_16bit
)
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,10 +694,10 @@ def configure_optimizers(self):
"""
raise NotImplementedError

def optimizer_step(self, epoch_idx, batch_idx, optimizer, optimizer_idx, second_order_closure=None):
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, second_order_closure=None):
"""Do something instead of the standard optimizer behavior

:param int epoch_idx:
:param int epoch:
:param int batch_idx:
:param optimizer:
:param optimizer_idx:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/logging/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def log_hyperparams(self, params):
pass

@rank_zero_only
def log_metrics(self, metrics, step_idx):
def log_metrics(self, metrics, step):
# metrics is a dictionary of metric names and values
# your code to record metrics goes here
pass
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/logging/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ class LightningLoggerBase(object):
def __init__(self):
self._rank = 0

def log_metrics(self, metrics, step_idx):
def log_metrics(self, metrics, step):
"""Record metrics.

:param float metric: Dictionary with metric names as keys and measured quanties as values
:param int|None step_idx: Step number at which the metrics should be recorded
:param int|None step: Step number at which the metrics should be recorded
"""
raise NotImplementedError()

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/logging/comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,13 +145,13 @@ def log_hyperparams(self, params):
self.experiment.log_parameters(vars(params))

@rank_zero_only
def log_metrics(self, metrics, step_idx=None):
def log_metrics(self, metrics, step=None):
# Comet.ml expects metrics to be a dictionary of detached tensors on CPU
for key, val in metrics.items():
if is_tensor(val):
metrics[key] = val.cpu().detach()

self.experiment.log_metrics(metrics, step=step_idx)
self.experiment.log_metrics(metrics, step=step)

@rank_zero_only
def finalize(self, status):
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/logging/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,15 @@ def log_hyperparams(self, params):
self.experiment.log_param(self.run_id, k, v)

@rank_zero_only
def log_metrics(self, metrics, step_idx=None):
def log_metrics(self, metrics, step=None):
timestamp_ms = int(time() * 1000)
for k, v in metrics.items():
if isinstance(v, str):
logger.warning(
f"Discarding metric with string value {k}={v}"
)
continue
self.experiment.log_metric(self.run_id, k, v, timestamp_ms, step_idx)
self.experiment.log_metric(self.run_id, k, v, timestamp_ms, step)

def save(self):
pass
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/logging/test_tube.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@ def log_hyperparams(self, params):
self.experiment.argparse(params)

@rank_zero_only
def log_metrics(self, metrics, step_idx=None):
def log_metrics(self, metrics, step=None):
# TODO: HACK figure out where this is being set to true
self.experiment.debug = self.debug
self.experiment.log(metrics, global_step=step_idx)
self.experiment.log(metrics, global_step=step)

@rank_zero_only
def save(self):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def log_metrics(self, metrics, grad_norm_dic):

# log actual metrics
if self.proc_rank == 0 and self.logger is not None:
self.logger.log_metrics(scalar_metrics, step_idx=self.global_step)
self.logger.log_metrics(scalar_metrics, step=self.global_step)
self.logger.save()

def add_tqdm_metrics(self, metrics):
Expand Down
26 changes: 13 additions & 13 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def __init__(
accumulate_grad_batches=1,
max_nb_epochs=None, # backward compatible, todo: remove in v0.8.0
min_nb_epochs=None, # backward compatible, todo: remove in v0.8.0
max_num_epochs=1000,
min_num_epochs=1,
max_epochs=1000,
min_epochs=1,
train_percent_check=1.0,
val_percent_check=1.0,
test_percent_check=1.0,
Expand Down Expand Up @@ -111,8 +111,8 @@ def __init__(
:param int check_val_every_n_epoch: check val every n train epochs
:param bool fast_dev_run: runs full iteration over everything to find bugs
:param int accumulate_grad_batches: Accumulates grads every k batches
:param int max_num_epochs:
:param int min_num_epochs:
:param int max_epochs:
:param int min_epochs:
:param int train_percent_check: How much of train set to check
:param int val_percent_check: How much of val set to check
:param int test_percent_check: How much of test set to check
Expand Down Expand Up @@ -158,17 +158,17 @@ def __init__(
self.process_position = process_position
self.weights_summary = weights_summary
if max_nb_epochs is not None: # Backward compatibility
warnings.warn("`max_nb_epochs` has renamed to `max_num_epochs` since v0.5.0"
warnings.warn("`max_nb_epochs` has renamed to `max_epochs` since v0.5.0"
" and will be removed in v0.8.0", DeprecationWarning)
if not max_num_epochs: # in case you did not set the proper value
max_num_epochs = max_nb_epochs
self.max_num_epochs = max_num_epochs
if not max_epochs: # in case you did not set the proper value
max_epochs = max_nb_epochs
self.max_epochs = max_epochs
if min_nb_epochs is not None: # Backward compatibility
warnings.warn("`min_nb_epochs` has renamed to `min_num_epochs` since v0.5.0"
warnings.warn("`min_nb_epochs` has renamed to `min_epochs` since v0.5.0"
" and will be removed in v0.8.0", DeprecationWarning)
if not min_num_epochs: # in case you did not set the proper value
min_num_epochs = min_nb_epochs
self.min_num_epochs = min_num_epochs
if not min_epochs: # in case you did not set the proper value
min_epochs = min_nb_epochs
self.min_epochs = min_epochs
if nb_sanity_val_steps is not None: # Backward compatibility
warnings.warn("`nb_sanity_val_steps` has renamed to `num_sanity_val_steps` since v0.5.0"
" and will be removed in v0.8.0", DeprecationWarning)
Expand All @@ -183,7 +183,7 @@ def __init__(
self.fast_dev_run = fast_dev_run
if self.fast_dev_run:
self.num_sanity_val_steps = 1
self.max_num_epochs = 1
self.max_epochs = 1
m = '''
Running in fast_dev_run mode: will run a full train,
val loop using a single batch
Expand Down
18 changes: 9 additions & 9 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
.. code-block:: python

# DEFAULT
trainer = Trainer(min_num_epochs=1, max_num_epochs=1000)
trainer = Trainer(min_epochs=1, max_epochs=1000)

Early stopping
--------------
Expand Down Expand Up @@ -259,17 +259,17 @@ def process_output(self, output, train):

def train(self):
# run all epochs
for epoch_idx in range(self.current_epoch, self.max_num_epochs):
for epoch in range(self.current_epoch, self.max_epochs):
# set seed for distributed sampler (enables shuffling for each epoch)
if self.use_ddp and hasattr(self.get_train_dataloader().sampler, 'set_epoch'):
self.get_train_dataloader().sampler.set_epoch(epoch_idx)
self.get_train_dataloader().sampler.set_epoch(epoch)

# get model
model = self.get_model()

# update training progress in trainer and model
model.current_epoch = epoch_idx
self.current_epoch = epoch_idx
model.current_epoch = epoch
self.current_epoch = epoch

# val can be checked multiple times in epoch
is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
Expand All @@ -294,11 +294,11 @@ def train(self):
# .reset() doesn't work on disabled progress bar so we should check
if not self.main_progress_bar.disable:
self.main_progress_bar.reset(num_iterations)
desc = f'Epoch {epoch_idx + 1}' if not self.is_iterable_train_dataloader else ''
desc = f'Epoch {epoch + 1}' if not self.is_iterable_train_dataloader else ''
self.main_progress_bar.set_description(desc)

# changing gradient according accumulation_scheduler
self.accumulation_scheduler.on_epoch_begin(epoch_idx, self)
self.accumulation_scheduler.on_epoch_begin(epoch, self)

# -----------------
# RUN TNG EPOCH
Expand All @@ -319,9 +319,9 @@ def train(self):
self.reduce_lr_on_plateau_scheduler.step(val_loss, epoch=self.current_epoch)

# early stopping
met_min_epochs = epoch_idx > self.min_num_epochs
met_min_epochs = epoch > self.min_epochs
if self.enable_early_stop and (met_min_epochs or self.fast_dev_run):
should_stop = self.early_stop_callback.on_epoch_end(epoch=epoch_idx,
should_stop = self.early_stop_callback.on_epoch_end(epoch=epoch,
logs=self.callback_metrics)
# stop training
stop = should_stop and met_min_epochs
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/utilities/arg_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ def add_default_args(parser, root_dir, rand_seed=None, possible_model_names=None
parser.opt_list('--accumulate_grad_batches', default=1, type=int, tunable=False,
help='accumulates gradients k times before applying update.'
' Simulates huge batch size')
parser.add_argument('--max_num_epochs', default=200, type=int, help='cap epochs')
parser.add_argument('--min_num_epochs', default=2, type=int, help='min epochs')
parser.add_argument('--max_epochs', default=200, type=int,
help='maximum number of epochs')
parser.add_argument('--min_epochs', default=2, type=int,
help='minimum number of epochs')
parser.add_argument('--train_percent_check', default=1.0, type=float,
help='how much of training set to check')
parser.add_argument('--val_percent_check', default=1.0, type=float,
Expand Down
12 changes: 6 additions & 6 deletions tests/test_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_amp_single_gpu(tmpdir):
trainer_options = dict(
default_save_path=tmpdir,
show_progress_bar=True,
max_num_epochs=1,
max_epochs=1,
gpus=1,
distributed_backend='ddp',
use_amp=True
Expand All @@ -45,7 +45,7 @@ def test_no_amp_single_gpu(tmpdir):
trainer_options = dict(
default_save_path=tmpdir,
show_progress_bar=True,
max_num_epochs=1,
max_epochs=1,
gpus=1,
distributed_backend='dp',
use_amp=True
Expand All @@ -69,7 +69,7 @@ def test_amp_gpu_ddp(tmpdir):
trainer_options = dict(
default_save_path=tmpdir,
show_progress_bar=True,
max_num_epochs=1,
max_epochs=1,
gpus=2,
distributed_backend='ddp',
use_amp=True
Expand All @@ -94,7 +94,7 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir):

trainer_options = dict(
show_progress_bar=True,
max_num_epochs=1,
max_epochs=1,
gpus=[0],
distributed_backend='ddp',
use_amp=True
Expand Down Expand Up @@ -153,7 +153,7 @@ def test_cpu_model_with_amp(tmpdir):
default_save_path=tmpdir,
show_progress_bar=False,
logger=tutils.get_test_tube_logger(tmpdir),
max_num_epochs=1,
max_epochs=1,
train_percent_check=0.4,
val_percent_check=0.4,
use_amp=True
Expand All @@ -175,7 +175,7 @@ def test_amp_gpu_dp(tmpdir):
model, hparams = tutils.get_model()
trainer_options = dict(
default_save_path=tmpdir,
max_num_epochs=1,
max_epochs=1,
gpus='0, 1', # test init with gpu string
distributed_backend='dp',
use_amp=True
Expand Down
18 changes: 9 additions & 9 deletions tests/test_cpu_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_lbfgs_cpu_model(tmpdir):

trainer_options = dict(
default_save_path=tmpdir,
max_num_epochs=1,
max_epochs=1,
print_nan_grads=True,
show_progress_bar=False,
weights_summary='top',
Expand All @@ -64,7 +64,7 @@ def test_default_logger_callbacks_cpu_model(tmpdir):

trainer_options = dict(
default_save_path=tmpdir,
max_num_epochs=1,
max_epochs=1,
gradient_clip_val=1.0,
overfit_pct=0.20,
print_nan_grads=True,
Expand Down Expand Up @@ -97,7 +97,7 @@ def test_running_test_after_fitting(tmpdir):
trainer_options = dict(
default_save_path=tmpdir,
show_progress_bar=False,
max_num_epochs=1,
max_epochs=1,
train_percent_check=0.4,
val_percent_check=0.2,
test_percent_check=0.2,
Expand Down Expand Up @@ -135,7 +135,7 @@ class CurrentTestModel(LightningTestMixin, LightningTestModelBase):

trainer_options = dict(
show_progress_bar=False,
max_num_epochs=1,
max_epochs=1,
train_percent_check=0.4,
val_percent_check=0.2,
test_percent_check=0.2,
Expand Down Expand Up @@ -209,7 +209,7 @@ def test_simple_cpu(tmpdir):
# logger file to get meta
trainer_options = dict(
default_save_path=tmpdir,
max_num_epochs=1,
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.1,
)
Expand All @@ -230,7 +230,7 @@ def test_cpu_model(tmpdir):
default_save_path=tmpdir,
show_progress_bar=False,
logger=tutils.get_test_tube_logger(tmpdir),
max_num_epochs=1,
max_epochs=1,
train_percent_check=0.4,
val_percent_check=0.4
)
Expand All @@ -253,7 +253,7 @@ def test_all_features_cpu_model(tmpdir):
show_progress_bar=False,
logger=tutils.get_test_tube_logger(tmpdir),
accumulate_grad_batches=2,
max_num_epochs=1,
max_epochs=1,
train_percent_check=0.4,
val_percent_check=0.4
)
Expand Down Expand Up @@ -314,7 +314,7 @@ def train_dataloader(self):

trainer_options = dict(
default_save_path=tmpdir,
max_num_epochs=1,
max_epochs=1,
truncated_bptt_steps=truncated_bptt_steps,
val_percent_check=0,
weights_summary=None,
Expand Down Expand Up @@ -348,7 +348,7 @@ def test_single_gpu_model(tmpdir):
trainer_options = dict(
default_save_path=tmpdir,
show_progress_bar=False,
max_num_epochs=1,
max_epochs=1,
train_percent_check=0.1,
val_percent_check=0.1,
gpus=1
Expand Down
Loading