Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLFlowObjectStore] [2/2] Support checkpointing with MLFlow #2810

Merged
merged 19 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions composer/callbacks/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
from typing import Any, Callable, Dict, List, Optional, Union

from composer.core import Callback, Event, State, Time, Timestamp
from composer.loggers import Logger
from composer.loggers import Logger, MLFlowLogger
from composer.utils import (FORMAT_NAME_WITH_DIST_AND_TIME_TABLE, FORMAT_NAME_WITH_DIST_TABLE, PartialFilePath,
checkpoint, create_interval_scheduler, create_symlink_file, dist,
ensure_folder_has_no_conflicting_files, format_name_with_dist,
format_name_with_dist_and_time, is_model_deepspeed, using_torch_2)
format_name_with_dist_and_time, is_model_deepspeed, partial_format, using_torch_2)
from composer.utils.object_store.mlflow_object_store import MLFLOW_EXPERIMENT_ID_FORMAT_KEY, MLFLOW_RUN_ID_FORMAT_KEY

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -270,6 +271,30 @@ def __init__(
self.start_batch = None

def init(self, state: State, logger: Logger) -> None:
# If MLFlowLogger is being used, format MLFlow-specific placeholders in the save folder and paths.
# Assumes that MLFlowLogger comes before CheckpointSaver in the list of loggers.
for destination in logger.destinations:
if isinstance(destination, MLFlowLogger):
mlflow_format_kwargs = {
MLFLOW_EXPERIMENT_ID_FORMAT_KEY: destination._experiment_id,
MLFLOW_RUN_ID_FORMAT_KEY: destination._run_id
}
self.folder = partial_format(self.folder, **mlflow_format_kwargs)

self.filename.folder = self.folder
if self.latest_filename is not None:
self.latest_filename.folder = self.folder

# The remote paths have the placeholders in their filename rather than folder
if self.remote_file_name is not None:
self.remote_file_name.filename = partial_format(self.remote_file_name.filename,
**mlflow_format_kwargs)
if self.latest_remote_file_name is not None:
self.latest_remote_file_name.filename = partial_format(self.latest_remote_file_name.filename,
**mlflow_format_kwargs)

break

folder = format_name_with_dist(self.folder, state.run_name)
os.makedirs(folder, exist_ok=True)

Expand Down
35 changes: 30 additions & 5 deletions composer/loggers/mlflow_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ def __init__(
self._rank_zero_only = rank_zero_only
self._last_flush_time = time.time()
self._flush_interval = flush_interval

self._experiment_id = None
self._run_id = None

if self._enabled:
self.tracking_uri = str(tracking_uri or mlflow.get_tracking_uri())
mlflow.set_tracking_uri(self.tracking_uri)
Expand Down Expand Up @@ -128,6 +132,10 @@ def init(self, state: State, logger: Logger) -> None:
if self.run_name is None:
self.run_name = state.run_name

# Store the Composer run name in the MLFlow run tags so it can be retrieved for autoresume.
self.tags = self.tags or {}
self.tags['composer_run_name'] = state.run_name

# Adjust name and group based on `rank_zero_only`.
if not self._rank_zero_only:
self.run_name += f'-rank{dist.get_global_rank()}'
Expand All @@ -141,17 +149,34 @@ def init(self, state: State, logger: Logger) -> None:
if env_run_id is not None:
self._run_id = env_run_id
else:
new_run = self._mlflow_client.create_run(
experiment_id=self._experiment_id,
run_name=self.run_name,
)
self._run_id = new_run.info.run_id
# Search for an existing run tagged with this Composer run.
existing_runs = mlflow.search_runs(experiment_ids=[self._experiment_id],
filter_string=f'tags.composer_run_name = "{state.run_name}"',
output_format='list')
if len(existing_runs) > 0:
self._run_id = existing_runs[0].info.run_id
else:
new_run = self._mlflow_client.create_run(
experiment_id=self._experiment_id,
run_name=self.run_name,
)
self._run_id = new_run.info.run_id
mlflow.start_run(
run_id=self._run_id,
tags=self.tags,
log_system_metrics=self.log_system_metrics,
)

# If rank zero only, broadcast the MLFlow experiment and run IDs to other ranks, so the MLFlow run info is
# available to other ranks during runtime.
if self._rank_zero_only:
mlflow_ids_list = [self._experiment_id, self._run_id]
dist.broadcast_object_list(mlflow_ids_list, src=0)
self._experiment_id, self._run_id = mlflow_ids_list

def after_load(self, state: State, logger: Logger) -> None:
logger.log_hyperparameters({'mlflow_experiment_id': self._experiment_id, 'mlflow_run_id': self._run_id})

def log_table(self, columns: List[str], rows: List[List[Any]], name: str = 'Table') -> None:
if self._enabled:
try:
Expand Down
45 changes: 37 additions & 8 deletions composer/loggers/remote_uploader_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@

from composer.loggers.logger import Logger
from composer.loggers.logger_destination import LoggerDestination
from composer.utils import (GCSObjectStore, LibcloudObjectStore, ObjectStore, ObjectStoreTransientError, OCIObjectStore,
S3ObjectStore, SFTPObjectStore, UCObjectStore, dist, format_name_with_dist, get_file, retry)
from composer.utils import (GCSObjectStore, LibcloudObjectStore, MLFlowObjectStore, ObjectStore,
ObjectStoreTransientError, OCIObjectStore, S3ObjectStore, SFTPObjectStore, UCObjectStore,
dist, format_name_with_dist, get_file, retry)
from composer.utils.object_store.mlflow_object_store import MLFLOW_DBFS_PATH_PREFIX

if TYPE_CHECKING:
from composer.core import State
Expand All @@ -37,19 +39,32 @@


def _build_remote_backend(remote_backend_name: str, backend_kwargs: Dict[str, Any]):
remote_backend_cls = None
remote_backend_name_to_cls = {
's3': S3ObjectStore,
'oci': OCIObjectStore,
'sftp': SFTPObjectStore,
'libcloud': LibcloudObjectStore,
'gs': GCSObjectStore,
'dbfs': UCObjectStore,
}
remote_backend_cls = remote_backend_name_to_cls.get(remote_backend_name, None)
if remote_backend_cls is None:
raise ValueError(
f'The remote backend {remote_backend_name} is not supported. Please use one of ({list(remote_backend_name_to_cls.keys())})'
)

# Handle `dbfs` backend as a special case, since it can map to either :class:`.UCObjectStore`
# or :class:`.MLFlowObjectStore`.
if remote_backend_name == 'dbfs':
path = backend_kwargs['path']
if path.startswith(MLFLOW_DBFS_PATH_PREFIX):
remote_backend_cls = MLFlowObjectStore
else:
# Validate if the path conforms to the requirements for UC volume paths
UCObjectStore.validate_path(path)
remote_backend_cls = UCObjectStore
else:
remote_backend_cls = remote_backend_name_to_cls.get(remote_backend_name, None)
if remote_backend_cls is None:
supported_remote_backends = list(remote_backend_name_to_cls.keys()) + ['dbfs']
raise ValueError(
f'The remote backend {remote_backend_name} is not supported. Please use one of ({supported_remote_backends})'
)

return remote_backend_cls(**backend_kwargs)

Expand Down Expand Up @@ -322,6 +337,20 @@ def init(self, state: State, logger: Logger) -> None:
if dist.get_global_rank() == 0:
retry(ObjectStoreTransientError,
self.num_attempts)(lambda: _validate_credentials(self.remote_backend, file_name_to_test))()

# If the remote backend is an `MLFlowObjectStore`, the original path kwarg may have placeholders that can be
# updated with information generated at runtime, i.e., the MLFlow experiment and run IDs. This information
# must be propagated across all ranks before the workers are started so that all workers use the same
# MLFlow run.
if self.backend_kwargs.get('path', '').startswith(MLFLOW_DBFS_PATH_PREFIX):
if dist.get_global_rank() == 0:
assert isinstance(self.remote_backend, MLFlowObjectStore)
self.backend_kwargs['path'] = self.remote_backend.get_dbfs_path(self.backend_kwargs['path'])

path_list = [self.backend_kwargs['path']]
dist.broadcast_object_list(path_list, src=0)
self.backend_kwargs['path'] = path_list[0]

assert len(self._workers) == 0, 'workers should be empty if self._worker_flag was None'
for _ in range(self._num_concurrent_uploads):
worker = self._proc_class(
Expand Down
34 changes: 32 additions & 2 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
PyTorchScheduler, State, Time, Timestamp, TimeUnit, TrainerMode, ensure_data_spec,
ensure_evaluator, ensure_time, get_precision_context, validate_eval_automicrobatching)
from composer.devices import Device, DeviceCPU, DeviceGPU, DeviceMPS, DeviceTPU
from composer.loggers import (ConsoleLogger, Logger, LoggerDestination, MosaicMLLogger, ProgressBarLogger,
from composer.loggers import (ConsoleLogger, Logger, LoggerDestination, MLFlowLogger, MosaicMLLogger, ProgressBarLogger,
RemoteUploaderDownloader, WandBLogger)
from composer.loggers.mosaicml_logger import MOSAICML_ACCESS_TOKEN_ENV_VAR, MOSAICML_PLATFORM_ENV_VAR
from composer.models import ComposerModel
Expand All @@ -54,8 +54,9 @@
ensure_tuple, export_with_logger, extract_hparams, format_name_with_dist,
get_composer_env_dict, get_device, get_file, is_tpu_installed, map_collection,
maybe_create_object_store_from_uri, maybe_create_remote_uploader_downloader_from_uri,
model_eval_mode, parse_uri, reproducibility, using_torch_2)
model_eval_mode, parse_uri, partial_format, reproducibility, using_torch_2)
from composer.utils.misc import is_model_deepspeed
from composer.utils.object_store.mlflow_object_store import MLFLOW_EXPERIMENT_ID_FORMAT_KEY, MLFLOW_RUN_ID_FORMAT_KEY

if is_tpu_installed():
import torch_xla.core.xla_model as xm
Expand Down Expand Up @@ -1085,6 +1086,11 @@ def __init__(
mosaicml_logger = MosaicMLLogger()
loggers.append(mosaicml_logger)

# Remote Uploader Downloader
# Keep the ``RemoteUploaderDownloader`` below client-provided loggers so the loggers init callbacks run before
# the ``RemoteUploaderDownloader`` init. This is necessary to use an ``MLFlowObjectStore`` to log objects to a
# run managed by an ``MLFlowLogger``, as the ``MLFlowObjectStore`` relies on the ``MLFlowLogger`` to initialize
# the active MLFlow run.
jerrychen109 marked this conversation as resolved.
Show resolved Hide resolved
if save_folder is not None:
remote_ud = maybe_create_remote_uploader_downloader_from_uri(save_folder, loggers)
if remote_ud is not None:
Expand Down Expand Up @@ -1158,6 +1164,30 @@ def __init__(
# Run Event.INIT
self.engine.run_event(Event.INIT)

# If the experiment is being tracked with an `MLFlowLogger`, then MLFlow experiment and run are available
# after Event.INIT.
if save_folder is not None:
mlflow_logger = None
for destination in self.logger.destinations:
if isinstance(destination, MLFlowLogger):
mlflow_logger = destination
break

if mlflow_logger is not None:
mlflow_experiment_id = mlflow_logger._experiment_id
mlflow_run_id = mlflow_logger._run_id

# The save folder and related paths/filenames may contain format placeholders for the MLFlow IDs, so
# populate them now.
mlflow_format_kwargs = {
MLFLOW_EXPERIMENT_ID_FORMAT_KEY: mlflow_experiment_id,
MLFLOW_RUN_ID_FORMAT_KEY: mlflow_run_id
}

save_folder = partial_format(save_folder, **mlflow_format_kwargs)
if latest_remote_file_name is not None:
latest_remote_file_name = partial_format(latest_remote_file_name, **mlflow_format_kwargs)

jerrychen109 marked this conversation as resolved.
Show resolved Hide resolved
# Log hparams.
if self.auto_log_hparams:
self.local_hparams = extract_hparams(locals())
Expand Down
3 changes: 2 additions & 1 deletion composer/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from composer.utils.inference import ExportFormat, Transform, export_for_inference, export_with_logger, quantize_dynamic
from composer.utils.iter_helpers import IteratorFileStream, ensure_tuple, map_collection
from composer.utils.misc import (create_interval_scheduler, get_free_tcp_port, is_model_deepspeed, is_model_fsdp,
is_notebook, model_eval_mode, using_torch_2)
is_notebook, model_eval_mode, partial_format, using_torch_2)
from composer.utils.object_store import (GCSObjectStore, LibcloudObjectStore, MLFlowObjectStore, ObjectStore,
ObjectStoreTransientError, OCIObjectStore, S3ObjectStore, SFTPObjectStore,
UCObjectStore)
Expand Down Expand Up @@ -92,4 +92,5 @@
'LambdaEvalClient',
'LocalEvalClient',
'MosaicMLLambdaEvalClient',
'partial_format',
]
44 changes: 34 additions & 10 deletions composer/utils/file_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@

from composer.utils import dist
from composer.utils.iter_helpers import iterate_with_callback
from composer.utils.object_store import GCSObjectStore, ObjectStore, OCIObjectStore, S3ObjectStore, UCObjectStore
from composer.utils.misc import partial_format
from composer.utils.object_store import (GCSObjectStore, MLFlowObjectStore, ObjectStore, OCIObjectStore, S3ObjectStore,
UCObjectStore)
from composer.utils.object_store.mlflow_object_store import MLFLOW_DBFS_PATH_PREFIX

if TYPE_CHECKING:
from composer.core import Timestamp
Expand Down Expand Up @@ -166,7 +169,8 @@ def ensure_folder_has_no_conflicting_files(folder_name: Union[str, pathlib.Path]


def format_name_with_dist(format_str: str, run_name: str, **extra_format_kwargs: object): # noqa: D103
formatted_str = format_str.format(
formatted_str = partial_format(
jerrychen109 marked this conversation as resolved.
Show resolved Hide resolved
format_str,
run_name=run_name,
**_get_dist_config(strict=False),
**extra_format_kwargs,
Expand Down Expand Up @@ -259,7 +263,8 @@ def format_name_with_dist_and_time(
timestamp: Timestamp,
**extra_format_kwargs: object,
): # noqa: D103
formatted_str = format_str.format(
formatted_str = partial_format(
format_str,
run_name=run_name,
epoch=int(timestamp.epoch),
batch=int(timestamp.batch),
Expand Down Expand Up @@ -350,9 +355,28 @@ def maybe_create_object_store_from_uri(uri: str) -> Optional[ObjectStore]:
elif backend == 'oci':
return OCIObjectStore(bucket=bucket_name)
elif backend == 'dbfs':
# validate if the path conforms to the requirements for UC volume paths
UCObjectStore.validate_path(path)
return UCObjectStore(path=path)
if path.startswith(MLFLOW_DBFS_PATH_PREFIX):
store = None
if dist.get_global_rank() == 0:
store = MLFlowObjectStore(path)

# The path may have had placeholders, so update it with the experiment/run IDs initialized by the store
path = store.get_dbfs_path(path)

# Broadcast the rank 0 updated path to all ranks for their own object stores
path_list = [path]
dist.broadcast_object_list(path_list, src=0)
path = path_list[0]

# Create the object store for all other ranks
if dist.get_global_rank() != 0:
store = MLFlowObjectStore(path)

return store
else:
# validate if the path conforms to the requirements for UC volume paths
UCObjectStore.validate_path(path)
return UCObjectStore(path=path)
else:
raise NotImplementedError(f'There is no implementation for the cloud backend {backend} via URI. Please use '
'one of the supported object stores')
Expand Down Expand Up @@ -388,13 +412,13 @@ def maybe_create_remote_uploader_downloader_from_uri(
if backend in ['s3', 'oci', 'gs']:
return RemoteUploaderDownloader(bucket_uri=f'{backend}://{bucket_name}')

elif backend == 'dbfs':
return RemoteUploaderDownloader(bucket_uri=uri, backend_kwargs={'path': path})

elif backend == 'wandb':
raise NotImplementedError(f'There is no implementation for WandB via URI. Please use '
'WandBLogger with log_artifacts set to True')
elif backend == 'dbfs':
# validate if the path conforms to the requirements for UC volume paths
UCObjectStore.validate_path(path)
return RemoteUploaderDownloader(bucket_uri=uri, backend_kwargs={'path': path})

else:
raise NotImplementedError(f'There is no implementation for the cloud backend {backend} via URI. Please use '
'one of the supported RemoteUploaderDownloader object stores')
Expand Down
20 changes: 20 additions & 0 deletions composer/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,23 @@ def using_torch_2_0_1() -> bool:
bool: Return True if current version is greater than or equal to 2.0.1 else False
"""
return version.parse(torch.__version__) >= version.parse('2.0.1')


def partial_format(s, *args, **kwargs) -> str:
"""Format a string with a partial set of arguments.

Since `str.format()` raises a `KeyError` if a format key is missing from the arguments, this
function allows for a partial set of arguments to be provided. Any missing arguments will be
left as-is in the string.
"""
max_iters = 10_000 # Just in case we get stuck in a loop somehow.
for _ in range(max_iters):
try:
return s.format(*args, **kwargs)
except IndexError as e: # Missing positional arg
args += ('{}',)
except KeyError as e: # Missing keyword arg
key = e.args[0]
kwargs[key] = '{' + key + '}'

raise RuntimeError(f'Failed to format string {s} after {max_iters} iterations.')
Loading
Loading