Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing loss scaling scheduler callback and schedulers #270

Merged
merged 38 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
9d933d4
feat: implemented loss scaling base class
laserkelvin Aug 7, 2024
9f5f062
feat: added property support for schedule grid
laserkelvin Aug 7, 2024
6d3b910
feat: added key property for base class
laserkelvin Aug 7, 2024
ba2995f
feat: adding linear schedule child
laserkelvin Aug 7, 2024
dd32486
refactor: added linear schedule ramp
laserkelvin Aug 7, 2024
2050da1
feat: implemented setup method for linear schedule
laserkelvin Aug 7, 2024
ad706d8
refactor: defining __all__ for loss scaling module
laserkelvin Aug 7, 2024
d51ea6d
refactor: making schedule a cached property
laserkelvin Aug 7, 2024
0581f9a
fix: correcting linear step logic by applying negative sign as needed
laserkelvin Aug 7, 2024
ce2f30f
test: added parametrized linear scheduler test
laserkelvin Aug 7, 2024
764477a
feat: implemented loss schedule callback class
laserkelvin Aug 8, 2024
67c8d3d
test: added unit test to ensure end-to-end linear scaling passes
laserkelvin Aug 8, 2024
778b692
fix: correcting type annotation for callback
laserkelvin Aug 8, 2024
b5a70d1
refactor: add exception handling for unexpected task keys
laserkelvin Aug 8, 2024
93cdce3
test: added unit test to catch unexpected task keys
laserkelvin Aug 8, 2024
c967e31
refactor: setting initial value of scaling as part of setup
laserkelvin Aug 8, 2024
eae15c3
refactor: making setup check task keys, not scaling keys
laserkelvin Aug 8, 2024
d5e1ef7
refactor & test: adding check to make sure task scaling has changed
laserkelvin Aug 8, 2024
a81eae2
test: added end value check to linear schedule as well
laserkelvin Aug 8, 2024
bd1f675
test: added unit test for epoch stepping
laserkelvin Aug 8, 2024
37d7393
docs: updated unit test docstrings
laserkelvin Aug 8, 2024
240ed17
refactor: moved setup to base class
laserkelvin Aug 8, 2024
11812f3
feat: implemented sigmoid schedule
laserkelvin Aug 8, 2024
227c059
docs: added docstring for sigmoid scaling schedule
laserkelvin Aug 8, 2024
5e50058
feat: adding sigmoid scaling to __all__
laserkelvin Aug 8, 2024
26ee19a
refactor: added messages and assertions for sigmoid values
laserkelvin Aug 8, 2024
38c1c76
test: added unit test for sigmoid schedule
laserkelvin Aug 8, 2024
f08a2a4
test: added unit test for sigmoid scaling in training loop
laserkelvin Aug 8, 2024
28e8ba5
refactor & tests: making egnn much smaller
laserkelvin Aug 8, 2024
61ba436
refactor: adding __all__ definition in callbacks
laserkelvin Aug 8, 2024
232837c
refactor: making sure that trainer sets up dataloaders
laserkelvin Aug 8, 2024
ccb4174
refactor & fix: fixed step count condition
laserkelvin Aug 8, 2024
bf9cbf6
fix: remapping initial and end values in sigmoid equation
laserkelvin Aug 8, 2024
a270156
refactor: added step value logging
laserkelvin Aug 8, 2024
b8acf7b
feat: added script that demonstrates loss schedule usage
laserkelvin Aug 8, 2024
d3b97da
refactor: combining functionality into unified function
laserkelvin Aug 9, 2024
4887b1c
fix & test: correcting assertion for end scheduler value check
laserkelvin Aug 9, 2024
75ed137
feat: added repr method for schedules
laserkelvin Aug 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions examples/callbacks/loss_scheduling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from __future__ import annotations

import pytorch_lightning as pl

from matsciml.datasets.transforms import (
PeriodicPropertiesTransform,
DistancesTransform,
PointCloudToGraphTransform,
)
from matsciml.lightning.data_utils import MatSciMLDataModule
from matsciml.models import SchNet
from matsciml.models.base import ScalarRegressionTask

from matsciml.lightning.callbacks import LossScalingScheduler
from matsciml.lightning.loss_scaling import SigmoidScalingSchedule

"""
This script demonstrates how to add loss scaling schedules
to training runs.
"""

# construct a scalar regression task with SchNet encoder
task = ScalarRegressionTask(
encoder_class=SchNet,
# kwargs to be passed into the creation of SchNet model
encoder_kwargs={
"encoder_only": True,
"hidden_feats": [128, 128, 128],
"atom_embedding_dim": 128,
},
output_kwargs={"lazy": False, "hidden_dim": 128, "input_dim": 128},
# which keys to use as targets
task_keys=["energy_relaxed"],
)

