diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index f4a54ecc4dabbd..4a43ca2aa9cb21 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -3790,7 +3790,10 @@ def push_to_hub(self, commit_message: Optional[str] = "End of training", blockin # Add additional tags in the case the model has already some tags and users pass # "tags" argument to `push_to_hub` so that trainer automatically handles internal tags # from all models since Trainer does not call `model.push_to_hub`. - if "tags" in kwargs and getattr(self.model, "model_tags", None) is not None: + if getattr(self.model, "model_tags", None) is not None: + if "tags" not in kwargs: + kwargs["tags"] = [] + # If it is a string, convert it to a list if isinstance(kwargs["tags"], str): kwargs["tags"] = [kwargs["tags"]] diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 55cc35cf6aa3eb..b56fddf94bd579 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -30,7 +30,7 @@ from unittest.mock import Mock, patch import numpy as np -from huggingface_hub import HfFolder, delete_repo, list_repo_commits, list_repo_files +from huggingface_hub import HfFolder, ModelCard, delete_repo, list_repo_commits, list_repo_files from parameterized import parameterized from requests.exceptions import HTTPError @@ -2423,7 +2423,13 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - for model in ["test-trainer", "test-trainer-epoch", "test-trainer-step", "test-trainer-tensorboard"]: + for model in [ + "test-trainer", + "test-trainer-epoch", + "test-trainer-step", + "test-trainer-tensorboard", + "test-trainer-tags", + ]: try: delete_repo(token=cls._token, repo_id=model) except HTTPError: @@ -2554,6 +2560,31 @@ def test_push_to_hub_with_tensorboard_logs(self): assert found_log is True, "No tensorboard log found in repo" + def test_push_to_hub_tags(self): + # Checks if `trainer.push_to_hub()` works correctly by adding the desired + # tag without having to pass `tags` in `push_to_hub` + # see: + with tempfile.TemporaryDirectory() as tmp_dir: + trainer = get_regression_trainer( + output_dir=os.path.join(tmp_dir, "test-trainer-tags"), + push_to_hub=True, + hub_token=self._token, + ) + + trainer.model.add_model_tags(["test-trainer-tags"]) + + url = trainer.push_to_hub() + + # Extract repo_name from the url + re_search = re.search(ENDPOINT_STAGING + r"/([^/]+/[^/]+)/", url) + self.assertTrue(re_search is not None) + repo_name = re_search.groups()[0] + + self.assertEqual(repo_name, f"{USER}/test-trainer-tags") + + model_card = ModelCard.load(repo_name) + self.assertTrue("test-trainer-tags" in model_card.data.tags) + @require_torch @require_optuna