Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Major] Support Re-Training #1635

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open

[Major] Support Re-Training #1635

wants to merge 17 commits into from

Conversation

ourownstory
Copy link
Owner

Continuation of #1605 on repository branch to enable metrics CI to run.

Copy link

github-actions bot commented Aug 23, 2024

Model Benchmark

Benchmark Metric main current diff
YosemiteTemps MAE_val 0.59734 0.58159 -2.64%
YosemiteTemps RMSE_val 0.88884 0.86935 -2.19%
YosemiteTemps Loss_val 0.00046 0.00044 -4.33%
YosemiteTemps train_loss 0.00126 0.0012 -4.44%
YosemiteTemps reg_loss 0 0 0.0%
YosemiteTemps MAE 0.97461 0.94811 -2.72%
YosemiteTemps RMSE 1.70806 1.66847 -2.32%
YosemiteTemps Loss 0.00126 0.0012 -4.44%
YosemiteTemps time 141.75 139.93 -1.28%
EnergyPriceDaily MAE_val 5.64247 5.42935 -3.78%
EnergyPriceDaily RMSE_val 7.19972 6.88991 -4.3%
EnergyPriceDaily Loss_val 0.02893 0.02655 -8.24% 🎉
EnergyPriceDaily train_loss 0.02957 0.02758 -6.73% 🎉
EnergyPriceDaily reg_loss 0 0 0.0%
EnergyPriceDaily MAE 6.39654 6.15242 -3.82%
EnergyPriceDaily RMSE 8.56207 8.26192 -3.51%
EnergyPriceDaily Loss 0.02936 0.02739 -6.73% 🎉
EnergyPriceDaily time 40.4214 40.35 -0.18%
AirPassengers MAE_val 30.8306 30.081 -2.43%
AirPassengers RMSE_val 31.8167 30.9826 -2.62%
AirPassengers Loss_val 0.01301 0.01234 -5.17% 🎉
AirPassengers train_loss 0.00071 0.00071 -0.86%
AirPassengers reg_loss 0 0 0.0%
AirPassengers MAE 6.88169 6.86203 -0.29%
AirPassengers RMSE 8.92083 8.81789 -1.15%
AirPassengers Loss 0.00076 0.00074 -2.27%
AirPassengers time 10.2224 10.06 -1.59%
PeytonManning MAE_val 0.35533 0.35447 -0.24%
PeytonManning RMSE_val 0.50396 0.50324 -0.14%
PeytonManning Loss_val 0.01802 0.01796 -0.32%
PeytonManning train_loss 0.01466 0.01461 -0.3%
PeytonManning reg_loss 0 0 0.0%
PeytonManning MAE 0.34756 0.34738 -0.05%
PeytonManning RMSE 0.4945 0.49347 -0.21%
PeytonManning Loss 0.01465 0.01461 -0.3%
PeytonManning time 25.6879 25.48 -0.81%
Model training plots

Model Training

PeytonManning

YosemiteTemps

AirPassengers

EnergyPriceDaily

@ourownstory ourownstory changed the title [Major] Train continue [Major] Support Custom Learning Rate Scheduler Aug 26, 2024
@@ -104,18 +122,21 @@ class Train:
n_data: int = field(init=False)
loss_func_name: str = field(init=False)
lr_finder_args: dict = field(default_factory=dict)
optimizer_state: dict = field(default_factory=dict)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to separate PR

"three_phase": True,
}
)
if self.continue_training:
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to other PR

@@ -239,6 +304,9 @@ def get_reg_delay_weight(self, e, iter_progress, reg_start_pct: float = 0.66, re
delay_weight = 1
return delay_weight

def set_optimizer_state(self, optimizer_state: dict):
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to other PR

--------
>>> from neuralprophet import NeuralProphet
>>> # Step Learning Rate scheduler
>>> m = NeuralProphet(scheduler="StepLR")
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add scheduler args example

newer_samples_start=newer_samples_start,
trend_reg_threshold=self.config_trend.trend_reg_threshold,
)
self.learning_rate = learning_rate
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

redo init of Train

self.n_forecasts = n_forecasts

# Lightning Config
self.config_train = config_train
self.config_normalization = config_normalization
self.compute_components_flag = compute_components_flag

# Continued training
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

separate PR

# Optimizer
optimizer = self._optimizer(self.parameters(), lr=self.learning_rate, **self.config_train.optimizer_args)

if self.continue_training:
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

separate PR

return config_train_params


def test_continue_training():
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modify to do without checkpoints

assert metrics["Loss"].min() >= metrics2["Loss"].min()


def test_continue_training_with_scheduler_selection():
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modify to work without checkpoints

assert metrics["Loss"].min() >= metrics2["Loss"].min()


def test_save_load_continue_training():
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modify to work without checkpoints

@ourownstory ourownstory changed the title [Major] Support Custom Learning Rate Scheduler [Major] Support Re-Training Aug 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants