Skip to content

Commit

Permalink
[build_base] [Tune] Add more comprehensive support for remote `upload…
Browse files Browse the repository at this point in the history
…_dir` w/ endpoint and params (ray-project#32479)

Currently, URI handling with parameters is done in multiple places in different ways (using `urllib.parse` or splitting by `'?'` directly). In some places, it's not done at all, which **causes errors when performing cloud checkpointing.** In particular, `Trial.remote_checkpoint_dir` and `Trainable._storage_path` do not handle URI path appends correctly when URL params are present.

Signed-off-by: Justin Yu <justinvyu@berkeley.edu>
Signed-off-by: Edward Oakes <ed.nmi.oakes@gmail.com>
  • Loading branch information
justinvyu authored and edoakes committed Mar 22, 2023
1 parent cee30c1 commit 9931d00
Show file tree
Hide file tree
Showing 8 changed files with 194 additions and 79 deletions.
56 changes: 24 additions & 32 deletions python/ray/_private/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1572,46 +1572,38 @@ def no_resource_leaks_excluding_node_resources():


@contextmanager
def simulate_storage(storage_type, root=None):
def simulate_storage(
storage_type: str,
root: Optional[str] = None,
port: int = 5002,
region: str = "us-west-2",
):
"""Context that simulates a given storage type and yields the URI.
Args:
storage_type: The storage type to simiulate ("fs" or "s3")
root: Root directory of the URI to return (e.g., s3 bucket name)
port: The port of the localhost endpoint where s3 is being served (s3 only)
region: The s3 region (s3 only)
"""
if storage_type == "fs":
if root is None:
with tempfile.TemporaryDirectory() as d:
yield "file://" + d
else:
yield "file://" + root
elif storage_type == "s3":
import uuid

from moto import mock_s3

from ray.tests.mock_s3_server import start_service, stop_process

@contextmanager
def aws_credentials():
old_env = os.environ
os.environ["AWS_ACCESS_KEY_ID"] = "testing"
os.environ["AWS_SECRET_ACCESS_KEY"] = "testing"
os.environ["AWS_SECURITY_TOKEN"] = "testing"
os.environ["AWS_SESSION_TOKEN"] = "testing"
yield
os.environ = old_env

@contextmanager
def moto_s3_server():
host = "localhost"
port = 5002
url = f"http://{host}:{port}"
process = start_service("s3", host, port)
yield url
stop_process(process)

if root is None:
root = uuid.uuid4().hex
with moto_s3_server() as s3_server, aws_credentials(), mock_s3():
url = f"s3://{root}?region=us-west-2&endpoint_override={s3_server}"
yield url
from moto.server import ThreadedMotoServer

root = root or uuid.uuid4().hex
s3_server = f"http://localhost:{port}"
server = ThreadedMotoServer(port=port)
server.start()
url = f"s3://{root}?region={region}&endpoint_override={s3_server}"
yield url
server.stop()
else:
raise ValueError(f"Unknown storage type: {storage_type}")
raise NotImplementedError(f"Unknown storage type: {storage_type}")


def job_hook(**kwargs):
Expand Down
61 changes: 61 additions & 0 deletions python/ray/air/_internal/uri_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from pathlib import Path
import urllib.parse
import os
from typing import Union


class URI:
"""Represents a URI, supporting path appending and retrieving parent URIs.
Example Usage:
>>> s3_uri = URI("s3://bucket/a?scheme=http&endpoint_override=localhost%3A900")
>>> s3_uri
URI<s3://bucket/a?scheme=http&endpoint_override=localhost%3A900>
>>> str(s3_uri / "b" / "c")
's3://bucket/a/b/c?scheme=http&endpoint_override=localhost%3A900'
>>> str(s3_uri.parent)
's3://bucket?scheme=http&endpoint_override=localhost%3A900'
>>> str(s3_uri)
's3://bucket/a?scheme=http&endpoint_override=localhost%3A900'
>>> s3_uri.parent.name, s3_uri.name
('bucket', 'a')
Args:
uri: The URI to represent.
Ex: s3://bucket?scheme=http&endpoint_override=localhost%3A900
Ex: file:///a/b/c/d
"""

def __init__(self, uri: str):
self._parsed = urllib.parse.urlparse(uri)
if not self._parsed.scheme:
raise ValueError(f"Invalid URI: {uri}")
self._path = Path(os.path.normpath(self._parsed.netloc + self._parsed.path))

@property
def name(self) -> str:
return self._path.name

@property
def parent(self) -> "URI":
assert self._path.parent != ".", f"{str(self)} has no valid parent URI"
return URI(self._get_str_representation(self._parsed, self._path.parent))

def __truediv__(self, path_to_append):
assert isinstance(path_to_append, str)
return URI(
self._get_str_representation(self._parsed, self._path / path_to_append)
)

@classmethod
def _get_str_representation(
cls, parsed_uri: urllib.parse.ParseResult, path: Union[str, Path]
) -> str:
return parsed_uri._replace(netloc=str(path), path="").geturl()

def __repr__(self):
return f"URI<{str(self)}>"

def __str__(self):
return self._get_str_representation(self._parsed, self._path)
11 changes: 3 additions & 8 deletions python/ray/tune/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)

