Skip to content

Commit

Permalink
Rename and clean variables
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Jul 15, 2021
1 parent fc18c16 commit e550e6d
Showing 1 changed file with 102 additions and 71 deletions.
173 changes: 102 additions & 71 deletions tests/loops/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ class CustomException(Exception):


def test_loop_restore():

class Simple(Loop):

def __init__(self, dataset: Iterator):
super().__init__()
self.dataset = dataset
Expand Down Expand Up @@ -92,11 +94,13 @@ def load_state_dict(self, state_dict: Dict) -> None:


def test_loop_hierarchy():

@dataclass
class SimpleProgress(BaseProgress):
increment: int = 0

class Simple(Loop):

def __init__(self, a):
super().__init__()
self.a = a
Expand Down Expand Up @@ -134,10 +138,18 @@ def on_load_checkpoint(self, state_dict: Dict) -> None:

state_dict = loop_parent.state_dict()
assert state_dict == {
'state_dict': {'a': 1},
'progress': {'increment': 0},
'loop_child.state_dict': {'a': 2},
'loop_child.progress': {'increment': 0},
'state_dict': {
'a': 1
},
'progress': {
'increment': 0
},
'loop_child.state_dict': {
'a': 2
},
'loop_child.progress': {
'increment': 0
},
}

state_dict["loop_child.state_dict"]["a"] = 3
Expand All @@ -150,10 +162,18 @@ def on_load_checkpoint(self, state_dict: Dict) -> None:
# check the new state after `run`
state_dict = loop_parent.state_dict()
assert state_dict == {
'state_dict': {'a': 1},
'progress': {'increment': 1},
'loop_child.state_dict': {'a': 3},
'loop_child.progress': {'increment': 1},
'state_dict': {
'a': 1
},
'progress': {
'increment': 1
},
'loop_child.state_dict': {
'a': 3
},
'loop_child.progress': {
'increment': 1
},
}

loop_parent_copy = deepcopy(loop_parent)
Expand Down Expand Up @@ -182,6 +202,7 @@ def test_loop_restart_progress_multiple_dataloaders(tmpdir):
n_epochs = 2

class ValidationModel(BoringModel):

def __init__(self):
super().__init__()

Expand Down Expand Up @@ -216,7 +237,12 @@ def val_dataloader(self):

