Skip to content

Commit

Permalink
Fix #10593 -- add --keep option for dvc experiments remove (#10633)
Browse files Browse the repository at this point in the history
* Add keep_selected parameter, and corresponding code to keep only the selected exps (and remove all the other ones)

* test keep_selected_by_name

* test keep_selected_by_rev

* test keep_selected multiple, by name

* test keep all by name

* test keep by rev, with num=2

* added option to cli

* refactoring to meet pr needs

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixed test_experiments to add keep_selected=False to remove tests

* rename parameter to match cli option

* follow the normal path, then invert the selection before removing

* fixed tests for list ordering + fixed test with non existent name, it didn't make sense to delete everything if an exp name did not exist

* changed cli option comment

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixed typing issue

* updated parameter name

* removed handling queued experiments (since --queue would remove them all)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* code simplification, added __eq__ and __hash__ to be able to compare ExpRefs, updated and parametrized tests.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixed linting issues

* - --keep and --queue together raise an InvalidArgumentError
    - added a test to check if the error is raised
    - fixed CLI message

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* re-run gh tests. Some tests which did not involve my changes started failing while they were passing fine before.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
rmic and pre-commit-ci[bot] authored Nov 30, 2024
1 parent 64ccd9c commit 368c785
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 2 deletions.
9 changes: 9 additions & 0 deletions dvc/commands/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,15 @@ def add_parser(subparsers, parent_parser):
hide_subparsers_from_help(experiments_subparsers)


def add_keep_selection_flag(experiments_subcmd_parser):
experiments_subcmd_parser.add_argument(
"--keep",
action="store_true",
default=False,
help="Keep the selected experiments instead of removing them.",
)


def add_rev_selection_flags(
experiments_subcmd_parser, command: str, default: bool = True
):
Expand Down
4 changes: 3 additions & 1 deletion dvc/commands/experiments/remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def run(self):
num=self.args.num,
queue=self.args.queue,
git_remote=self.args.git_remote,
keep=self.args.keep,
)
if removed:
ui.write(f"Removed experiments: {humanize.join(map(repr, removed))}")
Expand All @@ -44,7 +45,7 @@ def run(self):


def add_parser(experiments_subparsers, parent_parser):
from . import add_rev_selection_flags
from . import add_keep_selection_flag, add_rev_selection_flags

EXPERIMENTS_REMOVE_HELP = "Remove experiments."
experiments_remove_parser = experiments_subparsers.add_parser(
Expand All @@ -57,6 +58,7 @@ def add_parser(experiments_subparsers, parent_parser):
)
remove_group = experiments_remove_parser.add_mutually_exclusive_group()
add_rev_selection_flags(experiments_remove_parser, "Remove", False)
add_keep_selection_flag(experiments_remove_parser)
remove_group.add_argument(
"--queue", action="store_true", help="Remove all queued experiments."
)
Expand Down
9 changes: 9 additions & 0 deletions dvc/repo/experiments/refs.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,12 @@ def from_ref(cls, ref: str):
baseline_sha = parts[2] + parts[3]
name = parts[4] if len(parts) == 5 else None
return cls(baseline_sha, name)

def __eq__(self, other):
if not isinstance(other, ExpRefInfo):
return False

return self.baseline_sha == other.baseline_sha and self.name == other.name

def __hash__(self):
return hash((self.baseline_sha, self.name))
14 changes: 13 additions & 1 deletion dvc/repo/experiments/remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dvc.repo.scm_context import scm_context
from dvc.scm import Git, iter_revs

from .exceptions import UnresolvedExpNamesError
from .exceptions import InvalidArgumentError, UnresolvedExpNamesError
from .utils import exp_refs, exp_refs_by_baseline, push_refspec

if TYPE_CHECKING:
Expand All @@ -30,10 +30,16 @@ def remove( # noqa: C901, PLR0912
num: int = 1,
queue: bool = False,
git_remote: Optional[str] = None,
keep: bool = False,
) -> list[str]:
removed: list[str] = []

if all([keep, queue]):
raise InvalidArgumentError("Cannot use both `--keep` and `--queue`.")

if not any([exp_names, queue, all_commits, rev]):
return removed

celery_queue: LocalCeleryQueue = repo.experiments.celery_queue

