Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tune] let categorical values return indices that get resolved in a separate step #31927

Merged
merged 34 commits into from
Feb 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
10e6e68
[Tune] Replace reference values in a config dict with placeholders.
Jan 29, 2023
8352fdf
Update python/ray/tune/search/placeholder.py
Jan 30, 2023
3ae088a
Update python/ray/tune/tests/test_trial_runner_3.py
Jan 30, 2023
d08d245
Update python/ray/tune/search/placeholder.py
Jan 30, 2023
00dc4b1
Update python/ray/tune/search/placeholder.py
Jan 30, 2023
5764140
Update python/ray/tune/execution/trial_runner.py
Jan 30, 2023
4c7c068
address review comments
Jan 31, 2023
7123c59
move to python/ray/tune/impl/
Jan 31, 2023
df03fe9
lint
Jan 31, 2023
06750ef
make sample_from take only the config dict
Jan 31, 2023
28173a6
fix unit tests
Jan 31, 2023
df16f84
make sure config is also resolved during restore
Feb 1, 2023
6fd25a0
Fix tuple value. Also fix placement group factory creation.
Feb 1, 2023
578493e
add a restore unit test
Feb 1, 2023
0687d9f
revert tuple to list
Feb 1, 2023
ce226f3
comments
Feb 1, 2023
5ee2c0d
Fix tuple value replacement
Feb 1, 2023
1c6ea12
fix other Domains
Feb 1, 2023
5469d1a
fixes
Feb 1, 2023
2f59ec4
only replace obj refs
Feb 1, 2023
2d32989
fix
Feb 1, 2023
43a216d
allow disable placehold replacement
Feb 1, 2023
0ef9368
do not cache placement_group_factor
Feb 1, 2023
9752ba6
fix placement_group_creation
Feb 1, 2023
0a9dd68
get rid of local_mode
Feb 1, 2023
cc4eefd
handle simple nested search spaces, and fix everything
Feb 2, 2023
f580276
lint
Feb 2, 2023
cad449d
really handle nested or complex Categorical options
Feb 2, 2023
eab008e
lint
Feb 2, 2023
0dfba7e
minor fix
Feb 2, 2023
9794732
Update python/ray/tune/tests/test_ray_trial_executor.py
Feb 6, 2023
373fff8
Update python/ray/tune/impl/placeholder.py
Feb 6, 2023
247d5f0
minor update
Feb 6, 2023
52935d2
ci
Feb 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/ray/tune/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,14 @@ py_test(
tags = ["team:ml", "exclusive", "medium_instance"],
)

py_test(
name = "test_placeholder",
size = "small",
srcs = ["tests/test_placeholder.py"],
deps = [":tune_lib"],
tags = ["team:ml", "exclusive"],
)

