Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Checkpoint] Fix symlink issue where symlink file uploaded before checkpoint files upload #3376

Merged
merged 66 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
f52c770
a
bigning Jun 4, 2024
8ee8364
a
bigning Jun 5, 2024
4fecdf6
a
bigning Jun 6, 2024
7e53a3b
a
bigning Jun 6, 2024
76a5f2d
Merge https://github.com/mosaicml/composer into checkpoint_saver
bigning Jun 6, 2024
20cac57
Merge remote-tracking branch 'remotes/origin/checkpoint_saver' into c…
bigning Jun 6, 2024
f772e33
a
bigning Jun 6, 2024
4e391a6
a
bigning Jun 8, 2024
55ac530
a
bigning Jun 8, 2024
e2d267b
a
bigning Jun 10, 2024
8035f50
fix test
bigning Jun 11, 2024
e65110d
a
bigning Jun 11, 2024
40cddfb
a
bigning Jun 11, 2024
91d838c
a
bigning Jun 11, 2024
a23552b
a
bigning Jun 11, 2024
cf4e0f1
fix unit test
bigning Jun 11, 2024
229b57d
Merge https://github.com/mosaicml/composer into checkpoint_saver
bigning Jun 12, 2024
dde135f
Merge branch 'checkpoint_saver' of https://github.com/mosaicml/compos…
bigning Jun 12, 2024
e6884fc
a
bigning Jun 13, 2024
9911766
a
bigning Jun 13, 2024
36a1dc5
a
bigning Jun 13, 2024
e4db035
a
bigning Jun 13, 2024
081033c
a
bigning Jun 13, 2024
ae5ece3
fix 2gpu unit test
bigning Jun 13, 2024
2f5d6b0
a
bigning Jun 13, 2024
28a36e0
a
bigning Jun 13, 2024
703ef5f
Merge https://github.com/mosaicml/composer into checkpoint_saver
bigning Jun 13, 2024
c78f475
a
bigning Jun 13, 2024
7ecfcf3
a
bigning Jun 13, 2024
1280266
fix doctest
bigning Jun 14, 2024
c0cb94d
a
bigning Jun 14, 2024
95fca9f
fix test and lint
bigning Jun 14, 2024
2c77da9
up
bigning Jun 14, 2024
ca46b4f
a
bigning Jun 14, 2024
4f3108c
a
bigning Jun 14, 2024
11307f0
Merge branch 'dev' into checkpoint_saver
bigning Jun 17, 2024
f415d60
a
bigning Jun 18, 2024
a0a3e92
a
bigning Jun 18, 2024
301dd67
a
bigning Jun 18, 2024
c4c094b
a
bigning Jun 18, 2024
5ec3e28
a
bigning Jun 20, 2024
9813816
a
bigning Jun 20, 2024
8c3c5cc
address comments
bigning Jun 20, 2024
c81cc2f
a
bigning Jun 20, 2024
c1174d4
a
bigning Jun 20, 2024
df601d2
a
bigning Jun 20, 2024
a41f427
a
bigning Jun 20, 2024
bc06a7b
rerun test
bigning Jun 20, 2024
c87f36c
add logging
bigning Jun 21, 2024
1ebf5a7
Merge branch 'dev' into checkpoint_saver
bigning Jun 21, 2024
0e8ae23
remove debug comments
bigning Jun 21, 2024
c7541c4
comments
bigning Jun 21, 2024
a9081c2
a
bigning Jun 25, 2024
b98ad33
cleanup
bigning Jun 26, 2024
8a6f5d1
a
bigning Jun 26, 2024
ebbcc46
linter
bigning Jun 26, 2024
3575d1e
lint
bigning Jun 26, 2024
fb8dbba
Update composer/callbacks/checkpoint_saver.py
bigning Jun 28, 2024
df4f59a
commenst
bigning Jun 28, 2024
4971526
a
bigning Jun 28, 2024
ebbbf56
fix test
bigning Jun 28, 2024
3bb10c9
fix test
bigning Jun 28, 2024
0d4c7af
comments
bigning Jul 2, 2024
9d4e112
Merge branch 'dev' into checkpoint_saver
bigning Jul 3, 2024
6ed9aa7
a
bigning Jul 3, 2024
b781375
Merge branch 'dev' into checkpoint_saver
bigning Jul 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 157 additions & 22 deletions composer/callbacks/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
FORMAT_NAME_WITH_DIST_AND_TIME_TABLE,
FORMAT_NAME_WITH_DIST_TABLE,
PartialFilePath,
RemoteFilesExistingCheckStatus,
RemoteUploader,
checkpoint,
create_interval_scheduler,
create_symlink_file,
Expand All @@ -28,6 +30,7 @@
format_name_with_dist,
format_name_with_dist_and_time,
is_model_deepspeed,
parse_uri,
partial_format,
)
from composer.utils.checkpoint import _TORCH_DISTRIBUTED_CHECKPOINTS_METADATA_FILENAME
Expand Down Expand Up @@ -287,8 +290,13 @@ def __init__(
num_checkpoints_to_keep: int = -1,
weights_only: bool = False,
ignore_keys: Optional[Union[list[str], Callable[[dict], None]]] = None,
num_concurrent_uploads: int = 1,
upload_timeout_in_seconds: int = 3600,
):
folder = str(folder)
backend, _, local_folder = parse_uri(str(folder))
if local_folder == '':
local_folder = '.'

