From bc50049b165a91d575c3a07768be42bc08097c6f Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Mon, 11 Dec 2023 15:19:47 -0800 Subject: [PATCH 01/10] Gate FSDP param init test on torch 2.1 (#2774) --- tests/trainer/test_fsdp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/trainer/test_fsdp.py b/tests/trainer/test_fsdp.py index 99ec862ecc..25b294fea9 100644 --- a/tests/trainer/test_fsdp.py +++ b/tests/trainer/test_fsdp.py @@ -85,8 +85,8 @@ def test_fsdp_device_initialization(model: ComposerClassifier, mixed_precision: @pytest.mark.parametrize('device', _INIT_DEVICES) @world_size(2) @pytest.mark.gpu -@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'), - reason='FSDP requires PyTorch 1.13 or higher') +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('2.1.0'), + reason='This has only been fixed and tested starting with torch 2.1.0') def test_fsdp_inits_params_once(model: ComposerClassifier, device: str, world_size: int, expected_param_inits: int): resolved_device = device if device == 'mixed': From aad8901f59e6f7e9478bfe70e909140967b92f32 Mon Sep 17 00:00:00 2001 From: coryMosaicML <83666378+coryMosaicML@users.noreply.github.com> Date: Tue, 12 Dec 2023 08:53:50 -0800 Subject: [PATCH 02/10] Parallelize OCI multipart download (#2750) --- .../utils/object_store/oci_object_store.py | 71 ++++++++++++++----- .../object_store/test_oci_object_store.py | 4 +- 2 files changed, 57 insertions(+), 18 deletions(-) diff --git a/composer/utils/object_store/oci_object_store.py b/composer/utils/object_store/oci_object_store.py index bd9aed7709..ce3fd5ea2c 100644 --- a/composer/utils/object_store/oci_object_store.py +++ b/composer/utils/object_store/oci_object_store.py @@ -5,9 +5,11 @@ from __future__ import annotations +import concurrent.futures import os import pathlib import uuid +from tempfile import TemporaryDirectory from typing import Callable, List, Optional, Union from composer.utils.import_helpers import MissingConditionalImportError @@ -116,12 +118,24 @@ def upload_object( except Exception as e: _reraise_oci_errors(self.get_uri(object_name), e) + def _download_part(self, object_name, filename, start_byte, end_byte, part_number): + range_header = f'bytes={start_byte}-{end_byte}' + tmp_part_path = os.path.join(filename, f'part-{part_number}-{uuid.uuid4()}.tmp') + response = self.client.get_object(namespace_name=self.namespace, + bucket_name=self.bucket, + object_name=object_name, + range=range_header) + with open(tmp_part_path, 'wb') as f: + f.write(response.data.content) + return part_number, tmp_part_path + def download_object( self, object_name: str, filename: Union[str, pathlib.Path], overwrite: bool = False, callback: Optional[Callable[[int, int], None]] = None, + num_parts: int = 10, ): del callback if os.path.exists(filename) and not overwrite: @@ -130,24 +144,47 @@ def download_object( dirname = os.path.dirname(filename) if dirname: os.makedirs(dirname, exist_ok=True) - tmp_path = str(filename) + f'.{uuid.uuid4()}.tmp' - try: - response = self.client.get_object( - namespace_name=self.namespace, - bucket_name=self.bucket, - object_name=object_name, - ) - except Exception as e: - _reraise_oci_errors(self.get_uri(object_name), e) - - with open(tmp_path, 'wb') as f: - f.write(response.data.content) - - if overwrite: - os.replace(tmp_path, filename) - else: - os.rename(tmp_path, filename) + # Get the size of the object + head_object_response = self.client.head_object(self.namespace, self.bucket, object_name) + object_size = head_object_response.headers['content-length'] + # Calculate the part sizes + base_part_size, remainder = divmod(int(object_size), num_parts) + part_sizes = [base_part_size] * num_parts + for i in range(remainder): + part_sizes[i] += 1 + part_sizes = [part_size for part_size in part_sizes if part_size > 0] + + with TemporaryDirectory(dir=dirname, prefix=f'{str(filename)}') as temp_dir: + try: + # Download parts in parallel + parts = [] + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [] + start_byte = 0 + for i, part_size in enumerate(part_sizes): + end_byte = start_byte + part_size - 1 + futures.append( + executor.submit(self._download_part, object_name, temp_dir, start_byte, end_byte, i)) + start_byte = end_byte + 1 + + for future in concurrent.futures.as_completed(futures): + parts.append(future.result()) + parts = sorted(parts, key=lambda x: x[0]) + except Exception as e: + _reraise_oci_errors(self.get_uri(object_name), e) + + # Combine parts + tmp_path = os.path.join(temp_dir, f'{str(filename)}-{uuid.uuid4()}.tmp') + with open(tmp_path, 'wb') as outfile: + for i, part_file_name in parts: + with open(part_file_name, 'rb') as infile: + outfile.write(infile.read()) + + if overwrite: + os.replace(tmp_path, filename) + else: + os.rename(tmp_path, filename) def list_objects(self, prefix: Optional[str] = None) -> List[str]: if prefix is None: diff --git a/tests/utils/object_store/test_oci_object_store.py b/tests/utils/object_store/test_oci_object_store.py index 97807b177c..49676cd2e8 100644 --- a/tests/utils/object_store/test_oci_object_store.py +++ b/tests/utils/object_store/test_oci_object_store.py @@ -85,10 +85,12 @@ def test_download_object(test_oci_obj_store, monkeypatch, tmp_path, mock_bucket_ oci_os.download_object(object_name=mock_object_name, filename=file_to_download_to) mock_get_object.assert_called_once_with(namespace_name=oci_os.namespace, bucket_name=mock_bucket_name, - object_name=mock_object_name) + object_name=mock_object_name, + range='bytes=0-0') with open(file_to_download_to, 'rb') as f: actual_content = f.readline() + assert actual_content == file_content elif result == 'file_exists': From f497e6051171ccb97ce821a43937427e70a7bf95 Mon Sep 17 00:00:00 2001 From: Harsh Panchal <68880048+panchalhp-db@users.noreply.github.com> Date: Tue, 12 Dec 2023 09:33:20 -0800 Subject: [PATCH 03/10] [UCVolumes] Add support for list API (#2769) --- .../utils/object_store/uc_object_store.py | 26 +++++-- .../object_store/test_uc_object_store.py | 71 ++++++++++++++++++- 2 files changed, 90 insertions(+), 7 deletions(-) diff --git a/composer/utils/object_store/uc_object_store.py b/composer/utils/object_store/uc_object_store.py index af6f60321e..3063000ee8 100644 --- a/composer/utils/object_store/uc_object_store.py +++ b/composer/utils/object_store/uc_object_store.py @@ -5,6 +5,7 @@ from __future__ import annotations +import json import logging import os import pathlib @@ -46,6 +47,8 @@ class UCObjectStore(ObjectStore): not other Databricks Filesystems. """ + _UC_VOLUME_LIST_API_ENDPOINT = '/api/2.0/fs/list' + def __init__(self, path: str) -> None: try: from databricks.sdk import WorkspaceClient @@ -73,7 +76,7 @@ def validate_path(path: str) -> str: """ path = os.path.normpath(path) if not path.startswith('Volumes'): - raise ValueError('Databricks Unity Catalog Volumes paths should start with "/Volumes".') + raise ValueError('Databricks Unity Catalog Volumes paths should start with "Volumes".') dirs = path.split(os.sep) if len(dirs) < 4: @@ -203,12 +206,27 @@ def get_object_size(self, object_name: str) -> int: def list_objects(self, prefix: Optional[str]) -> List[str]: """List all objects in the object store with the given prefix. + .. note:: + + This function removes the directories from the returned list. + Args: prefix (str): The prefix to search for. Returns: list[str]: A list of object names that match the prefix. """ - # TODO: Implement this function once UC volumes list endpoint is available in the SDK - del prefix # unused - raise NotImplementedError(f'{type(self).__name__}.list_objects is not implemented') + if not prefix: + prefix = self.prefix + + from databricks.sdk.core import DatabricksError + try: + data = json.dumps({'path': self._get_object_path(prefix)}) + # NOTE: This API is in preview and should not be directly used outside of this instance + resp = self.client.api_client.do(method='GET', + path=self._UC_VOLUME_LIST_API_ENDPOINT, + data=data, + headers={'Source': 'mosaicml/composer'}) + return [f['path'] for f in resp.get('files', []) if not f['is_dir']] + except DatabricksError as e: + _wrap_errors(self.get_uri(prefix), e) diff --git a/tests/utils/object_store/test_uc_object_store.py b/tests/utils/object_store/test_uc_object_store.py index 792d7b3914..1f84143186 100644 --- a/tests/utils/object_store/test_uc_object_store.py +++ b/tests/utils/object_store/test_uc_object_store.py @@ -18,8 +18,10 @@ @pytest.fixture def ws_client(monkeypatch): mock_files = MagicMock() + mock_api_client = MagicMock() mock_ws_client = MagicMock() monkeypatch.setattr(mock_ws_client, 'files', mock_files) + monkeypatch.setattr(mock_ws_client, 'api_client', mock_api_client) return mock_ws_client @@ -61,7 +63,10 @@ def test_uc_object_store_without_env(): UCObjectStore(path='Volumes/test-volume/') -def test_uc_object_store_invalid_prefix(): +def test_uc_object_store_invalid_prefix(monkeypatch): + monkeypatch.setenv('DATABRICKS_HOST', 'test-host') + monkeypatch.setenv('DATABRICKS_TOKEN', 'test-token') + with pytest.raises(ValueError): UCObjectStore(path='root/') with pytest.raises(ValueError): @@ -97,8 +102,8 @@ def test_upload_object(ws_client, uc_object_store, tmp_path): with open(file_to_upload, 'wb') as f: f.write(bytes(range(20))) - uc_object_store.upload_object(object_name='train.txt', filename=file_to_upload) - ws_client.files.upload.assert_called_with('/Volumes/catalog/schema/volume/train.txt', ANY) + uc_object_store.upload_object(object_name='path/train.txt', filename=file_to_upload) + ws_client.files.upload.assert_called_with('/Volumes/catalog/schema/volume/path/train.txt', ANY) @pytest.mark.parametrize('result', ['success', 'file_exists', 'overwrite_file', 'not_found', 'error']) @@ -155,6 +160,66 @@ def generate_dummy_file(_): raise NotImplementedError(f'Test for result={result} is not implemented.') +@pytest.mark.parametrize('result', ['success', 'prefix_none', 'not_found', 'error']) +def test_list_objects(ws_client, uc_object_store, result): + expected_files = [ + '/Volumes/catalog/volume/schema/path/to/folder/file1.txt', + '/Volumes/catalog/volume/schema/path/to/folder/file2.txt', + ] + uc_list_api_response = { + 'files': [{ + 'path': '/Volumes/catalog/volume/schema/path/to/folder/file1.txt', + 'is_dir': False + }, { + 'path': '/Volumes/catalog/volume/schema/path/to/folder/file2.txt', + 'is_dir': False + }, { + 'path': '/Volumes/catalog/volume/schema/path/to/folder/samples/', + 'is_dir': True + }] + } + + prefix = 'Volumes/catalog/schema/volume/path/to/folder' + + if result == 'success': + ws_client.api_client.do.return_value = uc_list_api_response + actual_files = uc_object_store.list_objects(prefix=prefix) + + assert actual_files == expected_files + ws_client.api_client.do.assert_called_once_with( + method='GET', + path=uc_object_store._UC_VOLUME_LIST_API_ENDPOINT, + data='{"path": "/Volumes/catalog/schema/volume/path/to/folder"}', + headers={'Source': 'mosaicml/composer'}) + + elif result == 'prefix_none': + ws_client.api_client.do.return_value = uc_list_api_response + actual_files = uc_object_store.list_objects(prefix=None) + + assert actual_files == expected_files + ws_client.api_client.do.assert_called_once_with(method='GET', + path=uc_object_store._UC_VOLUME_LIST_API_ENDPOINT, + data='{"path": "/Volumes/catalog/schema/volume/."}', + headers={'Source': 'mosaicml/composer'}) + + elif result == 'not_found': + db_core = pytest.importorskip('databricks.sdk.core', reason='requires databricks') + ws_client.api_client.do.side_effect = db_core.DatabricksError( + 'The path you provided does not exist or is not a directory.', error_code='NOT_FOUND') + with pytest.raises(FileNotFoundError): + uc_object_store.list_objects(prefix=prefix) + + elif result == 'error': + db_core = pytest.importorskip('databricks.sdk.core', reason='requires databricks') + ws_client.api_client.do.side_effect = db_core.DatabricksError + + with pytest.raises(ObjectStoreTransientError): + uc_object_store.list_objects(prefix=prefix) + + else: + raise NotImplementedError(f'Test for result={result} is not implemented.') + + def test_uc_object_store_with_remote_ud(uc_object_store): uri = 'dbfs:/Volumes/path/to/my/folder/' rud = RemoteUploaderDownloader(bucket_uri=uri, backend_kwargs={'path': 'Volumes/catalog/schema/volume/path'}) From a7cad7c221ce8ad9697bde50db0b3f37f8b8025e Mon Sep 17 00:00:00 2001 From: Cheng Li Date: Tue, 12 Dec 2023 12:28:24 -0800 Subject: [PATCH 04/10] Add the memory timeline profiling support through the PyTorch profiler. (#2771) * v1 * fix issues * add logs * change names * comment * add device * uncomment original trace * add custome plot * fix pyright * Update composer/profiler/torch_profiler.py Co-authored-by: Charles Tang * address comments * fix code check * fix formatting * address comments * add unit test * fix check * fix check * fix check * fix check * fix print * add test comment * add test comment --------- Co-authored-by: Mihir Patel Co-authored-by: Charles Tang --- composer/profiler/profiler.py | 11 ++ composer/profiler/torch_profiler.py | 145 ++++++++++++++++++++----- composer/profiler/utils.py | 97 +++++++++++++++++ tests/profiler/test_memory_timeline.py | 52 +++++++++ 4 files changed, 279 insertions(+), 26 deletions(-) create mode 100644 composer/profiler/utils.py create mode 100644 tests/profiler/test_memory_timeline.py diff --git a/composer/profiler/profiler.py b/composer/profiler/profiler.py index 294dfb8471..876282dd99 100644 --- a/composer/profiler/profiler.py +++ b/composer/profiler/profiler.py @@ -76,6 +76,9 @@ def new_profiler_init(self, dummy_ellipsis=None, **kwargs): torch_prof_filename (str, optional): See :class:`~composer.profiler.torch_profiler.TorchProfiler`. torch_prof_remote_file_name (str, optional): See :class:`~composer.profiler.torch_profiler.TorchProfiler`. Additionally supports full object store paths e.g: s3://bucket/path/to/file. + torch_prof_memory_filename (str, optional): See :class:`~composer.profiler.torch_profiler.TorchProfiler`. + torch_prof_memory_remote_file_name (str, optional): See :class:`~composer.profiler.torch_profiler.TorchProfiler`. + Additionally supports full object store paths e.g: s3://bucket/path/to/file. torch_prof_overwrite (bool, optional): See :class:`~composer.profiler.torch_profiler.TorchProfiler`. torch_prof_use_gzip (bool, optional): See :class:`~composer.profiler.torch_profiler.TorchProfiler`. torch_prof_record_shapes (bool, optional): See :class:`~composer.profiler.torch_profiler.TorchProfiler`. @@ -97,6 +100,9 @@ def __init__( torch_prof_folder: str = '{run_name}/torch_traces', torch_prof_filename: str = 'rank{rank}.{batch}.pt.trace.json', torch_prof_remote_file_name: Optional[str] = '{run_name}/torch_traces/rank{rank}.{batch}.pt.trace.json', + torch_prof_memory_filename: str = 'rank{rank}.{batch}.pt.memory_trace.html', + torch_prof_memory_remote_file_name: Optional[ + str] = '{run_name}/torch_memory_traces/rank{rank}.{batch}.pt.memory_trace.html', torch_prof_overwrite: bool = False, torch_prof_use_gzip: bool = False, torch_prof_record_shapes: bool = False, @@ -116,6 +122,9 @@ def __init__( if torch_prof_remote_file_name: self.remote_filenames.append(torch_prof_remote_file_name) _, _, torch_prof_remote_file_name = parse_uri(torch_prof_remote_file_name) + if torch_prof_memory_remote_file_name: + self.remote_filenames.append(torch_prof_memory_remote_file_name) + _, _, torch_prof_memory_remote_file_name = parse_uri(torch_prof_memory_remote_file_name) for handler in self._trace_handlers: if isinstance(handler, JSONTraceHandler): if handler.remote_file_name: @@ -139,6 +148,8 @@ def __init__( TorchProfiler(filename=torch_prof_filename, folder=torch_prof_folder, remote_file_name=torch_prof_remote_file_name, + memory_filename=torch_prof_memory_filename, + memory_remote_file_name=torch_prof_memory_remote_file_name, num_traces_to_keep=torch_prof_num_traces_to_keep, overwrite=torch_prof_overwrite, record_shapes=torch_prof_record_shapes, diff --git a/composer/profiler/torch_profiler.py b/composer/profiler/torch_profiler.py index dc33d829aa..cfd4c0a48b 100644 --- a/composer/profiler/torch_profiler.py +++ b/composer/profiler/torch_profiler.py @@ -11,7 +11,9 @@ import textwrap from typing import TYPE_CHECKING, Optional, OrderedDict +import torch.cuda import torch.profiler +from packaging import version from torch.profiler.profiler import ProfilerAction as TorchProfilerAction from composer.core.callback import Callback @@ -92,9 +94,9 @@ class TorchProfiler(Callback): # noqa: D101 Each rank (process) will save traces to:: - awesome-training-run/torch_traces/ep1-ba42-rank0.json - awesome-training-run/torch_traces/ep1-ba42-rank1.json - awesome-training-run/torch_traces/ep1-ba42-rank2.json + awesome-training-run/torch_traces/ep1-ba42-rank0.pt.trace.json + awesome-training-run/torch_traces/ep1-ba42-rank1.pt.trace.json + awesome-training-run/torch_traces/ep1-ba42-rank2.pt.trace.json ... remote_file_name (str, optional): Format string for a Torch Profiler trace file's remote file name. @@ -107,6 +109,43 @@ class TorchProfiler(Callback): # noqa: D101 Leading slashes (``'/'``) will be stripped. + To disable uploading trace files, set this parameter to ``None``. + memory_filename (str, optional): A format string describing how to name Torch Profiler memory trace files. + Defaults to ``'rank{{rank}}.{{batch}}.pt.trace.memory.html'``. + + At the end of each batch where :meth:`~composer.profiler.Profiler.get_action` returns + :attr:`~composer.profiler._profiler_action.ProfilerAction.ACTIVE_AND_SAVE`, trace files are saved + approximately to ``{{folder.format(...)}}/{{memory_filename.format(...)}}``. + + The following format variables are available: + + {textwrap.indent(FORMAT_NAME_WITH_DIST_AND_TIME_TABLE, prefix=' ')} + + Consider the following scenario, where: + + * The :attr:`~.State.run_name` is ``'awesome-training-run'``. + * The default ``trace_folder='{{run_name}}/torch_traces'`` is used. + * The default ``name='rank{{rank}}.{{batch}}.pt.trace.memory.html'`` is used. + * The current epoch count is ``1``. + * The current batch count is ``42``. + + Each rank (process) will save traces to:: + + awesome-training-run/torch_traces/ep1-ba42-rank0.pt.trace.memory.html + awesome-training-run/torch_traces/ep1-ba42-rank1.pt.trace.memory.html + awesome-training-run/torch_traces/ep1-ba42-rank2.pt.trace.memory.html + ... + + memory_remote_file_name (str, optional): Format string for a Torch Profiler memory trace file's remote file name. + Defaults to ``'{{run_name}}/torch_traces/rank{{rank}}.{{batch}}.pt.trace.memory.json'``. + + Whenever a trace file is saved, it is also uploaded as a file according to this format string. + The same format variables as for ``filename`` are available. + + .. seealso:: :doc:`Uploading Files` for notes for file uploading. + + Leading slashes (``'/'``) will be stripped. + To disable uploading trace files, set this parameter to ``None``. overwrite (bool, optional): Whether to override existing Torch Profiler traces. Defaults to False. @@ -146,7 +185,10 @@ def __init__( folder: str = '{run_name}/torch_traces', filename: str = 'rank{rank}.{batch}.pt.trace.json', remote_file_name: Optional[str] = '{run_name}/torch_traces/rank{rank}.{batch}.pt.trace.json', - *, + memory_filename: Optional[str] = 'rank{rank}.{batch}.pt.trace.memory.html', + memory_remote_file_name: Optional[ + str] = '{run_name}/torch_memory_traces/rank{rank}.{batch}.pt.trace.memory.html', + memory_custom_plot: bool = True, overwrite: bool = False, use_gzip: bool = False, record_shapes: bool = False, @@ -157,12 +199,26 @@ def __init__( ) -> None: self.overwrite = overwrite self.folder = folder - if use_gzip and not filename.endswith('.gz'): - filename += '.gz' + + if use_gzip: + if not filename.endswith('.gz'): + filename += '.gz' self.filename = filename - if use_gzip and remote_file_name is not None and not remote_file_name.endswith('.gz'): - remote_file_name += '.gz' + + if use_gzip: + if remote_file_name is not None and not remote_file_name.endswith('.gz'): + remote_file_name += '.gz' self.remote_file_name = remote_file_name + + if memory_filename is not None: + assert memory_filename.endswith('.html'), f'memory_filename must end with .html, got {memory_filename}' + self.memory_filename = memory_filename + + if memory_remote_file_name is not None: + assert memory_remote_file_name.endswith( + '.html'), f'memory_remote_file_name must end with .html, got {memory_remote_file_name}' + self.memory_remote_file_name = memory_remote_file_name + self.record_shapes = record_shapes self.profile_memory = profile_memory self.with_stack = with_stack @@ -170,6 +226,7 @@ def __init__( self.num_traces_to_keep = num_traces_to_keep self.saved_traces = OrderedDict() self.profiler: Optional[torch.profiler.profile] = None + self.memory_custom_plot = memory_custom_plot def init(self, state: State, logger: Logger) -> None: if state.profiler is None: @@ -203,27 +260,63 @@ def handler_fn(prof: torch.profiler.profiler.profile): timestamp = state.timestamp - trace_file_name = os.path.join( - folder_name, - format_name_with_dist_and_time(self.filename, run_name=state.run_name, timestamp=timestamp), - ) - trace_file_dirname = os.path.dirname(trace_file_name) - if trace_file_dirname: - os.makedirs(trace_file_dirname, exist_ok=True) - prof.export_chrome_trace(trace_file_name) - state.profiler.record_chrome_json_trace_file(trace_file_name) - if self.remote_file_name is not None: - trace_remote_file_name = format_name_with_dist_and_time(self.remote_file_name, - run_name=state.run_name, - timestamp=timestamp) - trace_remote_file_name = trace_remote_file_name.lstrip('/') - logger.upload_file(remote_file_name=trace_remote_file_name, - file_path=trace_file_name, - overwrite=self.overwrite) + log.info(f'PyTorch Chrome trace profiler enabled: {self.filename if self.filename else False}') + if self.filename is not None: + trace_file_name = os.path.join( + folder_name, + format_name_with_dist_and_time(self.filename, run_name=state.run_name, timestamp=timestamp), + ) + trace_file_dirname = os.path.dirname(trace_file_name) + if trace_file_dirname: + os.makedirs(trace_file_dirname, exist_ok=True) + prof.export_chrome_trace(trace_file_name) + state.profiler.record_chrome_json_trace_file(trace_file_name) + if self.remote_file_name is not None: + trace_remote_file_name = format_name_with_dist_and_time(self.remote_file_name, + run_name=state.run_name, + timestamp=timestamp) + trace_remote_file_name = trace_remote_file_name.lstrip('/') + logger.upload_file(remote_file_name=trace_remote_file_name, + file_path=trace_file_name, + overwrite=self.overwrite) + + log.info( + f'PyTorch memory timeline profiler enabled: {self.memory_filename if self.memory_filename else False}') + if self.memory_filename is not None: + if version.parse(torch.__version__) > version.parse('2.1.0.dev'): # type: ignore + # memory timeline profiling is only supported in torch v2.1.0-rc1 or higher + memory_trace_file_name = os.path.join( + folder_name, + format_name_with_dist_and_time(self.memory_filename, + run_name=state.run_name, + timestamp=timestamp), + ) + log.debug(f'Saving memory trace to {memory_trace_file_name}') + memory_trace_file_dirname = os.path.dirname(memory_trace_file_name) + if memory_trace_file_dirname: + os.makedirs(memory_trace_file_dirname, exist_ok=True) + if self.memory_custom_plot: + from composer.profiler.utils import export_memory_timeline_html + export_memory_timeline_html(prof, memory_trace_file_name, + torch.cuda.current_device()) # type: ignore + else: + prof.export_memory_timeline(memory_trace_file_name, torch.cuda.current_device()) # type: ignore + log.debug(f'Uploaded memory trace to {self.memory_remote_file_name}') + if self.memory_remote_file_name is not None: + memory_trace_remote_file_name = format_name_with_dist_and_time(self.memory_remote_file_name, + run_name=state.run_name, + timestamp=timestamp) + memory_trace_remote_file_name = memory_trace_remote_file_name.lstrip('/') + log.debug( + f'Uploading memory trace to {memory_trace_remote_file_name} from {memory_trace_file_name}') + logger.upload_file(remote_file_name=memory_trace_remote_file_name, + file_path=memory_trace_file_name, + overwrite=self.overwrite) + else: + log.warning('Memory timeline is supported after PyTorch 2.1.0. Skipping memory trace.') if self.num_traces_to_keep >= 0: while len(self.saved_traces) > self.num_traces_to_keep: - # self.saved_traces is an ordered dict, so the zeroth item will be the oldest checkpoint timestamp, filepaths = next(iter(self.saved_traces.items())) if dist.get_global_rank() < len(filepaths): diff --git a/composer/profiler/utils.py b/composer/profiler/utils.py new file mode 100644 index 0000000000..b4df8396a7 --- /dev/null +++ b/composer/profiler/utils.py @@ -0,0 +1,97 @@ +# Copyright 2022 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +"""Utility functions for torch profiler.""" + +import importlib.util +import logging +from base64 import b64encode +from os import remove +from tempfile import NamedTemporaryFile +from typing import Any, Optional, Union + +import numpy as np +import torch +import torch.cuda +from packaging import version +from torch.profiler.profiler import profile as TorchProfile + +log = logging.getLogger(__name__) + + +def export_memory_timeline_html(prof: TorchProfile, + path: str, + device: Optional[str] = None, + figsize=(20, 12), + title=None, + yxis_step_size: float = 1.0, + return_fig: bool = False) -> Optional[Union[None, Any]]: + """Exports a memory timeline to an HTML file. Similar to the PyTorch plotting function, but with adjusted axis tickers and grids.""" + if version.parse(torch.__version__) <= version.parse('2.1.0.dev'): + log.warning('export_memory_timeline_html failed because memory timeline is supported after PyTorch 2.1.0.') + return + + from torch.profiler._memory_profiler import _CATEGORY_TO_COLORS, _CATEGORY_TO_INDEX, MemoryProfileTimeline + + # Default to device 0, if unset. Fallback on cpu. + if device is None and prof.use_device and prof.use_device != 'cuda': + device = prof.use_device + ':0' + + if device is None: + device = 'cuda:0' if torch.cuda.is_available() else 'cpu' + + # Construct the memory timeline plot data + mem_tl = MemoryProfileTimeline(prof._memory_profile()) + + # Check if user has matplotlib installed, return gracefully if not. + matplotlib_spec = importlib.util.find_spec('matplotlib') + if matplotlib_spec is None: + log.warning('export_memory_timeline_html failed because matplotlib was not found.') + return + import matplotlib.pyplot as plt + + mt = mem_tl._coalesce_timeline(device) + times, sizes = np.array(mt[0]), np.array(mt[1]) + stacked = np.cumsum(sizes, axis=1) / 1024**3 + max_memory_allocated = torch.cuda.max_memory_allocated() + max_memory_reserved = torch.cuda.max_memory_reserved() + + # Plot memory timeline as stacked data + fig = plt.figure(figsize=figsize, dpi=80) + axes = fig.gca() + for category, color in _CATEGORY_TO_COLORS.items(): + i = _CATEGORY_TO_INDEX[category] + axes.fill_between(times / 1e3, stacked[:, i], stacked[:, i + 1], color=color, alpha=0.7) + fig.legend(['Unknown' if i is None else i.name for i in _CATEGORY_TO_COLORS]) + axes.set_xlabel('Time (us)') + axes.set_ylabel('Memory (GB)') + _, end = axes.get_ylim() + axes.grid(True) + axes.set_yticks(np.arange(0, end, yxis_step_size)) + title = '\n\n'.join(([title] if title else []) + [ + f'Max memory allocated: {max_memory_allocated/(10**9):.2f} GB \n' + f'Max memory reserved: {max_memory_reserved/(10**9):.2f} GB' + ]) + axes.set_title(title) + + if return_fig: + return fig + + # Embed the memory timeline image into the HTML file + tmpfile = NamedTemporaryFile('wb', suffix='.png', delete=False) + tmpfile.close() + fig.savefig(tmpfile.name, format='png') + + with open(tmpfile.name, 'rb') as tmp: + encoded = b64encode(tmp.read()).decode('utf-8') + html = f""" + GPU Memory Timeline HTML + + + + """ + + with open(path, 'w') as f: + f.write(html) + log.debug('Memory timeline exported to', path, '.') + remove(tmpfile.name) diff --git a/tests/profiler/test_memory_timeline.py b/tests/profiler/test_memory_timeline.py new file mode 100644 index 0000000000..c8e685df0b --- /dev/null +++ b/tests/profiler/test_memory_timeline.py @@ -0,0 +1,52 @@ +# Copyright 2022 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +import os +import pathlib + +import pytest +import torch +from packaging import version + +from composer.profiler.utils import export_memory_timeline_html + + +@pytest.mark.gpu +def test_memory_timeline(tmp_path: pathlib.Path) -> None: + if version.parse(torch.__version__) <= version.parse('2.1.0.dev'): + # memory timeline is supported after PyTorch 2.1.0. + return + import torch.profiler._memory_profiler as _memory_profiler + + model = torch.nn.Sequential( + torch.nn.Linear(1024, 1024, bias=True), + torch.nn.ReLU(), + torch.nn.Linear(1024, 1024, bias=False), + torch.nn.Softmax(dim=1), + ).to('cuda') + optimizer = torch.optim.Adam(model.parameters(), lr=0.1) + + x = torch.ones((1024, 1024), device='cuda') + targets = torch.ones((1024, 1024), device='cuda') + with torch.profiler.profile(record_shapes=True, with_stack=True, profile_memory=True) as prof: + y = model(x) + loss = torch.nn.functional.mse_loss(y, targets) + loss.backward() + optimizer.step() + optimizer.zero_grad() + + memory_profile = prof._memory_profile() + timeline = memory_profile.timeline + + # this checks the default memory timeline event value (t == -1) for preexisting tensors + assert all((t == -1) if action == _memory_profiler.Action.PREEXISTING else (t > 0) for t, action, _, _ in timeline) + + fig = export_memory_timeline_html( + prof, + os.path.join(tmp_path, 'test_memory_timeline.html'), + yxis_step_size=0.01, + return_fig=True, + ) + assert fig is not None, 'export_memory_timeline_html should return a figure when return_fig=True' + _, end = fig.gca().get_ylim() + assert round(end, 2) == 0.06 From db3d18798f19beadf2891295a884902fe18271bc Mon Sep 17 00:00:00 2001 From: Cheng Li Date: Wed, 13 Dec 2023 11:48:27 -0800 Subject: [PATCH 05/10] Improve torch memory profiling arguments processing (#2777) * improve torch profile args * improve torch profile args * change default torch_prof_memory_filename * add memory profiling arg test * fix check * fix check * fix check * fix check * fix check * fix check --- composer/profiler/marker.py | 10 +-- composer/profiler/profiler.py | 17 +++- composer/profiler/torch_profiler.py | 11 +-- .../performance_tutorials/profiling.md | 49 +++++------ examples/profiler_demo.py | 1 + tests/callbacks/test_callbacks.py | 16 +++- tests/profiler/test_json_trace_handler.py | 1 + tests/profiler/test_memory_timeline.py | 52 ------------ tests/profiler/test_profiler.py | 84 ++++++++++++++++++- 9 files changed, 142 insertions(+), 99 deletions(-) delete mode 100644 tests/profiler/test_memory_timeline.py diff --git a/composer/profiler/marker.py b/composer/profiler/marker.py index a87bfc02ae..26dcc388ed 100644 --- a/composer/profiler/marker.py +++ b/composer/profiler/marker.py @@ -40,7 +40,7 @@ class Marker: .. testsetup:: from composer.profiler import Profiler, cyclic_schedule - profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[]) + profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[], torch_prof_memory_filename=None) profiler.bind_to_state(state) .. doctest:: @@ -57,7 +57,7 @@ class Marker: .. testsetup:: from composer.profiler import Profiler, cyclic_schedule - profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[]) + profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[], torch_prof_memory_filename=None) profiler.bind_to_state(state) .. doctest:: @@ -124,7 +124,7 @@ def start(self) -> None: .. testsetup:: from composer.profiler import Profiler, cyclic_schedule - profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[]) + profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[], torch_prof_memory_filename=None) profiler.bind_to_state(state) .. doctest:: @@ -187,7 +187,7 @@ def instant(self) -> None: .. testsetup:: from composer.profiler import Profiler, cyclic_schedule - profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[]) + profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[], torch_prof_memory_filename=None) profiler.bind_to_state(state) .. doctest:: @@ -213,7 +213,7 @@ def counter(self, values: Dict[str, Union[float, int]]) -> None: .. testsetup:: from composer.profiler import Profiler, cyclic_schedule - profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[]) + profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[], torch_prof_memory_filename=None) profiler.bind_to_state(state) .. doctest:: diff --git a/composer/profiler/profiler.py b/composer/profiler/profiler.py index 876282dd99..4e5a6bbbb2 100644 --- a/composer/profiler/profiler.py +++ b/composer/profiler/profiler.py @@ -45,6 +45,7 @@ class Profiler: def new_profiler_init(self, dummy_ellipsis=None, **kwargs): if 'trace_handlers' not in kwargs: kwargs['trace_handlers'] = [] + kwargs['torch_prof_memory_filename'] = None original_profiler_init(self, **kwargs) Profiler.__init__ = new_profiler_init @@ -62,6 +63,7 @@ def new_profiler_init(self, dummy_ellipsis=None, **kwargs): active=4, repeat=1, ), + torch_prof_memory_filename=None, ) trace_handlers (TraceHandler | Sequence[TraceHandler]): Trace handlers which record and @@ -100,7 +102,7 @@ def __init__( torch_prof_folder: str = '{run_name}/torch_traces', torch_prof_filename: str = 'rank{rank}.{batch}.pt.trace.json', torch_prof_remote_file_name: Optional[str] = '{run_name}/torch_traces/rank{rank}.{batch}.pt.trace.json', - torch_prof_memory_filename: str = 'rank{rank}.{batch}.pt.memory_trace.html', + torch_prof_memory_filename: Optional[str] = 'rank{rank}.{batch}.pt.memory_trace.html', torch_prof_memory_remote_file_name: Optional[ str] = '{run_name}/torch_memory_traces/rank{rank}.{batch}.pt.memory_trace.html', torch_prof_overwrite: bool = False, @@ -143,6 +145,17 @@ def __init__( profile_net=sys_prof_net, stats_thread_interval_seconds=sys_prof_stats_thread_interval_seconds)) + if torch_prof_memory_filename is not None: + if not (torch_prof_with_stack and torch_prof_record_shapes and torch_prof_profile_memory): + raise ValueError( + f'torch_prof_memory_filename is set. Generating the memory timeline graph requires all the three flags torch_prof_with_stack, torch_prof_record_shapes, and torch_prof_profile_memory to be true. Got torch_prof_with_stack={torch_prof_with_stack}, torch_prof_record_shapes={torch_prof_record_shapes}, torch_prof_profile_memory={torch_prof_profile_memory}' + ) + log.info( + f'Memory profiling is enabled and uses {torch_prof_memory_filename} as the filename to generate the memory timeline graph. To disable the memory timeline graph generation, explicitly set torch_prof_memory_filename to None.' + ) + else: + log.info(f'torch_prof_memory_filename is explicitly set to None. Memory timeline will not be be generated.') + if torch_prof_record_shapes or torch_prof_profile_memory or torch_prof_with_stack or torch_prof_with_flops: self._callbacks.append( TorchProfiler(filename=torch_prof_filename, @@ -230,7 +243,7 @@ def marker( from composer.profiler import Profiler, cyclic_schedule - profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[]) + profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[], torch_prof_memory_filename=None) profiler.bind_to_state(state) state.profiler = profiler diff --git a/composer/profiler/torch_profiler.py b/composer/profiler/torch_profiler.py index cfd4c0a48b..ef3fad2554 100644 --- a/composer/profiler/torch_profiler.py +++ b/composer/profiler/torch_profiler.py @@ -188,7 +188,6 @@ def __init__( memory_filename: Optional[str] = 'rank{rank}.{batch}.pt.trace.memory.html', memory_remote_file_name: Optional[ str] = '{run_name}/torch_memory_traces/rank{rank}.{batch}.pt.trace.memory.html', - memory_custom_plot: bool = True, overwrite: bool = False, use_gzip: bool = False, record_shapes: bool = False, @@ -226,7 +225,6 @@ def __init__( self.num_traces_to_keep = num_traces_to_keep self.saved_traces = OrderedDict() self.profiler: Optional[torch.profiler.profile] = None - self.memory_custom_plot = memory_custom_plot def init(self, state: State, logger: Logger) -> None: if state.profiler is None: @@ -295,12 +293,9 @@ def handler_fn(prof: torch.profiler.profiler.profile): memory_trace_file_dirname = os.path.dirname(memory_trace_file_name) if memory_trace_file_dirname: os.makedirs(memory_trace_file_dirname, exist_ok=True) - if self.memory_custom_plot: - from composer.profiler.utils import export_memory_timeline_html - export_memory_timeline_html(prof, memory_trace_file_name, - torch.cuda.current_device()) # type: ignore - else: - prof.export_memory_timeline(memory_trace_file_name, torch.cuda.current_device()) # type: ignore + from composer.profiler.utils import export_memory_timeline_html + export_memory_timeline_html(prof, memory_trace_file_name, + torch.cuda.current_device()) # type: ignore log.debug(f'Uploaded memory trace to {self.memory_remote_file_name}') if self.memory_remote_file_name is not None: memory_trace_remote_file_name = format_name_with_dist_and_time(self.memory_remote_file_name, diff --git a/docs/source/trainer/performance_tutorials/profiling.md b/docs/source/trainer/performance_tutorials/profiling.md index 0bc930cd47..6c87e2fea8 100644 --- a/docs/source/trainer/performance_tutorials/profiling.md +++ b/docs/source/trainer/performance_tutorials/profiling.md @@ -83,6 +83,7 @@ Note, we support both local and object store paths for the composer profiler, e. profiler = Profiler( trace_handlers=[JSONTraceHandler(remote_file_name='oci://your-bucket/composer_profiler/')], torch_remote_filename='s3://your-bucket/torch_profiler/', + torch_prof_memory_filename=None, ... ) ``` @@ -119,30 +120,30 @@ For example, let’s assume the profiling options are set as follows: Given the configuration above, profiling will be performed as follows: -| Epoch | Batch | Profiler State | Profiler Action | -| --- | --- | --- | --- | -| 0 | 0 | skip_first | Do not record | -| | 1 | wait | Do not record | -| | 2 | warmup | Record, Torch Profiler does not record | -| | 3 | active | Record | -| | 4 | active | Record | -| | 5 | wait | Do not record | -| | 6 | warmup | Record, Torch Profiler does not record | -| | 7 | active | Record | -| | 8 | active | Record | -| | 9 | disabled | Do not record | -| | ... | | | -| 1 | 0 | skip_first | Do not record | -| | 1 | wait | Do not record | -| | 2 | warmup | Record, Torch Profiler does not record | -| | 3 | active | Record | -| | 4 | active | Record | -| | 5 | wait | Do not record | -| | 6 | warmup | Record, Torch Profiler does not record | -| | 7 | active | Record | -| | 8 | active | Record | -| | 9 | disabled | Do not record | -| | ... | | | +| Epoch | Batch | Profiler State | Profiler Action | +| ----- | ----- | -------------- | -------------------------------------- | +| 0 | 0 | skip_first | Do not record | +| | 1 | wait | Do not record | +| | 2 | warmup | Record, Torch Profiler does not record | +| | 3 | active | Record | +| | 4 | active | Record | +| | 5 | wait | Do not record | +| | 6 | warmup | Record, Torch Profiler does not record | +| | 7 | active | Record | +| | 8 | active | Record | +| | 9 | disabled | Do not record | +| | ... | | | +| 1 | 0 | skip_first | Do not record | +| | 1 | wait | Do not record | +| | 2 | warmup | Record, Torch Profiler does not record | +| | 3 | active | Record | +| | 4 | active | Record | +| | 5 | wait | Do not record | +| | 6 | warmup | Record, Torch Profiler does not record | +| | 7 | active | Record | +| | 8 | active | Record | +| | 9 | disabled | Do not record | +| | ... | | | As we can see above, the profiler skips the first batch of each epoch and is in the wait state during the following batch, after which the profiler performs warms up in the next batch and actively records trace data for the diff --git a/examples/profiler_demo.py b/examples/profiler_demo.py index c166efa315..f06fa17f06 100644 --- a/examples/profiler_demo.py +++ b/examples/profiler_demo.py @@ -63,6 +63,7 @@ ), torch_prof_folder=torch_trace_dir, torch_prof_overwrite=True, + torch_prof_memory_filename=None, )) # [trainer-end] diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 0e6b137369..695be08c55 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -53,7 +53,9 @@ def test_multiple_fit_start_and_end(self, cb_cls: Type[Callback], dummy_state: S """Test that callbacks do not crash when Event.FIT_START and Event.FIT_END is called multiple times.""" cb_kwargs = get_cb_kwargs(cb_cls) dummy_state.callbacks.append(cb_cls(**cb_kwargs)) - dummy_state.profiler = Profiler(schedule=lambda _: ProfilerAction.SKIP, trace_handlers=[]) + dummy_state.profiler = Profiler(schedule=lambda _: ProfilerAction.SKIP, + trace_handlers=[], + torch_prof_memory_filename=None) dummy_state.profiler.bind_to_state(dummy_state) logger = Logger(dummy_state) @@ -71,7 +73,9 @@ def test_idempotent_close(self, cb_cls: Type[Callback], dummy_state: State): """Test that callbacks do not crash when .close() and .post_close() are called multiple times.""" cb_kwargs = get_cb_kwargs(cb_cls) dummy_state.callbacks.append(cb_cls(**cb_kwargs)) - dummy_state.profiler = Profiler(schedule=lambda _: ProfilerAction.SKIP, trace_handlers=[]) + dummy_state.profiler = Profiler(schedule=lambda _: ProfilerAction.SKIP, + trace_handlers=[], + torch_prof_memory_filename=None) dummy_state.profiler.bind_to_state(dummy_state) logger = Logger(dummy_state) @@ -85,7 +89,9 @@ def test_multiple_init_and_close(self, cb_cls: Type[Callback], dummy_state: Stat """Test that callbacks do not crash when INIT/.close()/.post_close() are called multiple times in that order.""" cb_kwargs = get_cb_kwargs(cb_cls) dummy_state.callbacks.append(cb_cls(**cb_kwargs)) - dummy_state.profiler = Profiler(schedule=lambda _: ProfilerAction.SKIP, trace_handlers=[]) + dummy_state.profiler = Profiler(schedule=lambda _: ProfilerAction.SKIP, + trace_handlers=[], + torch_prof_memory_filename=None) dummy_state.profiler.bind_to_state(dummy_state) logger = Logger(dummy_state) @@ -125,7 +131,9 @@ def _get_trainer(self, cb: Callback, device_train_microbatch_size: int): device_train_microbatch_size=device_train_microbatch_size, callbacks=callbacks, loggers=loggers, - profiler=Profiler(schedule=lambda _: ProfilerAction.SKIP, trace_handlers=[]), + profiler=Profiler(schedule=lambda _: ProfilerAction.SKIP, + trace_handlers=[], + torch_prof_memory_filename=None), ) def test_trains(self, cb_cls: Type[Callback], device_train_microbatch_size: int, _remote: bool): diff --git a/tests/profiler/test_json_trace_handler.py b/tests/profiler/test_json_trace_handler.py index c09ae00fe6..1d13aed18e 100644 --- a/tests/profiler/test_json_trace_handler.py +++ b/tests/profiler/test_json_trace_handler.py @@ -34,6 +34,7 @@ def test_json_trace_profiler_handler(tmp_path: pathlib.Path): torch_prof_profile_memory=False, torch_prof_with_stack=False, torch_prof_with_flops=False, + torch_prof_memory_filename=None, ) trainer = Trainer( model=SimpleModel(), diff --git a/tests/profiler/test_memory_timeline.py b/tests/profiler/test_memory_timeline.py deleted file mode 100644 index c8e685df0b..0000000000 --- a/tests/profiler/test_memory_timeline.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2022 MosaicML Composer authors -# SPDX-License-Identifier: Apache-2.0 - -import os -import pathlib - -import pytest -import torch -from packaging import version - -from composer.profiler.utils import export_memory_timeline_html - - -@pytest.mark.gpu -def test_memory_timeline(tmp_path: pathlib.Path) -> None: - if version.parse(torch.__version__) <= version.parse('2.1.0.dev'): - # memory timeline is supported after PyTorch 2.1.0. - return - import torch.profiler._memory_profiler as _memory_profiler - - model = torch.nn.Sequential( - torch.nn.Linear(1024, 1024, bias=True), - torch.nn.ReLU(), - torch.nn.Linear(1024, 1024, bias=False), - torch.nn.Softmax(dim=1), - ).to('cuda') - optimizer = torch.optim.Adam(model.parameters(), lr=0.1) - - x = torch.ones((1024, 1024), device='cuda') - targets = torch.ones((1024, 1024), device='cuda') - with torch.profiler.profile(record_shapes=True, with_stack=True, profile_memory=True) as prof: - y = model(x) - loss = torch.nn.functional.mse_loss(y, targets) - loss.backward() - optimizer.step() - optimizer.zero_grad() - - memory_profile = prof._memory_profile() - timeline = memory_profile.timeline - - # this checks the default memory timeline event value (t == -1) for preexisting tensors - assert all((t == -1) if action == _memory_profiler.Action.PREEXISTING else (t > 0) for t, action, _, _ in timeline) - - fig = export_memory_timeline_html( - prof, - os.path.join(tmp_path, 'test_memory_timeline.html'), - yxis_step_size=0.01, - return_fig=True, - ) - assert fig is not None, 'export_memory_timeline_html should return a figure when return_fig=True' - _, end = fig.gca().get_ylim() - assert round(end, 2) == 0.06 diff --git a/tests/profiler/test_profiler.py b/tests/profiler/test_profiler.py index 5d960f0dd1..2ae9383d79 100644 --- a/tests/profiler/test_profiler.py +++ b/tests/profiler/test_profiler.py @@ -1,12 +1,18 @@ # Copyright 2022 MosaicML Composer authors # SPDX-License-Identifier: Apache-2.0 +import os +import pathlib +from typing import Union from unittest.mock import MagicMock import pytest +import torch +from packaging import version from composer.core import State from composer.profiler import Profiler, ProfilerAction, SystemProfiler, TorchProfiler, cyclic_schedule +from composer.profiler.utils import export_memory_timeline_html @pytest.mark.parametrize('repeat', [1, 0]) @@ -50,6 +56,7 @@ def test_profiler_init(minimal_state: State): trace_handlers=[mock_trace_handler], schedule=cyclic_schedule(), torch_prof_profile_memory=True, + torch_prof_memory_filename=None, sys_prof_cpu=True, ) profiler.bind_to_state(minimal_state) @@ -59,10 +66,9 @@ def test_profiler_init(minimal_state: State): def test_marker(dummy_state: State): mock_trace_handler = MagicMock() - profiler = Profiler( - trace_handlers=[mock_trace_handler], - schedule=cyclic_schedule(), - ) + profiler = Profiler(trace_handlers=[mock_trace_handler], + schedule=cyclic_schedule(), + torch_prof_memory_filename=None) profiler.bind_to_state(dummy_state) dummy_state.profiler = profiler marker = profiler.marker('name', @@ -94,3 +100,73 @@ def func_to_profile2(bar: int): assert mock_trace_handler.process_duration_event.call_count == 8 assert mock_trace_handler.process_instant_event.call_count == 1 + + +@pytest.mark.parametrize('torch_prof_with_stack', [True, False]) +@pytest.mark.parametrize('torch_prof_record_shapes', [True, False]) +@pytest.mark.parametrize('torch_prof_profile_memory', [True, False]) +@pytest.mark.parametrize('torch_prof_memory_filename', [None, 'test.html']) +def test_profiler_error_message(torch_prof_with_stack: bool, torch_prof_record_shapes: bool, + torch_prof_profile_memory: bool, torch_prof_memory_filename: Union[None, str]) -> None: + # Construct a profiler and assert that it triggers the ValueError if the arguments are invalid + if (torch_prof_memory_filename is not None and + not (torch_prof_with_stack and torch_prof_record_shapes and torch_prof_profile_memory)): + with pytest.raises(ValueError): + _ = Profiler( + trace_handlers=[MagicMock()], + schedule=cyclic_schedule(), + torch_prof_with_stack=torch_prof_with_stack, + torch_prof_record_shapes=torch_prof_record_shapes, + torch_prof_profile_memory=torch_prof_profile_memory, + torch_prof_memory_filename=torch_prof_memory_filename, + ) + else: + _ = Profiler( + trace_handlers=[MagicMock()], + schedule=cyclic_schedule(), + torch_prof_with_stack=torch_prof_with_stack, + torch_prof_record_shapes=torch_prof_record_shapes, + torch_prof_profile_memory=torch_prof_profile_memory, + torch_prof_memory_filename=torch_prof_memory_filename, + ) + + +@pytest.mark.gpu +def test_memory_timeline(tmp_path: pathlib.Path) -> None: + if version.parse(torch.__version__) <= version.parse('2.1.0.dev'): + # memory timeline is supported after PyTorch 2.1.0. + return + import torch.profiler._memory_profiler as _memory_profiler + + model = torch.nn.Sequential( + torch.nn.Linear(1024, 1024, bias=True), + torch.nn.ReLU(), + torch.nn.Linear(1024, 1024, bias=False), + torch.nn.Softmax(dim=1), + ).to('cuda') + optimizer = torch.optim.Adam(model.parameters(), lr=0.1) + + x = torch.ones((1024, 1024), device='cuda') + targets = torch.ones((1024, 1024), device='cuda') + with torch.profiler.profile(record_shapes=True, with_stack=True, profile_memory=True) as prof: + y = model(x) + loss = torch.nn.functional.mse_loss(y, targets) + loss.backward() + optimizer.step() + optimizer.zero_grad() + + memory_profile = prof._memory_profile() + timeline = memory_profile.timeline + + # this checks the default memory timeline event value (t == -1) for preexisting tensors + assert all((t == -1) if action == _memory_profiler.Action.PREEXISTING else (t > 0) for t, action, _, _ in timeline) + + fig = export_memory_timeline_html( + prof, + os.path.join(tmp_path, 'test_memory_timeline.html'), + yxis_step_size=0.01, + return_fig=True, + ) + assert fig is not None, 'export_memory_timeline_html should return a figure when return_fig=True' + _, end = fig.gca().get_ylim() + assert round(end, 2) == 0.06 From 0d61164dd4b606f0a74bebeea41eafe7c1a9b453 Mon Sep 17 00:00:00 2001 From: willgleich <22464726+willgleich@users.noreply.github.com> Date: Wed, 13 Dec 2023 15:15:51 -0700 Subject: [PATCH 06/10] Add platform AWS and bump aws ofi nccl version (#2776) --- docker/Dockerfile | 5 +++-- docker/build_matrix.yaml | 6 +++--- docker/generate_build_matrix.py | 2 +- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 38633ee9c1..ea72ebc7b4 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -38,7 +38,7 @@ ARG MOFED_VERSION=5.5-1.0.3.2 # Version of EFA Drivers to install (for AWS Elastic Fabric Adapter support) # Leave blank for no EFA Drivers -ARG AWS_OFI_NCCL_VERSION=v1.7.3-aws +ARG AWS_OFI_NCCL_VERSION=v1.7.4-aws # Upgrade certifi to resolve CVE-2022-23491 ARG CERTIFI_VERSION='>=2022.12.7' @@ -250,7 +250,8 @@ RUN if [ -n "$AWS_OFI_NCCL_VERSION" ] ; then \ ./configure --prefix=/opt/aws-ofi-nccl/install \ --with-libfabric=/opt/amazon/efa/ \ --with-cuda=/usr/local/cuda \ - --disable-tests && \ + --disable-tests \ + --enable-platform-aws && \ make && make install ; \ fi diff --git a/docker/build_matrix.yaml b/docker/build_matrix.yaml index c564cf5cc6..60fd2f0222 100644 --- a/docker/build_matrix.yaml +++ b/docker/build_matrix.yaml @@ -27,7 +27,7 @@ - mosaicml/pytorch:latest TARGET: pytorch_stage TORCHVISION_VERSION: 0.16.1 -- AWS_OFI_NCCL_VERSION: v1.7.3-aws +- AWS_OFI_NCCL_VERSION: v1.7.4-aws BASE_IMAGE: nvidia/cuda:12.1.0-cudnn8-devel-ubuntu20.04 CUDA_VERSION: 12.1.0 IMAGE_NAME: torch-2-1-1-cu121-aws @@ -89,7 +89,7 @@ - mosaicml/pytorch:2.0.1_cu118-python3.10-ubuntu20.04 TARGET: pytorch_stage TORCHVISION_VERSION: 0.15.2 -- AWS_OFI_NCCL_VERSION: v1.7.3-aws +- AWS_OFI_NCCL_VERSION: v1.7.4-aws BASE_IMAGE: nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04 CUDA_VERSION: 11.8.0 IMAGE_NAME: torch-2-0-1-cu118-aws @@ -136,7 +136,7 @@ - mosaicml/pytorch:1.13.1_cu117-python3.10-ubuntu20.04 TARGET: pytorch_stage TORCHVISION_VERSION: 0.14.1 -- AWS_OFI_NCCL_VERSION: v1.7.3-aws +- AWS_OFI_NCCL_VERSION: v1.7.4-aws BASE_IMAGE: nvidia/cuda:11.7.1-cudnn8-devel-ubuntu20.04 CUDA_VERSION: 11.7.1 IMAGE_NAME: torch-1-13-1-cu117-aws diff --git a/docker/generate_build_matrix.py b/docker/generate_build_matrix.py index e2f6581a34..ebb28d0adf 100644 --- a/docker/generate_build_matrix.py +++ b/docker/generate_build_matrix.py @@ -224,7 +224,7 @@ def _main(): if interconnect != 'EFA': entry['AWS_OFI_NCCL_VERSION'] = '' else: - entry['AWS_OFI_NCCL_VERSION'] = 'v1.7.3-aws' + entry['AWS_OFI_NCCL_VERSION'] = 'v1.7.4-aws' pytorch_entries.append(entry) nightly_entry = { From 776d172e05b02a6662f7070757fa435adbea5af6 Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Wed, 13 Dec 2023 16:40:10 -0800 Subject: [PATCH 07/10] Extend checkpoint loading to accept a validation function (#2726) --- composer/utils/checkpoint.py | 100 ++++++++++++++++++++++++-- tests/trainer/test_checkpoint.py | 41 ++++++++++- tests/trainer/test_fsdp_checkpoint.py | 41 +++++++++++ 3 files changed, 175 insertions(+), 7 deletions(-) diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index 48a5dc51c8..f18b1abd22 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -14,6 +14,7 @@ import tempfile import textwrap import warnings +from importlib import import_module from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Optional, Union @@ -39,6 +40,66 @@ _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME = f'__{dist.get_global_rank()}_0.distcp' +def _get_checkpoint_validation_function() -> Optional[Callable[[Union[Path, str]], bool]]: + """Get the validation function by name. + + Args: + name (str): Qualified name of the checkpoint validation function. + It should be in the form '{module_name}.{fn_name}'. + + Returns: + Callable[[Union[Path, str]], bool] The checkpoint validation function that returns + True given a valid checkpoint and False otherwise. + """ + name = os.environ.get('CHECKPOINT_VALIDATION_FUNCTION', None) + if name is None: + return None + splits = name.split('.') + module_name, fn_name = '.'.join(splits[:-1]), splits[-1] + module = import_module(module_name) + fn = getattr(module, fn_name) + log.debug(f'Checkpoint validation function {name} has been found.') + return fn + + +def _ensure_valid_checkpoint(checkpoint_filepath: Union[Path, str]) -> Union[Path, str]: + """Ensures that the checkpoint at checkpoint_filepath is valid. + + using the function specified by the CHECKPOINT_VALIDATION_FUNCTION environment variable. + If CHECKPOINT_VALIDATION_FUNCTION is not set, we skip validation. + + Args: + checkpoint_filepath (Union[Path,str]): The path to the checkpoint file. + + Raises: + ValueError if checkpoint file is invalid. + """ + # Get the validation function by name. + validate = _get_checkpoint_validation_function() + + # No function name has been specified. + if validate is None: + log.debug('No validation function specified. Skipping checkpoint validation.') + return checkpoint_filepath + + # Validate the checkpoint. + if not validate(checkpoint_filepath): + raise ValueError(f'Checkpoint at {checkpoint_filepath} is invalid.') + + log.debug(f'Checkpoint at {checkpoint_filepath} is valid.') + return checkpoint_filepath + + +def _torch_load_with_validation(checkpoint_filepath: Union[Path, str], map_location: str) -> Any: + """Validates and loads a torch checkpoint. + + Args: + checkpoint_filepath (Union[Path,str]): The path to the checkpoint file. + map_location (str): The location to load the checkpoint to. + """ + return torch.load(_ensure_valid_checkpoint(checkpoint_filepath), map_location=map_location) + + def _format_path_with_rank_zero(path: str) -> str: """Formats ``path`` with the rank zero values.""" return path.format( @@ -338,8 +399,37 @@ def _get_num_ranks_that_saved_rng(metadata: Metadata): rng_inds = set(rng_inds) return len(rng_inds) - # A subclass of FileSystemReader that downloads files from the object store before reading them from the local filesystem. - class DistCPObjectStoreReader(dist_cp.FileSystemReader): + class FileSystemReaderWithValidation(dist_cp.FileSystemReader): + """FileSystemReader that validates checkpoint files prior to reading.""" + + def __init__(self, path: str): + if _get_checkpoint_validation_function() is None: + log.info('No checkpoint validation function found when loading sharded checkpoints.') + super().__init__(path) + + def read_data(self, plan: LoadPlan, planner: LoadPlanner): + """Reads data file. + + Raises: + ValueError if the data file is invalid. + """ + for read_item in plan.items: + data_path = self.path / self.storage_data[read_item.storage_index].relative_path + _ensure_valid_checkpoint(data_path) + return super().read_data(plan, planner) + + def read_metadata(self) -> Metadata: + """Reads metadata file. + + Raises: + ValueError if the metadata file is invalid. + """ + metadata_file_path = self.path / '.metadata' + _ensure_valid_checkpoint(metadata_file_path) + return super().read_metadata() + + # A subclass of FileSystemReaderWithValidation that downloads files from the object store before reading them from the local filesystem. + class DistCPObjectStoreReader(FileSystemReaderWithValidation): def __init__(self, source_path: str, destination_path: str, object_store): self.source_path = source_path @@ -401,7 +491,7 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner): Path(rank0_download_tempdir) / Path('checkpoints')), object_store=object_store) else: - storage_reader = dist_cp.FileSystemReader(source_path) + storage_reader = FileSystemReaderWithValidation(source_path) # We need no_grad because we overwrite tensor values with set_() when we do elastic loading and we don't want the set_ op recorded in the computation graph. with torch.no_grad(): @@ -695,7 +785,7 @@ def safe_torch_load( model = None optimizer = None if dist.get_global_rank() == 0: - state_dict_list[0] = torch.load(composer_states_filepath, map_location=map_location) + state_dict_list[0] = _torch_load_with_validation(composer_states_filepath, map_location=map_location) # Don't broadcast model/optimizer state if they exist if 'model' in state_dict_list[0]['state']: model = state_dict_list[0]['state']['model'] @@ -716,7 +806,7 @@ def safe_torch_load( return state_dict else: - return torch.load(composer_states_filepath, map_location=map_location) + return _torch_load_with_validation(composer_states_filepath, map_location=map_location) except TypeError as e: if 'Accuracy.__new__() missing 1 required positional argument' in str(e): raise Exception('As of v0.10.0, torchmetrics introduces a new required argument to Accuracy which ' diff --git a/tests/trainer/test_checkpoint.py b/tests/trainer/test_checkpoint.py index 52cc427e20..ebb4f4b422 100644 --- a/tests/trainer/test_checkpoint.py +++ b/tests/trainer/test_checkpoint.py @@ -11,7 +11,7 @@ import time from glob import glob from typing import Any, Dict, List, Optional, Union -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest import torch @@ -29,7 +29,7 @@ from composer.trainer import trainer from composer.trainer.trainer import Trainer from composer.utils import dist, is_tar -from composer.utils.checkpoint import glob_filter +from composer.utils.checkpoint import _ensure_valid_checkpoint, glob_filter from composer.utils.object_store.object_store import ObjectStore from composer.utils.object_store.s3_object_store import S3ObjectStore from tests.common import (RandomClassificationDataset, RandomImageDataset, RandomTextLMDataset, SimpleConvModel, @@ -1289,3 +1289,40 @@ def test_rotate_checkpoints( assert len(symlink_files) == ((1 if not deepspeed_enabled else world_size) if num_keep != 0 else 0) dist.barrier() # all ranks finish before cleaning up tmpdir + + +def simple_validate(filepath: str): + with open(filepath, 'r') as f: + return f.read() == 'good' + + +def test_checkpoint_validation(tmp_path): + checkpoint_filepath = tmp_path / 'dummy' + with open(checkpoint_filepath, 'w') as f: + f.write('good') + + # No validation function specified. + result = _ensure_valid_checkpoint(checkpoint_filepath) + assert result == checkpoint_filepath + + # Non-existent module specified. + with patch.dict(os.environ, {'CHECKPOINT_VALIDATION_FUNCTION': 'bad_module.bad_function'}): + with pytest.raises(ModuleNotFoundError): + _ensure_valid_checkpoint(checkpoint_filepath) + + # Non-existent function specified. + with patch.dict(os.environ, {'CHECKPOINT_VALIDATION_FUNCTION': 'tests.trainer.test_checkpoint.bad_function'}): + with pytest.raises(AttributeError): + _ensure_valid_checkpoint(checkpoint_filepath) + + # Correct usage and successful validation. + with patch.dict(os.environ, {'CHECKPOINT_VALIDATION_FUNCTION': 'tests.trainer.test_checkpoint.simple_validate'}): + result = _ensure_valid_checkpoint(checkpoint_filepath) + assert result == checkpoint_filepath + + # Correct usage and failed validation. + with open(checkpoint_filepath, 'w') as f: + f.write('bad') + with patch.dict(os.environ, {'CHECKPOINT_VALIDATION_FUNCTION': 'tests.trainer.test_checkpoint.simple_validate'}): + with pytest.raises(ValueError): + _ensure_valid_checkpoint(checkpoint_filepath) diff --git a/tests/trainer/test_fsdp_checkpoint.py b/tests/trainer/test_fsdp_checkpoint.py index 44468bd51c..897a87dbd6 100644 --- a/tests/trainer/test_fsdp_checkpoint.py +++ b/tests/trainer/test_fsdp_checkpoint.py @@ -9,8 +9,10 @@ import pathlib import textwrap import uuid +from contextlib import nullcontext as does_not_raise from functools import partial from typing import Any, Callable, Optional, Sequence +from unittest.mock import patch import numpy as np import pytest @@ -545,6 +547,45 @@ def test_fsdp_full_state_dict_load_with_ema( trainer2.close() +@pytest.mark.gpu +@world_size(2) +@pytest.mark.parametrize('is_valid_checkpoint', [True, False]) +@pytest.mark.parametrize('state_dict_type', ['full', 'sharded']) +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'), + reason='requires PyTorch 1.13 or higher') +@pytest.mark.filterwarnings(r'ignore:TypedStorage is deprecated.:UserWarning') +@pytest.mark.filterwarnings(r'ignore:.*metrics are not saved with sharded state dict.*:UserWarning') +@pytest.mark.filterwarnings(r'ignore:Please use DTensor instead and we are deprecating ShardedTensor.:UserWarning') +def test_checkpoint_loading_with_validation(world_size, tmp_path, is_valid_checkpoint: bool, state_dict_type: str): + from torch.distributed.checkpoint.api import CheckpointException + + def mock_get_checkpoint_validation_function(): + return lambda _: is_valid_checkpoint + + tmp_paths = dist.all_gather_object(os.path.abspath(tmp_path)) + save_folder = os.path.join(tmp_paths[0], 'checkpoints') + fsdp_config = FSDPConfig(state_dict_type=state_dict_type) + + # First trainer saves checkpoints. + trainer = get_trainer(save_folder=save_folder, fsdp_config=fsdp_config, max_duration='1ba') + trainer.fit() + trainer.close() + + expectation = does_not_raise() if is_valid_checkpoint else pytest.raises((ValueError, CheckpointException)) + + checkpoint_relpath = 'ba1-rank0.pt' if state_dict_type == 'full' else 'ba1' + + # Load checkpoints with checkpoint validation. + with expectation: + with patch('composer.utils.checkpoint._get_checkpoint_validation_function', + mock_get_checkpoint_validation_function): + trainer = get_trainer(load_path=os.path.join(save_folder, checkpoint_relpath), + max_duration='2ba', + fsdp_config=fsdp_config) + trainer.fit() + trainer.close() + + @pytest.mark.gpu @world_size(2) @pytest.mark.parametrize('weights_only', [False, True]) From 09f458018c35ed98cdef4b20b90858fe3c3c0fb0 Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Thu, 14 Dec 2023 10:11:45 -0800 Subject: [PATCH 08/10] Fix checkpoint validation tests for torch 1.13 (#2779) * fix checkpoint validation tests for torch 1.13 * Fix --- tests/trainer/test_fsdp_checkpoint.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/tests/trainer/test_fsdp_checkpoint.py b/tests/trainer/test_fsdp_checkpoint.py index 897a87dbd6..d9b7c5b5ee 100644 --- a/tests/trainer/test_fsdp_checkpoint.py +++ b/tests/trainer/test_fsdp_checkpoint.py @@ -550,14 +550,21 @@ def test_fsdp_full_state_dict_load_with_ema( @pytest.mark.gpu @world_size(2) @pytest.mark.parametrize('is_valid_checkpoint', [True, False]) -@pytest.mark.parametrize('state_dict_type', ['full', 'sharded']) +@pytest.mark.parametrize('state_dict_type', ['sharded', 'full']) @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'), reason='requires PyTorch 1.13 or higher') @pytest.mark.filterwarnings(r'ignore:TypedStorage is deprecated.:UserWarning') @pytest.mark.filterwarnings(r'ignore:.*metrics are not saved with sharded state dict.*:UserWarning') @pytest.mark.filterwarnings(r'ignore:Please use DTensor instead and we are deprecating ShardedTensor.:UserWarning') def test_checkpoint_loading_with_validation(world_size, tmp_path, is_valid_checkpoint: bool, state_dict_type: str): - from torch.distributed.checkpoint.api import CheckpointException + # Set the error expectations. + expectation = does_not_raise() + if not is_valid_checkpoint: + if using_torch_2() and state_dict_type == 'sharded': + from torch.distributed.checkpoint import CheckpointException + expectation = pytest.raises(CheckpointException) + else: + expectation = pytest.raises(ValueError) def mock_get_checkpoint_validation_function(): return lambda _: is_valid_checkpoint @@ -571,9 +578,13 @@ def mock_get_checkpoint_validation_function(): trainer.fit() trainer.close() - expectation = does_not_raise() if is_valid_checkpoint else pytest.raises((ValueError, CheckpointException)) - - checkpoint_relpath = 'ba1-rank0.pt' if state_dict_type == 'full' else 'ba1' + # Determine the checkpoint path for loading. + checkpoint_relpath = 'ba1-rank0.pt' + if state_dict_type == 'sharded': + if using_torch_2(): + checkpoint_relpath = 'ba1' + else: + checkpoint_relpath = 'ba1/ba1-rank{rank}.pt' # Load checkpoints with checkpoint validation. with expectation: From 7e0e40a72a6b0c87cca4d408792336e4e5c742da Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Thu, 14 Dec 2023 13:56:04 -0600 Subject: [PATCH 09/10] Bump version to 0.17.2 (#2780) * bump version * 0.17.2 * update matrix --- composer/_version.py | 2 +- docker/README.md | 4 ++-- docker/build_matrix.yaml | 12 ++++++------ docker/generate_build_matrix.py | 2 +- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/composer/_version.py b/composer/_version.py index c00efb3088..a41361e246 100644 --- a/composer/_version.py +++ b/composer/_version.py @@ -3,4 +3,4 @@ """The Composer Version.""" -__version__ = '0.17.1' +__version__ = '0.17.2' diff --git a/docker/README.md b/docker/README.md index b62ee5e677..e0680d38e1 100644 --- a/docker/README.md +++ b/docker/README.md @@ -15,8 +15,8 @@ all dependencies for both NLP and Vision models. They are built on top of the | Composer Version | CUDA Support | Docker Tag | |--------------------|----------------|----------------------------------------------------------------| -| 0.17.1 | Yes | `mosaicml/composer:latest`, `mosaicml/composer:0.17.1` | -| 0.17.1 | No | `mosaicml/composer:latest_cpu`, `mosaicml/composer:0.17.1_cpu` | +| 0.17.2 | Yes | `mosaicml/composer:latest`, `mosaicml/composer:0.17.2` | +| 0.17.2 | No | `mosaicml/composer:latest_cpu`, `mosaicml/composer:0.17.2_cpu` | **Note**: For a lightweight installation, we recommended using a [MosaicML PyTorch Image](#pytorch-images) and manually diff --git a/docker/build_matrix.yaml b/docker/build_matrix.yaml index 60fd2f0222..f12223928c 100644 --- a/docker/build_matrix.yaml +++ b/docker/build_matrix.yaml @@ -193,9 +193,9 @@ TORCHVISION_VERSION: 0.17.0 - AWS_OFI_NCCL_VERSION: '' BASE_IMAGE: nvidia/cuda:12.1.0-cudnn8-devel-ubuntu20.04 - COMPOSER_INSTALL_COMMAND: mosaicml[all]==0.17.1 + COMPOSER_INSTALL_COMMAND: mosaicml[all]==0.17.2 CUDA_VERSION: 12.1.0 - IMAGE_NAME: composer-0-17-1 + IMAGE_NAME: composer-0-17-2 MOFED_VERSION: 5.5-1.0.3.2 NVIDIA_REQUIRE_CUDA_OVERRIDE: cuda>=12.1 brand=tesla,driver>=450,driver<451 brand=tesla,driver>=470,driver<471 brand=unknown,driver>=470,driver<471 brand=nvidia,driver>=470,driver<471 brand=nvidiartx,driver>=470,driver<471 @@ -216,15 +216,15 @@ PYTORCH_NIGHTLY_VERSION: '' PYTORCH_VERSION: 2.1.1 TAGS: - - mosaicml/composer:0.17.1 + - mosaicml/composer:0.17.2 - mosaicml/composer:latest TARGET: composer_stage TORCHVISION_VERSION: 0.16.1 - AWS_OFI_NCCL_VERSION: '' BASE_IMAGE: ubuntu:20.04 - COMPOSER_INSTALL_COMMAND: mosaicml[all]==0.17.1 + COMPOSER_INSTALL_COMMAND: mosaicml[all]==0.17.2 CUDA_VERSION: '' - IMAGE_NAME: composer-0-17-1-cpu + IMAGE_NAME: composer-0-17-2-cpu MOFED_VERSION: 5.5-1.0.3.2 NVIDIA_REQUIRE_CUDA_OVERRIDE: '' PYTHON_VERSION: '3.10' @@ -232,7 +232,7 @@ PYTORCH_NIGHTLY_VERSION: '' PYTORCH_VERSION: 2.1.1 TAGS: - - mosaicml/composer:0.17.1_cpu + - mosaicml/composer:0.17.2_cpu - mosaicml/composer:latest_cpu TARGET: composer_stage TORCHVISION_VERSION: 0.16.1 diff --git a/docker/generate_build_matrix.py b/docker/generate_build_matrix.py index ebb28d0adf..d265b624ca 100644 --- a/docker/generate_build_matrix.py +++ b/docker/generate_build_matrix.py @@ -246,7 +246,7 @@ def _main(): composer_entries = [] # The `GIT_COMMIT` is a placeholder and Jenkins will substitute it with the actual git commit for the `composer_staging` images - composer_versions = ['0.17.1'] # Only build images for the latest composer version + composer_versions = ['0.17.2'] # Only build images for the latest composer version composer_python_versions = [LATEST_PYTHON_VERSION] # just build composer against the latest for product in itertools.product(composer_python_versions, composer_versions, cuda_options): From 45bb1359de800586ab528751f4c13512d03cc59a Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Fri, 15 Dec 2023 13:03:00 -0800 Subject: [PATCH 10/10] bump transformers version (#2781) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 6bf8fa3154..372b0156d9 100644 --- a/setup.py +++ b/setup.py @@ -184,7 +184,7 @@ def package_files(prefix: str, directory: str, extension: str): ] extra_deps['nlp'] = [ - 'transformers>=4.11,<4.36,!=4.34.0', + 'transformers>=4.11,<4.37,!=4.34.0', 'datasets>=2.4,<3', ]