Skip to content

Commit

Permalink
Merge pull request #270 from laserkelvin/loss-scaling-scheduler
Browse files Browse the repository at this point in the history
Implementing loss scaling scheduler callback and schedulers
  • Loading branch information
laserkelvin authored Aug 9, 2024
2 parents b7bc8fd + 75ed137 commit 3e8aa98
Show file tree
Hide file tree
Showing 4 changed files with 560 additions and 0 deletions.
75 changes: 75 additions & 0 deletions examples/callbacks/loss_scheduling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from __future__ import annotations

import pytorch_lightning as pl

from matsciml.datasets.transforms import (
PeriodicPropertiesTransform,
DistancesTransform,
PointCloudToGraphTransform,
)
from matsciml.lightning.data_utils import MatSciMLDataModule
from matsciml.models import SchNet
from matsciml.models.base import ScalarRegressionTask

from matsciml.lightning.callbacks import LossScalingScheduler
from matsciml.lightning.loss_scaling import SigmoidScalingSchedule

"""
This script demonstrates how to add loss scaling schedules
to training runs.
"""

# construct a scalar regression task with SchNet encoder
task = ScalarRegressionTask(
encoder_class=SchNet,
# kwargs to be passed into the creation of SchNet model
encoder_kwargs={
"encoder_only": True,
"hidden_feats": [128, 128, 128],
"atom_embedding_dim": 128,
},
output_kwargs={"lazy": False, "hidden_dim": 128, "input_dim": 128},
# which keys to use as targets
task_keys=["energy_relaxed"],
)

# Use IS2RE devset to test workflow
# SchNet uses RBFs, and expects edge features corresponding to atom-atom distances
dm = MatSciMLDataModule.from_devset(
"IS2REDataset",
dset_kwargs={
"transforms": [
PeriodicPropertiesTransform(6.0, True),
PointCloudToGraphTransform(
"dgl",
node_keys=["pos", "atomic_numbers"],
),
DistancesTransform(),
],
},
)

# run several epochs with a limited number of train batches
# to make sure nothing breaks between updates
trainer = pl.Trainer(
max_epochs=10,
limit_train_batches=10,
logger=False,
enable_checkpointing=False,
callbacks=[
LossScalingScheduler(
SigmoidScalingSchedule(
"energy_relaxed",
initial_value=10.0, # the first value will not be this exactly
end_value=1.0, # but close to it, due to nature of sigmoid
center_frac=0.5, # this means the sigmoid flips at half the total steps
curvature=1e-7, # can be modified to change ramping behavior
step_frequency="step",
),
log_level="DEBUG", # this makes it verbose, but setting it to INFO will surpress most
)
],
)
trainer.fit(task, datamodule=dm)
# print out the final scaling rates
print(task.task_loss_scaling)
102 changes: 102 additions & 0 deletions matsciml/lightning/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,25 @@
from matsciml.datasets.utils import concatenate_keys
from matsciml.models.base import BaseTaskModule
from matsciml.common.types import Embeddings, BatchDict
from matsciml.lightning.loss_scaling import BaseScalingSchedule

__all__ = [
"LeaderboardWriter",
"GradientCheckCallback",
"UnusedParametersCallback",
"ThroughputCallback",
"ForwardNaNDetection",
"ManualGradientClip",
"MonitorGradients",
"GarbageCallback",
"InferenceWriter",
"CodeCarbonCallback",
"SAM",
"TrainingHelperCallback",
"ModelAutocorrelation",
"ExponentialMovingAverageCallback",
"LossScalingScheduler",
]


class LeaderboardWriter(BasePredictionWriter):
Expand Down Expand Up @@ -1492,3 +1511,86 @@ def on_fit_end(
loader = trainer.train_dataloader
self.logger.info("Fit finished - updating EMA batch normalization state.")
update_bn(loader, pl_module.ema_module)


class LossScalingScheduler(Callback):
def __init__(
self,
*schedules: BaseScalingSchedule,
log_level: Literal["INFO", "DEBUG", "WARNING", "CRITICAL"] = "INFO",
) -> None:
"""
Callback for dynamically adjusting loss scaling values over
the course of training, a la curriculum learning.
This class is configured by supplying a list of schedules
as args; see `matsciml.lightning.loss_scaling` module for
available schedules. Each schedule instance has a `key`
attribute that points it to the corresponding task key
as set in the Lightning task module (e.g. `energy`, `force`).
Parameters
----------
args : BaseScalingSchedule
Scaling schedules for as many tasks as being performed.
"""
super().__init__()
assert len(schedules) > 0, "Must pass individual schedules to loss scheduler!"
self.schedules = schedules
self._logger = getLogger("matsciml.loss_scaling_scheduler")
self._logger.setLevel(log_level)
self._logger.debug(f"Configured {len(self.schedules)} schedules.")
self._logger.debug(
f"Schedules have {[s.key for s in self.schedules]} task keys."
)

def on_fit_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
) -> None:
trainer.datamodule.setup("fit")
for schedule in self.schedules:
# check to make sure the schedule key actually exists in the task
if schedule.key not in pl_module.task_keys:
raise KeyError(
f"Schedule for {schedule.key} expected, but not specified as a task key!"
)
# schedules grab what information they need from the
# trainer and task modules
schedule.setup(trainer, pl_module)
self._logger.debug("Configured {schedule.key} schedule.")

def _step_schedules(
self, pl_module: "pl.LightningModule", stage: Literal["step", "epoch"]
) -> None:
"""Base function to step schedules according to what stage we are in."""
for schedule in self.schedules:
if schedule.step_frequency == stage:
target_key = schedule.key
self._logger.debug(
f"Attempting to advance {target_key} schedule on {stage}."
)
try:
new_scaling_value = schedule.step()
pl_module.task_loss_scaling[target_key] = new_scaling_value
self._logger.debug(
f"Advanced {target_key} to new value: {new_scaling_value}"
)
except StopIteration:
self._logger.warning(
f"{target_key} has run out of scheduled values; this may be unintentional."
)

def on_train_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: Any,
batch: Any,
batch_idx: int,
) -> None:
self._step_schedules(pl_module, "step")

def on_train_epoch_end(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
) -> None:
self._step_schedules(pl_module, "epoch")
Loading

0 comments on commit 3e8aa98

Please sign in to comment.