Skip to content

Commit

Permalink
fix hpo bug
Browse files Browse the repository at this point in the history
  • Loading branch information
AnFreTh committed Dec 2, 2024
1 parent ca9dff0 commit 6d5f843
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 66 deletions.
66 changes: 44 additions & 22 deletions mambular/models/sklearn_base_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,15 @@ def optimize_hparams(
max_epochs=200,
prune_by_epoch=True,
prune_epoch=5,
fixed_params={
"pooling_method": "avg",
"head_skip_layers": False,
"head_layer_size_length": 0,
"cat_encoding": "int",
"head_skip_layer": False,
"use_cls": False,
},
custom_search_space=None,
**optimize_kwargs,
):
"""
Expand Down Expand Up @@ -656,7 +665,11 @@ def optimize_hparams(
"""

# Define the hyperparameter search space from the model config
param_names, param_space = get_search_space(self.config)
param_names, param_space = get_search_space(
self.config,
fixed_params=fixed_params,
custom_search_space=custom_search_space,
)

# Initial model fitting to get the baseline validation loss
self.fit(X, y, X_val=X_val, y_val=y_val, max_epochs=max_epochs)
Expand Down Expand Up @@ -727,32 +740,41 @@ def _objective(hyperparams):
self.task_model.pruning_epoch = prune_epoch

# Fit the model (limit epochs for faster optimization)
self.fit(
X, y, X_val=X_val, y_val=y_val, max_epochs=max_epochs, rebuild=False
)
try:
# Wrap the risky operation (model fitting) in a try-except block
self.fit(
X, y, X_val=X_val, y_val=y_val, max_epochs=max_epochs, rebuild=False
)

# Retrieve the current validation loss
if X_val is not None and y_val is not None:
val_loss = self.evaluate(
X_val, y_val, metrics={"Accuracy": (accuracy_score, False)}
)["Accuracy"]
else:
val_loss = self.trainer.validate(self.task_model, self.data_module)[0][
"val_loss"
]
# Evaluate validation loss
if X_val is not None and y_val is not None:
val_loss = self.evaluate(
X_val, y_val, metrics={"Mean Squared Error": mean_squared_error}
)["Mean Squared Error"]
else:
val_loss = self.trainer.validate(self.task_model, self.data_module)[
0
]["val_loss"]

# Pruning based on validation loss at specific epoch
epoch_val_loss = self.task_model.epoch_val_loss_at(prune_epoch)

# Retrieve validation loss at the specified epoch (e.g., epoch 5)
epoch_val_loss = self.task_model.epoch_val_loss_at(prune_epoch)
if prune_by_epoch and epoch_val_loss < best_epoch_val_loss:
best_epoch_val_loss = epoch_val_loss

# Update the best validation loss at the specified epoch
if prune_by_epoch and epoch_val_loss < best_epoch_val_loss:
best_epoch_val_loss = epoch_val_loss
if val_loss < best_val_loss:
best_val_loss = val_loss

# Update the best overall validation loss
if val_loss < best_val_loss:
best_val_loss = val_loss
return val_loss

return val_loss
except Exception as e:
# Penalize the hyperparameter configuration with a large value
print(
f"Error encountered during fit with hyperparameters {hyperparams}: {e}"
)
return (
best_val_loss * 100
) # Large value to discourage this configuration

# Perform Bayesian optimization using scikit-optimize
result = gp_minimize(_objective, param_space, n_calls=time, random_state=42)
Expand Down
65 changes: 44 additions & 21 deletions mambular/models/sklearn_base_lss.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,15 @@ def optimize_hparams(
max_epochs=200,
prune_by_epoch=True,
prune_epoch=5,
fixed_params={
"pooling_method": "avg",
"head_skip_layers": False,
"head_layer_size_length": 0,
"cat_encoding": "int",
"head_skip_layer": False,
"use_cls": False,
},
custom_search_space=None,
**optimize_kwargs,
):
"""
Expand Down Expand Up @@ -643,7 +652,11 @@ def optimize_hparams(
"""

# Define the hyperparameter search space from the model config
param_names, param_space = get_search_space(self.config)
param_names, param_space = get_search_space(
self.config,
fixed_params=fixed_params,
custom_search_space=custom_search_space,
)

# Initial model fitting to get the baseline validation loss
self.fit(X, y, X_val=X_val, y_val=y_val, max_epochs=max_epochs)
Expand Down Expand Up @@ -714,31 +727,41 @@ def _objective(hyperparams):
self.task_model.early_pruning_threshold = early_pruning_threshold
self.task_model.pruning_epoch = prune_epoch

# Fit the model (limit epochs for faster optimization)
self.fit(
X, y, X_val=X_val, y_val=y_val, max_epochs=max_epochs, rebuild=False
)
try:
# Wrap the risky operation (model fitting) in a try-except block
self.fit(
X, y, X_val=X_val, y_val=y_val, max_epochs=max_epochs, rebuild=False
)