py_test(
name = "test_searchers",
size = "large",
Expand Down
4 changes: 2 additions & 2 deletions python/ray/tune/execution/ray_trial_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Callable, Dict, Iterable, List, Optional, Set, Union, Tuple

import ray
from ray.actor import ActorHandle
from ray.air import Checkpoint, AcquiredResources, ResourceRequest
from ray.air._internal.checkpoint_manager import CheckpointStorage, _TrackedCheckpoint
from ray.air.constants import COPY_DIRECTORY_CHECKPOINTS_INSTEAD_OF_MOVING_ENV
Expand Down Expand Up @@ -351,7 +352,7 @@ def get_ready_trial(self) -> Optional[Trial]:

return None

def _maybe_use_cached_actor(self, trial, logger_creator) -> Optional:
def _maybe_use_cached_actor(self, trial, logger_creator) -> Optional[ActorHandle]:
if not self._reuse_actors:
return None

Expand Down Expand Up @@ -426,7 +427,6 @@ def _setup_remote_runner(self, trial):
# configure the remote runner to use a noop-logger.
trial_config = copy.deepcopy(trial.config)
trial_config[TRIAL_INFO] = _TrialInfo(trial)

stdout_file, stderr_file = trial.log_to_file
trial_config[STDOUT_FILE] = stdout_file
trial_config[STDERR_FILE] = stderr_file
Expand Down
15 changes: 13 additions & 2 deletions python/ray/tune/execution/trial_runner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections import defaultdict
from dataclasses import dataclass
from typing import DefaultDict, List, Optional, Union, Tuple, Set
from typing import Any, DefaultDict, Dict, List, Optional, Union, Tuple, Set

import click
from datetime import datetime
Expand Down Expand Up @@ -331,7 +331,9 @@ class TrialRunner:

def __init__(
self,
*,
search_alg: Optional[SearchAlgorithm] = None,
gjoliver marked this conversation as resolved.
Show resolved Hide resolved
placeholder_resolvers: Optional[Dict[Tuple, Any]] = None,
scheduler: Optional[TrialScheduler] = None,
local_checkpoint_dir: Optional[str] = None,
sync_config: Optional[SyncConfig] = None,
Expand All @@ -349,6 +351,7 @@ def __init__(
driver_sync_trial_checkpoints: bool = False,
):
self._search_alg = search_alg or BasicVariantGenerator()
self._placeholder_resolvers = placeholder_resolvers
self._scheduler_alg = scheduler or FIFOScheduler()
self.trial_executor = trial_executor or RayTrialExecutor()
self._callbacks = CallbackList(callbacks or [])
Expand Down Expand Up @@ -823,7 +826,6 @@ def resume(
if not ray.util.client.ray.is_connected():
trial.init_logdir() # Create logdir if it does not exist

trial.refresh_default_resource_request()
trials.append(trial)

# 4. Set trial statuses according to the resume configuration
Expand Down Expand Up @@ -1085,6 +1087,14 @@ def add_trial(self, trial: Trial):
Args:
trial: Trial to queue.
"""
# If the config map has had all the references replaced with placeholders,
# resolve them before adding the trial.
if self._placeholder_resolvers:
trial.resolve_config_placeholders(self._placeholder_resolvers)

# With trial.config resolved, create placement group factory if needed.
trial.create_placement_group_factory()

self._trials.append(trial)
if trial.status != Trial.TERMINATED:
self._live_trials.add(trial)
Expand Down Expand Up @@ -1586,6 +1596,7 @@ def __getstate__(self):
"_stop_queue",
"_server",
"_search_alg",
"_placeholder_resolvers",
"_scheduler_alg",
"_pending_trial_queue_times",
"trial_executor",
Expand Down
101 changes: 63 additions & 38 deletions python/ray/tune/experiment/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import re
import shutil
import time
from typing import Dict, Optional, Sequence, Union, Callable, List, Tuple
from typing import Any, Dict, Optional, Sequence, Union, Callable, List, Tuple
import uuid

import ray
Expand Down Expand Up @@ -239,6 +239,8 @@ class Trial:
"param_config",
"extra_arg",
"placement_group_factory",
"_resources",
"_default_placement_group_factory",
]

PENDING = "PENDING"
Expand Down Expand Up @@ -292,42 +294,25 @@ def __init__(
# Trial config
self.trainable_name = trainable_name
self.trial_id = Trial.generate_id() if trial_id is None else trial_id
self.config = config or {}
self._local_dir = local_dir # This remains unexpanded for syncing.

self.config = config or {}
# Save a copy of the original unresolved config so that we can swap
# out and update any reference config values after restoration.
self.__unresolved_config = self.config
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OOC, why double underscore?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very private. nobody should touch or have access to this variable outside of Trial. under the hood, python replace this variable with a name classname__variablename__.


# Parameters that Tune varies across searches.
self.evaluated_params = evaluated_params or {}
self.experiment_tag = experiment_tag
self.location = _Location()
trainable_cls = self.get_trainable_cls()
if trainable_cls and _setup_default_resource:
default_resources = trainable_cls.default_resource_request(self.config)

# If Trainable returns resources, do not allow manual override via
# `resources_per_trial` by the user.
if default_resources:
if resources or placement_group_factory:
raise ValueError(
"Resources for {} have been automatically set to {} "
"by its `default_resource_request()` method. Please "
"clear the `resources_per_trial` option.".format(
trainable_cls, default_resources
)
)

if isinstance(default_resources, PlacementGroupFactory):
placement_group_factory = default_resources
resources = None
else:
placement_group_factory = None
resources = default_resources

self.placement_group_factory = _to_pg_factory(
resources, placement_group_factory
)

self.stopping_criterion = stopping_criterion or {}

self._setup_default_resource = _setup_default_resource
self._resources = resources
self._default_placement_group_factory = placement_group_factory
# Will be created in create_placement_group_factory().
self.placement_group_factory = None

self.log_to_file = log_to_file
# Make sure `stdout_file, stderr_file = Trial.log_to_file` works
if (
Expand Down Expand Up @@ -417,6 +402,48 @@ def __init__(
self._state_json = None
self._state_valid = False

def create_placement_group_factory(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we should call this implicitly when Trial.placement_group_factory. But also ok to keep this for now

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea. added a TODO for now.
if there are tests failing because of this, I will turn it into a getter.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should do this in any case (turn into getter and create PGF on first call). Otherwise it increases the complexity of the Trial class

"""Compute placement group factor if needed.

Note: this must be called after all the placeholders in
self.config are resolved.
"""
trainable_cls = self.get_trainable_cls()
if not trainable_cls or not self._setup_default_resource:
# Create placement group factory using default resources.
self.placement_group_factory = _to_pg_factory(
self._resources, self._default_placement_group_factory
)
return

default_resources = trainable_cls.default_resource_request(self.config)

# If Trainable returns resources, do not allow manual override via
# `resources_per_trial` by the user.
if default_resources:
if self._resources or self._default_placement_group_factory:
raise ValueError(
"Resources for {} have been automatically set to {} "
"by its `default_resource_request()` method. Please "
"clear the `resources_per_trial` option.".format(
trainable_cls, default_resources
)
)

if isinstance(default_resources, PlacementGroupFactory):
default_placement_group_factory = default_resources
resources = None
else:
default_placement_group_factory = None
resources = default_resources
else:
default_placement_group_factory = self._default_placement_group_factory
resources = self._resources

self.placement_group_factory = _to_pg_factory(
resources, default_placement_group_factory
)

def _get_default_result_or_future(self) -> Optional[dict]:
"""Calls ray.get on self._default_result_or_future and assigns back.

Expand All @@ -439,6 +466,13 @@ def _get_default_result_or_future(self) -> Optional[dict]:
)
return self._default_result_or_future

def resolve_config_placeholders(self, placeholder_resolvers: Dict[Tuple, Any]):
from ray.tune.impl.placeholder import resolve_placeholders

# Make a copy of the unresolved config before resolve it.
self.config = copy.deepcopy(self.__unresolved_config)
resolve_placeholders(self.config, placeholder_resolvers)

@property
def last_result(self) -> dict:
# The logic in here is as follows:
Expand Down Expand Up @@ -654,15 +688,6 @@ def update_resources(

self.invalidate_json_state()

def refresh_default_resource_request(self):
"""Update trial resources according to the trainable's default resource
request, if it is provided."""
trainable_cls = self.get_trainable_cls()
if trainable_cls:
default_resources = trainable_cls.default_resource_request(self.config)
if default_resources:
self.update_resources(default_resources)

def set_runner(self, runner):
self.runner = runner
if runner:
Expand Down
Loading