Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for tracking checkpoint metrics with Orbax in T5X. #153

Merged
merged 1 commit into from
Dec 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion orbax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@
"""Orbax API."""

# A new PyPI release will be pushed everytime `__version__` is increased.
__version__ = '0.0.18'
__version__ = '0.0.19'
32 changes: 25 additions & 7 deletions orbax/checkpoint/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ class CheckpointManagerOptions:
function.
best_mode: one of ['max', 'min']. The best metric is determine on the basis of
this value.
keep_checkpoints_without_metrics: If False, checkpoints with metrics present
are eligible for cleanup. Otherwise, they will never be deleted.
step_prefix: if provided, step directories will take the form
f'{step_prefix}_<step>'. Otherwise, they will simply be an integer <step>.

Expand All @@ -88,8 +90,15 @@ class CheckpointManagerOptions:
keep_period: Optional[int] = None
best_fn: Optional[Callable[[PyTree], float]] = None
best_mode: str = 'max'
keep_checkpoints_without_metrics: bool = True
step_prefix: Optional[str] = None

def __post_init__(self):
if self.best_mode not in ('min', 'max'):
msg = ("`CheckpointManagerOptions.best_mode` must be one of None, 'min' "
"or 'max'. Got {self.dtype}.")
raise ValueError(msg)


@dataclasses.dataclass
class CheckpointInfo:
Expand Down Expand Up @@ -213,6 +222,8 @@ def best_step(self) -> Optional[int]:
if not self._checkpoints:
return None
_, sorted_checkpoints = self._sort_checkpoints_by_metrics(self._checkpoints)
if not sorted_checkpoints:
return None
return sorted_checkpoints[-1].step

def should_save(self, step: int) -> bool:
Expand Down Expand Up @@ -584,7 +595,7 @@ def get_metrics(step):
for s, t, m in zip(steps, times, metrics)
]

def _add_checkpoint_info(self, step, metrics):
def _add_checkpoint_info(self, step: int, metrics: Optional[PyTree]):
self._checkpoints.append(
CheckpointInfo(step, datetime.datetime.now(tz=datetime.timezone.utc),
metrics))
Expand Down Expand Up @@ -636,8 +647,12 @@ def _delete_directory(self, step: int):

def _remove_old_checkpoints(self):
"""Keeps the `max_to_keep` most recent checkpoint steps."""
# Must have set max_to_keep or keep_time_interval.
if not self._options.max_to_keep and not self._options.keep_time_interval:
return
# Not enough checkpoints accumulated to consider deletion.
if len(self._checkpoints) <= self._options.max_to_keep:
return
if self._track_best:
# Best steps (to keep) are at the end, after sorting.
checkpoints_without_metrics, sorted_checkpoints = self._sort_checkpoints_by_metrics(
Expand All @@ -647,12 +662,15 @@ def _remove_old_checkpoints(self):
checkpoints_without_metrics = []
sorted_checkpoints = self._checkpoints

to_remove = len(sorted_checkpoints) - self._options.max_to_keep
if to_remove <= 0:
return
maybe_delete = sorted_checkpoints[:to_remove]
active_checkpoints = checkpoints_without_metrics + sorted_checkpoints[
to_remove:]
keep = int(self._options.max_to_keep)
if self._options.keep_checkpoints_without_metrics:
maybe_delete = sorted_checkpoints[:-keep]
active_checkpoints = checkpoints_without_metrics + sorted_checkpoints[
-keep:]
else:
all_checkpoints = checkpoints_without_metrics + sorted_checkpoints
maybe_delete = all_checkpoints[:-keep]
active_checkpoints = all_checkpoints[-keep:]

kept_checkpoints = []
for info in maybe_delete:
Expand Down