Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tune][Fix]Remove the clear_checkpoint function during Trial restoration error handling. #48532

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions python/ray/tune/experiment/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,11 +793,11 @@ def get_error(self) -> Optional[TuneError]:
return None

def _handle_restore_error(self, exc: Exception):
# For Restoration errors, we only increment the restore failure count
# if the number of failures exceeds the restore retry limit.
if self.temporary_state.num_restore_failures >= int(
os.environ.get("TUNE_RESTORE_RETRY_NUM", 0)
):
# Restore was unsuccessful, try again without checkpoint.
self.clear_checkpoint()
self.run_metadata.num_failures += 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what are diff between self.run_metadata.num_failures and self.temporary_state.num_restore_failures?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_restore_failures is the number of failed restoration.
num_failures is the number of failures that caused by user/ application code.

because restoration is not a user defined behavior, but some feature we provided. We don't treat restoration failure same as normal application failure. The behavior is, when the program failed due to application, we increment the num_failures and trying to restore the application. If the restoration is successful, the program just goes on. If the restoration fail, we will keep on trying to restore but increments the number of num_restore_failures by 1. When the TUNE_RESTORE_RETRY_NUM restore reaches, we stop restoration, and increment the num_failures by another 1.

else:
self.temporary_state.num_restore_failures += 1
Expand Down Expand Up @@ -883,12 +883,6 @@ def should_checkpoint(self):
def has_checkpoint(self) -> bool:
return self.checkpoint is not None

def clear_checkpoint(self):
if self.latest_checkpoint_result:
self.latest_checkpoint_result.checkpoint = None
self.temporary_state.restoring_from = None
self.run_metadata.invalidate_cache()

def on_checkpoint(self, checkpoint_result: _TrainingResult):
"""Hook for handling checkpoints taken by the Trainable.

Expand Down
20 changes: 17 additions & 3 deletions python/ray/tune/tests/test_tuner_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,21 @@ def test_tuner_restore_latest_available_checkpoint(

@pytest.mark.parametrize("retry_num", [0, 2])
def test_restore_retry(ray_start_2_cpus, tmpdir, retry_num):
"""Test retrying restore on a trial level by setting `TUNE_RESTORE_RETRY_NUM`."""
"""
Test retrying restore on a trial level by setting `TUNE_RESTORE_RETRY_NUM`.

This unit test holds the following hyperparameters:
- `retry_num`: Maximum number of retry attempts for restoring a trial.
This value is assigned to the environment variable `TUNE_RESTORE_RETRY_NUM`.
If the restoration fails after retry_num attempts, the trial increments its
counter of total number of failures by 1.

- `retry_num_to_fail`: Number of restore attempts to fail. In this test,
retry_num_to_fail is set to 2, causing the first two restore attempts to fail.

- `max_failures`: Maximum allowable failures during training. Here, max_failures is
set to 2, meaning the training process will terminate after two total failures.
"""

class MockTrainable(Trainable):
"""A trainable that can generate one failure during training and
Expand All @@ -546,7 +560,7 @@ class MockTrainable(Trainable):
def setup(self, config):
self.idx = 0
self.tag_file_path = config["tag_file_path"]
self.retry_num_to_fail = config.get("retry_num_to_fail", 2)
self.retry_num_to_fail = 2
self._is_restored = False

def step(self):
Expand Down Expand Up @@ -592,7 +606,7 @@ def load_checkpoint(self, checkpoint_dir):
name="tryout_restore",
stop={"training_iteration": 5},
storage_path=str(tmpdir),
failure_config=FailureConfig(max_failures=1),
failure_config=FailureConfig(max_failures=2),
checkpoint_config=CheckpointConfig(checkpoint_frequency=1),
),
param_space={"tag_file_path": tag_file},
Expand Down
Loading