Skip to content

Commit

Permalink
make classifier_params a property of cnn.BaseModule
Browse files Browse the repository at this point in the history
this makes it easier to override the definition of how to retrieve .parameters() that should be treated as the classifier, without having to override the entire configure_optimizers() function
  • Loading branch information
sammlapp committed Jan 25, 2025
1 parent de19cdc commit 6b8c2ca
Showing 1 changed file with 19 additions and 15 deletions.
34 changes: 19 additions & 15 deletions opensoundscape/ml/cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,14 @@ def training_step(self, samples, batch_idx):

# def predict_step(self, batch): #runs forward() if we don't override default

@property
def classifier_params(self):
"""return the parameters of the classifier layer of the network
override this method if the classifier parameters should be retrieved in a different way
"""
return self.classifier.parameters()

def configure_optimizers(
self,
reset_optimizer=False,
Expand Down Expand Up @@ -337,6 +345,7 @@ def configure_optimizers(
dictionary with keys "optimizer" and "scheduler" containing the
optimizer and learning rate scheduler objects to use during training
"""

if reset_optimizer:
self.optimizer = None
if restart_scheduler:
Expand All @@ -354,28 +363,26 @@ def configure_optimizers(
if self.optimizer_params["classifier_lr"] is not None:
# customize the learning rate of the classifier layer
try:
classifier_params = list(self.classifier.parameters())
# for some reason, I get tensor mismatch if I check whether
# parameters are in list. Instead, compare the objects' ids.
# Cannot check `param in param_list`. Instead, compare the objects' ids.
# see https://discuss.pytorch.org/t/confused-by-runtimeerror-when-checking-for-parameter-in-list/211308
classifier_param_ids = {id(p) for p in classifier_params}
classifier_param_ids = {id(p) for p in self.classifier_params}
# remove these parameters from their current group
for param_group in optimizer.param_groups:
param_group["params"] = [
p
for p in param_group["params"]
if id(p) not in classifier_param_ids
]
except Exception as e:
raise ValueError(
"Could not access self.classifier.parameters(). "
"Make sure self.classifier propoerty returns a torch.nn.Module object."
) from e
# remove these parameters from their current group
for param_group in optimizer.param_groups:
param_group["params"] = [
p
for p in param_group["params"]
if id(p) not in classifier_param_ids
]

# add them to a new group with custom learning rate
optimizer.add_param_group(
{
"params": classifier_params,
"params": self.classifier_params,
"lr": self.optimizer_params["classifier_lr"],
}
)
Expand Down Expand Up @@ -2303,9 +2310,6 @@ def use_resample_loss(model, train_df):
model.loss_fn = ResampleLoss(class_frequency)


# TODO: implement `classifier_lr` key in self.optimizer_params and use in configure_optimizers


@register_model_cls
class InceptionV3(SpectrogramClassifier):
"""Child of SpectrogramClassifier class for InceptionV3 architecture"""
Expand Down

0 comments on commit 6b8c2ca

Please sign in to comment.