Skip to content

Commit

Permalink
fux bug
Browse files Browse the repository at this point in the history
  • Loading branch information
tvdboom committed Jun 13, 2023
1 parent 858116c commit a56393d
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 12 deletions.
12 changes: 4 additions & 8 deletions atom/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,14 +890,10 @@ def fit_model(
# the estimator (often changed by _get_parameters in models.py),
# we implement this hacky method to overwrite the params in storage
trial._cached_frozen_trial.params = params
for name, value in params.items():
distribution = trial.distributions[name]
trial.storage.set_trial_param(
trial_id=trial.number,
param_name=name,
param_value_internal=distribution.to_internal_repr(value),
distribution=distribution,
)
frozen_trial = self.study._storage._get_trial(trial.number)
frozen_trial.params = params
frozen_trial.distributions = self._ht["distributions"]
self.study._storage._set_trial(trial.number, frozen_trial)

# Store user defined tags
for key, value in self._ht["tags"].items():
Expand Down
8 changes: 4 additions & 4 deletions tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,8 +538,8 @@ def test_plot_hyperparameter_importance():

def test_plot_hyperparameters():
"""Assert that the plot_hyperparameters method works."""
atom = ATOMRegressor(X_reg, y_reg, random_state=1)
atom.run("Tree", n_trials=3)
atom = ATOMClassifier(X_bin, y_bin, random_state=1)
atom.run("lr", n_trials=3)

# Only one hyperparameter
with pytest.raises(ValueError, match=".*minimum of two parameters.*"):
Expand Down Expand Up @@ -569,8 +569,8 @@ def test_plot_pareto_front():

def test_plot_slice():
"""Assert that the plot_slice method works."""
atom = ATOMRegressor(X_reg, y_reg, random_state=1)
atom.run("tree", metric=["mae", "mse"], n_trials=3)
atom = ATOMClassifier(X_bin, y_bin, random_state=1)
atom.run("lr", metric=["f1", "recall"], n_trials=3)
atom.plot_slice(display=False)


Expand Down

0 comments on commit a56393d

Please sign in to comment.