From e3b2734eec8466fb8713414d864bbb3089408a08 Mon Sep 17 00:00:00 2001 From: Jerry Chen Date: Sat, 30 Dec 2023 16:04:01 -0800 Subject: [PATCH 01/17] Support checkpoint uploads to MLFlow (untested) Use MLFlow run tag for autoresume Add MLFlowLogger test for existing composer run tag --- composer/loggers/mlflow_logger.py | 21 ++++++--- .../loggers/remote_uploader_downloader.py | 44 +++++++++++++++---- composer/trainer/trainer.py | 5 +++ composer/utils/file_helpers.py | 21 +++++---- tests/loggers/test_mlflow_logger.py | 24 ++++++++++ 5 files changed, 94 insertions(+), 21 deletions(-) diff --git a/composer/loggers/mlflow_logger.py b/composer/loggers/mlflow_logger.py index 430d99b5ca..48daa7d1cb 100644 --- a/composer/loggers/mlflow_logger.py +++ b/composer/loggers/mlflow_logger.py @@ -123,6 +123,10 @@ def init(self, state: State, logger: Logger) -> None: if self.run_name is None: self.run_name = state.run_name + # Store the Composer run name in the MLFlow run tags so it can be retrieved for autoresume. + self.tags = self.tags or {} + self.tags['composer_run_name'] = state.run_name + # Adjust name and group based on `rank_zero_only`. if not self._rank_zero_only: self.run_name += f'-rank{dist.get_global_rank()}' @@ -133,11 +137,18 @@ def init(self, state: State, logger: Logger) -> None: if env_run_id is not None: self._run_id = env_run_id else: - new_run = self._mlflow_client.create_run( - experiment_id=self._experiment_id, - run_name=self.run_name, - ) - self._run_id = new_run.info.run_id + # Search for an existing run tagged with this Composer run. + existing_runs = mlflow.search_runs(experiment_ids=[self._experiment_id], + filter_string=f'tags.composer_run_name = "{state.run_name}"', + output_format='list') + if len(existing_runs) > 0: + self._run_id = existing_runs[0].info.run_id + else: + new_run = self._mlflow_client.create_run( + experiment_id=self._experiment_id, + run_name=self.run_name, + ) + self._run_id = new_run.info.run_id mlflow.start_run( run_id=self._run_id, tags=self.tags, diff --git a/composer/loggers/remote_uploader_downloader.py b/composer/loggers/remote_uploader_downloader.py index a3f9698483..106cc7a7f7 100644 --- a/composer/loggers/remote_uploader_downloader.py +++ b/composer/loggers/remote_uploader_downloader.py @@ -24,8 +24,10 @@ from composer.loggers.logger import Logger from composer.loggers.logger_destination import LoggerDestination -from composer.utils import (GCSObjectStore, LibcloudObjectStore, ObjectStore, ObjectStoreTransientError, OCIObjectStore, - S3ObjectStore, SFTPObjectStore, UCObjectStore, dist, format_name_with_dist, get_file, retry) +from composer.utils import (GCSObjectStore, LibcloudObjectStore, MLFlowObjectStore, ObjectStore, + ObjectStoreTransientError, OCIObjectStore, S3ObjectStore, SFTPObjectStore, UCObjectStore, + dist, format_name_with_dist, get_file, retry) +from composer.utils.object_store.mlflow_object_store import MLFLOW_DBFS_PATH_PREFIX if TYPE_CHECKING: from composer.core import State @@ -37,19 +39,32 @@ def _build_remote_backend(remote_backend_name: str, backend_kwargs: Dict[str, Any]): + remote_backend_cls = None remote_backend_name_to_cls = { 's3': S3ObjectStore, 'oci': OCIObjectStore, 'sftp': SFTPObjectStore, 'libcloud': LibcloudObjectStore, 'gs': GCSObjectStore, - 'dbfs': UCObjectStore, } - remote_backend_cls = remote_backend_name_to_cls.get(remote_backend_name, None) - if remote_backend_cls is None: - raise ValueError( - f'The remote backend {remote_backend_name} is not supported. Please use one of ({list(remote_backend_name_to_cls.keys())})' - ) + + # Handle `dbfs` backend as a special case, since it can map to either :class:`.UCObjectStore` + # or :class:`.MLFlowObjectStore`. + if remote_backend_name == 'dbfs': + path = backend_kwargs['path'] + if path.startswith(MLFLOW_DBFS_PATH_PREFIX): + remote_backend_cls = MLFlowObjectStore + else: + # Validate if the path conforms to the requirements for UC volume paths + UCObjectStore.validate_path(path) + remote_backend_cls = UCObjectStore + else: + remote_backend_cls = remote_backend_name_to_cls.get(remote_backend_name, None) + if remote_backend_cls is None: + supported_remote_backends = list(remote_backend_name_to_cls.keys()) + ['dbfs'] + raise ValueError( + f'The remote backend {remote_backend_name} is not supported. Please use one of ({supported_remote_backends})' + ) return remote_backend_cls(**backend_kwargs) @@ -319,6 +334,19 @@ def init(self, state: State, logger: Logger) -> None: if dist.get_global_rank() == 0: retry(ObjectStoreTransientError, self.num_attempts)(lambda: _validate_credentials(self.remote_backend, file_name_to_test))() + + # If the remote backend is an `MLFlowObjectStore`, the original path kwarg may have placeholders that can be + # updated with information generated at runtime, i.e., the MLFlow experiment and run IDs. This information + # must be propagated across all ranks before the workers are started so that all workers use the same + # MLFlow run. + if isinstance(self.remote_backend, MLFlowObjectStore): + if dist.get_global_rank() == 0: + self.backend_kwargs['path'] = self.remote_backend.get_dbfs_path(self.backend_kwargs['path']) + + path_list = [self.backend_kwargs['path']] + dist.broadcast_object_list(path_list, src=0) + self.backend_kwargs['path'] = path_list[0] + assert len(self._workers) == 0, 'workers should be empty if self._worker_flag was None' for _ in range(self._num_concurrent_uploads): worker = self._proc_class( diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index c8c6d325e0..fe0fa99e15 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -1085,6 +1085,11 @@ def __init__( mosaicml_logger = MosaicMLLogger() loggers.append(mosaicml_logger) + # Remote Uploader Downloader + # Keep the ``RemoteUploaderDownloader`` below client-provided loggers so the loggers init callbacks run before + # the ``RemoteUploaderDownloader`` init. This is necessary to use an ``MLFlowObjectStore`` to log objects to a + # run managed by an ``MLFlowLogger``, as the ``MLFlowObjectStore`` relies on the ``MLFlowLogger`` to initialize + # the active MLFlow run. if save_folder is not None: remote_ud = maybe_create_remote_uploader_downloader_from_uri(save_folder, loggers) if remote_ud is not None: diff --git a/composer/utils/file_helpers.py b/composer/utils/file_helpers.py index a3d421259b..d94a31f1db 100644 --- a/composer/utils/file_helpers.py +++ b/composer/utils/file_helpers.py @@ -20,7 +20,9 @@ from composer.utils import dist from composer.utils.iter_helpers import iterate_with_callback -from composer.utils.object_store import GCSObjectStore, ObjectStore, OCIObjectStore, S3ObjectStore, UCObjectStore +from composer.utils.object_store import (GCSObjectStore, MLFlowObjectStore, ObjectStore, OCIObjectStore, S3ObjectStore, + UCObjectStore) +from composer.utils.object_store.mlflow_object_store import MLFLOW_DBFS_PATH_PREFIX if TYPE_CHECKING: from composer.core import Timestamp @@ -350,9 +352,12 @@ def maybe_create_object_store_from_uri(uri: str) -> Optional[ObjectStore]: elif backend == 'oci': return OCIObjectStore(bucket=bucket_name) elif backend == 'dbfs': - # validate if the path conforms to the requirements for UC volume paths - UCObjectStore.validate_path(path) - return UCObjectStore(path=path) + if path.startswith(MLFLOW_DBFS_PATH_PREFIX): + return MLFlowObjectStore(path) + else: + # validate if the path conforms to the requirements for UC volume paths + UCObjectStore.validate_path(path) + return UCObjectStore(path=path) else: raise NotImplementedError(f'There is no implementation for the cloud backend {backend} via URI. Please use ' 'one of the supported object stores') @@ -388,13 +393,13 @@ def maybe_create_remote_uploader_downloader_from_uri( if backend in ['s3', 'oci', 'gs']: return RemoteUploaderDownloader(bucket_uri=f'{backend}://{bucket_name}') + elif backend == 'dbfs': + return RemoteUploaderDownloader(bucket_uri=uri, backend_kwargs={'path': path}) + elif backend == 'wandb': raise NotImplementedError(f'There is no implementation for WandB via URI. Please use ' 'WandBLogger with log_artifacts set to True') - elif backend == 'dbfs': - # validate if the path conforms to the requirements for UC volume paths - UCObjectStore.validate_path(path) - return RemoteUploaderDownloader(bucket_uri=uri, backend_kwargs={'path': path}) + else: raise NotImplementedError(f'There is no implementation for the cloud backend {backend} via URI. Please use ' 'one of the supported RemoteUploaderDownloader object stores') diff --git a/tests/loggers/test_mlflow_logger.py b/tests/loggers/test_mlflow_logger.py index 5ff0a2fa3c..12545fee78 100644 --- a/tests/loggers/test_mlflow_logger.py +++ b/tests/loggers/test_mlflow_logger.py @@ -164,6 +164,26 @@ def test_mlflow_experiment_init_experiment_name(monkeypatch): id_logger.post_close() +def test_mlflow_experiment_init_existing_composer_run(monkeypatch): + """ Test that an existing MLFlow run is used if one already exists in the experiment for the Composer run. + """ + mlflow = pytest.importorskip('mlflow') + + monkeypatch.setattr(mlflow, 'set_tracking_uri', MagicMock()) + monkeypatch.setattr(mlflow, 'start_run', MagicMock()) + + mock_state = MagicMock() + mock_state.run_name = 'dummy-run-name' + + existing_id = 'dummy-id' + mock_search_runs = MagicMock(return_value=[MagicMock(info=MagicMock(run_id=existing_id))]) + monkeypatch.setattr(mlflow, 'search_runs', mock_search_runs) + + test_logger = MLFlowLogger() + test_logger.init(state=mock_state, logger=MagicMock()) + assert test_logger._run_id == existing_id + + def test_mlflow_experiment_set_up(tmp_path): """ Test that MLFlow experiment is set up correctly within mlflow """ @@ -189,6 +209,7 @@ def test_mlflow_experiment_set_up(tmp_path): ) run_id = run.info.run_id experiment_id = run.info.experiment_id + tags = run.data.tags # Check uri set correctly. assert mlflow_uri.exists() @@ -207,6 +228,9 @@ def test_mlflow_experiment_set_up(tmp_path): actual_run_name = run_cfg['run_name'] assert actual_run_name == expected_run_name + # Check run tagged with Composer run name. + assert tags['composer_run_name'] == mock_state.run_name + # Check run ended. test_mlflow_logger.post_close() assert mlflow.active_run() is None From 8d8e7f99a7a1dd897dd330d258a7e6c0d8cc1528 Mon Sep 17 00:00:00 2001 From: Jerry Chen Date: Tue, 2 Jan 2024 11:53:27 -0800 Subject: [PATCH 02/17] Try formatting mlflow save folder after INIT Make MLFlow experiment and run ID available on all ranks Fix path issue Format mlflow placeholders in remote filenames --- composer/callbacks/checkpoint_saver.py | 30 +++++++- composer/loggers/mlflow_logger.py | 11 +++ composer/trainer/trainer.py | 20 ++++- composer/utils/__init__.py | 76 ++++--------------- .../utils/object_store/mlflow_object_store.py | 15 ++-- composer/utils/string_helpers.py | 31 ++++++++ .../object_store/test_mlflow_object_store.py | 21 ++--- 7 files changed, 123 insertions(+), 81 deletions(-) create mode 100644 composer/utils/string_helpers.py diff --git a/composer/callbacks/checkpoint_saver.py b/composer/callbacks/checkpoint_saver.py index c876343f21..bbbc69d9ec 100644 --- a/composer/callbacks/checkpoint_saver.py +++ b/composer/callbacks/checkpoint_saver.py @@ -15,12 +15,14 @@ from typing import Callable, List, Optional, Union from composer.core import Callback, Event, State, Time -from composer.loggers import Logger +from composer.loggers import Logger, MLFlowLogger from composer.utils import (FORMAT_NAME_WITH_DIST_AND_TIME_TABLE, FORMAT_NAME_WITH_DIST_TABLE, PartialFilePath, checkpoint, create_interval_scheduler, create_symlink_file, dist, ensure_folder_has_no_conflicting_files, format_name_with_dist, - format_name_with_dist_and_time, is_model_deepspeed, reproducibility, using_torch_2) + format_name_with_dist_and_time, is_model_deepspeed, partial_format, reproducibility, + using_torch_2) from composer.utils.checkpoint import _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME +from composer.utils.object_store.mlflow_object_store import MLFLOW_EXPERIMENT_ID_FORMAT_KEY, MLFLOW_RUN_ID_FORMAT_KEY log = logging.getLogger(__name__) @@ -270,6 +272,30 @@ def __init__( self.start_batch = None def init(self, state: State, logger: Logger) -> None: + # If MLFlowLogger is being used, format MLFlow-specific placeholders in the save folder and paths. + # Assumes that MLFlowLogger comes before CheckpointSaver in the list of loggers. + for destination in logger.destinations: + if isinstance(destination, MLFlowLogger): + mlflow_format_kwargs = { + MLFLOW_EXPERIMENT_ID_FORMAT_KEY: destination._experiment_id, + MLFLOW_RUN_ID_FORMAT_KEY: destination._run_id + } + self.folder = partial_format(self.folder, **mlflow_format_kwargs) + + self.filename.folder = self.folder + if self.latest_filename is not None: + self.latest_filename.folder = self.folder + + # The remote paths have the placeholders in their filename rather than folder + if self.remote_file_name is not None: + self.remote_file_name.filename = partial_format(self.remote_file_name.filename, + **mlflow_format_kwargs) + if self.latest_remote_file_name is not None: + self.latest_remote_file_name.filename = partial_format(self.latest_remote_file_name.filename, + **mlflow_format_kwargs) + + break + folder = format_name_with_dist(self.folder, state.run_name) os.makedirs(folder, exist_ok=True) diff --git a/composer/loggers/mlflow_logger.py b/composer/loggers/mlflow_logger.py index 48daa7d1cb..dd60b93f25 100644 --- a/composer/loggers/mlflow_logger.py +++ b/composer/loggers/mlflow_logger.py @@ -93,6 +93,10 @@ def __init__( self._rank_zero_only = rank_zero_only self._last_flush_time = time.time() self._flush_interval = flush_interval + + self._experiment_id = None + self._run_id = None + if self._enabled: self.tracking_uri = str(tracking_uri or mlflow.get_tracking_uri()) mlflow.set_tracking_uri(self.tracking_uri) @@ -155,6 +159,13 @@ def init(self, state: State, logger: Logger) -> None: log_system_metrics=self.log_system_metrics, ) + # If rank zero only, broadcast the MLFlow experiment and run IDs to other ranks, so the MLFlow run info is + # available to other ranks during runtime. + if self._rank_zero_only: + mlflow_ids_list = [self._experiment_id, self._run_id] + dist.broadcast_object_list(mlflow_ids_list, src=0) + self._experiment_id, self._run_id = mlflow_ids_list + def log_table(self, columns: List[str], rows: List[List[Any]], name: str = 'Table') -> None: if self._enabled: try: diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index fe0fa99e15..286b6a7a0c 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -39,7 +39,7 @@ PyTorchScheduler, State, Time, Timestamp, TimeUnit, TrainerMode, ensure_data_spec, ensure_evaluator, ensure_time, get_precision_context, validate_eval_automicrobatching) from composer.devices import Device, DeviceCPU, DeviceGPU, DeviceMPS, DeviceTPU -from composer.loggers import (ConsoleLogger, Logger, LoggerDestination, MosaicMLLogger, ProgressBarLogger, +from composer.loggers import (ConsoleLogger, Logger, LoggerDestination, MLFlowLogger, MosaicMLLogger, ProgressBarLogger, RemoteUploaderDownloader, WandBLogger) from composer.loggers.mosaicml_logger import MOSAICML_ACCESS_TOKEN_ENV_VAR, MOSAICML_PLATFORM_ENV_VAR from composer.models import ComposerModel @@ -54,8 +54,9 @@ ensure_tuple, export_with_logger, extract_hparams, format_name_with_dist, get_composer_env_dict, get_device, get_file, is_tpu_installed, map_collection, maybe_create_object_store_from_uri, maybe_create_remote_uploader_downloader_from_uri, - model_eval_mode, parse_uri, reproducibility, using_torch_2) + model_eval_mode, parse_uri, partial_format, reproducibility, using_torch_2) from composer.utils.misc import is_model_deepspeed +from composer.utils.object_store.mlflow_object_store import MLFLOW_EXPERIMENT_ID_FORMAT_KEY, MLFLOW_RUN_ID_FORMAT_KEY if is_tpu_installed(): import torch_xla.core.xla_model as xm @@ -1163,6 +1164,21 @@ def __init__( # Run Event.INIT self.engine.run_event(Event.INIT) + # If the experiment is being tracked with an `MLFlowLogger`, then the `save_folder` and + # related paths/filenames may have placeholders or the MLFlow experiment and run IDs that must be populated + # after running Event.INIT. + if save_folder is not None: + for destination in self.logger.destinations: + if isinstance(destination, MLFlowLogger): + mlflow_format_kwargs = { + MLFLOW_EXPERIMENT_ID_FORMAT_KEY: destination._experiment_id, + MLFLOW_RUN_ID_FORMAT_KEY: destination._run_id + } + + save_folder = partial_format(save_folder, **mlflow_format_kwargs) + if latest_remote_file_name is not None: + latest_remote_file_name = partial_format(latest_remote_file_name, **mlflow_format_kwargs) + # Log hparams. if self.auto_log_hparams: self.local_hparams = extract_hparams(locals()) diff --git a/composer/utils/__init__.py b/composer/utils/__init__.py index 30930250d9..81fabb4e45 100644 --- a/composer/utils/__init__.py +++ b/composer/utils/__init__.py @@ -26,68 +26,20 @@ UCObjectStore) from composer.utils.retrying import retry from composer.utils.string_enum import StringEnum +from composer.utils.string_helpers import partial_format __all__ = [ - 'ensure_tuple', - 'get_free_tcp_port', - 'map_collection', - 'IteratorFileStream', - 'FORMAT_NAME_WITH_DIST_AND_TIME_TABLE', - 'FORMAT_NAME_WITH_DIST_TABLE', - 'get_file', - 'PartialFilePath', - 'create_symlink_file', - 'ObjectStore', - 'ObjectStoreTransientError', - 'LibcloudObjectStore', - 'S3ObjectStore', - 'SFTPObjectStore', - 'OCIObjectStore', - 'GCSObjectStore', - 'UCObjectStore', - 'MLFlowObjectStore', - 'MissingConditionalImportError', - 'import_object', - 'is_model_deepspeed', - 'is_model_fsdp', - 'is_notebook', - 'StringEnum', - 'load_checkpoint', - 'save_checkpoint', - 'safe_torch_load', - 'ensure_folder_is_empty', - 'ensure_folder_has_no_conflicting_files', - 'export_for_inference', - 'export_with_logger', - 'quantize_dynamic', - 'format_name_with_dist', - 'format_name_with_dist_and_time', - 'is_tar', - 'maybe_create_object_store_from_uri', - 'maybe_create_remote_uploader_downloader_from_uri', - 'parse_uri', - 'batch_get', - 'batch_set', - 'configure_excepthook', - 'disable_env_report', - 'enable_env_report', - 'print_env', - 'get_composer_env_dict', - 'retry', - 'model_eval_mode', - 'get_device', - 'is_tpu_installed', - 'is_hpu_installed', - 'ExportFormat', - 'Transform', - 'export_with_logger', - 'extract_hparams', - 'convert_nested_dict_to_flat_dict', - 'convert_flat_dict_to_nested_dict', - 'using_torch_2', - 'create_interval_scheduler', - 'EvalClient', - 'LambdaEvalClient', - 'LocalEvalClient', - 'MosaicMLLambdaEvalClient', + 'ensure_tuple', 'get_free_tcp_port', 'map_collection', 'IteratorFileStream', 'FORMAT_NAME_WITH_DIST_AND_TIME_TABLE', + 'FORMAT_NAME_WITH_DIST_TABLE', 'get_file', 'PartialFilePath', 'create_symlink_file', 'ObjectStore', + 'ObjectStoreTransientError', 'LibcloudObjectStore', 'S3ObjectStore', 'SFTPObjectStore', 'OCIObjectStore', + 'GCSObjectStore', 'UCObjectStore', 'MLFlowObjectStore', 'MissingConditionalImportError', 'import_object', + 'is_model_deepspeed', 'is_model_fsdp', 'is_notebook', 'StringEnum', 'load_checkpoint', 'save_checkpoint', + 'safe_torch_load', 'ensure_folder_is_empty', 'ensure_folder_has_no_conflicting_files', 'export_for_inference', + 'export_with_logger', 'quantize_dynamic', 'format_name_with_dist', 'format_name_with_dist_and_time', 'is_tar', + 'maybe_create_object_store_from_uri', 'maybe_create_remote_uploader_downloader_from_uri', 'parse_uri', 'batch_get', + 'batch_set', 'configure_excepthook', 'disable_env_report', 'enable_env_report', 'print_env', + 'get_composer_env_dict', 'retry', 'model_eval_mode', 'get_device', 'is_tpu_installed', 'is_hpu_installed', + 'ExportFormat', 'Transform', 'export_with_logger', 'extract_hparams', 'convert_nested_dict_to_flat_dict', + 'convert_flat_dict_to_nested_dict', 'using_torch_2', 'create_interval_scheduler', 'EvalClient', 'LambdaEvalClient', + 'LocalEvalClient', 'MosaicMLLambdaEvalClient', 'partial_format' ] diff --git a/composer/utils/object_store/mlflow_object_store.py b/composer/utils/object_store/mlflow_object_store.py index 15f50bcdb0..27f3a5efbc 100644 --- a/composer/utils/object_store/mlflow_object_store.py +++ b/composer/utils/object_store/mlflow_object_store.py @@ -21,8 +21,11 @@ DEFAULT_MLFLOW_EXPERIMENT_NAME = 'mlflow-object-store' -PLACEHOLDER_EXPERIMENT_ID = '{mlflow_experiment_id}' -PLACEHOLDER_RUN_ID = '{mlflow_run_id}' +MLFLOW_EXPERIMENT_ID_FORMAT_KEY = 'mlflow_experiment_id' +MLFLOW_RUN_ID_FORMAT_KEY = 'mlflow_run_id' + +MLFLOW_EXPERIMENT_ID_PLACEHOLDER = '{' + MLFLOW_EXPERIMENT_ID_FORMAT_KEY + '}' +MLFLOW_RUN_ID_PLACEHOLDER = '{' + MLFLOW_RUN_ID_FORMAT_KEY + '}' log = logging.getLogger(__name__) @@ -132,9 +135,9 @@ def __init__(self, path: str, multipart_upload_chunk_size: int = 100 * 1024 * 10 mlflow.environment_variables.MLFLOW_MULTIPART_UPLOAD_CHUNK_SIZE.set(multipart_upload_chunk_size) experiment_id, run_id, _ = MLFlowObjectStore.parse_dbfs_path(path) - if experiment_id == PLACEHOLDER_EXPERIMENT_ID: + if experiment_id == MLFLOW_EXPERIMENT_ID_PLACEHOLDER: experiment_id = None - if run_id == PLACEHOLDER_RUN_ID: + if run_id == MLFLOW_RUN_ID_PLACEHOLDER: run_id = None # Construct the `experiment_id` and `run_id` depending on whether format placeholders were provided. @@ -236,10 +239,10 @@ def get_artifact_path(self, object_name: str) -> str: """ if object_name.startswith(MLFLOW_DBFS_PATH_PREFIX): experiment_id, run_id, object_name = self.parse_dbfs_path(object_name) - if (experiment_id != self.experiment_id and experiment_id != PLACEHOLDER_EXPERIMENT_ID): + if (experiment_id != self.experiment_id and experiment_id != MLFLOW_EXPERIMENT_ID_PLACEHOLDER): raise ValueError(f'Object {object_name} belongs to experiment ID {experiment_id}, ' f'but MLFlowObjectStore is associated with experiment ID {self.experiment_id}.') - if (run_id != self.run_id and run_id != PLACEHOLDER_RUN_ID): + if (run_id != self.run_id and run_id != MLFLOW_EXPERIMENT_ID_PLACEHOLDER): raise ValueError(f'Object {object_name} belongs to run ID {run_id}, ' f'but MLFlowObjectStore is associated with run ID {self.run_id}.') return object_name diff --git a/composer/utils/string_helpers.py b/composer/utils/string_helpers.py new file mode 100644 index 0000000000..1db1d7620d --- /dev/null +++ b/composer/utils/string_helpers.py @@ -0,0 +1,31 @@ +# Copyright 2022 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +"""Utilities for string manipulation.""" + + +def partial_format(s, *args, **kwargs): + """Format a string with a partial set of arguments. + + Since `str.format()` raises a `KeyError` if a format key is missing from the arguments, this + function allows for a partial set of arguments to be provided. + + For example: + + >>> partial_format('{foo} {bar}', foo='Hello') + 'Hello {bar}' + + >>> partial_format('{foo} {bar}', foo='Hello', bar='World') + 'Hello World' + """ + result = s + done = False + while not done: + try: + result = s.format(*args, **kwargs) + done = True + except KeyError as e: + key = e.args[0] + kwargs[key] = '{' + key + '}' + + return result diff --git a/tests/utils/object_store/test_mlflow_object_store.py b/tests/utils/object_store/test_mlflow_object_store.py index d46fc493a4..ecbedd2e50 100644 --- a/tests/utils/object_store/test_mlflow_object_store.py +++ b/tests/utils/object_store/test_mlflow_object_store.py @@ -8,7 +8,7 @@ import pytest from composer.utils import MLFlowObjectStore -from composer.utils.object_store.mlflow_object_store import PLACEHOLDER_EXPERIMENT_ID, PLACEHOLDER_RUN_ID +from composer.utils.object_store.mlflow_object_store import MLFLOW_EXPERIMENT_ID_PLACEHOLDER, MLFLOW_RUN_ID_PLACEHOLDER TEST_PATH_FORMAT = 'databricks/mlflow-tracking/{experiment_id}/{run_id}/artifacts/' EXPERIMENT_ID = '123' @@ -66,7 +66,7 @@ def test_init_with_experiment_and_no_run(monkeypatch): mock_mlflow_client.return_value.create_run.return_value = MagicMock( info=MagicMock(run_id=RUN_ID, run_name='test-run')) - store = MLFlowObjectStore(TEST_PATH_FORMAT.format(experiment_id=EXPERIMENT_ID, run_id=PLACEHOLDER_RUN_ID)) + store = MLFlowObjectStore(TEST_PATH_FORMAT.format(experiment_id=EXPERIMENT_ID, run_id=MLFLOW_RUN_ID_PLACEHOLDER)) assert store.experiment_id == EXPERIMENT_ID assert store.run_id == RUN_ID @@ -76,7 +76,7 @@ def test_init_with_run_and_no_experiment(monkeypatch): monkeypatch.setattr(dbx_sdk, 'WorkspaceClient', MagicMock()) with pytest.raises(ValueError): - MLFlowObjectStore(TEST_PATH_FORMAT.format(experiment_id=PLACEHOLDER_EXPERIMENT_ID, run_id=RUN_ID)) + MLFlowObjectStore(TEST_PATH_FORMAT.format(experiment_id=MLFLOW_EXPERIMENT_ID_PLACEHOLDER, run_id=RUN_ID)) def test_init_with_active_run(monkeypatch): @@ -91,7 +91,7 @@ def test_init_with_active_run(monkeypatch): mock_active_run.return_value = MagicMock(info=MagicMock(experiment_id=EXPERIMENT_ID, run_id=RUN_ID)) store = MLFlowObjectStore( - TEST_PATH_FORMAT.format(experiment_id=PLACEHOLDER_EXPERIMENT_ID, run_id=PLACEHOLDER_RUN_ID)) + TEST_PATH_FORMAT.format(experiment_id=MLFLOW_EXPERIMENT_ID_PLACEHOLDER, run_id=MLFLOW_RUN_ID_PLACEHOLDER)) assert store.experiment_id == EXPERIMENT_ID assert store.run_id == RUN_ID @@ -109,7 +109,7 @@ def test_init_with_existing_experiment_and_no_run(monkeypatch): info=MagicMock(run_id=RUN_ID, run_name='test-run')) store = MLFlowObjectStore( - TEST_PATH_FORMAT.format(experiment_id=PLACEHOLDER_EXPERIMENT_ID, run_id=PLACEHOLDER_RUN_ID)) + TEST_PATH_FORMAT.format(experiment_id=MLFLOW_EXPERIMENT_ID_PLACEHOLDER, run_id=MLFLOW_RUN_ID_PLACEHOLDER)) assert store.experiment_id == EXPERIMENT_ID assert store.run_id == RUN_ID @@ -128,7 +128,7 @@ def test_init_with_no_experiment_and_no_run(monkeypatch): info=MagicMock(run_id=RUN_ID, run_name='test-run')) store = MLFlowObjectStore( - TEST_PATH_FORMAT.format(experiment_id=PLACEHOLDER_EXPERIMENT_ID, run_id=PLACEHOLDER_RUN_ID)) + TEST_PATH_FORMAT.format(experiment_id=MLFLOW_EXPERIMENT_ID_PLACEHOLDER, run_id=MLFLOW_RUN_ID_PLACEHOLDER)) assert store.experiment_id == EXPERIMENT_ID assert store.run_id == RUN_ID @@ -190,16 +190,19 @@ def test_get_artifact_path(mlflow_object_store): assert mlflow_object_store.get_artifact_path(DEFAULT_PATH + ARTIFACT_PATH) == ARTIFACT_PATH # Absolute DBFS path with placeholders - path = TEST_PATH_FORMAT.format(experiment_id=PLACEHOLDER_EXPERIMENT_ID, run_id=PLACEHOLDER_RUN_ID) + ARTIFACT_PATH + path = TEST_PATH_FORMAT.format(experiment_id=MLFLOW_EXPERIMENT_ID_PLACEHOLDER, + run_id=MLFLOW_RUN_ID_PLACEHOLDER) + ARTIFACT_PATH assert mlflow_object_store.get_artifact_path(path) == ARTIFACT_PATH # Raises ValueError for different experiment ID - path = TEST_PATH_FORMAT.format(experiment_id='different-experiment', run_id=PLACEHOLDER_RUN_ID) + ARTIFACT_PATH + path = TEST_PATH_FORMAT.format(experiment_id='different-experiment', + run_id=MLFLOW_RUN_ID_PLACEHOLDER) + ARTIFACT_PATH with pytest.raises(ValueError): mlflow_object_store.get_artifact_path(path) # Raises ValueError for different run ID - path = TEST_PATH_FORMAT.format(experiment_id=PLACEHOLDER_EXPERIMENT_ID, run_id='different-run') + ARTIFACT_PATH + path = TEST_PATH_FORMAT.format(experiment_id=MLFLOW_EXPERIMENT_ID_PLACEHOLDER, + run_id='different-run') + ARTIFACT_PATH with pytest.raises(ValueError): mlflow_object_store.get_artifact_path(path) From cc264fa1e91f097c8e45ed91659c2002e34178ec Mon Sep 17 00:00:00 2001 From: Jerry Chen Date: Tue, 2 Jan 2024 14:08:48 -0800 Subject: [PATCH 03/17] Unit tests for partial_format --- composer/utils/string_helpers.py | 4 +++- tests/utils/test_string_helpers.py | 21 +++++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) create mode 100644 tests/utils/test_string_helpers.py diff --git a/composer/utils/string_helpers.py b/composer/utils/string_helpers.py index 1db1d7620d..c66495d2bf 100644 --- a/composer/utils/string_helpers.py +++ b/composer/utils/string_helpers.py @@ -24,7 +24,9 @@ def partial_format(s, *args, **kwargs): try: result = s.format(*args, **kwargs) done = True - except KeyError as e: + except IndexError as e: # Missing positional arg + args += ('{}',) + except KeyError as e: # Missing keyword arg key = e.args[0] kwargs[key] = '{' + key + '}' diff --git a/tests/utils/test_string_helpers.py b/tests/utils/test_string_helpers.py new file mode 100644 index 0000000000..9c16bc514f --- /dev/null +++ b/tests/utils/test_string_helpers.py @@ -0,0 +1,21 @@ +# Copyright 2022 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for string helpers.""" + +from composer.utils.string_helpers import partial_format + + +def test_partial_format(): + # Keyword args + assert partial_format('{foo} {bar}', foo='Hello') == 'Hello {bar}' + assert partial_format('{foo} {bar}', foo='Hello', bar='World') == 'Hello World' + + # Positional args + assert partial_format('{} {}', 'Hello') == 'Hello {}' + assert partial_format('{} {}', 'Hello', 'World') == 'Hello World' + + # Positional and keyword args + assert partial_format('{foo} {}', 'World') == '{foo} World' + assert partial_format('{foo} {}', foo='Hello') == 'Hello {}' + assert partial_format('{foo} {}', 'World', foo='Hello') == 'Hello World' From 7b15ffb63b080105413a7c567e4e875e96d88f91 Mon Sep 17 00:00:00 2001 From: Jerry Chen Date: Tue, 2 Jan 2024 16:03:34 -0800 Subject: [PATCH 04/17] Log mlflow info as hyperparams --- composer/trainer/trainer.py | 35 +++++++++++++++++++++++++---------- 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 286b6a7a0c..f1701ecf25 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -1164,20 +1164,35 @@ def __init__( # Run Event.INIT self.engine.run_event(Event.INIT) - # If the experiment is being tracked with an `MLFlowLogger`, then the `save_folder` and - # related paths/filenames may have placeholders or the MLFlow experiment and run IDs that must be populated - # after running Event.INIT. + # If the experiment is being tracked with an `MLFlowLogger`, then MLFlow experiment and run are available + # after Event.INIT. if save_folder is not None: + mlflow_logger = None for destination in self.logger.destinations: if isinstance(destination, MLFlowLogger): - mlflow_format_kwargs = { - MLFLOW_EXPERIMENT_ID_FORMAT_KEY: destination._experiment_id, - MLFLOW_RUN_ID_FORMAT_KEY: destination._run_id - } + mlflow_logger = destination + break + + if mlflow_logger is not None: + mlflow_experiment_id = mlflow_logger._experiment_id + mlflow_run_id = mlflow_logger._run_id + + # Log the experiment and run IDs as hyperparameters. + self.logger.log_hyperparameters({ + 'mlflow_experiment_id': mlflow_experiment_id, + 'mlflow_run_id': mlflow_run_id + }) + + # The save folder and related paths/filenames may contain format placeholders for the MLFlow IDs, so + # populate them now. + mlflow_format_kwargs = { + MLFLOW_EXPERIMENT_ID_FORMAT_KEY: mlflow_experiment_id, + MLFLOW_RUN_ID_FORMAT_KEY: mlflow_run_id + } - save_folder = partial_format(save_folder, **mlflow_format_kwargs) - if latest_remote_file_name is not None: - latest_remote_file_name = partial_format(latest_remote_file_name, **mlflow_format_kwargs) + save_folder = partial_format(save_folder, **mlflow_format_kwargs) + if latest_remote_file_name is not None: + latest_remote_file_name = partial_format(latest_remote_file_name, **mlflow_format_kwargs) # Log hparams. if self.auto_log_hparams: From 942b9a53105a34771088a8199008122608f41347 Mon Sep 17 00:00:00 2001 From: Jerry Chen Date: Tue, 2 Jan 2024 16:13:31 -0800 Subject: [PATCH 05/17] partial_format doc update --- composer/utils/string_helpers.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/composer/utils/string_helpers.py b/composer/utils/string_helpers.py index c66495d2bf..df1aec7596 100644 --- a/composer/utils/string_helpers.py +++ b/composer/utils/string_helpers.py @@ -8,15 +8,8 @@ def partial_format(s, *args, **kwargs): """Format a string with a partial set of arguments. Since `str.format()` raises a `KeyError` if a format key is missing from the arguments, this - function allows for a partial set of arguments to be provided. - - For example: - - >>> partial_format('{foo} {bar}', foo='Hello') - 'Hello {bar}' - - >>> partial_format('{foo} {bar}', foo='Hello', bar='World') - 'Hello World' + function allows for a partial set of arguments to be provided. Any missing arguments will be + left as-is in the string. """ result = s done = False From 7983dce4394f6c5e785c08c130987815efcbe27a Mon Sep 17 00:00:00 2001 From: Jerry Chen Date: Fri, 5 Jan 2024 13:56:55 -0800 Subject: [PATCH 06/17] Fix formatting --- composer/utils/__init__.py | 76 +++++++++++++++++++++++++++++++------- 1 file changed, 63 insertions(+), 13 deletions(-) diff --git a/composer/utils/__init__.py b/composer/utils/__init__.py index 81fabb4e45..43d9ffcd47 100644 --- a/composer/utils/__init__.py +++ b/composer/utils/__init__.py @@ -29,17 +29,67 @@ from composer.utils.string_helpers import partial_format __all__ = [ - 'ensure_tuple', 'get_free_tcp_port', 'map_collection', 'IteratorFileStream', 'FORMAT_NAME_WITH_DIST_AND_TIME_TABLE', - 'FORMAT_NAME_WITH_DIST_TABLE', 'get_file', 'PartialFilePath', 'create_symlink_file', 'ObjectStore', - 'ObjectStoreTransientError', 'LibcloudObjectStore', 'S3ObjectStore', 'SFTPObjectStore', 'OCIObjectStore', - 'GCSObjectStore', 'UCObjectStore', 'MLFlowObjectStore', 'MissingConditionalImportError', 'import_object', - 'is_model_deepspeed', 'is_model_fsdp', 'is_notebook', 'StringEnum', 'load_checkpoint', 'save_checkpoint', - 'safe_torch_load', 'ensure_folder_is_empty', 'ensure_folder_has_no_conflicting_files', 'export_for_inference', - 'export_with_logger', 'quantize_dynamic', 'format_name_with_dist', 'format_name_with_dist_and_time', 'is_tar', - 'maybe_create_object_store_from_uri', 'maybe_create_remote_uploader_downloader_from_uri', 'parse_uri', 'batch_get', - 'batch_set', 'configure_excepthook', 'disable_env_report', 'enable_env_report', 'print_env', - 'get_composer_env_dict', 'retry', 'model_eval_mode', 'get_device', 'is_tpu_installed', 'is_hpu_installed', - 'ExportFormat', 'Transform', 'export_with_logger', 'extract_hparams', 'convert_nested_dict_to_flat_dict', - 'convert_flat_dict_to_nested_dict', 'using_torch_2', 'create_interval_scheduler', 'EvalClient', 'LambdaEvalClient', - 'LocalEvalClient', 'MosaicMLLambdaEvalClient', 'partial_format' + 'ensure_tuple', + 'get_free_tcp_port', + 'map_collection', + 'IteratorFileStream', + 'FORMAT_NAME_WITH_DIST_AND_TIME_TABLE', + 'FORMAT_NAME_WITH_DIST_TABLE', + 'get_file', + 'PartialFilePath', + 'create_symlink_file', + 'ObjectStore', + 'ObjectStoreTransientError', + 'LibcloudObjectStore', + 'S3ObjectStore', + 'SFTPObjectStore', + 'OCIObjectStore', + 'GCSObjectStore', + 'UCObjectStore', + 'MLFlowObjectStore', + 'MissingConditionalImportError', + 'import_object', + 'is_model_deepspeed', + 'is_model_fsdp', + 'is_notebook', + 'StringEnum', + 'load_checkpoint', + 'save_checkpoint', + 'safe_torch_load', + 'ensure_folder_is_empty', + 'ensure_folder_has_no_conflicting_files', + 'export_for_inference', + 'export_with_logger', + 'quantize_dynamic', + 'format_name_with_dist', + 'format_name_with_dist_and_time', + 'is_tar', + 'maybe_create_object_store_from_uri', + 'maybe_create_remote_uploader_downloader_from_uri', + 'parse_uri', + 'batch_get', + 'batch_set', + 'configure_excepthook', + 'disable_env_report', + 'enable_env_report', + 'print_env', + 'get_composer_env_dict', + 'retry', + 'model_eval_mode', + 'get_device', + 'is_tpu_installed', + 'is_hpu_installed', + 'ExportFormat', + 'Transform', + 'export_with_logger', + 'extract_hparams', + 'convert_nested_dict_to_flat_dict', + 'convert_flat_dict_to_nested_dict', + 'using_torch_2', + 'create_interval_scheduler', + 'EvalClient', + 'LambdaEvalClient', + 'LocalEvalClient', + 'MosaicMLLambdaEvalClient', + 'partial_format', ] From 207eacc51e9c423f0439985e46fd56ea237798fc Mon Sep 17 00:00:00 2001 From: Jerry Chen Date: Fri, 5 Jan 2024 15:51:12 -0800 Subject: [PATCH 07/17] Pull distributed logic out of MLFlowObjectStore Add debug tracebacks Bugfix Add path to debug info Try fixing RUD object store init Pyright --- composer/loggers/remote_uploader_downloader.py | 3 ++- composer/utils/file_helpers.py | 18 +++++++++++++++++- .../utils/object_store/mlflow_object_store.py | 2 +- 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/composer/loggers/remote_uploader_downloader.py b/composer/loggers/remote_uploader_downloader.py index 106cc7a7f7..7d457e1d0a 100644 --- a/composer/loggers/remote_uploader_downloader.py +++ b/composer/loggers/remote_uploader_downloader.py @@ -339,8 +339,9 @@ def init(self, state: State, logger: Logger) -> None: # updated with information generated at runtime, i.e., the MLFlow experiment and run IDs. This information # must be propagated across all ranks before the workers are started so that all workers use the same # MLFlow run. - if isinstance(self.remote_backend, MLFlowObjectStore): + if self.backend_kwargs['path'].startswith(MLFLOW_DBFS_PATH_PREFIX): if dist.get_global_rank() == 0: + assert isinstance(self.remote_backend, MLFlowObjectStore) self.backend_kwargs['path'] = self.remote_backend.get_dbfs_path(self.backend_kwargs['path']) path_list = [self.backend_kwargs['path']] diff --git a/composer/utils/file_helpers.py b/composer/utils/file_helpers.py index d94a31f1db..a638f4eb9e 100644 --- a/composer/utils/file_helpers.py +++ b/composer/utils/file_helpers.py @@ -353,7 +353,23 @@ def maybe_create_object_store_from_uri(uri: str) -> Optional[ObjectStore]: return OCIObjectStore(bucket=bucket_name) elif backend == 'dbfs': if path.startswith(MLFLOW_DBFS_PATH_PREFIX): - return MLFlowObjectStore(path) + store = None + if dist.get_global_rank() == 0: + store = MLFlowObjectStore(path) + + # The path may have had placeholders, so update it with the experiment/run IDs initialized by the store + path = store.get_dbfs_path(path) + + # Broadcast the rank 0 updated path to all ranks for their own object stores + path_list = [path] + dist.broadcast_object_list(path_list, src=0) + path = path_list[0] + + # Create the object store for all other ranks + if dist.get_global_rank() != 0: + store = MLFlowObjectStore(path) + + return store else: # validate if the path conforms to the requirements for UC volume paths UCObjectStore.validate_path(path) diff --git a/composer/utils/object_store/mlflow_object_store.py b/composer/utils/object_store/mlflow_object_store.py index 27f3a5efbc..eb90ec13a7 100644 --- a/composer/utils/object_store/mlflow_object_store.py +++ b/composer/utils/object_store/mlflow_object_store.py @@ -242,7 +242,7 @@ def get_artifact_path(self, object_name: str) -> str: if (experiment_id != self.experiment_id and experiment_id != MLFLOW_EXPERIMENT_ID_PLACEHOLDER): raise ValueError(f'Object {object_name} belongs to experiment ID {experiment_id}, ' f'but MLFlowObjectStore is associated with experiment ID {self.experiment_id}.') - if (run_id != self.run_id and run_id != MLFLOW_EXPERIMENT_ID_PLACEHOLDER): + if (run_id != self.run_id and run_id != MLFLOW_RUN_ID_PLACEHOLDER): raise ValueError(f'Object {object_name} belongs to run ID {run_id}, ' f'but MLFlowObjectStore is associated with run ID {self.run_id}.') return object_name From ef59a30f9286c4eb9fae6d4985a2f6030b175c3a Mon Sep 17 00:00:00 2001 From: Jerry Chen Date: Sat, 6 Jan 2024 12:02:04 -0800 Subject: [PATCH 08/17] Partial format in format_name helpers --- composer/utils/file_helpers.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/composer/utils/file_helpers.py b/composer/utils/file_helpers.py index a638f4eb9e..ea42e14e3a 100644 --- a/composer/utils/file_helpers.py +++ b/composer/utils/file_helpers.py @@ -18,7 +18,7 @@ import requests import tqdm -from composer.utils import dist +from composer.utils import dist, partial_format from composer.utils.iter_helpers import iterate_with_callback from composer.utils.object_store import (GCSObjectStore, MLFlowObjectStore, ObjectStore, OCIObjectStore, S3ObjectStore, UCObjectStore) @@ -168,7 +168,8 @@ def ensure_folder_has_no_conflicting_files(folder_name: Union[str, pathlib.Path] def format_name_with_dist(format_str: str, run_name: str, **extra_format_kwargs: object): # noqa: D103 - formatted_str = format_str.format( + formatted_str = partial_format( + format_str, run_name=run_name, **_get_dist_config(strict=False), **extra_format_kwargs, @@ -261,7 +262,8 @@ def format_name_with_dist_and_time( timestamp: Timestamp, **extra_format_kwargs: object, ): # noqa: D103 - formatted_str = format_str.format( + formatted_str = partial_format( + format_str, run_name=run_name, epoch=int(timestamp.epoch), batch=int(timestamp.batch), From ebf758c653d6ed44f7c724db7379ceee8b597ef0 Mon Sep 17 00:00:00 2001 From: Jerry Chen Date: Sat, 6 Jan 2024 12:04:36 -0800 Subject: [PATCH 09/17] Fix import --- composer/utils/file_helpers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/composer/utils/file_helpers.py b/composer/utils/file_helpers.py index ea42e14e3a..0b408e6b72 100644 --- a/composer/utils/file_helpers.py +++ b/composer/utils/file_helpers.py @@ -18,11 +18,12 @@ import requests import tqdm -from composer.utils import dist, partial_format +from composer.utils import dist from composer.utils.iter_helpers import iterate_with_callback from composer.utils.object_store import (GCSObjectStore, MLFlowObjectStore, ObjectStore, OCIObjectStore, S3ObjectStore, UCObjectStore) from composer.utils.object_store.mlflow_object_store import MLFLOW_DBFS_PATH_PREFIX +from composer.utils.string_helpers import partial_format if TYPE_CHECKING: from composer.core import Timestamp From 130e8f67ab4098dfd2c382362b64869acbec345c Mon Sep 17 00:00:00 2001 From: Jerry Chen Date: Sat, 6 Jan 2024 12:08:52 -0800 Subject: [PATCH 10/17] Add extra partial_format test --- tests/utils/test_string_helpers.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/utils/test_string_helpers.py b/tests/utils/test_string_helpers.py index 9c16bc514f..f6a8083c01 100644 --- a/tests/utils/test_string_helpers.py +++ b/tests/utils/test_string_helpers.py @@ -7,6 +7,9 @@ def test_partial_format(): + # No args provided + assert partial_format('{foo} {bar} {}') == '{foo} {bar} {}' + # Keyword args assert partial_format('{foo} {bar}', foo='Hello') == 'Hello {bar}' assert partial_format('{foo} {bar}', foo='Hello', bar='World') == 'Hello World' From 0097938f5296a3baa7c4da57f357a6518974d82b Mon Sep 17 00:00:00 2001 From: Jerry Chen Date: Sat, 6 Jan 2024 12:14:16 -0800 Subject: [PATCH 11/17] Fix mlflow RUD check --- composer/loggers/remote_uploader_downloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/composer/loggers/remote_uploader_downloader.py b/composer/loggers/remote_uploader_downloader.py index 7d457e1d0a..f9bdf45c2a 100644 --- a/composer/loggers/remote_uploader_downloader.py +++ b/composer/loggers/remote_uploader_downloader.py @@ -339,7 +339,7 @@ def init(self, state: State, logger: Logger) -> None: # updated with information generated at runtime, i.e., the MLFlow experiment and run IDs. This information # must be propagated across all ranks before the workers are started so that all workers use the same # MLFlow run. - if self.backend_kwargs['path'].startswith(MLFLOW_DBFS_PATH_PREFIX): + if self.backend_kwargs.get('path', '').startswith(MLFLOW_DBFS_PATH_PREFIX): if dist.get_global_rank() == 0: assert isinstance(self.remote_backend, MLFlowObjectStore) self.backend_kwargs['path'] = self.remote_backend.get_dbfs_path(self.backend_kwargs['path']) From 8dd25f8d1a58a376b75891fb3eb1cf6c3d579be1 Mon Sep 17 00:00:00 2001 From: Jerry Chen Date: Mon, 8 Jan 2024 10:50:39 -0800 Subject: [PATCH 12/17] Fix test pyright No longer expect KeyError for format_with_dist using partial_format Refactor partial_format for readability --- composer/utils/string_helpers.py | 11 +++-------- tests/utils/test_file_helpers.py | 15 +++------------ 2 files changed, 6 insertions(+), 20 deletions(-) diff --git a/composer/utils/string_helpers.py b/composer/utils/string_helpers.py index df1aec7596..3f5d3e259a 100644 --- a/composer/utils/string_helpers.py +++ b/composer/utils/string_helpers.py @@ -4,23 +4,18 @@ """Utilities for string manipulation.""" -def partial_format(s, *args, **kwargs): +def partial_format(s, *args, **kwargs) -> str: """Format a string with a partial set of arguments. Since `str.format()` raises a `KeyError` if a format key is missing from the arguments, this function allows for a partial set of arguments to be provided. Any missing arguments will be left as-is in the string. """ - result = s - done = False - while not done: + while True: try: - result = s.format(*args, **kwargs) - done = True + return s.format(*args, **kwargs) except IndexError as e: # Missing positional arg args += ('{}',) except KeyError as e: # Missing keyword arg key = e.args[0] kwargs[key] = '{' + key + '}' - - return result diff --git a/tests/utils/test_file_helpers.py b/tests/utils/test_file_helpers.py index 2e757afbe4..4566da9fd8 100644 --- a/tests/utils/test_file_helpers.py +++ b/tests/utils/test_file_helpers.py @@ -213,17 +213,6 @@ def test_safe_format_name_with_dist(monkeypatch: pytest.MonkeyPatch, world_size) assert format_name_with_dist(format_str, 'awesome_run') == expected_str -@world_size(2) -def test_unsafe_format_name_with_dist(monkeypatch: pytest.MonkeyPatch, world_size): - """Node rank is deleted, but also in the format string, so expect error.""" - vars = ['run_name', 'node_rank'] - format_str = ','.join(f'{x}={{{x}}}' for x in vars) - - monkeypatch.delenv('NODE_RANK') - with pytest.raises(KeyError): - assert format_name_with_dist(format_str, 'awesome_run') == 'run_name=awesome_run,node_rank=3' - - def test_format_name_with_dist_and_time(): vars = [ 'run_name', @@ -357,7 +346,9 @@ def test_maybe_create_remote_uploader_downloader_from_uri(monkeypatch): backend_kwargs={'path': 'Volumes/checkpoint/for/my/model.pt'}) with pytest.raises(ValueError): - maybe_create_remote_uploader_downloader_from_uri('dbfs:/checkpoint/for/my/model.pt', loggers=[]) + rud = maybe_create_remote_uploader_downloader_from_uri('dbfs:/checkpoint/for/my/model.pt', loggers=[]) + assert rud is not None + _ = rud.remote_backend def test_ensure_folder_is_empty(tmp_path: pathlib.Path): From c1f88a0f9a10f1d781e2a8bd79ce3a76a5741998 Mon Sep 17 00:00:00 2001 From: Jerry Chen Date: Mon, 8 Jan 2024 14:42:36 -0800 Subject: [PATCH 13/17] Max iters on partial_format --- composer/utils/string_helpers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/composer/utils/string_helpers.py b/composer/utils/string_helpers.py index 3f5d3e259a..fb7f5164a1 100644 --- a/composer/utils/string_helpers.py +++ b/composer/utils/string_helpers.py @@ -11,7 +11,8 @@ def partial_format(s, *args, **kwargs) -> str: function allows for a partial set of arguments to be provided. Any missing arguments will be left as-is in the string. """ - while True: + max_iters = 10_000 # Just in case we get stuck in a loop somehow. + while max_iters: try: return s.format(*args, **kwargs) except IndexError as e: # Missing positional arg From 2e10273be52c6324e5dfed587e69fe11b105e424 Mon Sep 17 00:00:00 2001 From: Jerry Chen Date: Wed, 10 Jan 2024 10:51:49 -0800 Subject: [PATCH 14/17] Fix partial_format --- composer/utils/string_helpers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/composer/utils/string_helpers.py b/composer/utils/string_helpers.py index fb7f5164a1..4702550d73 100644 --- a/composer/utils/string_helpers.py +++ b/composer/utils/string_helpers.py @@ -12,7 +12,7 @@ def partial_format(s, *args, **kwargs) -> str: left as-is in the string. """ max_iters = 10_000 # Just in case we get stuck in a loop somehow. - while max_iters: + for _ in range(max_iters): try: return s.format(*args, **kwargs) except IndexError as e: # Missing positional arg @@ -20,3 +20,5 @@ def partial_format(s, *args, **kwargs) -> str: except KeyError as e: # Missing keyword arg key = e.args[0] kwargs[key] = '{' + key + '}' + + raise RuntimeError(f'Failed to format string {s} after {max_iters} iterations.') From 925a6c2a44f2a1637c458f4ad826e57ffea63768 Mon Sep 17 00:00:00 2001 From: Jerry Chen Date: Fri, 12 Jan 2024 13:11:17 -0800 Subject: [PATCH 15/17] Clean up --- composer/loggers/mlflow_logger.py | 3 +++ composer/trainer/trainer.py | 6 ----- composer/utils/__init__.py | 3 +-- composer/utils/file_helpers.py | 2 +- composer/utils/misc.py | 20 ++++++++++++++++ composer/utils/string_helpers.py | 24 ------------------- .../{test_string_helpers.py => test_misc.py} | 2 -- 7 files changed, 25 insertions(+), 35 deletions(-) delete mode 100644 composer/utils/string_helpers.py rename tests/utils/{test_string_helpers.py => test_misc.py} (95%) diff --git a/composer/loggers/mlflow_logger.py b/composer/loggers/mlflow_logger.py index dd60b93f25..1b452ddda7 100644 --- a/composer/loggers/mlflow_logger.py +++ b/composer/loggers/mlflow_logger.py @@ -166,6 +166,9 @@ def init(self, state: State, logger: Logger) -> None: dist.broadcast_object_list(mlflow_ids_list, src=0) self._experiment_id, self._run_id = mlflow_ids_list + def after_load(self, state: State, logger: Logger) -> None: + logger.log_hyperparameters({'mlflow_experiment_id': self._experiment_id, 'mlflow_run_id': self._run_id}) + def log_table(self, columns: List[str], rows: List[List[Any]], name: str = 'Table') -> None: if self._enabled: try: diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 1f6ecad1fb..22bb7f52d2 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -1174,12 +1174,6 @@ def __init__( mlflow_experiment_id = mlflow_logger._experiment_id mlflow_run_id = mlflow_logger._run_id - # Log the experiment and run IDs as hyperparameters. - self.logger.log_hyperparameters({ - 'mlflow_experiment_id': mlflow_experiment_id, - 'mlflow_run_id': mlflow_run_id - }) - # The save folder and related paths/filenames may contain format placeholders for the MLFlow IDs, so # populate them now. mlflow_format_kwargs = { diff --git a/composer/utils/__init__.py b/composer/utils/__init__.py index 43d9ffcd47..ad9ce17c5c 100644 --- a/composer/utils/__init__.py +++ b/composer/utils/__init__.py @@ -20,13 +20,12 @@ from composer.utils.inference import ExportFormat, Transform, export_for_inference, export_with_logger, quantize_dynamic from composer.utils.iter_helpers import IteratorFileStream, ensure_tuple, map_collection from composer.utils.misc import (create_interval_scheduler, get_free_tcp_port, is_model_deepspeed, is_model_fsdp, - is_notebook, model_eval_mode, using_torch_2) + is_notebook, model_eval_mode, partial_format, using_torch_2) from composer.utils.object_store import (GCSObjectStore, LibcloudObjectStore, MLFlowObjectStore, ObjectStore, ObjectStoreTransientError, OCIObjectStore, S3ObjectStore, SFTPObjectStore, UCObjectStore) from composer.utils.retrying import retry from composer.utils.string_enum import StringEnum -from composer.utils.string_helpers import partial_format __all__ = [ 'ensure_tuple', diff --git a/composer/utils/file_helpers.py b/composer/utils/file_helpers.py index 0b408e6b72..d62487e106 100644 --- a/composer/utils/file_helpers.py +++ b/composer/utils/file_helpers.py @@ -20,10 +20,10 @@ from composer.utils import dist from composer.utils.iter_helpers import iterate_with_callback +from composer.utils.misc import partial_format from composer.utils.object_store import (GCSObjectStore, MLFlowObjectStore, ObjectStore, OCIObjectStore, S3ObjectStore, UCObjectStore) from composer.utils.object_store.mlflow_object_store import MLFLOW_DBFS_PATH_PREFIX -from composer.utils.string_helpers import partial_format if TYPE_CHECKING: from composer.core import Timestamp diff --git a/composer/utils/misc.py b/composer/utils/misc.py index 76573f8901..fca6bb7076 100644 --- a/composer/utils/misc.py +++ b/composer/utils/misc.py @@ -224,3 +224,23 @@ def using_torch_2_0_1() -> bool: bool: Return True if current version is greater than or equal to 2.0.1 else False """ return version.parse(torch.__version__) >= version.parse('2.0.1') + + +def partial_format(s, *args, **kwargs) -> str: + """Format a string with a partial set of arguments. + + Since `str.format()` raises a `KeyError` if a format key is missing from the arguments, this + function allows for a partial set of arguments to be provided. Any missing arguments will be + left as-is in the string. + """ + max_iters = 10_000 # Just in case we get stuck in a loop somehow. + for _ in range(max_iters): + try: + return s.format(*args, **kwargs) + except IndexError as e: # Missing positional arg + args += ('{}',) + except KeyError as e: # Missing keyword arg + key = e.args[0] + kwargs[key] = '{' + key + '}' + + raise RuntimeError(f'Failed to format string {s} after {max_iters} iterations.') diff --git a/composer/utils/string_helpers.py b/composer/utils/string_helpers.py deleted file mode 100644 index 4702550d73..0000000000 --- a/composer/utils/string_helpers.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright 2022 MosaicML Composer authors -# SPDX-License-Identifier: Apache-2.0 - -"""Utilities for string manipulation.""" - - -def partial_format(s, *args, **kwargs) -> str: - """Format a string with a partial set of arguments. - - Since `str.format()` raises a `KeyError` if a format key is missing from the arguments, this - function allows for a partial set of arguments to be provided. Any missing arguments will be - left as-is in the string. - """ - max_iters = 10_000 # Just in case we get stuck in a loop somehow. - for _ in range(max_iters): - try: - return s.format(*args, **kwargs) - except IndexError as e: # Missing positional arg - args += ('{}',) - except KeyError as e: # Missing keyword arg - key = e.args[0] - kwargs[key] = '{' + key + '}' - - raise RuntimeError(f'Failed to format string {s} after {max_iters} iterations.') diff --git a/tests/utils/test_string_helpers.py b/tests/utils/test_misc.py similarity index 95% rename from tests/utils/test_string_helpers.py rename to tests/utils/test_misc.py index f6a8083c01..1bf1058f85 100644 --- a/tests/utils/test_string_helpers.py +++ b/tests/utils/test_misc.py @@ -1,8 +1,6 @@ # Copyright 2022 MosaicML Composer authors # SPDX-License-Identifier: Apache-2.0 -"""Unit tests for string helpers.""" - from composer.utils.string_helpers import partial_format From 12a5411dcd42853bdbd13c6039e2723e0eeb6529 Mon Sep 17 00:00:00 2001 From: Jerry Chen Date: Fri, 12 Jan 2024 13:34:28 -0800 Subject: [PATCH 16/17] fix test import --- tests/utils/test_misc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils/test_misc.py b/tests/utils/test_misc.py index 1bf1058f85..333262795d 100644 --- a/tests/utils/test_misc.py +++ b/tests/utils/test_misc.py @@ -1,7 +1,7 @@ # Copyright 2022 MosaicML Composer authors # SPDX-License-Identifier: Apache-2.0 -from composer.utils.string_helpers import partial_format +from composer.utils.misc import partial_format def test_partial_format(): From ff99273eab81964ff57fa0e879a4dc17c46b497d Mon Sep 17 00:00:00 2001 From: Jerry Chen Date: Fri, 12 Jan 2024 14:02:31 -0800 Subject: [PATCH 17/17] Fix test --- tests/loggers/test_mlflow_logger.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/loggers/test_mlflow_logger.py b/tests/loggers/test_mlflow_logger.py index b3d07eb48d..7e0b788825 100644 --- a/tests/loggers/test_mlflow_logger.py +++ b/tests/loggers/test_mlflow_logger.py @@ -514,7 +514,8 @@ def test_mlflow_logging_works(tmp_path, device): actual_params_list = [param_filepath.stem for param_filepath in param_path.iterdir()] expected_params_list = [ - 'num_cpus_per_node', 'node_name', 'num_nodes', 'rank_zero_seed', 'composer_version', 'composer_commit_hash' + 'num_cpus_per_node', 'node_name', 'num_nodes', 'rank_zero_seed', 'composer_version', 'composer_commit_hash', + 'mlflow_experiment_id', 'mlflow_run_id' ] assert set(expected_params_list) == set(actual_params_list)