Skip to content

Commit

Permalink
Merge pull request #1253 from neptune-ai/aw/fix_fetch_runs_by_state
Browse files Browse the repository at this point in the history
Fix `fetch_runs_table` method for state param usage
  • Loading branch information
AleksanderWWW authored Feb 22, 2023
2 parents 187fe82 + 9c36143 commit 94b5b19
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 3 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
- Removed `get_run_url` method ([#1238](https://github.com/neptune-ai/neptune-client/pull/1238))
- Removed `neptune.new.sync` module ([#1240](https://github.com/neptune-ai/neptune-client/pull/1240))
- Added exception for unsupported types ([#1229](https://github.com/neptune-ai/neptune-client/pull/1229))
- Change run status to Active / Inactive ([#1233](https://github.com/neptune-ai/neptune-client/pull/1233))
- Change run status in the table returned by `fetch_runs_table` to Active / Inactive ([#1233](https://github.com/neptune-ai/neptune-client/pull/1233))
- Package renamed from `neptune-client` to `neptune` ([#1225](https://github.com/neptune-ai/neptune-client/pull/1225))
- Changed values used to filter runs table by state ([#1253](https://github.com/neptune-ai/neptune-client/pull/1253))

### Fixes
- Fixed input value type verification for `append()` method ([#1254](https://github.com/neptune-ai/neptune-client/pull/1254))
Expand Down
13 changes: 13 additions & 0 deletions src/neptune/internal/utils/run_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ class RunState(enum.Enum):
_api_active = "running"
_api_inactive = "idle"

@classmethod
def from_string(cls, value: str) -> "RunState":
try:
return cls(value.capitalize())
except ValueError as e:
raise NeptuneException(f"Can't map RunState to API: {value}") from e

@staticmethod
def from_api(value: str) -> "RunState":
if value == RunState._api_active.value:
Expand All @@ -35,3 +42,9 @@ def from_api(value: str) -> "RunState":
return RunState.inactive
else:
raise NeptuneException(f"Unknown RunState: {value}")

def to_api(self) -> str:
if self is RunState.active:
return self._api_active.value
if self is RunState.inactive:
return self._api_inactive.value
4 changes: 2 additions & 2 deletions src/neptune/metadata_containers/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from neptune.internal.operation_processors.operation_processor import OperationProcessor
from neptune.internal.state import ContainerState
from neptune.internal.utils import as_list
from neptune.internal.utils.run_state import RunState
from neptune.metadata_containers import MetadataContainer
from neptune.metadata_containers.metadata_containers_table import Table
from neptune.types.mode import Mode
Expand Down Expand Up @@ -131,7 +132,7 @@ def _prepare_nql_query(ids, states, owners, tags):
name="sys/state",
type=NQLAttributeType.EXPERIMENT_STATE,
operator=NQLAttributeOperator.EQUALS,
value=state,
value=RunState.from_string(state).to_api(),
)
for state in states
],
Expand Down Expand Up @@ -261,7 +262,6 @@ def fetch_runs_table(
tags = as_list("tag", tag)

nql_query = self._prepare_nql_query(ids, states, owners, tags)

return MetadataContainer._fetch_entries(
self,
child_type=ContainerType.RUN,
Expand Down
18 changes: 18 additions & 0 deletions tests/e2e/standard/test_fetch_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,21 @@ def get_model_versions_as_rows(**kwargs):
return container.fetch_model_versions_table(**kwargs).to_rows()

self._test_fetch_from_container(init_run, get_model_versions_as_rows)

def test_fetch_runs_table_by_state(self, environment, project):
tag = str(uuid.uuid4())
random_val = random.random()
with neptune.init_run(project=environment.project) as run:
run["sys/tags"] = tag
run["some_random_val"] = random_val
runs = project.fetch_runs_table(state="active").to_rows()
assert len(runs) == 1
assert runs[0].get_attribute_value("sys/tags") == tag
assert runs[0].get_attribute_value("some_random_val") == random_val

time.sleep(5)

runs = project.fetch_runs_table(state="inactive").to_rows()
assert len(runs) > 0
assert runs[0].get_attribute_value("sys/tags") == tag
assert runs[0].get_attribute_value("some_random_val") == random_val
22 changes: 22 additions & 0 deletions tests/unit/neptune/new/client/test_run_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
import unittest
from typing import List

from mock import patch

from neptune import init_project
from neptune.exceptions import NeptuneException
from neptune.internal.backends.neptune_backend_mock import NeptuneBackendMock
from neptune.internal.container_type import ContainerType
from neptune.metadata_containers.metadata_containers_table import (
Table,
Expand All @@ -34,3 +38,21 @@ def get_table(self, **kwargs) -> Table:

def get_table_entries(self, table) -> List[TableEntry]:
return table.to_rows()

@patch("neptune.internal.backends.factory.HostedNeptuneBackend", NeptuneBackendMock)
def test_fetch_runs_table_is_case_insensitive(self):
states = ["active", "inactive", "Active", "Inactive", "aCTive", "INacTiVe"]
for state in states:
with self.subTest(state):
try:
self.get_table(state=state)
except Exception as e:
self.fail(e)

@patch("neptune.internal.backends.factory.HostedNeptuneBackend", NeptuneBackendMock)
def test_fetch_runs_table_raises_correct_exception_for_incorrect_states(self):
for incorrect_state in ["idle", "running", "some_arbitrary_state"]:
with self.subTest(incorrect_state):
with self.assertRaises(NeptuneException) as context:
self.get_table(state=incorrect_state)
self.assertEquals(f"Can't map RunState to API: {incorrect_state}", str(context.exception))

0 comments on commit 94b5b19

Please sign in to comment.