diff --git a/src/huggingface_hub/_commit_scheduler.py b/src/huggingface_hub/_commit_scheduler.py index 62d7bf1d0d..ba0b63afc7 100644 --- a/src/huggingface_hub/_commit_scheduler.py +++ b/src/huggingface_hub/_commit_scheduler.py @@ -30,8 +30,9 @@ class CommitScheduler: """ Scheduler to upload a local folder to the Hub at regular intervals (e.g. push to hub every 5 minutes). - The scheduler is started when instantiated and run indefinitely. At the end of your script, a last commit is - triggered. Checkout the [upload guide](https://huggingface.co/docs/huggingface_hub/guides/upload#scheduled-uploads) + The recommended way to use the scheduler is to use it as a context manager. This ensures that the scheduler is + properly stopped and the last commit is triggered when the script ends. The scheduler can also be stopped manually + with the `stop` method. Checkout the [upload guide](https://huggingface.co/docs/huggingface_hub/guides/upload#scheduled-uploads) to learn more about how to use it. Args: @@ -78,6 +79,22 @@ class CommitScheduler: >>> with csv_path.open("a") as f: ... f.write("second line") ``` + + Example using a context manager: + ```py + >>> from pathlib import Path + >>> from huggingface_hub import CommitScheduler + + >>> with CommitScheduler(repo_id="test_scheduler", repo_type="dataset", folder_path="watched_folder", every=10) as scheduler: + ... csv_path = Path("watched_folder/data.csv") + ... with csv_path.open("a") as f: + ... f.write("first line") + ... (...) + ... with csv_path.open("a") as f: + ... f.write("second line") + + # Scheduler is now stopped and last commit have been triggered + ``` """ def __init__( @@ -144,6 +161,15 @@ def stop(self) -> None: """ self.__stopped = True + def __enter__(self) -> "CommitScheduler": + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + # Upload last changes before exiting + self.trigger().result() + self.stop() + return + def _run_scheduler(self) -> None: """Dumb thread waiting between each scheduled push to Hub.""" while True: diff --git a/tests/test_commit_scheduler.py b/tests/test_commit_scheduler.py index 7d3a04ca0a..a38d8cb947 100644 --- a/tests/test_commit_scheduler.py +++ b/tests/test_commit_scheduler.py @@ -160,6 +160,23 @@ def test_sync_and_squash_history(self) -> None: self.assertEqual(len(commits), 1) self.assertEqual(commits[0].title, "Super-squash branch 'main' using huggingface_hub") + def test_context_manager(self) -> None: + watched_folder = self.cache_dir / "watched_folder" + watched_folder.mkdir(exist_ok=True, parents=True) + file_path = watched_folder / "file.txt" + + with CommitScheduler( + folder_path=watched_folder, + repo_id=self.repo_name, + every=5, # every 5min + hf_api=self.api, + ) as scheduler: + with file_path.open("w") as f: + f.write("first line\n") + + assert "file.txt" in self.api.list_repo_files(scheduler.repo_id) + assert scheduler._CommitScheduler__stopped # means the scheduler has been stopped when exiting the context + @pytest.mark.usefixtures("fx_cache_dir") class TestPartialFileIO(unittest.TestCase):