# Use IS2RE devset to test workflow
# SchNet uses RBFs, and expects edge features corresponding to atom-atom distances
dm = MatSciMLDataModule.from_devset(
"IS2REDataset",
dset_kwargs={
"transforms": [
PeriodicPropertiesTransform(6.0, True),
PointCloudToGraphTransform(
"dgl",
node_keys=["pos", "atomic_numbers"],
),
DistancesTransform(),
],
},
)

# run several epochs with a limited number of train batches
# to make sure nothing breaks between updates
trainer = pl.Trainer(
max_epochs=10,
limit_train_batches=10,
logger=False,
enable_checkpointing=False,
callbacks=[
LossScalingScheduler(
SigmoidScalingSchedule(
"energy_relaxed",
initial_value=10.0, # the first value will not be this exactly
end_value=1.0, # but close to it, due to nature of sigmoid
center_frac=0.5, # this means the sigmoid flips at half the total steps
curvature=1e-7, # can be modified to change ramping behavior
step_frequency="step",
),
log_level="DEBUG", # this makes it verbose, but setting it to INFO will surpress most
)
],
)
trainer.fit(task, datamodule=dm)
# print out the final scaling rates
print(task.task_loss_scaling)
102 changes: 102 additions & 0 deletions matsciml/lightning/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,25 @@
from matsciml.datasets.utils import concatenate_keys
from matsciml.models.base import BaseTaskModule
from matsciml.common.types import Embeddings, BatchDict
from matsciml.lightning.loss_scaling import BaseScalingSchedule

__all__ = [
"LeaderboardWriter",
"GradientCheckCallback",
"UnusedParametersCallback",
"ThroughputCallback",
"ForwardNaNDetection",
"ManualGradientClip",
"MonitorGradients",
"GarbageCallback",
"InferenceWriter",
"CodeCarbonCallback",
"SAM",
"TrainingHelperCallback",
"ModelAutocorrelation",
"ExponentialMovingAverageCallback",
"LossScalingScheduler",
]


class LeaderboardWriter(BasePredictionWriter):
Expand Down Expand Up @@ -1492,3 +1511,86 @@ def on_fit_end(
loader = trainer.train_dataloader
self.logger.info("Fit finished - updating EMA batch normalization state.")
update_bn(loader, pl_module.ema_module)


class LossScalingScheduler(Callback):
def __init__(
self,
*schedules: BaseScalingSchedule,
log_level: Literal["INFO", "DEBUG", "WARNING", "CRITICAL"] = "INFO",
) -> None:
"""
Callback for dynamically adjusting loss scaling values over
the course of training, a la curriculum learning.

This class is configured by supplying a list of schedules
as args; see `matsciml.lightning.loss_scaling` module for
available schedules. Each schedule instance has a `key`
attribute that points it to the corresponding task key
as set in the Lightning task module (e.g. `energy`, `force`).

Parameters
----------
args : BaseScalingSchedule
Scaling schedules for as many tasks as being performed.
"""
super().__init__()
assert len(schedules) > 0, "Must pass individual schedules to loss scheduler!"
self.schedules = schedules
self._logger = getLogger("matsciml.loss_scaling_scheduler")
self._logger.setLevel(log_level)
self._logger.debug(f"Configured {len(self.schedules)} schedules.")
self._logger.debug(
f"Schedules have {[s.key for s in self.schedules]} task keys."
)

def on_fit_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
) -> None:
trainer.datamodule.setup("fit")
for schedule in self.schedules:
# check to make sure the schedule key actually exists in the task
if schedule.key not in pl_module.task_keys:
raise KeyError(
f"Schedule for {schedule.key} expected, but not specified as a task key!"
)
# schedules grab what information they need from the
# trainer and task modules
schedule.setup(trainer, pl_module)
self._logger.debug("Configured {schedule.key} schedule.")

def _step_schedules(
self, pl_module: "pl.LightningModule", stage: Literal["step", "epoch"]
) -> None:
"""Base function to step schedules according to what stage we are in."""
for schedule in self.schedules:
if schedule.step_frequency == stage:
target_key = schedule.key
self._logger.debug(
f"Attempting to advance {target_key} schedule on {stage}."
)
try:
new_scaling_value = schedule.step()
pl_module.task_loss_scaling[target_key] = new_scaling_value
self._logger.debug(
f"Advanced {target_key} to new value: {new_scaling_value}"
)
except StopIteration:
self._logger.warning(
f"{target_key} has run out of scheduled values; this may be unintentional."
)

def on_train_batch_end(
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
outputs: Any,
batch: Any,
batch_idx: int,
) -> None:
self._step_schedules(pl_module, "step")

def on_train_epoch_end(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
) -> None:
self._step_schedules(pl_module, "epoch")
Loading
Loading