filename = str(filename)
remote_file_name = str(remote_file_name) if remote_file_name is not None else None
latest_filename = str(latest_filename) if latest_filename is not None else None
Expand All @@ -304,10 +312,10 @@ def __init__(
self.save_interval = save_interval
self.last_checkpoint_batch: Optional[Time] = None

self.folder = folder
self.folder = local_folder

self.filename = PartialFilePath(filename.lstrip('/'), folder)
self.latest_filename = PartialFilePath(latest_filename.lstrip('/'), folder) if latest_filename else None
self.filename = PartialFilePath(filename.lstrip('/'), local_folder)
self.latest_filename = PartialFilePath(latest_filename.lstrip('/'), local_folder) if latest_filename else None
self.remote_file_name = PartialFilePath(remote_file_name) if remote_file_name else None
self.latest_remote_file_name = PartialFilePath(latest_remote_file_name) if latest_remote_file_name else None

Expand All @@ -320,6 +328,23 @@ def __init__(

self.start_batch = None

self.remote_uploader = None
self.rank_saves_symlinks: bool = False
self.tmp_dir_for_symlink = tempfile.TemporaryDirectory()
self.num_concurrent_uploads = num_concurrent_uploads
self.upload_timeout_in_seconds = upload_timeout_in_seconds
# Allow unit test to override this to make it faster
self._symlink_upload_wait_before_next_try_in_seconds = 30.0
self.pid = os.getpid()
self.symlink_count = 0
bigning marked this conversation as resolved.
Show resolved Hide resolved
self.symlink_upload_tasks = []

if backend != '':
self.remote_uploader = RemoteUploader(
remote_folder=str(folder),
num_concurrent_uploads=self.num_concurrent_uploads,
)

def init(self, state: State, logger: Logger) -> None:
# If MLFlowLogger is being used, format MLFlow-specific placeholders in the save folder and paths.
# Assumes that MLFlowLogger comes before CheckpointSaver in the list of loggers.
Expand All @@ -346,9 +371,10 @@ def init(self, state: State, logger: Logger) -> None:
self.latest_remote_file_name.filename,
**mlflow_format_kwargs,
)

break

if self.remote_uploader is not None:
self.remote_uploader.init()
folder = format_name_with_dist(self.folder, state.run_name)
os.makedirs(folder, exist_ok=True)

Expand Down Expand Up @@ -410,6 +436,27 @@ def load_state_dict(self, state: dict[str, Any]):
load_timestamp.load_state_dict(timestamp_state)
self.all_saved_checkpoints_to_timestamp[save_filename] = load_timestamp

def _upload_checkpoint(
self,
remote_file_name: str,
local_file_name: str,
local_remote_file_names: list[str],
logger: Logger,
):
if self.remote_uploader is not None:
self.remote_uploader.upload_file_async(
remote_file_name=remote_file_name,
file_path=pathlib.Path(local_file_name),
overwrite=self.overwrite,
)
local_remote_file_names.append(remote_file_name)
else:
logger.upload_file(
remote_file_name=remote_file_name,
file_path=local_file_name,
overwrite=self.overwrite,
)

def _save_checkpoint(self, state: State, logger: Logger):
self.last_checkpoint_batch = state.timestamp.batch

Expand All @@ -432,7 +479,14 @@ def _save_checkpoint(self, state: State, logger: Logger):
)
log.debug(f'Checkpoint locally saved to {saved_path}')