total = (n_epochs - 1) * n_dataloaders + stop_dataloader
expected = {
"total": {"ready": total + 1, "started": None, "processed": None, "completed": total},
"total": {
"ready": total + 1,
"started": None,
"processed": None,
"completed": total
},
"current": {
"ready": stop_dataloader + 1,
"started": None,
Expand All @@ -229,7 +255,12 @@ def val_dataloader(self):
trainer.fit_loop.load_state_dict(checkpoint, restart_progress=False)
total = n_dataloaders * n_batches + n_batches + stop_epoch
expected = {
"total": {"ready": total + 1, "started": total + 1, "processed": total, "completed": total},
"total": {
"ready": total + 1,
"started": total + 1,
"processed": total,
"completed": total
},
"current": {
"ready": stop_batch + 1,
"started": stop_batch + 1,
Expand All @@ -241,8 +272,18 @@ def val_dataloader(self):

trainer.fit_loop.load_state_dict(checkpoint)
expected = {
"total": {"ready": total, "started": total, "processed": total, "completed": total},
"current": {"ready": stop_batch, "started": stop_batch, "processed": stop_batch, "completed": stop_batch},
"total": {
"ready": total,
"started": total,
"processed": total,
"completed": total
},
"current": {
"ready": stop_batch,
"started": stop_batch,
"processed": stop_batch,
"completed": stop_batch
},
}
assert trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.state_dict() == expected

Expand All @@ -251,14 +292,15 @@ def val_dataloader(self):
@pytest.mark.parametrize("accumulate_grad_batches", (1, 2)) # FIXME: 3 is broken
@pytest.mark.parametrize("n_optimizers", (1, 3, 5))
@pytest.mark.parametrize("stop_epoch", (1, 2))
@pytest.mark.parametrize("stop_batch", (1,)) # FIXME: 2 is broken
@pytest.mark.parametrize("stop_batch", (1, )) # FIXME: 2 is broken
@pytest.mark.parametrize("stop_optimizer", (1, 2))
def test_loop_state_on_exception(accumulate_grad_batches, stop_epoch, stop_batch, stop_optimizer, n_optimizers, tmpdir):
stop_optimizer = stop_optimizer if stop_optimizer < n_optimizers else 0
n_epochs = 3
n_batches = 3

class TestModel(BoringModel):

def __init__(self):
super().__init__()
if n_optimizers > 1:
Expand Down Expand Up @@ -304,57 +346,46 @@ def configure_optimizers_multiple(self):
optim_progress = trainer.fit_loop.epoch_loop.batch_loop.optim_progress
scheduler_progress = trainer.fit_loop.epoch_loop.scheduler_progress

non_breaking_epoch_batches_completed = stop_epoch * n_batches
breaking_epoch_batches_completed = stop_batch
breaking_epoch_batches_ready = stop_batch + 1
# `nb_`: non-breaking, as in, no exception will be raised. `b_`: breaking
nb_epoch_batches_completed = stop_epoch * n_batches
b_epoch_batches_completed = stop_batch
b_epoch_batches_ready = stop_batch + 1
# lightning applies leftover accumulated gradients when the epoch ends
has_leftover_accumulation_batches = n_batches % accumulate_grad_batches != 0
non_breaking_stepping_batches = non_breaking_epoch_batches_completed // accumulate_grad_batches
breaking_stepping_batches = breaking_epoch_batches_completed // accumulate_grad_batches

non_breaking_total_optimizer_steps = (
non_breaking_stepping_batches + has_leftover_accumulation_batches
) * n_optimizers
should_last_batch_step = breaking_epoch_batches_ready % accumulate_grad_batches == 0
breaking_total_optimizer_steps = breaking_stepping_batches * n_optimizers + should_last_batch_step * stop_optimizer
total_optimizer_steps = non_breaking_total_optimizer_steps + breaking_total_optimizer_steps
current_optimizer_steps = breaking_total_optimizer_steps
has_optimizer_step_in_breaking_epoch = accumulate_grad_batches == 1 or n_batches % accumulate_grad_batches != 0
assert optim_progress.optimizer_steps == total_optimizer_steps
assert optim_progress.optimizer.step.current.completed == current_optimizer_steps

non_breaking_total_zero_grad = (non_breaking_stepping_batches + has_leftover_accumulation_batches) * n_optimizers
nb_stepping_batches = nb_epoch_batches_completed // accumulate_grad_batches
b_stepping_batches = b_epoch_batches_completed // accumulate_grad_batches

nb_total_optimizer_steps = (nb_stepping_batches + has_leftover_accumulation_batches) * n_optimizers
should_last_batch_step = b_epoch_batches_ready % accumulate_grad_batches == 0
b_total_optimizer_steps = b_stepping_batches * n_optimizers + should_last_batch_step * stop_optimizer
has_optimizer_step_in_b_epoch = accumulate_grad_batches == 1 or n_batches % accumulate_grad_batches != 0
assert optim_progress.optimizer_steps == nb_total_optimizer_steps + b_total_optimizer_steps
assert optim_progress.optimizer.step.current.completed == b_total_optimizer_steps

nb_total_zero_grad = (nb_stepping_batches + has_leftover_accumulation_batches) * n_optimizers
# FIXME: What the hell
if accumulate_grad_batches > 1:
# FIXME: ready or completed? 0 or stop_optimizer?
breaking_total_zero_grad = (
n_optimizers
+ (breaking_epoch_batches_ready // accumulate_grad_batches - (accumulate_grad_batches > 1))
* (n_optimizers - 1)
+ 0
b_total_zero_grad = (
n_optimizers + (b_epoch_batches_ready // accumulate_grad_batches - (accumulate_grad_batches > 1)) *
(n_optimizers - 1) + 0
)
# breaking_total_zero_grad = breaking_epoch_batches_ready // accumulate_grad_batches * n_optimizers + 0
# b_total_zero_grad = b_epoch_batches_ready // accumulate_grad_batches * n_optimizers + 0
else:
breaking_total_zero_grad = breaking_stepping_batches * n_optimizers + stop_optimizer
total_zero_grad = non_breaking_total_zero_grad + breaking_total_zero_grad
current_zero_grad = breaking_total_zero_grad
assert optim_progress.optimizer.zero_grad.total.completed == total_zero_grad
assert optim_progress.optimizer.zero_grad.current.completed == current_zero_grad

non_breaking_scheduler_steps = stop_epoch
breaking_scheduler_steps = 0 # the current epoch did not complete
b_total_zero_grad = b_stepping_batches * n_optimizers + stop_optimizer
assert optim_progress.optimizer.zero_grad.total.completed == nb_total_zero_grad + b_total_zero_grad
assert optim_progress.optimizer.zero_grad.current.completed == b_total_zero_grad

nb_scheduler_steps = stop_epoch
b_scheduler_steps = 0 # the current epoch did not complete
if n_optimizers > 1:
# assumes that the scheduler config is unchanged
# `* 1` because there is only one step-level scheduler
non_breaking_scheduler_steps = (
stop_epoch + non_breaking_stepping_batches + has_leftover_accumulation_batches * 1
)
nb_scheduler_steps = stop_epoch + nb_stepping_batches + has_leftover_accumulation_batches * 1
# `0 +` for the epoch-level scheduler
breaking_scheduler_steps = 0 + breaking_stepping_batches
total_scheduler_steps = non_breaking_scheduler_steps + breaking_scheduler_steps
current_scheduler_steps = breaking_scheduler_steps
assert scheduler_progress.total.completed == total_scheduler_steps
assert scheduler_progress.current.completed == current_scheduler_steps
b_scheduler_steps = 0 + b_stepping_batches
assert scheduler_progress.total.completed == nb_scheduler_steps + b_scheduler_steps
assert scheduler_progress.current.completed == b_scheduler_steps

# yapf: disable
expected = {
Expand All @@ -376,10 +407,10 @@ def configure_optimizers_multiple(self):
"epoch_loop.state_dict": {},
"epoch_loop.batch_progress": {
"total": {
"ready": non_breaking_epoch_batches_completed + breaking_epoch_batches_completed + 1,
"started": non_breaking_epoch_batches_completed + breaking_epoch_batches_completed + 1,
"processed": non_breaking_epoch_batches_completed + breaking_epoch_batches_completed,
"completed": non_breaking_epoch_batches_completed + breaking_epoch_batches_completed,
"ready": nb_epoch_batches_completed + b_epoch_batches_completed + 1,
"started": nb_epoch_batches_completed + b_epoch_batches_completed + 1,
"processed": nb_epoch_batches_completed + b_epoch_batches_completed,
"completed": nb_epoch_batches_completed + b_epoch_batches_completed,
},
"current": {
"ready": stop_batch + 1,
Expand All @@ -390,16 +421,16 @@ def configure_optimizers_multiple(self):
},
"epoch_loop.scheduler_progress": {
"total": {
"ready": total_scheduler_steps,
"ready": nb_scheduler_steps + b_scheduler_steps,
"started": None,
"processed": None,
"completed": total_scheduler_steps,
"completed": nb_scheduler_steps + b_scheduler_steps,
},
"current": {
"ready": current_scheduler_steps,
"ready": b_scheduler_steps,
"started": None,
"processed": None,
"completed": current_scheduler_steps,
"completed": b_scheduler_steps,
},
},
"epoch_loop.batch_loop.state_dict": {},
Expand All @@ -408,30 +439,30 @@ def configure_optimizers_multiple(self):
"optimizer": {
"step": {
"total": {
"ready": total_optimizer_steps + has_optimizer_step_in_breaking_epoch,
"ready": nb_total_optimizer_steps + b_total_optimizer_steps + has_optimizer_step_in_b_epoch,
"started": None,
"processed": None,
"completed": total_optimizer_steps,
"completed": nb_total_optimizer_steps + b_total_optimizer_steps,
},
"current": {
"ready": current_optimizer_steps + has_optimizer_step_in_breaking_epoch,
"ready": b_total_optimizer_steps + has_optimizer_step_in_b_epoch,
"started": None,
"processed": None,
"completed": current_optimizer_steps,
"completed": b_total_optimizer_steps,
},
},
"zero_grad": {
"total": {
"ready": total_zero_grad,
"started": total_zero_grad,
"ready": nb_total_zero_grad + b_total_zero_grad,
"started": nb_total_zero_grad + b_total_zero_grad,
"processed": None,
"completed": total_zero_grad,
"completed": nb_total_zero_grad + b_total_zero_grad,
},
"current": {
"ready": current_zero_grad,
"started": current_zero_grad,
"ready": b_total_zero_grad,
"started": b_total_zero_grad,
"processed": None,
"completed": current_zero_grad,
"completed": b_total_zero_grad,
},
},
},
Expand Down

0 comments on commit e550e6d

Please sign in to comment.