-
Notifications
You must be signed in to change notification settings - Fork 485
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Major] Support Re-Training #1635
base: main
Are you sure you want to change the base?
Conversation
Model Benchmark
|
30c106f
to
6a74680
Compare
@@ -104,18 +122,21 @@ class Train: | |||
n_data: int = field(init=False) | |||
loss_func_name: str = field(init=False) | |||
lr_finder_args: dict = field(default_factory=dict) | |||
optimizer_state: dict = field(default_factory=dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move to separate PR
"three_phase": True, | ||
} | ||
) | ||
if self.continue_training: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move to other PR
@@ -239,6 +304,9 @@ def get_reg_delay_weight(self, e, iter_progress, reg_start_pct: float = 0.66, re | |||
delay_weight = 1 | |||
return delay_weight | |||
|
|||
def set_optimizer_state(self, optimizer_state: dict): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move to other PR
-------- | ||
>>> from neuralprophet import NeuralProphet | ||
>>> # Step Learning Rate scheduler | ||
>>> m = NeuralProphet(scheduler="StepLR") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add scheduler args example
newer_samples_start=newer_samples_start, | ||
trend_reg_threshold=self.config_trend.trend_reg_threshold, | ||
) | ||
self.learning_rate = learning_rate |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
redo init of Train
self.n_forecasts = n_forecasts | ||
|
||
# Lightning Config | ||
self.config_train = config_train | ||
self.config_normalization = config_normalization | ||
self.compute_components_flag = compute_components_flag | ||
|
||
# Continued training |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
separate PR
# Optimizer | ||
optimizer = self._optimizer(self.parameters(), lr=self.learning_rate, **self.config_train.optimizer_args) | ||
|
||
if self.continue_training: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
separate PR
return config_train_params | ||
|
||
|
||
def test_continue_training(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
modify to do without checkpoints
assert metrics["Loss"].min() >= metrics2["Loss"].min() | ||
|
||
|
||
def test_continue_training_with_scheduler_selection(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
modify to work without checkpoints
assert metrics["Loss"].min() >= metrics2["Loss"].min() | ||
|
||
|
||
def test_save_load_continue_training(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
modify to work without checkpoints
Continuation of #1605 on repository branch to enable metrics CI to run.