self.symlink_count += 1
bigning marked this conversation as resolved.
Show resolved Hide resolved
# Remote checkpoint file names on this rank
local_remote_file_names = []
bigning marked this conversation as resolved.
Show resolved Hide resolved
all_remote_filenames = []

if not saved_path: # not all ranks save
if self.remote_file_name is not None and self.remote_uploader is not None:
all_remote_filenames = dist.all_gather_object(local_remote_file_names)
return

metadata_local_file_path = None
Expand All @@ -443,6 +497,7 @@ def _save_checkpoint(self, state: State, logger: Logger):
state.timestamp,
)

self.rank_saves_symlinks = dist.get_global_rank() == 0 or not state.fsdp_sharded_state_dict_enabled
if self.latest_filename is not None and self.num_checkpoints_to_keep != 0:
symlink = self.latest_filename.format(state, is_deepspeed)
os.makedirs(os.path.dirname(symlink), exist_ok=True)
Expand All @@ -455,8 +510,7 @@ def _save_checkpoint(self, state: State, logger: Logger):
src_path = str(pathlib.Path(saved_path).parent)
else:
src_path = saved_path
this_rank_saves_symlinks = dist.get_global_rank() == 0 or not state.fsdp_sharded_state_dict_enabled
if this_rank_saves_symlinks:
if self.rank_saves_symlinks:
os.symlink(os.path.relpath(src_path, os.path.dirname(symlink)), symlink)

