Skip to content

Commit

Permalink
Merge pull request #129 from point-cloud-radar/scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
lyashevska authored Jul 10, 2023
2 parents e3587a0 + 3c3676a commit 0ad3dfc
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions bird_cloud_gnn/gnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ def fit_and_evaluate(
callback=None,
learning_rate=0.01,
num_epochs=20,
sch_explr_gamma=0.99,
sch_multisteplr_milestones=None,
sch_multisteplr_gamma=0.1,
):
"""Fit the model while evaluating every iteraction.
Expand All @@ -135,9 +138,24 @@ def fit_and_evaluate(
Defaults to None.
learning_rate (float, optional): Learning rate. Defaults to 0.01.
num_epochs (int, optional): Number of training epochs. Defaults to 20.
sch_explr_gamma (float): The exponential decay rate of the learning rate.
sch_multisteplr_milestones (list): epoch numbers where the learning rate is decreased
by a factor of sch_multisteplr_gamma. If None this is done at epoch 100
sch_multisteplr_gamma (float): If a stepped decay of the learning rate is taken,
the multiplication factor
"""
if sch_multisteplr_milestones is None:
sch_multisteplr_milestones = [min(num_epochs, 100)]
progress_bar = tqdm(total=num_epochs)
optimizer = optim.Adam(self.parameters(), lr=learning_rate)
schedulers = [
optim.lr_scheduler.ExponentialLR(optimizer, gamma=sch_explr_gamma),
optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=sch_multisteplr_milestones,
gamma=sch_multisteplr_gamma,
),
]
epoch_values = {}
for epoch in range(num_epochs):
epoch_values["epoch"] = epoch
Expand Down Expand Up @@ -204,13 +222,18 @@ def fit_and_evaluate(
epoch_values["Accuracy/test"] = num_correct / num_total
epoch_values["Layer/conv1"] = self.conv1.weight.detach()
epoch_values["Layer/conv2"] = self.conv2.weight.detach()
for i, pg in enumerate(optimizer.param_groups):
epoch_values[f"LearningRate/ParGrp{i}"] = pg["lr"]
if self.num_classes == 2:
epoch_values["FalseNegativeRate/test"] = num_false_negative / num_total
epoch_values["FalsePositiveRate/test"] = num_false_positive / num_total

progress_bar.set_postfix({"Epoch": epoch})
progress_bar.update(1)

for scheduler in schedulers:
scheduler.step()

if callback is not None:
user_request_stop = callback(epoch_values)
if user_request_stop is True: # Check for explicit True
Expand Down

0 comments on commit 0ad3dfc

Please sign in to comment.