Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce querying capabilities to fetch_models_table #1677

Merged
merged 8 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
- Added `get_workspace_status()` method to management API ([#1662](https://github.com/neptune-ai/neptune-client/pull/1662))
- Added auto-scaling pixel values for image logging ([#1664](https://github.com/neptune-ai/neptune-client/pull/1664))
- Introduce querying capabilities to `fetch_runs_table()` ([#1660](https://github.com/neptune-ai/neptune-client/pull/1660))
- Introduce querying capabilities to `fetch_models_table()` ([#1677](https://github.com/neptune-ai/neptune-client/pull/1677))

### Fixes
- Restored support for SSL verification exception ([#1661](https://github.com/neptune-ai/neptune-client/pull/1661))
Expand Down
26 changes: 12 additions & 14 deletions src/neptune/metadata_containers/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,6 @@
from neptune.envs import CONNECTION_MODE
from neptune.exceptions import InactiveProjectException
from neptune.internal.backends.api_model import ApiExperiment
from neptune.internal.backends.nql import (
NQLAttributeOperator,
NQLAttributeType,
NQLEmptyQuery,
NQLQueryAttribute,
)
from neptune.internal.container_type import ContainerType
from neptune.internal.init.parameters import (
ASYNC_LAG_THRESHOLD,
Expand Down Expand Up @@ -345,6 +339,7 @@ def fetch_runs_table(
def fetch_models_table(
self,
*,
query: Optional[str] = None,
columns: Optional[Iterable[str]] = None,
trashed: Optional[bool] = False,
limit: Optional[int] = None,
Expand All @@ -355,6 +350,8 @@ def fetch_models_table(
"""Retrieve models stored in the project.

Args:
query: NQL query string. Syntax: https://docs.neptune.ai/usage/nql/
Example: `"(model_size: float > 100) AND (backbone: string = VGG)"`.
trashed: Whether to retrieve trashed models.
If `True`, only trashed models are retrieved.
If `False`, only not-trashed models are retrieved.
Expand Down Expand Up @@ -392,9 +389,15 @@ def fetch_models_table(
... # Extract the ID of the first listed (oldest) model object
... last_model_id = models_table_df["sys/id"].values[0]

>>> # Fetch models with VGG backbone
... models_table_df = project.fetch_models_table(
query="(backbone: string = VGG)"
).to_pandas()

See also the API reference in the docs:
https://docs.neptune.ai/api/project#fetch_models_table
"""
verify_type("query", query, (str, type(None)))
verify_type("limit", limit, (int, type(None)))
verify_type("sort_by", sort_by, str)
verify_type("ascending", ascending, bool)
Expand All @@ -403,17 +406,12 @@ def fetch_models_table(
if isinstance(limit, int) and limit <= 0:
raise ValueError(f"Parameter 'limit' must be a positive integer or None. Got {limit}.")

query = query if query is not None else ""
nql = build_raw_query(query=query, trashed=trashed)
return MetadataContainer._fetch_entries(
self,
child_type=ContainerType.MODEL,
query=NQLQueryAttribute(
name="sys/trashed",
type=NQLAttributeType.BOOLEAN,
operator=NQLAttributeOperator.EQUALS,
value=trashed,
)
if trashed is not None
else NQLEmptyQuery,
query=nql,
columns=columns,
limit=limit,
sort_by=sort_by,
Expand Down
48 changes: 48 additions & 0 deletions tests/e2e/standard/test_fetch_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,3 +360,51 @@ def test_fetch_runs_invalid_query_handling(self, project):
# then
with pytest.raises(NeptuneInvalidQueryException):
next(iter(runs_table))

def test_fetch_models_raw_query_trashed(self, environment, project):
# given
val: float = 2.2
with neptune.init_model(project=environment.project, key=a_key(), name="name-1") as model:
model["key"] = val

with neptune.init_model(project=environment.project, key=a_key(), name="name-2") as model:
model["key"] = val

time.sleep(5)

# when
models = project.fetch_models_table(
query=f"(key: float = {val})", progress_bar=False, trashed=False
).to_pandas()

# then
model_list = models["sys/name"].dropna().to_list()
assert sorted(model_list) == sorted(["name-1", "name-2"])

# when
neptune.management.trash_objects(
project=environment.project, ids=models[models["sys/name"] == "name-1"]["sys/id"].item()
)

time.sleep(5)

trashed_vals = [True, False, None]
expected_model_names = [["name-1"], ["name-2"], ["name-1", "name-2"]]

for trashed, model_names in zip(trashed_vals, expected_model_names):
# when
models = project.fetch_models_table(
query=f"(key: float = {val})", progress_bar=False, trashed=trashed
).to_pandas()

# then
model_list = models["sys/name"].dropna().to_list()
assert sorted(model_list) == sorted(model_names)

def test_fetch_models_invalid_query_handling(self, project):
# given
runs_table = project.fetch_models_table(query="key: float = (-_-)", progress_bar=False)

# then
with pytest.raises(NeptuneInvalidQueryException):
next(iter(runs_table))
Loading