# if remote file name provided, upload the checkpoint
Expand All @@ -482,10 +536,11 @@ def _save_checkpoint(self, state: State, logger: Logger):
state.timestamp,
)
assert metadata_local_file_path is not None
logger.upload_file(
self._upload_checkpoint(
remote_file_name=metadata_remote_file_name,
file_path=metadata_local_file_path,
overwrite=self.overwrite,
local_file_name=metadata_local_file_path,
local_remote_file_names=local_remote_file_names,
logger=logger,
)
else:
remote_file_name = self.remote_file_name.format(
Expand All @@ -495,12 +550,20 @@ def _save_checkpoint(self, state: State, logger: Logger):

log.debug(f'Uploading checkpoint to {remote_file_name}')
try:
logger.upload_file(remote_file_name=remote_file_name, file_path=saved_path, overwrite=self.overwrite)
self._upload_checkpoint(
remote_file_name=remote_file_name,
local_file_name=saved_path,
local_remote_file_names=local_remote_file_names,
logger=logger,
)
except FileExistsError as e:
raise FileExistsError(
f'Uploading checkpoint failed with error: {e}. overwrite was set to {self.overwrite}. To overwrite checkpoints with Trainer, set save_overwrite to True.',
) from e

if self.remote_uploader is not None:
all_remote_filenames = dist.all_gather_object(local_remote_file_names)

# symlinks stay the same with sharded checkpointing
if self.latest_remote_file_name is not None:
symlink_name = self.latest_remote_file_name.format(
Expand All @@ -509,17 +572,31 @@ def _save_checkpoint(self, state: State, logger: Logger):
).lstrip('/') + '.symlink'

# create and upload a symlink file
with tempfile.TemporaryDirectory() as tmpdir:
symlink_filename = os.path.join(tmpdir, 'latest.symlink')
# Sharded checkpoints for torch >2.0 use directories not files for load_paths
if state.fsdp_sharded_state_dict_enabled:
src_path = str(pathlib.Path(remote_file_name).parent)
symlink_filename = os.path.join(
self.tmp_dir_for_symlink.name,
f'latest.{self.symlink_count}.symlink',
)
# Sharded checkpoints for torch >2.0 use directories not files for load_paths
if state.fsdp_sharded_state_dict_enabled:
src_path = str(pathlib.Path(remote_file_name).parent)
else:
src_path = remote_file_name
log.debug(f'Creating symlink file {symlink_filename} -> {src_path}')
if self.rank_saves_symlinks:
create_symlink_file(src_path, symlink_filename)
if self.remote_uploader is not None:
bigning marked this conversation as resolved.
Show resolved Hide resolved
remote_checkpoint_file_names = []
for file_names in all_remote_filenames:
remote_checkpoint_file_names += file_names
bigning marked this conversation as resolved.
Show resolved Hide resolved
check_remote_files_exist_future = self.remote_uploader.check_remote_files_exist_async(
remote_checkpoint_file_names=remote_checkpoint_file_names,
max_wait_time_in_seconds=self.upload_timeout_in_seconds,
wait_before_next_try_in_seconds=self._symlink_upload_wait_before_next_try_in_seconds,
)
self.symlink_upload_tasks.append(
(check_remote_files_exist_future, symlink_filename, symlink_name),
)
else:
src_path = remote_file_name
log.debug(f'Creating symlink file {symlink_filename} -> {src_path}')
this_rank_saves_symlinks = dist.get_global_rank() == 0 or not state.fsdp_sharded_state_dict_enabled
if this_rank_saves_symlinks:
create_symlink_file(src_path, symlink_filename)
logger.upload_file(
remote_file_name=symlink_name,
file_path=symlink_filename,
Expand All @@ -532,7 +609,6 @@ def _save_checkpoint(self, state: State, logger: Logger):
self._rotate_checkpoints(sharding_enabled=state.fsdp_sharded_state_dict_enabled)

def _rotate_checkpoints(self, sharding_enabled: bool = False):

while len(self.saved_checkpoints) > self.num_checkpoints_to_keep:
prefix_dir = None
checkpoint_to_delete = self.saved_checkpoints.pop(0)
Expand All @@ -542,3 +618,62 @@ def _rotate_checkpoints(self, sharding_enabled: bool = False):
else:
if dist.get_global_rank() == 0:
shutil.rmtree(prefix_dir)

def batch_end(self, state: State, logger: Logger) -> None:
del state, logger # unused
if self.remote_uploader is None:
return
self.remote_uploader.check_workers()
if not self.rank_saves_symlinks:
return
undone_symlink_upload_tasks = []
for (check_remote_files_exist_future, local_symlink_file,
remote_symlink_file) in reversed(self.symlink_upload_tasks):
if not check_remote_files_exist_future.done():
undone_symlink_upload_tasks.insert(
0,
(check_remote_files_exist_future, local_symlink_file, remote_symlink_file),
)
continue
if check_remote_files_exist_future.done():
result = check_remote_files_exist_future.result()
if result == RemoteFilesExistingCheckStatus.EXIST:
self.remote_uploader.upload_file_async(
remote_file_name=remote_symlink_file,
file_path=local_symlink_file,
overwrite=True,
)
break
bigning marked this conversation as resolved.
Show resolved Hide resolved
else:
raise RuntimeError(f'Failed to check if checkpoint files upload finish: {result}')
self.symlink_upload_tasks = undone_symlink_upload_tasks

def fit_end(self, state: State, logger: Logger) -> None:
del state, logger # unused
if self.remote_uploader is None:
return
log.info('Waiting for checkpoint uploading to finish')
self.remote_uploader.wait()
if self.rank_saves_symlinks and len(self.symlink_upload_tasks) > 0:
log.debug('Uploading symlink to the latest checkpoint')
# We only need to upload a symlink pointing to the latest checkpoint files, so we can ignore successful uploads of older checkpoints.
check_remote_files_exist_future, local_symlink_file, remote_symlink_file = self.symlink_upload_tasks[-1]
bigning marked this conversation as resolved.
Show resolved Hide resolved
result = check_remote_files_exist_future.result()
if result == RemoteFilesExistingCheckStatus.EXIST:
symlink_upload_future = self.remote_uploader.upload_file_async(
remote_file_name=remote_symlink_file,
file_path=local_symlink_file,
overwrite=True,
)
symlink_upload_future.result()
else:
raise RuntimeError(f'Failed to check if checkpoint files upload finish: {result}')
log.info('Checkpoint uploading finished!')

def post_close(self):
if self.remote_uploader is not None:
# Wait the symlink file upload to finish and close remote uploader
try:
self.remote_uploader.wait_and_close()
except Exception as e:
log.error(f'RemoteUploader run into exception {e}')
Loading
Loading