from ray.air import CheckpointConfig
from ray.air._internal.uri_utils import URI
from ray.tune.error import TuneError
from ray.tune.registry import register_trainable, is_function_trainable
from ray.tune.result import DEFAULT_RESULTS_DIR
Expand Down Expand Up @@ -443,14 +444,8 @@ def checkpoint_dir(self):
def remote_checkpoint_dir(self) -> Optional[str]:
if not self.sync_config.upload_dir or not self.dir_name:
return None

# NOTE: `upload_dir` can contain query strings. For example:
# 's3://bucket?scheme=http&endpoint_override=localhost%3A9000'.
if "?" in self.sync_config.upload_dir:
path, query = self.sync_config.upload_dir.split("?")
return os.path.join(path, self.dir_name) + "?" + query

return os.path.join(self.sync_config.upload_dir, self.dir_name)
uri = URI(self.sync_config.upload_dir)
return str(uri / self.dir_name)

@property
def run_identifier(self):
Expand Down
8 changes: 4 additions & 4 deletions python/ray/tune/experiment/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import ray
from ray.air import CheckpointConfig
from ray.air._internal.uri_utils import URI
from ray.air._internal.checkpoint_manager import _TrackedCheckpoint, CheckpointStorage
import ray.cloudpickle as cloudpickle
from ray.exceptions import RayActorError, RayTaskError
Expand Down Expand Up @@ -601,17 +602,16 @@ def generate_id(cls):
return str(uuid.uuid4().hex)[:8]

@property
def remote_checkpoint_dir(self):
def remote_checkpoint_dir(self) -> str:
"""This is the **per trial** remote checkpoint dir.
This is different from **per experiment** remote checkpoint dir.
"""
assert self.logdir, "Trial {}: logdir not initialized.".format(self)
if not self.sync_config.upload_dir or not self.experiment_dir_name:
return None
return os.path.join(
self.sync_config.upload_dir, self.experiment_dir_name, self.relative_logdir
)
uri = URI(self.sync_config.upload_dir)
return str(uri / self.experiment_dir_name / self.relative_logdir)

@property
def uses_cloud_checkpointing(self):
Expand Down
13 changes: 4 additions & 9 deletions python/ray/tune/impl/tuner_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import tempfile
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Type, Union, TYPE_CHECKING, Tuple
import urllib.parse

import ray
import ray.cloudpickle as pickle
from ray.util import inspect_serializability
from ray.air._internal.uri_utils import URI
from ray.air._internal.remote_storage import download_from_uri, is_non_local_path_uri
from ray.air.config import RunConfig, ScalingConfig
from ray.tune import Experiment, TuneError, ExperimentAnalysis
Expand Down Expand Up @@ -359,14 +359,9 @@ def _restore_from_path_or_uri(
self._run_config.name = experiment_path.name
else:
# Set the experiment `name` and `upload_dir` according to the URI
parsed_uri = urllib.parse.urlparse(path_or_uri)
remote_path = Path(os.path.normpath(parsed_uri.netloc + parsed_uri.path))
upload_dir = parsed_uri._replace(
netloc="", path=str(remote_path.parent)
).geturl()

self._run_config.name = remote_path.name
self._run_config.sync_config.upload_dir = upload_dir
uri = URI(path_or_uri)
self._run_config.name = uri.name
self._run_config.sync_config.upload_dir = str(uri.parent)

# If we synced, `experiment_checkpoint_dir` will contain a temporary
# directory. Create an experiment checkpoint dir instead and move
Expand Down
17 changes: 14 additions & 3 deletions python/ray/tune/tests/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,30 @@
import ray
from ray.air import CheckpointConfig
from ray.tune import register_trainable, SyncConfig
from ray.tune.experiment import Experiment, _convert_to_experiment_list
from ray.tune.experiment import Experiment, Trial, _convert_to_experiment_list
from ray.tune.error import TuneError
from ray.tune.utils import diagnose_serialization


def test_remote_checkpoint_dir_with_query_string():
def test_remote_checkpoint_dir_with_query_string(tmp_path):
sync_config = SyncConfig(syncer="auto", upload_dir="s3://bucket?scheme=http")
experiment = Experiment(
name="spam",
run=lambda config: config,
sync_config=SyncConfig(syncer="auto", upload_dir="s3://bucket?scheme=http"),
sync_config=sync_config,
)
assert experiment.remote_checkpoint_dir == "s3://bucket/spam?scheme=http"

trial = Trial(
"mock",
stub=True,
sync_config=sync_config,
experiment_dir_name="spam",
local_dir=str(tmp_path),
)
trial.relative_logdir = "trial_dirname"
assert trial.remote_checkpoint_dir == "s3://bucket/spam/trial_dirname?scheme=http"


class ExperimentTest(unittest.TestCase):
def tearDown(self):
Expand Down
Loading

0 comments on commit 9931d00

Please sign in to comment.