Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix fetch_runs_table method for state param usage #1253

Merged
merged 12 commits into from
Feb 22, 2023
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
- Removed `get_run_url` method ([#1238](https://github.com/neptune-ai/neptune-client/pull/1238))
- Removed `neptune.new.sync` module ([#1240](https://github.com/neptune-ai/neptune-client/pull/1240))
- Added exception for unsupported types ([#1229](https://github.com/neptune-ai/neptune-client/pull/1229))
- Change run status to Active / Inactive ([#1233](https://github.com/neptune-ai/neptune-client/pull/1233))
- Change run status in the table returned by `fetch_runs_table` to Active / Inactive ([#1233](https://github.com/neptune-ai/neptune-client/pull/1233))
- Package renamed from `neptune-client` to `neptune` ([#1225](https://github.com/neptune-ai/neptune-client/pull/1225))
- Changed values used to filter runs table by state ([#1253](https://github.com/neptune-ai/neptune-client/pull/1253))

## neptune-client 0.16.18

Expand Down
13 changes: 13 additions & 0 deletions src/neptune/internal/utils/run_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ class RunState(enum.Enum):
_api_active = "running"
_api_inactive = "idle"

@classmethod
def from_string(cls, value: str) -> "RunState":
try:
return cls(value.capitalize())
except ValueError as e:
raise NeptuneException(f"Can't map RunState to API: {value}") from e

@staticmethod
def from_api(value: str) -> "RunState":
if value == RunState._api_active.value:
Expand All @@ -35,3 +42,9 @@ def from_api(value: str) -> "RunState":
return RunState.inactive
else:
raise NeptuneException(f"Unknown RunState: {value}")

def to_api(self) -> str:
if self is RunState.active:
return self._api_active.value
if self is RunState.inactive:
return self._api_inactive.value
4 changes: 2 additions & 2 deletions src/neptune/metadata_containers/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from neptune.internal.operation_processors.operation_processor import OperationProcessor
from neptune.internal.state import ContainerState
from neptune.internal.utils import as_list
from neptune.internal.utils.run_state import RunState
from neptune.metadata_containers import MetadataContainer
from neptune.metadata_containers.metadata_containers_table import Table
from neptune.types.mode import Mode
Expand Down Expand Up @@ -131,7 +132,7 @@ def _prepare_nql_query(ids, states, owners, tags):
name="sys/state",
type=NQLAttributeType.EXPERIMENT_STATE,
operator=NQLAttributeOperator.EQUALS,
value=state,
value=RunState.from_string(state).to_api(),
)
for state in states
],
Expand Down Expand Up @@ -261,7 +262,6 @@ def fetch_runs_table(
tags = as_list("tag", tag)

nql_query = self._prepare_nql_query(ids, states, owners, tags)

return MetadataContainer._fetch_entries(
self,
child_type=ContainerType.RUN,
Expand Down
18 changes: 18 additions & 0 deletions tests/e2e/standard/test_fetch_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,21 @@ def get_model_versions_as_rows(**kwargs):
return container.fetch_model_versions_table(**kwargs).to_rows()

self._test_fetch_from_container(init_run, get_model_versions_as_rows)

def test_fetch_runs_table_by_state(self, environment, project):
tag = str(uuid.uuid4())
random_val = random.random()
with neptune.init_run(project=environment.project) as run:
run["sys/tags"] = tag
run["some_random_val"] = random_val
runs = project.fetch_runs_table(state="active").to_rows()
assert len(runs) == 1
assert runs[0].get_attribute_value("sys/tags") == tag
assert runs[0].get_attribute_value("some_random_val") == random_val

time.sleep(5)

runs = project.fetch_runs_table(state="inactive").to_rows()
assert len(runs) > 0
assert runs[0].get_attribute_value("sys/tags") == tag
assert runs[0].get_attribute_value("some_random_val") == random_val
22 changes: 22 additions & 0 deletions tests/unit/neptune/new/client/test_run_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
import unittest
from typing import List

from mock import patch

from neptune import init_project
from neptune.exceptions import NeptuneException
from neptune.internal.backends.neptune_backend_mock import NeptuneBackendMock
from neptune.internal.container_type import ContainerType
from neptune.metadata_containers.metadata_containers_table import (
Table,
Expand All @@ -34,3 +38,21 @@ def get_table(self, **kwargs) -> Table:

def get_table_entries(self, table) -> List[TableEntry]:
return table.to_rows()

@patch("neptune.internal.backends.factory.HostedNeptuneBackend", NeptuneBackendMock)
def test_fetch_runs_table_is_case_insensitive(self):
states = ["active", "inactive", "Active", "Inactive", "aCTive", "INacTiVe"]
for state in states:
with self.subTest(state):
try:
self.get_table(state=state)
except Exception as e:
self.fail(e)

@patch("neptune.internal.backends.factory.HostedNeptuneBackend", NeptuneBackendMock)
def test_fetch_runs_table_raises_correct_exception_for_incorrect_states(self):
for incorrect_state in ["idle", "running", "some_arbitrary_state"]:
with self.subTest(incorrect_state):
with self.assertRaises(NeptuneException) as context:
self.get_table(state=incorrect_state)
self.assertEquals(f"Can't map RunState to API: {incorrect_state}", str(context.exception))