# Retrieve the current validation loss
if X_val is not None and y_val is not None:
val_loss = self.score(X_val, y_val)
else:
val_loss = self.trainer.validate(self.task_model, self.data_module)[0][
"val_loss"
]
# Evaluate validation loss
if X_val is not None and y_val is not None:
val_loss = self.evaluate(
X_val, y_val, metrics={"Mean Squared Error": mean_squared_error}
)["Mean Squared Error"]
else:
val_loss = self.trainer.validate(self.task_model, self.data_module)[
0
]["val_loss"]

# Pruning based on validation loss at specific epoch
epoch_val_loss = self.task_model.epoch_val_loss_at(prune_epoch)

# Retrieve validation loss at the specified epoch (e.g., epoch 5)
epoch_val_loss = self.task_model.epoch_val_loss_at(prune_epoch)
if prune_by_epoch and epoch_val_loss < best_epoch_val_loss:
best_epoch_val_loss = epoch_val_loss

# Update the best validation loss at the specified epoch
if prune_by_epoch and epoch_val_loss < best_epoch_val_loss:
best_epoch_val_loss = epoch_val_loss
if val_loss < best_val_loss:
best_val_loss = val_loss

# Update the best overall validation loss
if val_loss < best_val_loss:
best_val_loss = val_loss
return val_loss

return val_loss
except Exception as e:
# Penalize the hyperparameter configuration with a large value
print(
f"Error encountered during fit with hyperparameters {hyperparams}: {e}"
)
return (
best_val_loss * 100
) # Large value to discourage this configuration

# Perform Bayesian optimization using scikit-optimize
result = gp_minimize(_objective, param_space, n_calls=time, random_state=42)
Expand Down
67 changes: 44 additions & 23 deletions mambular/models/sklearn_base_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,15 @@ def optimize_hparams(
max_epochs=200,
prune_by_epoch=True,
prune_epoch=5,
fixed_params={
"pooling_method": "avg",
"head_skip_layers": False,
"head_layer_size_length": 0,
"cat_encoding": "int",
"head_skip_layer": False,
"use_cls": False,
},
custom_search_space=None,
**optimize_kwargs,
):
"""
Expand Down Expand Up @@ -556,7 +565,11 @@ def optimize_hparams(
"""

# Define the hyperparameter search space from the model config
param_names, param_space = get_search_space(self.config)
param_names, param_space = get_search_space(
self.config,
fixed_params=fixed_params,
custom_search_space=custom_search_space,
)

# Initial model fitting to get the baseline validation loss
self.fit(X, y, X_val=X_val, y_val=y_val, max_epochs=max_epochs)
Expand Down Expand Up @@ -626,33 +639,41 @@ def _objective(hyperparams):
self.task_model.early_pruning_threshold = early_pruning_threshold
self.task_model.pruning_epoch = prune_epoch

# Fit the model (limit epochs for faster optimization)
self.fit(
X, y, X_val=X_val, y_val=y_val, max_epochs=max_epochs, rebuild=False
)
try:
# Wrap the risky operation (model fitting) in a try-except block
self.fit(
X, y, X_val=X_val, y_val=y_val, max_epochs=max_epochs, rebuild=False
)

# Retrieve the current validation loss
if X_val is not None and y_val is not None:
val_loss = self.evaluate(
X_val, y_val, metrics={"Mean Squared Error": mean_squared_error}
)["Mean Squared Error"]
else:
val_loss = self.trainer.validate(self.task_model, self.data_module)[0][
"val_loss"
]
# Evaluate validation loss
if X_val is not None and y_val is not None:
val_loss = self.evaluate(
X_val, y_val, metrics={"Mean Squared Error": mean_squared_error}
)["Mean Squared Error"]
else:
val_loss = self.trainer.validate(self.task_model, self.data_module)[
0
]["val_loss"]

# Pruning based on validation loss at specific epoch
epoch_val_loss = self.task_model.epoch_val_loss_at(prune_epoch)

# Retrieve validation loss at the specified epoch (e.g., epoch 5)
epoch_val_loss = self.task_model.epoch_val_loss_at(prune_epoch)
if prune_by_epoch and epoch_val_loss < best_epoch_val_loss:
best_epoch_val_loss = epoch_val_loss

# Update the best validation loss at the specified epoch
if prune_by_epoch and epoch_val_loss < best_epoch_val_loss:
best_epoch_val_loss = epoch_val_loss
if val_loss < best_val_loss:
best_val_loss = val_loss

# Update the best overall validation loss
if val_loss < best_val_loss:
best_val_loss = val_loss
return val_loss

return val_loss
except Exception as e:
# Penalize the hyperparameter configuration with a large value
print(
f"Error encountered during fit with hyperparameters {hyperparams}: {e}"
)
return (
best_val_loss * 100
) # Large value to discourage this configuration

# Perform Bayesian optimization using scikit-optimize
result = gp_minimize(_objective, param_space, n_calls=time, random_state=42)
Expand Down

0 comments on commit 6d5f843

Please sign in to comment.