Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[build_base] [Tune] Add more comprehensive support for remote upload_dir w/ endpoint and params #32479

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 24 additions & 32 deletions python/ray/_private/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1572,46 +1572,38 @@ def no_resource_leaks_excluding_node_resources():


@contextmanager
def simulate_storage(storage_type, root=None):
def simulate_storage(
storage_type: str,
root: Optional[str] = None,
port: int = 5002,
region: str = "us-west-2",
):
"""Context that simulates a given storage type and yields the URI.
Args:
storage_type: The storage type to simiulate ("fs" or "s3")
root: Root directory of the URI to return (e.g., s3 bucket name)
port: The port of the localhost endpoint where s3 is being served (s3 only)
region: The s3 region (s3 only)
"""
if storage_type == "fs":
if root is None:
with tempfile.TemporaryDirectory() as d:
yield "file://" + d
else:
yield "file://" + root
elif storage_type == "s3":
import uuid

from moto import mock_s3

from ray.tests.mock_s3_server import start_service, stop_process

@contextmanager
def aws_credentials():
old_env = os.environ
os.environ["AWS_ACCESS_KEY_ID"] = "testing"
os.environ["AWS_SECRET_ACCESS_KEY"] = "testing"
os.environ["AWS_SECURITY_TOKEN"] = "testing"
os.environ["AWS_SESSION_TOKEN"] = "testing"
yield
os.environ = old_env

@contextmanager
def moto_s3_server():
host = "localhost"
port = 5002
url = f"http://{host}:{port}"
process = start_service("s3", host, port)
yield url
stop_process(process)

if root is None:
root = uuid.uuid4().hex
with moto_s3_server() as s3_server, aws_credentials(), mock_s3():
url = f"s3://{root}?region=us-west-2&endpoint_override={s3_server}"
yield url
from moto.server import ThreadedMotoServer

root = root or uuid.uuid4().hex
s3_server = f"http://localhost:{port}"
server = ThreadedMotoServer(port=port)
server.start()
url = f"s3://{root}?region={region}&endpoint_override={s3_server}"
yield url
server.stop()
else:
raise ValueError(f"Unknown storage type: {storage_type}")
raise NotImplementedError(f"Unknown storage type: {storage_type}")


def job_hook(**kwargs):
Expand Down
61 changes: 61 additions & 0 deletions python/ray/air/_internal/uri_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from pathlib import Path
import urllib.parse
import os
from typing import Union


class URI:
"""Represents a URI, supporting path appending and retrieving parent URIs.

Example Usage:

>>> s3_uri = URI("s3://bucket/a?scheme=http&endpoint_override=localhost%3A900")
>>> s3_uri
URI<s3://bucket/a?scheme=http&endpoint_override=localhost%3A900>
>>> str(s3_uri / "b" / "c")
's3://bucket/a/b/c?scheme=http&endpoint_override=localhost%3A900'
>>> str(s3_uri.parent)
's3://bucket?scheme=http&endpoint_override=localhost%3A900'
>>> str(s3_uri)
's3://bucket/a?scheme=http&endpoint_override=localhost%3A900'
>>> s3_uri.parent.name, s3_uri.name
('bucket', 'a')

Args:
uri: The URI to represent.
Ex: s3://bucket?scheme=http&endpoint_override=localhost%3A900
Ex: file:///a/b/c/d
"""

def __init__(self, uri: str):
self._parsed = urllib.parse.urlparse(uri)
if not self._parsed.scheme:
raise ValueError(f"Invalid URI: {uri}")
self._path = Path(os.path.normpath(self._parsed.netloc + self._parsed.path))

@property
def name(self) -> str:
return self._path.name

@property
def parent(self) -> "URI":
assert self._path.parent != ".", f"{str(self)} has no valid parent URI"
return URI(self._get_str_representation(self._parsed, self._path.parent))

def __truediv__(self, path_to_append):
assert isinstance(path_to_append, str)
return URI(
self._get_str_representation(self._parsed, self._path / path_to_append)
)

@classmethod
def _get_str_representation(
cls, parsed_uri: urllib.parse.ParseResult, path: Union[str, Path]
) -> str:
return parsed_uri._replace(netloc=str(path), path="").geturl()

def __repr__(self):
return f"URI<{str(self)}>"

def __str__(self):
return self._get_str_representation(self._parsed, self._path)
11 changes: 3 additions & 8 deletions python/ray/tune/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)

