Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLFlowObjectStore] [2/2] Support checkpointing with MLFlow #2810

Merged
merged 19 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions composer/callbacks/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
from typing import Callable, List, Optional, Union

from composer.core import Callback, Event, State, Time
from composer.loggers import Logger
from composer.loggers import Logger, MLFlowLogger
from composer.utils import (FORMAT_NAME_WITH_DIST_AND_TIME_TABLE, FORMAT_NAME_WITH_DIST_TABLE, PartialFilePath,
checkpoint, create_interval_scheduler, create_symlink_file, dist,
ensure_folder_has_no_conflicting_files, format_name_with_dist,
format_name_with_dist_and_time, is_model_deepspeed, reproducibility, using_torch_2)
format_name_with_dist_and_time, is_model_deepspeed, partial_format, reproducibility,
using_torch_2)
from composer.utils.checkpoint import _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME
from composer.utils.object_store.mlflow_object_store import MLFLOW_EXPERIMENT_ID_FORMAT_KEY, MLFLOW_RUN_ID_FORMAT_KEY

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -270,6 +272,30 @@ def __init__(
self.start_batch = None

def init(self, state: State, logger: Logger) -> None:
# If MLFlowLogger is being used, format MLFlow-specific placeholders in the save folder and paths.
# Assumes that MLFlowLogger comes before CheckpointSaver in the list of loggers.
for destination in logger.destinations:
if isinstance(destination, MLFlowLogger):
mlflow_format_kwargs = {
MLFLOW_EXPERIMENT_ID_FORMAT_KEY: destination._experiment_id,
MLFLOW_RUN_ID_FORMAT_KEY: destination._run_id
}
self.folder = partial_format(self.folder, **mlflow_format_kwargs)

self.filename.folder = self.folder
if self.latest_filename is not None:
self.latest_filename.folder = self.folder

# The remote paths have the placeholders in their filename rather than folder
if self.remote_file_name is not None:
self.remote_file_name.filename = partial_format(self.remote_file_name.filename,
**mlflow_format_kwargs)
if self.latest_remote_file_name is not None:
self.latest_remote_file_name.filename = partial_format(self.latest_remote_file_name.filename,
**mlflow_format_kwargs)

break

folder = format_name_with_dist(self.folder, state.run_name)
os.makedirs(folder, exist_ok=True)

Expand Down
32 changes: 27 additions & 5 deletions composer/loggers/mlflow_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ def __init__(
self._rank_zero_only = rank_zero_only
self._last_flush_time = time.time()
self._flush_interval = flush_interval

self._experiment_id = None
self._run_id = None

if self._enabled:
self.tracking_uri = str(tracking_uri or mlflow.get_tracking_uri())
mlflow.set_tracking_uri(self.tracking_uri)
Expand Down Expand Up @@ -123,6 +127,10 @@ def init(self, state: State, logger: Logger) -> None:
if self.run_name is None:
self.run_name = state.run_name

# Store the Composer run name in the MLFlow run tags so it can be retrieved for autoresume.
self.tags = self.tags or {}
self.tags['composer_run_name'] = state.run_name

# Adjust name and group based on `rank_zero_only`.
if not self._rank_zero_only:
self.run_name += f'-rank{dist.get_global_rank()}'
Expand All @@ -133,17 +141,31 @@ def init(self, state: State, logger: Logger) -> None:
if env_run_id is not None:
self._run_id = env_run_id
else:
new_run = self._mlflow_client.create_run(
experiment_id=self._experiment_id,
run_name=self.run_name,
)
self._run_id = new_run.info.run_id
# Search for an existing run tagged with this Composer run.
existing_runs = mlflow.search_runs(experiment_ids=[self._experiment_id],
filter_string=f'tags.composer_run_name = "{state.run_name}"',
output_format='list')
if len(existing_runs) > 0:
self._run_id = existing_runs[0].info.run_id
else:
new_run = self._mlflow_client.create_run(
experiment_id=self._experiment_id,
run_name=self.run_name,
)
self._run_id = new_run.info.run_id
mlflow.start_run(
run_id=self._run_id,
tags=self.tags,
log_system_metrics=self.log_system_metrics,
)

# If rank zero only, broadcast the MLFlow experiment and run IDs to other ranks, so the MLFlow run info is
# available to other ranks during runtime.
if self._rank_zero_only:
mlflow_ids_list = [self._experiment_id, self._run_id]
dist.broadcast_object_list(mlflow_ids_list, src=0)
self._experiment_id, self._run_id = mlflow_ids_list

def log_table(self, columns: List[str], rows: List[List[Any]], name: str = 'Table') -> None:
if self._enabled:
try:
Expand Down
45 changes: 37 additions & 8 deletions composer/loggers/remote_uploader_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@

from composer.loggers.logger import Logger
from composer.loggers.logger_destination import LoggerDestination
from composer.utils import (GCSObjectStore, LibcloudObjectStore, ObjectStore, ObjectStoreTransientError, OCIObjectStore,
S3ObjectStore, SFTPObjectStore, UCObjectStore, dist, format_name_with_dist, get_file, retry)
from composer.utils import (GCSObjectStore, LibcloudObjectStore, MLFlowObjectStore, ObjectStore,
ObjectStoreTransientError, OCIObjectStore, S3ObjectStore, SFTPObjectStore, UCObjectStore,
dist, format_name_with_dist, get_file, retry)
from composer.utils.object_store.mlflow_object_store import MLFLOW_DBFS_PATH_PREFIX

if TYPE_CHECKING:
from composer.core import State
Expand All @@ -37,19 +39,32 @@


def _build_remote_backend(remote_backend_name: str, backend_kwargs: Dict[str, Any]):
remote_backend_cls = None
remote_backend_name_to_cls = {
's3': S3ObjectStore,
'oci': OCIObjectStore,
'sftp': SFTPObjectStore,
'libcloud': LibcloudObjectStore,
'gs': GCSObjectStore,
'dbfs': UCObjectStore,
}
remote_backend_cls = remote_backend_name_to_cls.get(remote_backend_name, None)
if remote_backend_cls is None:
raise ValueError(
f'The remote backend {remote_backend_name} is not supported. Please use one of ({list(remote_backend_name_to_cls.keys())})'
)

# Handle `dbfs` backend as a special case, since it can map to either :class:`.UCObjectStore`
# or :class:`.MLFlowObjectStore`.
if remote_backend_name == 'dbfs':
path = backend_kwargs['path']
if path.startswith(MLFLOW_DBFS_PATH_PREFIX):
remote_backend_cls = MLFlowObjectStore
else:
# Validate if the path conforms to the requirements for UC volume paths
UCObjectStore.validate_path(path)
remote_backend_cls = UCObjectStore
else:
remote_backend_cls = remote_backend_name_to_cls.get(remote_backend_name, None)
if remote_backend_cls is None:
supported_remote_backends = list(remote_backend_name_to_cls.keys()) + ['dbfs']
raise ValueError(
f'The remote backend {remote_backend_name} is not supported. Please use one of ({supported_remote_backends})'
)

return remote_backend_cls(**backend_kwargs)

Expand Down Expand Up @@ -319,6 +334,20 @@ def init(self, state: State, logger: Logger) -> None:
if dist.get_global_rank() == 0:
retry(ObjectStoreTransientError,
self.num_attempts)(lambda: _validate_credentials(self.remote_backend, file_name_to_test))()

# If the remote backend is an `MLFlowObjectStore`, the original path kwarg may have placeholders that can be
# updated with information generated at runtime, i.e., the MLFlow experiment and run IDs. This information
# must be propagated across all ranks before the workers are started so that all workers use the same
# MLFlow run.
if self.backend_kwargs.get('path', '').startswith(MLFLOW_DBFS_PATH_PREFIX):
if dist.get_global_rank() == 0:
assert isinstance(self.remote_backend, MLFlowObjectStore)
self.backend_kwargs['path'] = self.remote_backend.get_dbfs_path(self.backend_kwargs['path'])

path_list = [self.backend_kwargs['path']]
dist.broadcast_object_list(path_list, src=0)
self.backend_kwargs['path'] = path_list[0]

assert len(self._workers) == 0, 'workers should be empty if self._worker_flag was None'
for _ in range(self._num_concurrent_uploads):
worker = self._proc_class(
Expand Down
40 changes: 38 additions & 2 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
PyTorchScheduler, State, Time, Timestamp, TimeUnit, TrainerMode, ensure_data_spec,
ensure_evaluator, ensure_time, get_precision_context, validate_eval_automicrobatching)
from composer.devices import Device, DeviceCPU, DeviceGPU, DeviceMPS, DeviceTPU
from composer.loggers import (ConsoleLogger, Logger, LoggerDestination, MosaicMLLogger, ProgressBarLogger,
from composer.loggers import (ConsoleLogger, Logger, LoggerDestination, MLFlowLogger, MosaicMLLogger, ProgressBarLogger,
RemoteUploaderDownloader, WandBLogger)
from composer.loggers.mosaicml_logger import MOSAICML_ACCESS_TOKEN_ENV_VAR, MOSAICML_PLATFORM_ENV_VAR
from composer.models import ComposerModel
Expand All @@ -54,8 +54,9 @@
ensure_tuple, export_with_logger, extract_hparams, format_name_with_dist,
get_composer_env_dict, get_device, get_file, is_tpu_installed, map_collection,
maybe_create_object_store_from_uri, maybe_create_remote_uploader_downloader_from_uri,
model_eval_mode, parse_uri, reproducibility, using_torch_2)
model_eval_mode, parse_uri, partial_format, reproducibility, using_torch_2)
from composer.utils.misc import is_model_deepspeed
from composer.utils.object_store.mlflow_object_store import MLFLOW_EXPERIMENT_ID_FORMAT_KEY, MLFLOW_RUN_ID_FORMAT_KEY

if is_tpu_installed():
import torch_xla.core.xla_model as xm
Expand Down Expand Up @@ -1085,6 +1086,11 @@ def __init__(
mosaicml_logger = MosaicMLLogger()
loggers.append(mosaicml_logger)

# Remote Uploader Downloader
# Keep the ``RemoteUploaderDownloader`` below client-provided loggers so the loggers init callbacks run before
# the ``RemoteUploaderDownloader`` init. This is necessary to use an ``MLFlowObjectStore`` to log objects to a
# run managed by an ``MLFlowLogger``, as the ``MLFlowObjectStore`` relies on the ``MLFlowLogger`` to initialize
# the active MLFlow run.
jerrychen109 marked this conversation as resolved.
Show resolved Hide resolved
if save_folder is not None:
remote_ud = maybe_create_remote_uploader_downloader_from_uri(save_folder, loggers)
if remote_ud is not None:
Expand Down Expand Up @@ -1158,6 +1164,36 @@ def __init__(
# Run Event.INIT
self.engine.run_event(Event.INIT)

# If the experiment is being tracked with an `MLFlowLogger`, then MLFlow experiment and run are available
# after Event.INIT.
if save_folder is not None:
mlflow_logger = None
for destination in self.logger.destinations:
if isinstance(destination, MLFlowLogger):
mlflow_logger = destination
break

if mlflow_logger is not None:
mlflow_experiment_id = mlflow_logger._experiment_id
mlflow_run_id = mlflow_logger._run_id

# Log the experiment and run IDs as hyperparameters.
self.logger.log_hyperparameters({
'mlflow_experiment_id': mlflow_experiment_id,
'mlflow_run_id': mlflow_run_id
})
jerrychen109 marked this conversation as resolved.
Show resolved Hide resolved

# The save folder and related paths/filenames may contain format placeholders for the MLFlow IDs, so
# populate them now.
mlflow_format_kwargs = {
MLFLOW_EXPERIMENT_ID_FORMAT_KEY: mlflow_experiment_id,
MLFLOW_RUN_ID_FORMAT_KEY: mlflow_run_id
}

save_folder = partial_format(save_folder, **mlflow_format_kwargs)
if latest_remote_file_name is not None:
latest_remote_file_name = partial_format(latest_remote_file_name, **mlflow_format_kwargs)

jerrychen109 marked this conversation as resolved.
Show resolved Hide resolved
# Log hparams.
if self.auto_log_hparams:
self.local_hparams = extract_hparams(locals())
Expand Down
2 changes: 2 additions & 0 deletions composer/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
UCObjectStore)
from composer.utils.retrying import retry
from composer.utils.string_enum import StringEnum
from composer.utils.string_helpers import partial_format

__all__ = [
'ensure_tuple',
Expand Down Expand Up @@ -90,4 +91,5 @@
'LambdaEvalClient',
'LocalEvalClient',
'MosaicMLLambdaEvalClient',
'partial_format',
]
44 changes: 34 additions & 10 deletions composer/utils/file_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@

from composer.utils import dist
from composer.utils.iter_helpers import iterate_with_callback
from composer.utils.object_store import GCSObjectStore, ObjectStore, OCIObjectStore, S3ObjectStore, UCObjectStore
from composer.utils.object_store import (GCSObjectStore, MLFlowObjectStore, ObjectStore, OCIObjectStore, S3ObjectStore,
UCObjectStore)
from composer.utils.object_store.mlflow_object_store import MLFLOW_DBFS_PATH_PREFIX
from composer.utils.string_helpers import partial_format

if TYPE_CHECKING:
from composer.core import Timestamp
Expand Down Expand Up @@ -166,7 +169,8 @@ def ensure_folder_has_no_conflicting_files(folder_name: Union[str, pathlib.Path]


def format_name_with_dist(format_str: str, run_name: str, **extra_format_kwargs: object): # noqa: D103
formatted_str = format_str.format(
formatted_str = partial_format(
jerrychen109 marked this conversation as resolved.
Show resolved Hide resolved
format_str,
run_name=run_name,
**_get_dist_config(strict=False),
**extra_format_kwargs,
Expand Down Expand Up @@ -259,7 +263,8 @@ def format_name_with_dist_and_time(
timestamp: Timestamp,
**extra_format_kwargs: object,
): # noqa: D103
formatted_str = format_str.format(
formatted_str = partial_format(
format_str,
run_name=run_name,
epoch=int(timestamp.epoch),
batch=int(timestamp.batch),
Expand Down Expand Up @@ -350,9 +355,28 @@ def maybe_create_object_store_from_uri(uri: str) -> Optional[ObjectStore]:
elif backend == 'oci':
return OCIObjectStore(bucket=bucket_name)
elif backend == 'dbfs':
# validate if the path conforms to the requirements for UC volume paths
UCObjectStore.validate_path(path)
return UCObjectStore(path=path)
if path.startswith(MLFLOW_DBFS_PATH_PREFIX):
store = None
if dist.get_global_rank() == 0:
store = MLFlowObjectStore(path)

# The path may have had placeholders, so update it with the experiment/run IDs initialized by the store
path = store.get_dbfs_path(path)

# Broadcast the rank 0 updated path to all ranks for their own object stores
path_list = [path]
dist.broadcast_object_list(path_list, src=0)
path = path_list[0]

# Create the object store for all other ranks
if dist.get_global_rank() != 0:
store = MLFlowObjectStore(path)

return store
else:
# validate if the path conforms to the requirements for UC volume paths
UCObjectStore.validate_path(path)
return UCObjectStore(path=path)
else:
raise NotImplementedError(f'There is no implementation for the cloud backend {backend} via URI. Please use '
'one of the supported object stores')
Expand Down Expand Up @@ -388,13 +412,13 @@ def maybe_create_remote_uploader_downloader_from_uri(
if backend in ['s3', 'oci', 'gs']:
return RemoteUploaderDownloader(bucket_uri=f'{backend}://{bucket_name}')

elif backend == 'dbfs':
return RemoteUploaderDownloader(bucket_uri=uri, backend_kwargs={'path': path})

elif backend == 'wandb':
raise NotImplementedError(f'There is no implementation for WandB via URI. Please use '
'WandBLogger with log_artifacts set to True')
elif backend == 'dbfs':
# validate if the path conforms to the requirements for UC volume paths
UCObjectStore.validate_path(path)
return RemoteUploaderDownloader(bucket_uri=uri, backend_kwargs={'path': path})

else:
raise NotImplementedError(f'There is no implementation for the cloud backend {backend} via URI. Please use '
'one of the supported RemoteUploaderDownloader object stores')
Expand Down
15 changes: 9 additions & 6 deletions composer/utils/object_store/mlflow_object_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@

DEFAULT_MLFLOW_EXPERIMENT_NAME = 'mlflow-object-store'

PLACEHOLDER_EXPERIMENT_ID = '{mlflow_experiment_id}'
PLACEHOLDER_RUN_ID = '{mlflow_run_id}'
MLFLOW_EXPERIMENT_ID_FORMAT_KEY = 'mlflow_experiment_id'
MLFLOW_RUN_ID_FORMAT_KEY = 'mlflow_run_id'

MLFLOW_EXPERIMENT_ID_PLACEHOLDER = '{' + MLFLOW_EXPERIMENT_ID_FORMAT_KEY + '}'
MLFLOW_RUN_ID_PLACEHOLDER = '{' + MLFLOW_RUN_ID_FORMAT_KEY + '}'

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -132,9 +135,9 @@ def __init__(self, path: str, multipart_upload_chunk_size: int = 100 * 1024 * 10
mlflow.environment_variables.MLFLOW_MULTIPART_UPLOAD_CHUNK_SIZE.set(multipart_upload_chunk_size)

experiment_id, run_id, _ = MLFlowObjectStore.parse_dbfs_path(path)
if experiment_id == PLACEHOLDER_EXPERIMENT_ID:
if experiment_id == MLFLOW_EXPERIMENT_ID_PLACEHOLDER:
experiment_id = None
if run_id == PLACEHOLDER_RUN_ID:
if run_id == MLFLOW_RUN_ID_PLACEHOLDER:
run_id = None

# Construct the `experiment_id` and `run_id` depending on whether format placeholders were provided.
Expand Down Expand Up @@ -236,10 +239,10 @@ def get_artifact_path(self, object_name: str) -> str:
"""
if object_name.startswith(MLFLOW_DBFS_PATH_PREFIX):
experiment_id, run_id, object_name = self.parse_dbfs_path(object_name)
if (experiment_id != self.experiment_id and experiment_id != PLACEHOLDER_EXPERIMENT_ID):
if (experiment_id != self.experiment_id and experiment_id != MLFLOW_EXPERIMENT_ID_PLACEHOLDER):
raise ValueError(f'Object {object_name} belongs to experiment ID {experiment_id}, '
f'but MLFlowObjectStore is associated with experiment ID {self.experiment_id}.')
if (run_id != self.run_id and run_id != PLACEHOLDER_RUN_ID):
if (run_id != self.run_id and run_id != MLFLOW_RUN_ID_PLACEHOLDER):
raise ValueError(f'Object {object_name} belongs to run ID {run_id}, '
f'but MLFlowObjectStore is associated with run ID {self.run_id}.')
return object_name
Expand Down
Loading