Skip to content

Commit

Permalink
refactor: combining functionality into unified function
Browse files Browse the repository at this point in the history
  • Loading branch information
laserkelvin committed Aug 9, 2024
1 parent b8acf7b commit d3b97da
Showing 1 changed file with 16 additions and 25 deletions.
41 changes: 16 additions & 25 deletions matsciml/lightning/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1559,19 +1559,15 @@ def on_fit_start(
schedule.setup(trainer, pl_module)
self._logger.debug("Configured {schedule.key} schedule.")

def on_train_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: Any,
batch: Any,
batch_idx: int,
def _step_schedules(
self, pl_module: "pl.LightningModule", stage: Literal["step", "epoch"]
) -> None:
"""Base function to step schedules according to what stage we are in."""
for schedule in self.schedules:
if schedule.step_frequency == "step":
if schedule.step_frequency == stage:
target_key = schedule.key
self._logger.debug(
f"Attempting to advance {target_key} schedule on step."
f"Attempting to advance {target_key} schedule on {stage}."
)
try:
new_scaling_value = schedule.step()
Expand All @@ -1584,22 +1580,17 @@ def on_train_batch_end(
f"{target_key} has run out of scheduled values; this may be unintentional."
)

def on_train_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: Any,
batch: Any,
batch_idx: int,
) -> None:
self._step_schedules(pl_module, "step")

def on_train_epoch_end(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
) -> None:
for schedule in self.schedules:
if schedule.step_frequency == "epoch":
target_key = schedule.key
self._logger.debug(
f"Attempting to advance {target_key} schedule on epoch."
)
try:
new_scaling_value = schedule.step()
pl_module.task_loss_scaling[target_key] = new_scaling_value
self._logger.debug(
f"Advanced {target_key} to new value: {new_scaling_value}"
)
except StopIteration:
self._logger.warning(
f"{target_key} has run out of scheduled values; this may be unintentional."
)
self._step_schedules(pl_module, "epoch")

0 comments on commit d3b97da

Please sign in to comment.