Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing loss scaling scheduler callback and schedulers #270

Merged
merged 38 commits into from
Aug 9, 2024
Merged
Changes from 1 commit
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
9d933d4
feat: implemented loss scaling base class
laserkelvin Aug 7, 2024
9f5f062
feat: added property support for schedule grid
laserkelvin Aug 7, 2024
6d3b910
feat: added key property for base class
laserkelvin Aug 7, 2024
ba2995f
feat: adding linear schedule child
laserkelvin Aug 7, 2024
dd32486
refactor: added linear schedule ramp
laserkelvin Aug 7, 2024
2050da1
feat: implemented setup method for linear schedule
laserkelvin Aug 7, 2024
ad706d8
refactor: defining __all__ for loss scaling module
laserkelvin Aug 7, 2024
d51ea6d
refactor: making schedule a cached property
laserkelvin Aug 7, 2024
0581f9a
fix: correcting linear step logic by applying negative sign as needed
laserkelvin Aug 7, 2024
ce2f30f
test: added parametrized linear scheduler test
laserkelvin Aug 7, 2024
764477a
feat: implemented loss schedule callback class
laserkelvin Aug 8, 2024
67c8d3d
test: added unit test to ensure end-to-end linear scaling passes
laserkelvin Aug 8, 2024
778b692
fix: correcting type annotation for callback
laserkelvin Aug 8, 2024
b5a70d1
refactor: add exception handling for unexpected task keys
laserkelvin Aug 8, 2024
93cdce3
test: added unit test to catch unexpected task keys
laserkelvin Aug 8, 2024
c967e31
refactor: setting initial value of scaling as part of setup
laserkelvin Aug 8, 2024
eae15c3
refactor: making setup check task keys, not scaling keys
laserkelvin Aug 8, 2024
d5e1ef7
refactor & test: adding check to make sure task scaling has changed
laserkelvin Aug 8, 2024
a81eae2
test: added end value check to linear schedule as well
laserkelvin Aug 8, 2024
bd1f675
test: added unit test for epoch stepping
laserkelvin Aug 8, 2024
37d7393
docs: updated unit test docstrings
laserkelvin Aug 8, 2024
240ed17
refactor: moved setup to base class
laserkelvin Aug 8, 2024
11812f3
feat: implemented sigmoid schedule
laserkelvin Aug 8, 2024
227c059
docs: added docstring for sigmoid scaling schedule
laserkelvin Aug 8, 2024
5e50058
feat: adding sigmoid scaling to __all__
laserkelvin Aug 8, 2024
26ee19a
refactor: added messages and assertions for sigmoid values
laserkelvin Aug 8, 2024
38c1c76
test: added unit test for sigmoid schedule
laserkelvin Aug 8, 2024
f08a2a4
test: added unit test for sigmoid scaling in training loop
laserkelvin Aug 8, 2024
28e8ba5
refactor & tests: making egnn much smaller
laserkelvin Aug 8, 2024
61ba436
refactor: adding __all__ definition in callbacks
laserkelvin Aug 8, 2024
232837c
refactor: making sure that trainer sets up dataloaders
laserkelvin Aug 8, 2024
ccb4174
refactor & fix: fixed step count condition
laserkelvin Aug 8, 2024
bf9cbf6
fix: remapping initial and end values in sigmoid equation
laserkelvin Aug 8, 2024
a270156
refactor: added step value logging
laserkelvin Aug 8, 2024
b8acf7b
feat: added script that demonstrates loss schedule usage
laserkelvin Aug 8, 2024
d3b97da
refactor: combining functionality into unified function
laserkelvin Aug 9, 2024
4887b1c
fix & test: correcting assertion for end scheduler value check
laserkelvin Aug 9, 2024
75ed137
feat: added repr method for schedules
laserkelvin Aug 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat: adding linear schedule child
  • Loading branch information
laserkelvin committed Aug 7, 2024
commit ba2995f36ab1348f1d18cd48cad90d8397ba244f
17 changes: 17 additions & 0 deletions matsciml/lightning/loss_scaling.py
Original file line number Diff line number Diff line change
@@ -64,3 +64,20 @@ def schedule(self) -> Generator[float, None, None]:
def setup(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Configures the schedule by grabbing whatever is needed from trainer/module"""
...


class LinearScalingSchedule(BaseScalingSchedule):
def __init__(
self,
key: str,
initial_value: float,
end_value: float | None = None,
step_frequency: Literal["step", "epoch"] = "epoch",
) -> None:
super().__init__()
self.key = key
self.initial_value = initial_value
if not end_value:
end_value = initial_value
self.end_value = end_value
self.step_frequency = step_frequency