diff --git a/dvc/commands/experiments/__init__.py b/dvc/commands/experiments/__init__.py index 6b29bfacb5..0413765d2b 100644 --- a/dvc/commands/experiments/__init__.py +++ b/dvc/commands/experiments/__init__.py @@ -58,6 +58,15 @@ def add_parser(subparsers, parent_parser): hide_subparsers_from_help(experiments_subparsers) +def add_keep_selection_flag(experiments_subcmd_parser): + experiments_subcmd_parser.add_argument( + "--keep", + action="store_true", + default=False, + help="Keep the selected experiments instead of removing them.", + ) + + def add_rev_selection_flags( experiments_subcmd_parser, command: str, default: bool = True ): diff --git a/dvc/commands/experiments/remove.py b/dvc/commands/experiments/remove.py index f7c48d73ac..ecb6541c07 100644 --- a/dvc/commands/experiments/remove.py +++ b/dvc/commands/experiments/remove.py @@ -34,6 +34,7 @@ def run(self): num=self.args.num, queue=self.args.queue, git_remote=self.args.git_remote, + keep=self.args.keep, ) if removed: ui.write(f"Removed experiments: {humanize.join(map(repr, removed))}") @@ -44,7 +45,7 @@ def run(self): def add_parser(experiments_subparsers, parent_parser): - from . import add_rev_selection_flags + from . import add_keep_selection_flag, add_rev_selection_flags EXPERIMENTS_REMOVE_HELP = "Remove experiments." experiments_remove_parser = experiments_subparsers.add_parser( @@ -57,6 +58,7 @@ def add_parser(experiments_subparsers, parent_parser): ) remove_group = experiments_remove_parser.add_mutually_exclusive_group() add_rev_selection_flags(experiments_remove_parser, "Remove", False) + add_keep_selection_flag(experiments_remove_parser) remove_group.add_argument( "--queue", action="store_true", help="Remove all queued experiments." ) diff --git a/dvc/repo/experiments/refs.py b/dvc/repo/experiments/refs.py index 5497a25c77..3a34ff35a0 100644 --- a/dvc/repo/experiments/refs.py +++ b/dvc/repo/experiments/refs.py @@ -67,3 +67,12 @@ def from_ref(cls, ref: str): baseline_sha = parts[2] + parts[3] name = parts[4] if len(parts) == 5 else None return cls(baseline_sha, name) + + def __eq__(self, other): + if not isinstance(other, ExpRefInfo): + return False + + return self.baseline_sha == other.baseline_sha and self.name == other.name + + def __hash__(self): + return hash((self.baseline_sha, self.name)) diff --git a/dvc/repo/experiments/remove.py b/dvc/repo/experiments/remove.py index ef6ee086ab..cd8ca07e8b 100644 --- a/dvc/repo/experiments/remove.py +++ b/dvc/repo/experiments/remove.py @@ -6,7 +6,7 @@ from dvc.repo.scm_context import scm_context from dvc.scm import Git, iter_revs -from .exceptions import UnresolvedExpNamesError +from .exceptions import InvalidArgumentError, UnresolvedExpNamesError from .utils import exp_refs, exp_refs_by_baseline, push_refspec if TYPE_CHECKING: @@ -30,10 +30,16 @@ def remove( # noqa: C901, PLR0912 num: int = 1, queue: bool = False, git_remote: Optional[str] = None, + keep: bool = False, ) -> list[str]: removed: list[str] = [] + + if all([keep, queue]): + raise InvalidArgumentError("Cannot use both `--keep` and `--queue`.") + if not any([exp_names, queue, all_commits, rev]): return removed + celery_queue: LocalCeleryQueue = repo.experiments.celery_queue if queue: @@ -43,6 +49,7 @@ def remove( # noqa: C901, PLR0912 exp_ref_list: list[ExpRefInfo] = [] queue_entry_list: list[QueueEntry] = [] + if exp_names: results: dict[str, ExpRefAndQueueEntry] = ( celery_queue.get_ref_and_entry_by_names(exp_names, git_remote) @@ -70,6 +77,10 @@ def remove( # noqa: C901, PLR0912 exp_ref_list.extend(exp_refs(repo.scm, git_remote)) removed = [ref.name for ref in exp_ref_list] + if keep: + exp_ref_list = list(set(exp_refs(repo.scm, git_remote)) - set(exp_ref_list)) + removed = [ref.name for ref in exp_ref_list] + if exp_ref_list: _remove_commited_exps(repo.scm, exp_ref_list, git_remote) @@ -83,6 +94,7 @@ def remove( # noqa: C901, PLR0912 removed_refs = [str(r) for r in exp_ref_list] notify_refs_to_studio(repo, git_remote, removed=removed_refs) + return removed diff --git a/tests/func/experiments/test_remove.py b/tests/func/experiments/test_remove.py index 60d653d1e8..1864cc541d 100644 --- a/tests/func/experiments/test_remove.py +++ b/tests/func/experiments/test_remove.py @@ -179,3 +179,85 @@ def test_remove_multi_rev(tmp_dir, scm, dvc, exp_stage): assert scm.get_ref(str(baseline_exp_ref)) is None assert scm.get_ref(str(new_exp_ref)) is None + + +@pytest.mark.parametrize( + "keep, expected_removed", + [ + [["exp1"], ["exp2", "exp3"]], + [["exp1", "exp2"], ["exp3"]], + [["exp1", "exp2", "exp3"], []], + [[], []], # remove does nothing if no experiments are specified + ], +) +def test_keep_selected_by_name(tmp_dir, scm, dvc, exp_stage, keep, expected_removed): + # Setup: Run experiments + refs = {} + for i in range(1, len(keep) + len(expected_removed) + 1): + results = dvc.experiments.run( + exp_stage.addressing, params=[f"foo={i}"], name=f"exp{i}" + ) + refs[f"exp{i}"] = first(exp_refs_by_rev(scm, first(results))) + assert scm.get_ref(str(refs[f"exp{i}"])) is not None + + removed = dvc.experiments.remove(exp_names=keep, keep=True) + assert sorted(removed) == sorted(expected_removed) + + for exp in expected_removed: + assert scm.get_ref(str(refs[exp])) is None + + for exp in keep: + assert scm.get_ref(str(refs[exp])) is not None + + +def test_keep_selected_by_nonexistent_name(tmp_dir, scm, dvc, exp_stage): + # non existent name should raise an error + with pytest.raises(UnresolvedExpNamesError): + dvc.experiments.remove(exp_names=["nonexistent"], keep=True) + + +@pytest.mark.parametrize( + "num_exps, rev, num, expected_removed", + [ + [2, "exp1", 1, ["exp2"]], + [3, "exp3", 1, ["exp1", "exp2"]], + [3, "exp3", 2, ["exp1"]], + [3, "exp3", 3, []], + [3, "exp2", 2, ["exp3"]], + [4, "exp2", 2, ["exp3", "exp4"]], + [4, "exp4", 2, ["exp1", "exp2"]], + [1, None, 1, []], # remove does nothing if no experiments are specified + ], +) +def test_keep_selected_by_rev( + tmp_dir, scm, dvc, exp_stage, num_exps, rev, num, expected_removed +): + refs = {} + revs = {} + # Setup: Run experiments and commit + for i in range(1, num_exps + 1): + scm.commit(f"commit{i}") + results = dvc.experiments.run( + exp_stage.addressing, params=[f"foo={i}"], name=f"exp{i}" + ) + refs[f"exp{i}"] = first(exp_refs_by_rev(scm, first(results))) + revs[f"exp{i}"] = scm.get_rev() + assert scm.get_ref(str(refs[f"exp{i}"])) is not None + + # Keep the experiment from the new revision + removed = dvc.experiments.remove(rev=revs.get(rev), num=num, keep=True) + assert sorted(removed) == sorted(expected_removed) + + # Check remaining experiments + for exp in expected_removed: + assert scm.get_ref(str(refs[exp])) is None + + for exp, ref in refs.items(): + if exp not in expected_removed: + assert scm.get_ref(str(ref)) is not None + + +def test_remove_with_queue_and_keep(tmp_dir, scm, dvc, exp_stage): + # This should raise an exception, until decided otherwise + with pytest.raises(InvalidArgumentError): + dvc.experiments.remove(queue=True, keep=True) diff --git a/tests/unit/command/test_experiments.py b/tests/unit/command/test_experiments.py index 686d414b92..5831eacad4 100644 --- a/tests/unit/command/test_experiments.py +++ b/tests/unit/command/test_experiments.py @@ -384,6 +384,7 @@ def test_experiments_remove_flag(dvc, scm, mocker, capsys, caplog): num=2, queue=False, git_remote="myremote", + keep=False, ) @@ -410,6 +411,7 @@ def test_experiments_remove_special(dvc, scm, mocker, capsys, caplog): num=1, queue=False, git_remote="myremote", + keep=False, )