Skip to content

Commit

Permalink
[tune] Upgrade ax-platform (#36452)
Browse files Browse the repository at this point in the history
Our current installation of ax-platform is outdated and doesn't work with more recent pandas versions. This PR upgrades the ax-platform version and re-enables previously disabled tests.

Signed-off-by: Kai Fricke <kai@anyscale.com>
  • Loading branch information
krfricke authored Jun 15, 2023
1 parent ba7d490 commit 4071f8e
Show file tree
Hide file tree
Showing 4 changed files with 2 additions and 14 deletions.
4 changes: 0 additions & 4 deletions python/ray/tune/tests/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,8 +470,6 @@ def train(config):
self.assertSequenceEqual(integers_1, integers_2)
self.assertSequenceEqual(choices_1, choices_2)

# Todo: Upgrade ax. This will upgrade sub-dependencies that may break other parts.
@unittest.skip("ax tests currently failing (need to upgrade ax)")
def testConvertAx(self):
from ray.tune.search.ax import AxSearch
from ax.service.ax_client import AxClient
Expand Down Expand Up @@ -537,7 +535,6 @@ def testConvertAx(self):
self.assertTrue(5 <= config["a"] <= 6)
self.assertTrue(8 <= config["b"] <= 9)

@unittest.skip("ax tests currently failing (need to upgrade ax)")
def testSampleBoundsAx(self):
from ray.tune.search.ax import AxSearch
from ax.service.ax_client import AxClient
Expand Down Expand Up @@ -1689,7 +1686,6 @@ def _testPointsToEvaluate(self, cls, config, exact=True, **kwargs):
else:
self.assertDictEqual(trial_config_dict, points_to_evaluate[i])

@unittest.skip("ax tests currently failing (need to upgrade ax)")
def testPointsToEvaluateAx(self):
config = {
"metric": ray.tune.search.sample.Categorical([1, 2, 3, 4]).uniform(),
Expand Down
4 changes: 0 additions & 4 deletions python/ray/tune/tests/test_searchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,6 @@ def check_searcher_checkpoint_errors_scope(self):
for x in buffer
), "Searcher checkpointing failed (unable to serialize)."

# Todo: Upgrade ax. This will upgrade sub-dependencies that may break other parts.
@unittest.skip("ax tests currently failing (need to upgrade ax)")
def testAxManualSetup(self):
from ray.tune.search.ax import AxSearch
from ax.service.ax_client import AxClient
Expand Down Expand Up @@ -109,7 +107,6 @@ def testAxManualSetup(self):
self.assertLess(out.best_trial.config["mixed_list"][1], 3)
self.assertEqual(out.best_trial.config["mixed_list"][2], 4)

@unittest.skip("ax tests currently failing (need to upgrade ax)")
def testAx(self):
from ray.tune.search.ax import AxSearch

Expand Down Expand Up @@ -607,7 +604,6 @@ def _restore(self, searcher):
if hasattr(searcher, "_live_trial_mapping"):
assert "not_completed" in searcher._live_trial_mapping

@unittest.skip("ax tests currently failing (need to upgrade ax)")
def testAx(self):
from ray.tune.search.ax import AxSearch
from ax.service.ax_client import AxClient
Expand Down
1 change: 0 additions & 1 deletion python/ray/tune/tests/test_tune_restore_warm_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,6 @@ def cost(param, reporter):
return search_alg, cost


@unittest.skip("ax warm start tests currently failing (need to upgrade ax)")
class AxWarmStartTest(AbstractWarmStartTest, unittest.TestCase):
def set_basic_conf(self):
from ax.service.ax_client import AxClient
Expand Down
7 changes: 2 additions & 5 deletions python/requirements/ml/requirements_tune.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
-r requirements_dl.txt

aim==3.16.1
ax-platform[mysql]==0.2.6
# Newer version of gpytorch are incompatible with ax 0.2.6.
# Todo: Remove pin when upgrading ax
gpytorch==1.8.1

ax-platform[mysql]==0.2.6; python_version < '3.8'
ax-platform[mysql]==0.3.2; python_version >= '3.8'
bayesian-optimization==1.2.0
comet-ml==3.31.9
ConfigSpace==0.4.18
Expand Down

0 comments on commit 4071f8e

Please sign in to comment.