if queue:
Expand All @@ -43,6 +49,7 @@ def remove( # noqa: C901, PLR0912

exp_ref_list: list[ExpRefInfo] = []
queue_entry_list: list[QueueEntry] = []

if exp_names:
results: dict[str, ExpRefAndQueueEntry] = (
celery_queue.get_ref_and_entry_by_names(exp_names, git_remote)
Expand Down Expand Up @@ -70,6 +77,10 @@ def remove( # noqa: C901, PLR0912
exp_ref_list.extend(exp_refs(repo.scm, git_remote))
removed = [ref.name for ref in exp_ref_list]

if keep:
exp_ref_list = list(set(exp_refs(repo.scm, git_remote)) - set(exp_ref_list))
removed = [ref.name for ref in exp_ref_list]

if exp_ref_list:
_remove_commited_exps(repo.scm, exp_ref_list, git_remote)

Expand All @@ -83,6 +94,7 @@ def remove( # noqa: C901, PLR0912

removed_refs = [str(r) for r in exp_ref_list]
notify_refs_to_studio(repo, git_remote, removed=removed_refs)

return removed


Expand Down
82 changes: 82 additions & 0 deletions tests/func/experiments/test_remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,85 @@ def test_remove_multi_rev(tmp_dir, scm, dvc, exp_stage):

assert scm.get_ref(str(baseline_exp_ref)) is None
assert scm.get_ref(str(new_exp_ref)) is None


@pytest.mark.parametrize(
"keep, expected_removed",
[
[["exp1"], ["exp2", "exp3"]],
[["exp1", "exp2"], ["exp3"]],
[["exp1", "exp2", "exp3"], []],
[[], []], # remove does nothing if no experiments are specified
],
)
def test_keep_selected_by_name(tmp_dir, scm, dvc, exp_stage, keep, expected_removed):
# Setup: Run experiments
refs = {}
for i in range(1, len(keep) + len(expected_removed) + 1):
results = dvc.experiments.run(
exp_stage.addressing, params=[f"foo={i}"], name=f"exp{i}"
)
refs[f"exp{i}"] = first(exp_refs_by_rev(scm, first(results)))
assert scm.get_ref(str(refs[f"exp{i}"])) is not None

removed = dvc.experiments.remove(exp_names=keep, keep=True)
assert sorted(removed) == sorted(expected_removed)

for exp in expected_removed:
assert scm.get_ref(str(refs[exp])) is None

for exp in keep:
assert scm.get_ref(str(refs[exp])) is not None


def test_keep_selected_by_nonexistent_name(tmp_dir, scm, dvc, exp_stage):
# non existent name should raise an error
with pytest.raises(UnresolvedExpNamesError):
dvc.experiments.remove(exp_names=["nonexistent"], keep=True)


@pytest.mark.parametrize(
"num_exps, rev, num, expected_removed",
[
[2, "exp1", 1, ["exp2"]],
[3, "exp3", 1, ["exp1", "exp2"]],
[3, "exp3", 2, ["exp1"]],
[3, "exp3", 3, []],
[3, "exp2", 2, ["exp3"]],
[4, "exp2", 2, ["exp3", "exp4"]],
[4, "exp4", 2, ["exp1", "exp2"]],
[1, None, 1, []], # remove does nothing if no experiments are specified
],
)
def test_keep_selected_by_rev(
tmp_dir, scm, dvc, exp_stage, num_exps, rev, num, expected_removed
):
refs = {}
revs = {}
# Setup: Run experiments and commit
for i in range(1, num_exps + 1):
scm.commit(f"commit{i}")
results = dvc.experiments.run(
exp_stage.addressing, params=[f"foo={i}"], name=f"exp{i}"
)
refs[f"exp{i}"] = first(exp_refs_by_rev(scm, first(results)))
revs[f"exp{i}"] = scm.get_rev()
assert scm.get_ref(str(refs[f"exp{i}"])) is not None

# Keep the experiment from the new revision
removed = dvc.experiments.remove(rev=revs.get(rev), num=num, keep=True)
assert sorted(removed) == sorted(expected_removed)

# Check remaining experiments
for exp in expected_removed:
assert scm.get_ref(str(refs[exp])) is None

for exp, ref in refs.items():
if exp not in expected_removed:
assert scm.get_ref(str(ref)) is not None


def test_remove_with_queue_and_keep(tmp_dir, scm, dvc, exp_stage):
# This should raise an exception, until decided otherwise
with pytest.raises(InvalidArgumentError):
dvc.experiments.remove(queue=True, keep=True)
2 changes: 2 additions & 0 deletions tests/unit/command/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ def test_experiments_remove_flag(dvc, scm, mocker, capsys, caplog):
num=2,
queue=False,
git_remote="myremote",
keep=False,
)


Expand All @@ -410,6 +411,7 @@ def test_experiments_remove_special(dvc, scm, mocker, capsys, caplog):
num=1,
queue=False,
git_remote="myremote",
keep=False,
)


Expand Down

0 comments on commit 368c785

Please sign in to comment.