from ray.air import CheckpointConfig
from ray.air._internal.uri_utils import URI
from ray.tune.error import TuneError
from ray.tune.registry import register_trainable, is_function_trainable
from ray.tune.result import DEFAULT_RESULTS_DIR
Expand Down Expand Up @@ -443,14 +444,8 @@ def checkpoint_dir(self):
def remote_checkpoint_dir(self) -> Optional[str]:
if not self.sync_config.upload_dir or not self.dir_name:
return None

# NOTE: `upload_dir` can contain query strings. For example:
# 's3://bucket?scheme=http&endpoint_override=localhost%3A9000'.
if "?" in self.sync_config.upload_dir:
path, query = self.sync_config.upload_dir.split("?")
return os.path.join(path, self.dir_name) + "?" + query

return os.path.join(self.sync_config.upload_dir, self.dir_name)
uri = URI(self.sync_config.upload_dir)
return str(uri / self.dir_name)

@property
def run_identifier(self):
Expand Down
8 changes: 4 additions & 4 deletions python/ray/tune/experiment/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import ray
from ray.air import CheckpointConfig
from ray.air._internal.uri_utils import URI
from ray.air._internal.checkpoint_manager import _TrackedCheckpoint, CheckpointStorage
import ray.cloudpickle as cloudpickle
from ray.exceptions import RayActorError, RayTaskError
Expand Down Expand Up @@ -601,17 +602,16 @@ def generate_id(cls):
return str(uuid.uuid4().hex)[:8]

@property
def remote_checkpoint_dir(self):
def remote_checkpoint_dir(self) -> str:
"""This is the **per trial** remote checkpoint dir.
This is different from **per experiment** remote checkpoint dir.
"""
assert self.logdir, "Trial {}: logdir not initialized.".format(self)
if not self.sync_config.upload_dir or not self.experiment_dir_name:
return None
return os.path.join(
self.sync_config.upload_dir, self.experiment_dir_name, self.relative_logdir
)
uri = URI(self.sync_config.upload_dir)
return str(uri / self.experiment_dir_name / self.relative_logdir)

@property
def uses_cloud_checkpointing(self):
Expand Down
13 changes: 4 additions & 9 deletions python/ray/tune/impl/tuner_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import tempfile
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Type, Union, TYPE_CHECKING, Tuple
import urllib.parse

import ray
import ray.cloudpickle as pickle
from ray.util import inspect_serializability
from ray.air._internal.uri_utils import URI
from ray.air._internal.remote_storage import download_from_uri, is_non_local_path_uri
from ray.air.config import RunConfig, ScalingConfig
from ray.tune import Experiment, TuneError, ExperimentAnalysis
Expand Down Expand Up @@ -347,14 +347,9 @@ def _restore_from_path_or_uri(
self._run_config.name = experiment_path.name
else:
# Set the experiment `name` and `upload_dir` according to the URI
parsed_uri = urllib.parse.urlparse(path_or_uri)
remote_path = Path(os.path.normpath(parsed_uri.netloc + parsed_uri.path))
upload_dir = parsed_uri._replace(
netloc="", path=str(remote_path.parent)
).geturl()

self._run_config.name = remote_path.name
self._run_config.sync_config.upload_dir = upload_dir
uri = URI(path_or_uri)
self._run_config.name = uri.name
self._run_config.sync_config.upload_dir = str(uri.parent)

# If we synced, `experiment_checkpoint_dir` will contain a temporary
# directory. Create an experiment checkpoint dir instead and move
Expand Down
17 changes: 14 additions & 3 deletions python/ray/tune/tests/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,30 @@
import ray
from ray.air import CheckpointConfig
from ray.tune import register_trainable, SyncConfig
from ray.tune.experiment import Experiment, _convert_to_experiment_list
from ray.tune.experiment import Experiment, Trial, _convert_to_experiment_list
from ray.tune.error import TuneError
from ray.tune.utils import diagnose_serialization


def test_remote_checkpoint_dir_with_query_string():
def test_remote_checkpoint_dir_with_query_string(tmp_path):
sync_config = SyncConfig(syncer="auto", upload_dir="s3://bucket?scheme=http")
experiment = Experiment(
name="spam",
run=lambda config: config,
sync_config=SyncConfig(syncer="auto", upload_dir="s3://bucket?scheme=http"),
sync_config=sync_config,
)
assert experiment.remote_checkpoint_dir == "s3://bucket/spam?scheme=http"

trial = Trial(
"mock",
stub=True,
sync_config=sync_config,
experiment_dir_name="spam",
local_dir=str(tmp_path),
)
trial.relative_logdir = "trial_dirname"
assert trial.remote_checkpoint_dir == "s3://bucket/spam/trial_dirname?scheme=http"


class ExperimentTest(unittest.TestCase):
def tearDown(self):
Expand Down
Loading