From b6a72f229a50cd5c8550685a85095ece3a1feaa1 Mon Sep 17 00:00:00 2001 From: AleksanderWWW Date: Wed, 6 Dec 2023 09:43:32 +0100 Subject: [PATCH 01/22] add fetching with iterators (#1585) --- CHANGELOG.md | 1 + .../backends/hosted_neptune_backend.py | 19 ++++++----- .../internal/backends/neptune_backend.py | 3 +- .../internal/backends/neptune_backend_mock.py | 3 +- .../metadata_containers_table.py | 27 ++++++++++------ tests/e2e/standard/test_fetch_tables.py | 5 +++ .../new/client/abstract_tables_test.py | 32 +++++++++++++++++++ 7 files changed, 68 insertions(+), 22 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bc14e3142..6cc339431 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ### Features - Add support for seaborn figures ([#1613](https://github.com/neptune-ai/neptune-client/pull/1613)) +- Added fetching with iterators in `fetch_*_table()` methods ([#1585](https://github.com/neptune-ai/neptune-client/pull/1585)) ### Fixes - Add direct requirement of `typing-extensions` ([#1586](https://github.com/neptune-ai/neptune-client/pull/1586)) diff --git a/src/neptune/internal/backends/hosted_neptune_backend.py b/src/neptune/internal/backends/hosted_neptune_backend.py index e02684676..550489684 100644 --- a/src/neptune/internal/backends/hosted_neptune_backend.py +++ b/src/neptune/internal/backends/hosted_neptune_backend.py @@ -24,6 +24,7 @@ TYPE_CHECKING, Any, Dict, + Generator, Iterable, List, Optional, @@ -1021,22 +1022,20 @@ def search_leaderboard_entries( types: Optional[Iterable[ContainerType]] = None, query: Optional[NQLQuery] = None, columns: Optional[Iterable[str]] = None, - ) -> List[LeaderboardEntry]: + ) -> Generator[LeaderboardEntry, None, None]: step_size = int(os.getenv(NEPTUNE_FETCH_TABLE_STEP_SIZE, "100")) types_filter = list(map(lambda container_type: container_type.to_api(), types)) if types else None attributes_filter = {"attributeFilters": [{"path": column} for column in columns]} if columns else {} try: - return list( - iter_over_pages( - client=self.leaderboard_client, - project_id=project_id, - types=types_filter, - query=query, - attributes_filter=attributes_filter, - step_size=step_size, - ) + return iter_over_pages( + client=self.leaderboard_client, + project_id=project_id, + types=types_filter, + query=query, + attributes_filter=attributes_filter, + step_size=step_size, ) except HTTPNotFound: raise ProjectNotFound(project_id) diff --git a/src/neptune/internal/backends/neptune_backend.py b/src/neptune/internal/backends/neptune_backend.py index e9693f40a..e989e0344 100644 --- a/src/neptune/internal/backends/neptune_backend.py +++ b/src/neptune/internal/backends/neptune_backend.py @@ -18,6 +18,7 @@ import abc from typing import ( Any, + Generator, List, Optional, Tuple, @@ -303,7 +304,7 @@ def search_leaderboard_entries( types: Optional[List[ContainerType]] = None, query: Optional[NQLQuery] = None, columns: Optional[List[str]] = None, - ) -> List[LeaderboardEntry]: + ) -> Generator[LeaderboardEntry, None, None]: pass @abc.abstractmethod diff --git a/src/neptune/internal/backends/neptune_backend_mock.py b/src/neptune/internal/backends/neptune_backend_mock.py index 4c103363e..f5ed71d96 100644 --- a/src/neptune/internal/backends/neptune_backend_mock.py +++ b/src/neptune/internal/backends/neptune_backend_mock.py @@ -23,6 +23,7 @@ from typing import ( Any, Dict, + Generator, Iterable, List, Optional, @@ -542,7 +543,7 @@ def search_leaderboard_entries( types: Optional[Iterable[ContainerType]] = None, query: Optional[NQLQuery] = None, columns: Optional[Iterable[str]] = None, - ) -> List[LeaderboardEntry]: + ) -> Generator[LeaderboardEntry, None, None]: """Non relevant for mock""" class AttributeTypeConverterValueVisitor(ValueVisitor[AttributeType]): diff --git a/src/neptune/metadata_containers/metadata_containers_table.py b/src/neptune/metadata_containers/metadata_containers_table.py index 5ebf49992..19aafab6e 100644 --- a/src/neptune/metadata_containers/metadata_containers_table.py +++ b/src/neptune/metadata_containers/metadata_containers_table.py @@ -20,6 +20,7 @@ from typing import ( Any, Dict, + Generator, List, Optional, Union, @@ -157,22 +158,28 @@ def __init__( self, backend: NeptuneBackend, container_type: ContainerType, - entries: List[LeaderboardEntry], + entries: Generator[LeaderboardEntry, None, None], ): self._backend = backend self._entries = entries self._container_type = container_type + self._iterator = iter(entries if entries else ()) def to_rows(self) -> List[TableEntry]: - return [ - TableEntry( - backend=self._backend, - container_type=self._container_type, - _id=e.id, - attributes=e.attributes, - ) - for e in self._entries - ] + return list(self) + + def __iter__(self) -> "Table": + return self + + def __next__(self) -> TableEntry: + entry = next(self._iterator) + + return TableEntry( + backend=self._backend, + container_type=self._container_type, + _id=entry.id, + attributes=entry.attributes, + ) def to_pandas(self): import pandas as pd diff --git a/tests/e2e/standard/test_fetch_tables.py b/tests/e2e/standard/test_fetch_tables.py index 375e1dd67..dab720b96 100644 --- a/tests/e2e/standard/test_fetch_tables.py +++ b/tests/e2e/standard/test_fetch_tables.py @@ -67,6 +67,11 @@ def test_fetch_model_versions_with_correct_ids(self, container: Model, environme for index in range(versions_to_initialize): assert versions_table[index].get_attribute_value("sys/id") == f"{model_sys_id}-{index + 1}" + versions_table_gen = container.fetch_model_versions_table() + for te1, te2 in zip(list(versions_table_gen), versions_table): + assert te1._id == te2._id + assert te1._container_type == te2._container_type + def _test_fetch_from_container(self, init_container, get_containers_as_rows): container_id1, container_id2 = None, None key1 = self.gen_key() diff --git a/tests/unit/neptune/new/client/abstract_tables_test.py b/tests/unit/neptune/new/client/abstract_tables_test.py index 6504c8c76..a78d7baf8 100644 --- a/tests/unit/neptune/new/client/abstract_tables_test.py +++ b/tests/unit/neptune/new/client/abstract_tables_test.py @@ -128,6 +128,38 @@ def test_get_table_as_pandas(self, search_leaderboard_entries): with self.assertRaises(KeyError): self.assertTrue(df["image/series"]) + @patch.object(NeptuneBackendMock, "search_leaderboard_entries") + def test_get_table_as_rows(self, search_leaderboard_entries): + # given + now = datetime.now() + attributes = self.build_attributes_leaderboard(now) + + # and + empty_entry = LeaderboardEntry(str(uuid.uuid4()), []) + filled_entry = LeaderboardEntry(str(uuid.uuid4()), attributes) + search_leaderboard_entries.return_value = [empty_entry, filled_entry] + + # and + # (check if using both to_rows and table generator produces the same results) + table_gen = self.get_table() + next(table_gen) # to move to the second table entry + # when + for row in (self.get_table().to_rows()[1], next(table_gen)): + # then + self.assertEqual("Inactive", row.get_attribute_value("run/state")) + self.assertEqual(12.5, row.get_attribute_value("float")) + self.assertEqual("some text", row.get_attribute_value("string")) + self.assertEqual(now, row.get_attribute_value("datetime")) + self.assertEqual(8.7, row.get_attribute_value("float/series")) + self.assertEqual("last text", row.get_attribute_value("string/series")) + self.assertEqual({"a", "b"}, row.get_attribute_value("string/set")) + self.assertEqual("abcdef0123456789", row.get_attribute_value("git/ref")) + + with self.assertRaises(MetadataInconsistency): + row.get_attribute_value("file") + with self.assertRaises(MetadataInconsistency): + row.get_attribute_value("image/series") + @patch.object(NeptuneBackendMock, "search_leaderboard_entries") @patch.object(NeptuneBackendMock, "download_file") @patch.object(NeptuneBackendMock, "download_file_set") From 7b38f7481e475c860644da775041ef68b3c9f9d7 Mon Sep 17 00:00:00 2001 From: AleksanderWWW Date: Mon, 11 Dec 2023 10:42:41 +0100 Subject: [PATCH 02/22] Add limit parameter to fetch_*_table (#1593) --- CHANGELOG.md | 1 + .../backends/hosted_neptune_backend.py | 5 +++- .../internal/backends/neptune_backend.py | 1 + .../internal/backends/neptune_backend_mock.py | 1 + .../metadata_containers/metadata_container.py | 11 +++++---- src/neptune/metadata_containers/model.py | 10 +++++++- src/neptune/metadata_containers/project.py | 22 +++++++++++++++++- .../new/client/abstract_tables_test.py | 23 +++++++++++++++++++ 8 files changed, 66 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6cc339431..09390d73b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ### Features - Add support for seaborn figures ([#1613](https://github.com/neptune-ai/neptune-client/pull/1613)) - Added fetching with iterators in `fetch_*_table()` methods ([#1585](https://github.com/neptune-ai/neptune-client/pull/1585)) +- Added `limit` parameter to `fetch_*_table()` methods ([#1593](https://github.com/neptune-ai/neptune-client/pull/1593)) ### Fixes - Add direct requirement of `typing-extensions` ([#1586](https://github.com/neptune-ai/neptune-client/pull/1586)) diff --git a/src/neptune/internal/backends/hosted_neptune_backend.py b/src/neptune/internal/backends/hosted_neptune_backend.py index 550489684..b56b812ab 100644 --- a/src/neptune/internal/backends/hosted_neptune_backend.py +++ b/src/neptune/internal/backends/hosted_neptune_backend.py @@ -1022,8 +1022,11 @@ def search_leaderboard_entries( types: Optional[Iterable[ContainerType]] = None, query: Optional[NQLQuery] = None, columns: Optional[Iterable[str]] = None, + limit: Optional[int] = None, ) -> Generator[LeaderboardEntry, None, None]: - step_size = int(os.getenv(NEPTUNE_FETCH_TABLE_STEP_SIZE, "100")) + default_step_size = int(os.getenv(NEPTUNE_FETCH_TABLE_STEP_SIZE, "100")) + + step_size = min(default_step_size, limit) if limit else default_step_size types_filter = list(map(lambda container_type: container_type.to_api(), types)) if types else None attributes_filter = {"attributeFilters": [{"path": column} for column in columns]} if columns else {} diff --git a/src/neptune/internal/backends/neptune_backend.py b/src/neptune/internal/backends/neptune_backend.py index e989e0344..17a83d647 100644 --- a/src/neptune/internal/backends/neptune_backend.py +++ b/src/neptune/internal/backends/neptune_backend.py @@ -304,6 +304,7 @@ def search_leaderboard_entries( types: Optional[List[ContainerType]] = None, query: Optional[NQLQuery] = None, columns: Optional[List[str]] = None, + limit: Optional[int] = None, ) -> Generator[LeaderboardEntry, None, None]: pass diff --git a/src/neptune/internal/backends/neptune_backend_mock.py b/src/neptune/internal/backends/neptune_backend_mock.py index f5ed71d96..7d291cf0a 100644 --- a/src/neptune/internal/backends/neptune_backend_mock.py +++ b/src/neptune/internal/backends/neptune_backend_mock.py @@ -543,6 +543,7 @@ def search_leaderboard_entries( types: Optional[Iterable[ContainerType]] = None, query: Optional[NQLQuery] = None, columns: Optional[Iterable[str]] = None, + limit: Optional[int] = None, ) -> Generator[LeaderboardEntry, None, None]: """Non relevant for mock""" diff --git a/src/neptune/metadata_containers/metadata_container.py b/src/neptune/metadata_containers/metadata_container.py index b332b1ec3..1d2c5538c 100644 --- a/src/neptune/metadata_containers/metadata_container.py +++ b/src/neptune/metadata_containers/metadata_container.py @@ -652,19 +652,20 @@ def _startup(self, debug_mode): def _shutdown_hook(self): self.stop() - def _fetch_entries(self, child_type: ContainerType, query: NQLQuery, columns: Optional[Iterable[str]]) -> Table: + def _fetch_entries( + self, child_type: ContainerType, query: NQLQuery, columns: Optional[Iterable[str]], limit: Optional[int] + ) -> Table: if columns is not None: # always return entries with `sys/id` column when filter applied columns = set(columns) columns.add("sys/id") leaderboard_entries = self._backend.search_leaderboard_entries( - project_id=self._project_id, - types=[child_type], - query=query, - columns=columns, + project_id=self._project_id, types=[child_type], query=query, columns=columns, limit=limit ) + leaderboard_entries = itertools.islice(leaderboard_entries, limit) if limit else leaderboard_entries + return Table( backend=self._backend, container_type=child_type, diff --git a/src/neptune/metadata_containers/model.py b/src/neptune/metadata_containers/model.py index b028d045e..99781d0a2 100644 --- a/src/neptune/metadata_containers/model.py +++ b/src/neptune/metadata_containers/model.py @@ -257,7 +257,9 @@ def get_url(self) -> str: sys_id=self._sys_id, ) - def fetch_model_versions_table(self, *, columns: Optional[Iterable[str]] = None) -> Table: + def fetch_model_versions_table( + self, *, columns: Optional[Iterable[str]] = None, limit: Optional[int] = None + ) -> Table: """Retrieve all versions of the given model. Args: @@ -267,6 +269,7 @@ def fetch_model_versions_table(self, *, columns: Optional[Iterable[str]] = None) Fields: `["params/lr", "params/batch", "val/acc"]` - these fields are included as columns. Namespaces: `["params", "val"]` - all the fields inside the namespaces are included as columns. If `None` (default), all the columns of the model versions table are included. + limit: How many entries to return at most (default: None - return all entries). Returns: `Table` object containing `ModelVersion` objects that match the specified criteria. @@ -299,6 +302,10 @@ def fetch_model_versions_table(self, *, columns: Optional[Iterable[str]] = None) See also the API referene: https://docs.neptune.ai/api/model/#fetch_model_versions_table """ + verify_type("limit", limit, (int, type(None))) + + if isinstance(limit, int) and limit <= 0: + raise ValueError(f"Parameter 'limit' must be a positive integer or None. Got {limit}.") return MetadataContainer._fetch_entries( self, child_type=ContainerType.MODEL_VERSION, @@ -320,4 +327,5 @@ def fetch_model_versions_table(self, *, columns: Optional[Iterable[str]] = None) aggregator=NQLAggregator.AND, ), columns=columns, + limit=limit, ) diff --git a/src/neptune/metadata_containers/project.py b/src/neptune/metadata_containers/project.py index bb12af11b..ae9b61e85 100644 --- a/src/neptune/metadata_containers/project.py +++ b/src/neptune/metadata_containers/project.py @@ -196,6 +196,7 @@ def fetch_runs_table( tag: Optional[Union[str, Iterable[str]]] = None, columns: Optional[Iterable[str]] = None, trashed: Optional[bool] = False, + limit: Optional[int] = None, ) -> Table: """Retrieve runs matching the specified criteria. @@ -227,6 +228,7 @@ def fetch_runs_table( If `True`, only trashed runs are retrieved. If `False` (default), only not-trashed runs are retrieved. If `None`, both trashed and not-trashed runs are retrieved. + limit: How many entries to return at most (default: None - return all entries). Returns: `Table` object containing `Run` objects matching the specified criteria. @@ -275,6 +277,10 @@ def fetch_runs_table( tags = as_list("tag", tag) verify_type("trashed", trashed, (bool, type(None))) + verify_type("limit", limit, (int, type(None))) + + if isinstance(limit, int) and limit <= 0: + raise ValueError(f"Parameter 'limit' must be a positive integer or None. Got {limit}.") nql_query = prepare_nql_query(ids, states, owners, tags, trashed) @@ -283,9 +289,16 @@ def fetch_runs_table( child_type=ContainerType.RUN, query=nql_query, columns=columns, + limit=limit, ) - def fetch_models_table(self, *, columns: Optional[Iterable[str]] = None, trashed: Optional[bool] = False) -> Table: + def fetch_models_table( + self, + *, + columns: Optional[Iterable[str]] = None, + trashed: Optional[bool] = False, + limit: Optional[int] = None, + ) -> Table: """Retrieve models stored in the project. Args: @@ -299,6 +312,7 @@ def fetch_models_table(self, *, columns: Optional[Iterable[str]] = None, trashed Fields: `["datasets/test", "info/size"]` - these fields are included as columns. Namespaces: `["datasets", "info"]` - all the fields inside the namespaces are included as columns. If `None` (default), all the columns of the models table are included. + limit: How many entries to return at most (default: None - return all entries). Returns: `Table` object containing `Model` objects. @@ -330,6 +344,11 @@ def fetch_models_table(self, *, columns: Optional[Iterable[str]] = None, trashed You may also want to check the API referene in the docs: https://docs.neptune.ai/api/project#fetch_models_table """ + verify_type("limit", limit, (int, type(None))) + + if isinstance(limit, int) and limit <= 0: + raise ValueError(f"Parameter 'limit' must be a positive integer or None. Got {limit}.") + return MetadataContainer._fetch_entries( self, child_type=ContainerType.MODEL, @@ -342,4 +361,5 @@ def fetch_models_table(self, *, columns: Optional[Iterable[str]] = None, trashed if trashed is not None else NQLEmptyQuery, columns=columns, + limit=limit, ) diff --git a/tests/unit/neptune/new/client/abstract_tables_test.py b/tests/unit/neptune/new/client/abstract_tables_test.py index a78d7baf8..19a99e211 100644 --- a/tests/unit/neptune/new/client/abstract_tables_test.py +++ b/tests/unit/neptune/new/client/abstract_tables_test.py @@ -19,6 +19,7 @@ from datetime import datetime from typing import List +import pytest from mock import patch from neptune import ANONYMOUS_API_TOKEN @@ -213,3 +214,25 @@ def test_get_table_as_table_entries( path=["file", "set"], destination="some_directory", ) + + @patch.object(NeptuneBackendMock, "search_leaderboard_entries") + def test_table_limit(self, search_leaderboard_entries): + # given + now = datetime.now() + attributes = self.build_attributes_leaderboard(now) + + # and + empty_entry = LeaderboardEntry(str(uuid.uuid4()), []) + filled_entry = LeaderboardEntry(str(uuid.uuid4()), attributes) + search_leaderboard_entries.return_value = [empty_entry, filled_entry] + + # and + for limit, expected_len in [(1, 1), (2, 2), (3, 2), (10_000, 2)]: + # then + assert len(self.get_table(limit=limit).to_rows()) == expected_len + + with pytest.raises(ValueError): + self.get_table(limit=-4) + + with pytest.raises(ValueError): + self.get_table(limit=0) From 57ce6837870380f06fe2ef3b08b80cb0c7d7a7a9 Mon Sep 17 00:00:00 2001 From: AleksanderWWW Date: Fri, 15 Dec 2023 16:53:56 +0100 Subject: [PATCH 03/22] Sorting for `fetch_*_table` (#1595) --- CHANGELOG.md | 1 + src/neptune/api/searching_entries.py | 8 ++- .../backends/hosted_neptune_backend.py | 61 +++++++++++++++++++ .../internal/backends/neptune_backend.py | 1 + .../internal/backends/neptune_backend_mock.py | 1 + .../metadata_containers/metadata_container.py | 17 +++++- src/neptune/metadata_containers/model.py | 11 +++- src/neptune/metadata_containers/project.py | 12 ++++ tests/e2e/standard/test_fetch_tables.py | 49 +++++++++++++++ .../neptune/new/api/test_searching_entries.py | 26 ++++---- .../new/client/abstract_tables_test.py | 2 +- .../internal/backends/test_hosted_client.py | 32 ++++++++++ 12 files changed, 203 insertions(+), 18 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 09390d73b..6b49e315d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ - Add support for seaborn figures ([#1613](https://github.com/neptune-ai/neptune-client/pull/1613)) - Added fetching with iterators in `fetch_*_table()` methods ([#1585](https://github.com/neptune-ai/neptune-client/pull/1585)) - Added `limit` parameter to `fetch_*_table()` methods ([#1593](https://github.com/neptune-ai/neptune-client/pull/1593)) +- Added `sort_by` parameter to `fetch_*_table()` methods ([#1595](https://github.com/neptune-ai/neptune-client/pull/1595)) ### Fixes - Add direct requirement of `typing-extensions` ([#1586](https://github.com/neptune-ai/neptune-client/pull/1586)) diff --git a/src/neptune/api/searching_entries.py b/src/neptune/api/searching_entries.py index 319a2a2b3..6999affcf 100644 --- a/src/neptune/api/searching_entries.py +++ b/src/neptune/api/searching_entries.py @@ -61,6 +61,7 @@ def get_single_page( limit: int, offset: int, sort_by: Optional[str] = None, + sort_by_column_type: Optional[str] = None, types: Optional[Iterable[str]] = None, query: Optional["NQLQuery"] = None, searching_after: Optional[str] = None, @@ -84,7 +85,10 @@ def get_single_page( "sorting": { "dir": "ascending", "aggregationMode": "none", - "sortBy": {"name": sort_by, "type": "string"}, + "sortBy": { + "name": sort_by, + "type": sort_by_column_type if sort_by_column_type else AttributeType.STRING.value, + }, } } if sort_by @@ -141,6 +145,7 @@ def iter_over_pages( step_size: int, sort_by: str = "sys/id", max_offset: int = MAX_SERVER_OFFSET, + sort_by_column_type: Optional[str] = None, **kwargs: Any, ) -> Generator[Any, None, None]: searching_after = None @@ -160,6 +165,7 @@ def iter_over_pages( limit=min(step_size, max_offset - offset), offset=offset, sort_by=sort_by, + sort_by_column_type=sort_by_column_type, searching_after=searching_after, **kwargs, ) diff --git a/src/neptune/internal/backends/hosted_neptune_backend.py b/src/neptune/internal/backends/hosted_neptune_backend.py index b56b812ab..d09aca6a0 100644 --- a/src/neptune/internal/backends/hosted_neptune_backend.py +++ b/src/neptune/internal/backends/hosted_neptune_backend.py @@ -48,6 +48,10 @@ NeptuneException, ) from neptune.common.patterns import PROJECT_QUALIFIED_NAME_PATTERN +from neptune.common.warnings import ( + NeptuneWarning, + warn_once, +) from neptune.envs import NEPTUNE_FETCH_TABLE_STEP_SIZE from neptune.exceptions import ( AmbiguousProjectName, @@ -150,6 +154,15 @@ _logger = logging.getLogger(__name__) +ATOMIC_ATTRIBUTE_TYPES = { + AttributeType.INT.value, + AttributeType.FLOAT.value, + AttributeType.STRING.value, + AttributeType.BOOL.value, + AttributeType.DATETIME.value, + AttributeType.RUN_STATE.value, +} + class HostedNeptuneBackend(NeptuneBackend): def __init__(self, credentials: Credentials, proxies: Optional[Dict[str, str]] = None): @@ -1015,6 +1028,20 @@ def _get_file_set_download_request(self, container_id: str, container_type: Cont except HTTPNotFound: raise FetchAttributeNotFoundException(path_to_str(path)) + @with_api_exceptions_handler + def _get_column_types(self, project_id: UniqueId, column: str, types: Optional[Iterable[str]] = None) -> List[Any]: + params = { + "projectIdentifier": project_id, + "search": f"/^{column}$/", # exact regex match + "type": types, + "params": {}, + **DEFAULT_REQUEST_KWARGS, + } + try: + return self.leaderboard_client.api.searchLeaderboardAttributes(**params).response().result.entries + except HTTPNotFound as e: + raise ProjectNotFound(project_id=project_id) from e + @with_api_exceptions_handler def search_leaderboard_entries( self, @@ -1023,6 +1050,7 @@ def search_leaderboard_entries( query: Optional[NQLQuery] = None, columns: Optional[Iterable[str]] = None, limit: Optional[int] = None, + sort_by: str = "sys/creation_time", ) -> Generator[LeaderboardEntry, None, None]: default_step_size = int(os.getenv(NEPTUNE_FETCH_TABLE_STEP_SIZE, "100")) @@ -1031,6 +1059,12 @@ def search_leaderboard_entries( types_filter = list(map(lambda container_type: container_type.to_api(), types)) if types else None attributes_filter = {"attributeFilters": [{"path": column} for column in columns]} if columns else {} + if sort_by == "sys/creation_time": + sort_by_column_type = AttributeType.DATETIME.value + else: + sort_by_column_type_candidates = self._get_column_types(project_id, sort_by, types_filter) + sort_by_column_type = _get_column_type_from_entries(sort_by_column_type_candidates, sort_by) + try: return iter_over_pages( client=self.leaderboard_client, @@ -1039,6 +1073,8 @@ def search_leaderboard_entries( query=query, attributes_filter=attributes_filter, step_size=step_size, + sort_by=sort_by, + sort_by_column_type=sort_by_column_type, ) except HTTPNotFound: raise ProjectNotFound(project_id) @@ -1065,3 +1101,28 @@ def get_model_version_url( ) -> str: base_url = self.get_display_address() return f"{base_url}/{workspace}/{project_name}/m/{model_id}/v/{sys_id}" + + +def _get_column_type_from_entries(entries: List[Any], column: str) -> str: + if not entries: # column chosen is not present in the table + raise ValueError(f"Column '{column}' chosen for sorting is not present in the table") + + if len(entries) == 1 and entries[0].name == column: + return entries[0].type + + types = set() + for entry in entries: + if entry.name != column: # caught by regex, but it's not this column + continue + if entry.type not in ATOMIC_ATTRIBUTE_TYPES: # non-atomic type - no need to look further + raise ValueError(f"Column {column} used for sorting is not of atomic type.") + types.add(entry.type) + + if types == {AttributeType.INT.value, AttributeType.FLOAT.value}: + return AttributeType.FLOAT.value + + warn_once( + f"Column {column} contains more than one atomic data type. Sorting result might be inaccurate.", + exception=NeptuneWarning, + ) + return AttributeType.STRING.value diff --git a/src/neptune/internal/backends/neptune_backend.py b/src/neptune/internal/backends/neptune_backend.py index 17a83d647..7823e1761 100644 --- a/src/neptune/internal/backends/neptune_backend.py +++ b/src/neptune/internal/backends/neptune_backend.py @@ -305,6 +305,7 @@ def search_leaderboard_entries( query: Optional[NQLQuery] = None, columns: Optional[List[str]] = None, limit: Optional[int] = None, + sort_by: str = "sys/creation_time", ) -> Generator[LeaderboardEntry, None, None]: pass diff --git a/src/neptune/internal/backends/neptune_backend_mock.py b/src/neptune/internal/backends/neptune_backend_mock.py index 7d291cf0a..3328e269b 100644 --- a/src/neptune/internal/backends/neptune_backend_mock.py +++ b/src/neptune/internal/backends/neptune_backend_mock.py @@ -544,6 +544,7 @@ def search_leaderboard_entries( query: Optional[NQLQuery] = None, columns: Optional[Iterable[str]] = None, limit: Optional[int] = None, + sort_by: str = "sys/creation_time", ) -> Generator[LeaderboardEntry, None, None]: """Non relevant for mock""" diff --git a/src/neptune/metadata_containers/metadata_container.py b/src/neptune/metadata_containers/metadata_container.py index 1d2c5538c..a2dcd6aee 100644 --- a/src/neptune/metadata_containers/metadata_container.py +++ b/src/neptune/metadata_containers/metadata_container.py @@ -653,15 +653,26 @@ def _shutdown_hook(self): self.stop() def _fetch_entries( - self, child_type: ContainerType, query: NQLQuery, columns: Optional[Iterable[str]], limit: Optional[int] + self, + child_type: ContainerType, + query: NQLQuery, + columns: Optional[Iterable[str]], + limit: Optional[int], + sort_by: str, ) -> Table: if columns is not None: - # always return entries with `sys/id` column when filter applied + # always return entries with 'sys/id' and the column chosen for sorting when filter applied columns = set(columns) columns.add("sys/id") + columns.add(sort_by) leaderboard_entries = self._backend.search_leaderboard_entries( - project_id=self._project_id, types=[child_type], query=query, columns=columns, limit=limit + project_id=self._project_id, + types=[child_type], + query=query, + columns=columns, + limit=limit, + sort_by=sort_by, ) leaderboard_entries = itertools.islice(leaderboard_entries, limit) if limit else leaderboard_entries diff --git a/src/neptune/metadata_containers/model.py b/src/neptune/metadata_containers/model.py index 99781d0a2..365a8c277 100644 --- a/src/neptune/metadata_containers/model.py +++ b/src/neptune/metadata_containers/model.py @@ -258,7 +258,11 @@ def get_url(self) -> str: ) def fetch_model_versions_table( - self, *, columns: Optional[Iterable[str]] = None, limit: Optional[int] = None + self, + *, + columns: Optional[Iterable[str]] = None, + limit: Optional[int] = None, + sort_by: str = "sys/creation_time", ) -> Table: """Retrieve all versions of the given model. @@ -270,6 +274,9 @@ def fetch_model_versions_table( Namespaces: `["params", "val"]` - all the fields inside the namespaces are included as columns. If `None` (default), all the columns of the model versions table are included. limit: How many entries to return at most (default: None - return all entries). + sort_by: Name of the column to sort the results by. + Must be an atomic column (string, float, datetime, integer, boolean), otherwise raises `ValueError`. + Default: 'sys/creation_time. Returns: `Table` object containing `ModelVersion` objects that match the specified criteria. @@ -303,6 +310,7 @@ def fetch_model_versions_table( https://docs.neptune.ai/api/model/#fetch_model_versions_table """ verify_type("limit", limit, (int, type(None))) + verify_type("sort_by", sort_by, str) if isinstance(limit, int) and limit <= 0: raise ValueError(f"Parameter 'limit' must be a positive integer or None. Got {limit}.") @@ -328,4 +336,5 @@ def fetch_model_versions_table( ), columns=columns, limit=limit, + sort_by=sort_by, ) diff --git a/src/neptune/metadata_containers/project.py b/src/neptune/metadata_containers/project.py index ae9b61e85..0153b468a 100644 --- a/src/neptune/metadata_containers/project.py +++ b/src/neptune/metadata_containers/project.py @@ -197,6 +197,7 @@ def fetch_runs_table( columns: Optional[Iterable[str]] = None, trashed: Optional[bool] = False, limit: Optional[int] = None, + sort_by: str = "sys/creation_time", ) -> Table: """Retrieve runs matching the specified criteria. @@ -229,6 +230,9 @@ def fetch_runs_table( If `False` (default), only not-trashed runs are retrieved. If `None`, both trashed and not-trashed runs are retrieved. limit: How many entries to return at most (default: None - return all entries). + sort_by: Name of the column to sort the results by. + Must be an atomic column (string, float, datetime, integer, boolean), otherwise raises `ValueError`. + Default: 'sys/creation_time. Returns: `Table` object containing `Run` objects matching the specified criteria. @@ -278,6 +282,7 @@ def fetch_runs_table( verify_type("trashed", trashed, (bool, type(None))) verify_type("limit", limit, (int, type(None))) + verify_type("sort_by", sort_by, str) if isinstance(limit, int) and limit <= 0: raise ValueError(f"Parameter 'limit' must be a positive integer or None. Got {limit}.") @@ -290,6 +295,7 @@ def fetch_runs_table( query=nql_query, columns=columns, limit=limit, + sort_by=sort_by, ) def fetch_models_table( @@ -298,6 +304,7 @@ def fetch_models_table( columns: Optional[Iterable[str]] = None, trashed: Optional[bool] = False, limit: Optional[int] = None, + sort_by: str = "sys/creation_time", ) -> Table: """Retrieve models stored in the project. @@ -313,6 +320,9 @@ def fetch_models_table( Namespaces: `["datasets", "info"]` - all the fields inside the namespaces are included as columns. If `None` (default), all the columns of the models table are included. limit: How many entries to return at most (default: None - return all entries). + sort_by: Name of the column to sort the results by. + Must be an atomic column (string, float, datetime, integer, boolean), otherwise raises `ValueError`. + Default: 'sys/creation_time. Returns: `Table` object containing `Model` objects. @@ -345,6 +355,7 @@ def fetch_models_table( https://docs.neptune.ai/api/project#fetch_models_table """ verify_type("limit", limit, (int, type(None))) + verify_type("sort_by", sort_by, str) if isinstance(limit, int) and limit <= 0: raise ValueError(f"Parameter 'limit' must be a positive integer or None. Got {limit}.") @@ -362,4 +373,5 @@ def fetch_models_table( else NQLEmptyQuery, columns=columns, limit=limit, + sort_by=sort_by, ) diff --git a/tests/e2e/standard/test_fetch_tables.py b/tests/e2e/standard/test_fetch_tables.py index dab720b96..298510a00 100644 --- a/tests/e2e/standard/test_fetch_tables.py +++ b/tests/e2e/standard/test_fetch_tables.py @@ -179,3 +179,52 @@ def test_fetch_runs_table_by_state(self, environment, project): assert not runs.empty assert tag in runs["sys/tags"].values assert random_val in runs["some_random_val"].values + + def test_fetch_runs_table_sorting(self, environment, project): + # given + with neptune.init_run(project=environment.project, custom_run_id="run1") as run: + run["metrics/accuracy"] = 0.95 + run["some_val"] = "b" + + with neptune.init_run(project=environment.project, custom_run_id="run2") as run: + run["metrics/accuracy"] = 0.90 + run["some_val"] = "a" + + time.sleep(30) + + # when + runs = project.fetch_runs_table(sort_by="sys/creation_time").to_pandas() + + # then + # runs are correctly sorted by creation time -> run1 was first + assert not runs.empty + assert runs["sys/custom_run_id"].dropna().to_list() == ["run1", "run2"] + + # when + runs = project.fetch_runs_table(sort_by="metrics/accuracy").to_pandas() + + # then + # run2 has lower accuracy + assert not runs.empty + assert runs["sys/custom_run_id"].dropna().to_list() == ["run2", "run1"] + + # when + runs = project.fetch_runs_table(sort_by="some_val").to_pandas() + + # then + # run2 has a "lower" "some_val" field value + assert not runs.empty + assert runs["sys/custom_run_id"].dropna().to_list() == ["run2", "run1"] + + # test if now it fails when we add a non-atomic type to that field + + # given + with neptune.init_run(project=environment.project, custom_run_id="run3") as run: + for i in range(5): + run["metrics/accuracy"].log(0.95) + + time.sleep(30) + + # then + with pytest.raises(ValueError): + project.fetch_runs_table(sort_by="metrics/accuracy") diff --git a/tests/unit/neptune/new/api/test_searching_entries.py b/tests/unit/neptune/new/api/test_searching_entries.py index 13c1ce002..77158ce38 100644 --- a/tests/unit/neptune/new/api/test_searching_entries.py +++ b/tests/unit/neptune/new/api/test_searching_entries.py @@ -95,10 +95,10 @@ def test__iter_over_pages__single_pagination(get_single_page): # then assert result == generate_leaderboard_entries(values=["a", "b", "c", "d", "e", "f", "g", "h", "j"]) assert get_single_page.mock_calls == [ - call(limit=3, offset=0, sort_by="sys/id", searching_after=None), - call(limit=3, offset=3, sort_by="sys/id", searching_after=None), - call(limit=3, offset=6, sort_by="sys/id", searching_after=None), - call(limit=3, offset=9, sort_by="sys/id", searching_after=None), + call(limit=3, offset=0, sort_by="sys/id", sort_by_column_type=None, searching_after=None), + call(limit=3, offset=3, sort_by="sys/id", sort_by_column_type=None, searching_after=None), + call(limit=3, offset=6, sort_by="sys/id", sort_by_column_type=None, searching_after=None), + call(limit=3, offset=9, sort_by="sys/id", sort_by_column_type=None, searching_after=None), ] @@ -118,10 +118,10 @@ def test__iter_over_pages__multiple_search_after(get_single_page): # then assert result == generate_leaderboard_entries(values=["a", "b", "c", "d", "e", "f", "g", "h", "j"]) assert get_single_page.mock_calls == [ - call(limit=3, offset=0, sort_by="sys/id", searching_after=None), - call(limit=3, offset=3, sort_by="sys/id", searching_after=None), - call(limit=3, offset=0, sort_by="sys/id", searching_after="f"), - call(limit=3, offset=3, sort_by="sys/id", searching_after="f"), + call(limit=3, offset=0, sort_by="sys/id", sort_by_column_type=None, searching_after=None), + call(limit=3, offset=3, sort_by="sys/id", sort_by_column_type=None, searching_after=None), + call(limit=3, offset=0, sort_by="sys/id", sort_by_column_type=None, searching_after="f"), + call(limit=3, offset=3, sort_by="sys/id", sort_by_column_type=None, searching_after="f"), ] @@ -135,7 +135,9 @@ def test__iter_over_pages__empty(get_single_page): # then assert result == [] - assert get_single_page.mock_calls == [call(limit=3, offset=0, sort_by="sys/id", searching_after=None)] + assert get_single_page.mock_calls == [ + call(limit=3, offset=0, sort_by="sys/id", sort_by_column_type=None, searching_after=None) + ] @patch("neptune.api.searching_entries.get_single_page") @@ -153,9 +155,9 @@ def test__iter_over_pages__max_server_offset(get_single_page): # then assert result == generate_leaderboard_entries(values=["a", "b", "c", "d", "e"]) assert get_single_page.mock_calls == [ - call(offset=0, limit=3, sort_by="sys/id", searching_after=None), - call(offset=3, limit=2, sort_by="sys/id", searching_after=None), - call(offset=0, limit=3, sort_by="sys/id", searching_after="e"), + call(offset=0, limit=3, sort_by="sys/id", sort_by_column_type=None, searching_after=None), + call(offset=3, limit=2, sort_by="sys/id", sort_by_column_type=None, searching_after=None), + call(offset=0, limit=3, sort_by="sys/id", sort_by_column_type=None, searching_after="e"), ] diff --git a/tests/unit/neptune/new/client/abstract_tables_test.py b/tests/unit/neptune/new/client/abstract_tables_test.py index 19a99e211..7c128deeb 100644 --- a/tests/unit/neptune/new/client/abstract_tables_test.py +++ b/tests/unit/neptune/new/client/abstract_tables_test.py @@ -96,7 +96,7 @@ def test_get_table_with_columns_filter(self, search_leaderboard_entries): # then self.assertEqual(1, search_leaderboard_entries.call_count) parameters = search_leaderboard_entries.call_args[1] - self.assertEqual({"sys/id", "datetime"}, parameters.get("columns")) + self.assertEqual({"sys/id", "sys/creation_time", "datetime"}, parameters.get("columns")) @patch.object(NeptuneBackendMock, "search_leaderboard_entries") def test_get_table_as_pandas(self, search_leaderboard_entries): diff --git a/tests/unit/neptune/new/internal/backends/test_hosted_client.py b/tests/unit/neptune/new/internal/backends/test_hosted_client.py index 3cf9f1366..f2d43a6fd 100644 --- a/tests/unit/neptune/new/internal/backends/test_hosted_client.py +++ b/tests/unit/neptune/new/internal/backends/test_hosted_client.py @@ -15,7 +15,9 @@ # import unittest import uuid +from dataclasses import dataclass +import pytest from bravado.exception import ( HTTPBadRequest, HTTPConflict, @@ -30,6 +32,7 @@ patch, ) +from neptune.internal.backends.api_model import AttributeType from neptune.internal.backends.hosted_client import ( DEFAULT_REQUEST_KWARGS, _get_token_client, @@ -39,6 +42,7 @@ create_leaderboard_client, get_client_config, ) +from neptune.internal.backends.hosted_neptune_backend import _get_column_type_from_entries from neptune.internal.backends.utils import verify_host_resolution from neptune.management import ( MemberRole, @@ -530,3 +534,31 @@ def test_remove_project_member_permissions(self, swagger_client_factory): # then: with self.assertRaises(AccessRevokedOnMemberRemoval): remove_project_member(project="org/proj", username="tester", api_token=API_TOKEN) + + +def test__get_column_type_from_entries(): + @dataclass + class DTO: + type: str + name: str = "test_column" + + # when + test_cases = [ + {"entries": [], "exc": ValueError}, + {"entries": [DTO(type="float")], "result": AttributeType.FLOAT.value}, + {"entries": [DTO(type="string")], "result": AttributeType.STRING.value}, + {"entries": [DTO(type="float"), DTO(type="floatSeries")], "exc": ValueError}, + {"entries": [DTO(type="float"), DTO(type="int")], "result": AttributeType.FLOAT.value}, + {"entries": [DTO(type="float"), DTO(type="int"), DTO(type="datetime")], "result": AttributeType.STRING.value}, + {"entries": [DTO(type="float"), DTO(type="int"), DTO(type="string")], "result": AttributeType.STRING.value}, + ] + + # then + for tc in test_cases: + exc = tc.get("exc", None) + if exc is not None: + with pytest.raises(exc): + _get_column_type_from_entries(tc["entries"], column="test_column") + else: + result = _get_column_type_from_entries(tc["entries"], column="test_column") + assert result == tc["result"] From 48eaf749effd97fdc36b40526b5f71ae1dce8a66 Mon Sep 17 00:00:00 2001 From: AleksanderWWW Date: Mon, 18 Dec 2023 08:54:19 +0100 Subject: [PATCH 04/22] fix sorting bug (#1601) --- src/neptune/internal/backends/hosted_neptune_backend.py | 2 +- .../unit/neptune/new/internal/backends/test_hosted_client.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/neptune/internal/backends/hosted_neptune_backend.py b/src/neptune/internal/backends/hosted_neptune_backend.py index d09aca6a0..eb30064f1 100644 --- a/src/neptune/internal/backends/hosted_neptune_backend.py +++ b/src/neptune/internal/backends/hosted_neptune_backend.py @@ -1032,7 +1032,7 @@ def _get_file_set_download_request(self, container_id: str, container_type: Cont def _get_column_types(self, project_id: UniqueId, column: str, types: Optional[Iterable[str]] = None) -> List[Any]: params = { "projectIdentifier": project_id, - "search": f"/^{column}$/", # exact regex match + "search": column, "type": types, "params": {}, **DEFAULT_REQUEST_KWARGS, diff --git a/tests/unit/neptune/new/internal/backends/test_hosted_client.py b/tests/unit/neptune/new/internal/backends/test_hosted_client.py index f2d43a6fd..aefe62b1a 100644 --- a/tests/unit/neptune/new/internal/backends/test_hosted_client.py +++ b/tests/unit/neptune/new/internal/backends/test_hosted_client.py @@ -551,6 +551,10 @@ class DTO: {"entries": [DTO(type="float"), DTO(type="int")], "result": AttributeType.FLOAT.value}, {"entries": [DTO(type="float"), DTO(type="int"), DTO(type="datetime")], "result": AttributeType.STRING.value}, {"entries": [DTO(type="float"), DTO(type="int"), DTO(type="string")], "result": AttributeType.STRING.value}, + { + "entries": [DTO(type="float"), DTO(type="int"), DTO(type="string", name="test_column_different")], + "result": AttributeType.FLOAT.value, + }, ] # then From 3eedaad1e09d8cd4c28b50b0484fdd682661d489 Mon Sep 17 00:00:00 2001 From: AleksanderWWW Date: Mon, 18 Dec 2023 19:52:32 +0100 Subject: [PATCH 05/22] Add `ascending` parameter to `fetch_*_table` (#1602) --- CHANGELOG.md | 1 + src/neptune/api/searching_entries.py | 5 ++- .../backends/hosted_neptune_backend.py | 2 + .../internal/backends/neptune_backend.py | 1 + .../internal/backends/neptune_backend_mock.py | 1 + .../metadata_containers/metadata_container.py | 2 + src/neptune/metadata_containers/model.py | 5 +++ src/neptune/metadata_containers/project.py | 10 +++++ tests/e2e/standard/test_fetch_tables.py | 37 ++++++++++++++----- .../neptune/new/api/test_searching_entries.py | 24 ++++++------ 10 files changed, 65 insertions(+), 23 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6b49e315d..ddea1301c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ - Added fetching with iterators in `fetch_*_table()` methods ([#1585](https://github.com/neptune-ai/neptune-client/pull/1585)) - Added `limit` parameter to `fetch_*_table()` methods ([#1593](https://github.com/neptune-ai/neptune-client/pull/1593)) - Added `sort_by` parameter to `fetch_*_table()` methods ([#1595](https://github.com/neptune-ai/neptune-client/pull/1595)) +- Added `ascending` parameter to `fetch_*_table()` methods ([#1602](https://github.com/neptune-ai/neptune-client/pull/1602)) ### Fixes - Add direct requirement of `typing-extensions` ([#1586](https://github.com/neptune-ai/neptune-client/pull/1586)) diff --git a/src/neptune/api/searching_entries.py b/src/neptune/api/searching_entries.py index 6999affcf..954af1880 100644 --- a/src/neptune/api/searching_entries.py +++ b/src/neptune/api/searching_entries.py @@ -62,6 +62,7 @@ def get_single_page( offset: int, sort_by: Optional[str] = None, sort_by_column_type: Optional[str] = None, + ascending: bool = False, types: Optional[Iterable[str]] = None, query: Optional["NQLQuery"] = None, searching_after: Optional[str] = None, @@ -83,7 +84,7 @@ def get_single_page( sorting = ( { "sorting": { - "dir": "ascending", + "dir": "ascending" if ascending else "descending", "aggregationMode": "none", "sortBy": { "name": sort_by, @@ -146,6 +147,7 @@ def iter_over_pages( sort_by: str = "sys/id", max_offset: int = MAX_SERVER_OFFSET, sort_by_column_type: Optional[str] = None, + ascending: bool = False, **kwargs: Any, ) -> Generator[Any, None, None]: searching_after = None @@ -167,6 +169,7 @@ def iter_over_pages( sort_by=sort_by, sort_by_column_type=sort_by_column_type, searching_after=searching_after, + ascending=ascending, **kwargs, ) diff --git a/src/neptune/internal/backends/hosted_neptune_backend.py b/src/neptune/internal/backends/hosted_neptune_backend.py index eb30064f1..2d50c8d29 100644 --- a/src/neptune/internal/backends/hosted_neptune_backend.py +++ b/src/neptune/internal/backends/hosted_neptune_backend.py @@ -1051,6 +1051,7 @@ def search_leaderboard_entries( columns: Optional[Iterable[str]] = None, limit: Optional[int] = None, sort_by: str = "sys/creation_time", + ascending: bool = False, ) -> Generator[LeaderboardEntry, None, None]: default_step_size = int(os.getenv(NEPTUNE_FETCH_TABLE_STEP_SIZE, "100")) @@ -1074,6 +1075,7 @@ def search_leaderboard_entries( attributes_filter=attributes_filter, step_size=step_size, sort_by=sort_by, + ascending=ascending, sort_by_column_type=sort_by_column_type, ) except HTTPNotFound: diff --git a/src/neptune/internal/backends/neptune_backend.py b/src/neptune/internal/backends/neptune_backend.py index 7823e1761..b69950b8e 100644 --- a/src/neptune/internal/backends/neptune_backend.py +++ b/src/neptune/internal/backends/neptune_backend.py @@ -306,6 +306,7 @@ def search_leaderboard_entries( columns: Optional[List[str]] = None, limit: Optional[int] = None, sort_by: str = "sys/creation_time", + ascending: bool = False, ) -> Generator[LeaderboardEntry, None, None]: pass diff --git a/src/neptune/internal/backends/neptune_backend_mock.py b/src/neptune/internal/backends/neptune_backend_mock.py index 3328e269b..75cf2f3cf 100644 --- a/src/neptune/internal/backends/neptune_backend_mock.py +++ b/src/neptune/internal/backends/neptune_backend_mock.py @@ -545,6 +545,7 @@ def search_leaderboard_entries( columns: Optional[Iterable[str]] = None, limit: Optional[int] = None, sort_by: str = "sys/creation_time", + ascending: bool = False, ) -> Generator[LeaderboardEntry, None, None]: """Non relevant for mock""" diff --git a/src/neptune/metadata_containers/metadata_container.py b/src/neptune/metadata_containers/metadata_container.py index a2dcd6aee..32beb4853 100644 --- a/src/neptune/metadata_containers/metadata_container.py +++ b/src/neptune/metadata_containers/metadata_container.py @@ -659,6 +659,7 @@ def _fetch_entries( columns: Optional[Iterable[str]], limit: Optional[int], sort_by: str, + ascending: bool, ) -> Table: if columns is not None: # always return entries with 'sys/id' and the column chosen for sorting when filter applied @@ -673,6 +674,7 @@ def _fetch_entries( columns=columns, limit=limit, sort_by=sort_by, + ascending=ascending, ) leaderboard_entries = itertools.islice(leaderboard_entries, limit) if limit else leaderboard_entries diff --git a/src/neptune/metadata_containers/model.py b/src/neptune/metadata_containers/model.py index 365a8c277..d99ec39dc 100644 --- a/src/neptune/metadata_containers/model.py +++ b/src/neptune/metadata_containers/model.py @@ -263,6 +263,7 @@ def fetch_model_versions_table( columns: Optional[Iterable[str]] = None, limit: Optional[int] = None, sort_by: str = "sys/creation_time", + ascending: bool = False, ) -> Table: """Retrieve all versions of the given model. @@ -277,6 +278,8 @@ def fetch_model_versions_table( sort_by: Name of the column to sort the results by. Must be an atomic column (string, float, datetime, integer, boolean), otherwise raises `ValueError`. Default: 'sys/creation_time. + ascending: Whether to sort model versions in the ascending order of the sorting column values. + Default: False - descending order. Returns: `Table` object containing `ModelVersion` objects that match the specified criteria. @@ -311,6 +314,7 @@ def fetch_model_versions_table( """ verify_type("limit", limit, (int, type(None))) verify_type("sort_by", sort_by, str) + verify_type("ascending", ascending, bool) if isinstance(limit, int) and limit <= 0: raise ValueError(f"Parameter 'limit' must be a positive integer or None. Got {limit}.") @@ -337,4 +341,5 @@ def fetch_model_versions_table( columns=columns, limit=limit, sort_by=sort_by, + ascending=ascending, ) diff --git a/src/neptune/metadata_containers/project.py b/src/neptune/metadata_containers/project.py index 0153b468a..2e10640d7 100644 --- a/src/neptune/metadata_containers/project.py +++ b/src/neptune/metadata_containers/project.py @@ -198,6 +198,7 @@ def fetch_runs_table( trashed: Optional[bool] = False, limit: Optional[int] = None, sort_by: str = "sys/creation_time", + ascending: bool = False, ) -> Table: """Retrieve runs matching the specified criteria. @@ -233,6 +234,8 @@ def fetch_runs_table( sort_by: Name of the column to sort the results by. Must be an atomic column (string, float, datetime, integer, boolean), otherwise raises `ValueError`. Default: 'sys/creation_time. + ascending: Whether to sort runs in the ascending order of the sorting column values. + Default: False - descending order. Returns: `Table` object containing `Run` objects matching the specified criteria. @@ -283,6 +286,7 @@ def fetch_runs_table( verify_type("trashed", trashed, (bool, type(None))) verify_type("limit", limit, (int, type(None))) verify_type("sort_by", sort_by, str) + verify_type("ascending", ascending, bool) if isinstance(limit, int) and limit <= 0: raise ValueError(f"Parameter 'limit' must be a positive integer or None. Got {limit}.") @@ -296,6 +300,7 @@ def fetch_runs_table( columns=columns, limit=limit, sort_by=sort_by, + ascending=ascending, ) def fetch_models_table( @@ -305,6 +310,7 @@ def fetch_models_table( trashed: Optional[bool] = False, limit: Optional[int] = None, sort_by: str = "sys/creation_time", + ascending: bool = False, ) -> Table: """Retrieve models stored in the project. @@ -323,6 +329,8 @@ def fetch_models_table( sort_by: Name of the column to sort the results by. Must be an atomic column (string, float, datetime, integer, boolean), otherwise raises `ValueError`. Default: 'sys/creation_time. + ascending: Whether to sort models in the ascending order of the sorting column values. + Default: False - descending order. Returns: `Table` object containing `Model` objects. @@ -356,6 +364,7 @@ def fetch_models_table( """ verify_type("limit", limit, (int, type(None))) verify_type("sort_by", sort_by, str) + verify_type("ascending", ascending, bool) if isinstance(limit, int) and limit <= 0: raise ValueError(f"Parameter 'limit' must be a positive integer or None. Got {limit}.") @@ -374,4 +383,5 @@ def fetch_models_table( columns=columns, limit=limit, sort_by=sort_by, + ascending=ascending, ) diff --git a/tests/e2e/standard/test_fetch_tables.py b/tests/e2e/standard/test_fetch_tables.py index 298510a00..8fcf974a6 100644 --- a/tests/e2e/standard/test_fetch_tables.py +++ b/tests/e2e/standard/test_fetch_tables.py @@ -67,7 +67,7 @@ def test_fetch_model_versions_with_correct_ids(self, container: Model, environme for index in range(versions_to_initialize): assert versions_table[index].get_attribute_value("sys/id") == f"{model_sys_id}-{index + 1}" - versions_table_gen = container.fetch_model_versions_table() + versions_table_gen = container.fetch_model_versions_table(ascending=True) for te1, te2 in zip(list(versions_table_gen), versions_table): assert te1._id == te2._id assert te1._container_type == te2._container_type @@ -180,7 +180,8 @@ def test_fetch_runs_table_by_state(self, environment, project): assert tag in runs["sys/tags"].values assert random_val in runs["some_random_val"].values - def test_fetch_runs_table_sorting(self, environment, project): + @pytest.mark.parametrize("ascending", [True, False]) + def test_fetch_runs_table_sorting(self, environment, project, ascending): # given with neptune.init_run(project=environment.project, custom_run_id="run1") as run: run["metrics/accuracy"] = 0.95 @@ -193,33 +194,49 @@ def test_fetch_runs_table_sorting(self, environment, project): time.sleep(30) # when - runs = project.fetch_runs_table(sort_by="sys/creation_time").to_pandas() + runs = project.fetch_runs_table(sort_by="sys/creation_time", ascending=ascending).to_pandas() # then # runs are correctly sorted by creation time -> run1 was first assert not runs.empty - assert runs["sys/custom_run_id"].dropna().to_list() == ["run1", "run2"] + run_list = runs["sys/custom_run_id"].dropna().to_list() + if ascending: + assert run_list == ["run1", "run2"] + else: + assert run_list == ["run2", "run1"] # when - runs = project.fetch_runs_table(sort_by="metrics/accuracy").to_pandas() + runs = project.fetch_runs_table(sort_by="metrics/accuracy", ascending=ascending).to_pandas() # then - # run2 has lower accuracy assert not runs.empty - assert runs["sys/custom_run_id"].dropna().to_list() == ["run2", "run1"] + run_list = runs["sys/custom_run_id"].dropna().to_list() + + if ascending: + assert run_list == ["run2", "run1"] + else: + assert run_list == ["run1", "run2"] # when - runs = project.fetch_runs_table(sort_by="some_val").to_pandas() + runs = project.fetch_runs_table(sort_by="some_val", ascending=ascending).to_pandas() # then - # run2 has a "lower" "some_val" field value assert not runs.empty - assert runs["sys/custom_run_id"].dropna().to_list() == ["run2", "run1"] + run_list = runs["sys/custom_run_id"].dropna().to_list() + + if ascending: + assert run_list == ["run2", "run1"] + else: + assert run_list == ["run1", "run2"] + def test_fetch_runs_table_non_atomic_type(self, environment, project): # test if now it fails when we add a non-atomic type to that field # given with neptune.init_run(project=environment.project, custom_run_id="run3") as run: + run["metrics/accuracy"] = 0.9 + + with neptune.init_run(project=environment.project, custom_run_id="run4") as run: for i in range(5): run["metrics/accuracy"].log(0.95) diff --git a/tests/unit/neptune/new/api/test_searching_entries.py b/tests/unit/neptune/new/api/test_searching_entries.py index 77158ce38..e2d912ace 100644 --- a/tests/unit/neptune/new/api/test_searching_entries.py +++ b/tests/unit/neptune/new/api/test_searching_entries.py @@ -95,10 +95,10 @@ def test__iter_over_pages__single_pagination(get_single_page): # then assert result == generate_leaderboard_entries(values=["a", "b", "c", "d", "e", "f", "g", "h", "j"]) assert get_single_page.mock_calls == [ - call(limit=3, offset=0, sort_by="sys/id", sort_by_column_type=None, searching_after=None), - call(limit=3, offset=3, sort_by="sys/id", sort_by_column_type=None, searching_after=None), - call(limit=3, offset=6, sort_by="sys/id", sort_by_column_type=None, searching_after=None), - call(limit=3, offset=9, sort_by="sys/id", sort_by_column_type=None, searching_after=None), + call(limit=3, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), + call(limit=3, offset=3, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), + call(limit=3, offset=6, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), + call(limit=3, offset=9, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), ] @@ -118,10 +118,10 @@ def test__iter_over_pages__multiple_search_after(get_single_page): # then assert result == generate_leaderboard_entries(values=["a", "b", "c", "d", "e", "f", "g", "h", "j"]) assert get_single_page.mock_calls == [ - call(limit=3, offset=0, sort_by="sys/id", sort_by_column_type=None, searching_after=None), - call(limit=3, offset=3, sort_by="sys/id", sort_by_column_type=None, searching_after=None), - call(limit=3, offset=0, sort_by="sys/id", sort_by_column_type=None, searching_after="f"), - call(limit=3, offset=3, sort_by="sys/id", sort_by_column_type=None, searching_after="f"), + call(limit=3, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), + call(limit=3, offset=3, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), + call(limit=3, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after="f"), + call(limit=3, offset=3, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after="f"), ] @@ -136,7 +136,7 @@ def test__iter_over_pages__empty(get_single_page): # then assert result == [] assert get_single_page.mock_calls == [ - call(limit=3, offset=0, sort_by="sys/id", sort_by_column_type=None, searching_after=None) + call(limit=3, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None) ] @@ -155,9 +155,9 @@ def test__iter_over_pages__max_server_offset(get_single_page): # then assert result == generate_leaderboard_entries(values=["a", "b", "c", "d", "e"]) assert get_single_page.mock_calls == [ - call(offset=0, limit=3, sort_by="sys/id", sort_by_column_type=None, searching_after=None), - call(offset=3, limit=2, sort_by="sys/id", sort_by_column_type=None, searching_after=None), - call(offset=0, limit=3, sort_by="sys/id", sort_by_column_type=None, searching_after="e"), + call(offset=0, limit=3, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), + call(offset=3, limit=2, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), + call(offset=0, limit=3, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after="e"), ] From 825d796fe03a7581b3ed2203cd72090ebf521dc0 Mon Sep 17 00:00:00 2001 From: AleksanderWWW Date: Thu, 18 Jan 2024 16:04:46 +0100 Subject: [PATCH 06/22] Progress bars in `fetch_*_table` (#1599) --- CHANGELOG.md | 1 + src/neptune/api/searching_entries.py | 83 +++++++++++++------ .../backends/hosted_neptune_backend.py | 3 + .../internal/backends/neptune_backend.py | 3 + .../internal/backends/neptune_backend_mock.py | 2 + src/neptune/internal/backends/utils.py | 44 ++++++++++ src/neptune/internal/utils/runningmode.py | 16 ++-- .../metadata_containers/metadata_container.py | 4 + src/neptune/metadata_containers/model.py | 6 ++ src/neptune/metadata_containers/project.py | 8 ++ src/neptune/typing.py | 18 +++- src/neptune/utils.py | 74 ++++++++++++++++- tests/e2e/standard/test_fetch_tables.py | 26 +++--- .../neptune/new/api/test_searching_entries.py | 34 +++++--- .../new/internal/backends/test_utils.py | 57 ++++++++++++- 15 files changed, 315 insertions(+), 64 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ddea1301c..35f9da729 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ - Added `limit` parameter to `fetch_*_table()` methods ([#1593](https://github.com/neptune-ai/neptune-client/pull/1593)) - Added `sort_by` parameter to `fetch_*_table()` methods ([#1595](https://github.com/neptune-ai/neptune-client/pull/1595)) - Added `ascending` parameter to `fetch_*_table()` methods ([#1602](https://github.com/neptune-ai/neptune-client/pull/1602)) +- Added `progress_bar` parameter to `fetch_*_table()` methods ([#1599](https://github.com/neptune-ai/neptune-client/pull/1599)) ### Fixes - Add direct requirement of `typing-extensions` ([#1586](https://github.com/neptune-ai/neptune-client/pull/1586)) diff --git a/src/neptune/api/searching_entries.py b/src/neptune/api/searching_entries.py index 954af1880..1b1520b53 100644 --- a/src/neptune/api/searching_entries.py +++ b/src/neptune/api/searching_entries.py @@ -23,6 +23,8 @@ Iterable, List, Optional, + Type, + Union, ) from bravado.client import construct_request # type: ignore @@ -43,7 +45,9 @@ NQLQueryAggregate, NQLQueryAttribute, ) +from neptune.internal.backends.utils import which_progress_bar from neptune.internal.init.parameters import MAX_SERVER_OFFSET +from neptune.typing import ProgressBarCallback if TYPE_CHECKING: from neptune.internal.backends.swagger_client_wrapper import SwaggerClientWrapper @@ -66,7 +70,7 @@ def get_single_page( types: Optional[Iterable[str]] = None, query: Optional["NQLQuery"] = None, searching_after: Optional[str] = None, -) -> List[Any]: +) -> Any: normalized_query = query or NQLEmptyQuery() if sort_by and searching_after: sort_by_as_nql = NQLQueryAttribute( @@ -113,14 +117,12 @@ def get_single_page( http_client = client.swagger_spec.http_client - result = ( + return ( http_client.request(request_params, operation=None, request_config=request_config) .response() .incoming_response.json() ) - return list(map(to_leaderboard_entry, result.get("entries", []))) - def to_leaderboard_entry(entry: Dict[str, Any]) -> LeaderboardEntry: return LeaderboardEntry( @@ -148,34 +150,67 @@ def iter_over_pages( max_offset: int = MAX_SERVER_OFFSET, sort_by_column_type: Optional[str] = None, ascending: bool = False, + progress_bar: Optional[Union[bool, Type[ProgressBarCallback]]] = None, **kwargs: Any, ) -> Generator[Any, None, None]: searching_after = None last_page = None - while True: - if last_page: - page_attribute = find_attribute(entry=last_page[-1], path=sort_by) - - if not page_attribute: - raise ValueError(f"Cannot find attribute {sort_by} in last page") + total = 0 - searching_after = page_attribute.properties["value"] + progress_bar = progress_bar if step_size >= total else None - for offset in range(0, max_offset, step_size): - page = get_single_page( - limit=min(step_size, max_offset - offset), - offset=offset, - sort_by=sort_by, - sort_by_column_type=sort_by_column_type, - searching_after=searching_after, - ascending=ascending, + with _construct_progress_bar(progress_bar) as bar: + # beginning of the first page + bar.update( + by=0, + total=get_single_page( + limit=0, + offset=0, **kwargs, - ) + ).get("matchingItemCount", 0), + ) + + while True: + if last_page: + page_attribute = find_attribute(entry=last_page[-1], path=sort_by) + + if not page_attribute: + raise ValueError(f"Cannot find attribute {sort_by} in last page") + + searching_after = page_attribute.properties["value"] + + for offset in range(0, max_offset, step_size): + result = get_single_page( + limit=min(step_size, max_offset - offset), + offset=offset, + sort_by=sort_by, + sort_by_column_type=sort_by_column_type, + searching_after=searching_after, + ascending=ascending, + **kwargs, + ) + + # fetch the item count everytime a new page is started + if offset == 0: + total += result.get("matchingItemCount", 0) + + page = _entries_from_page(result) + + if not page: + return + + bar.update(by=step_size, total=total) + + yield from page + + last_page = page + - if not page: - return +def _construct_progress_bar(progress_bar: Optional[Union[bool, Type[ProgressBarCallback]]]) -> ProgressBarCallback: + progress_bar_type = which_progress_bar(progress_bar) + return progress_bar_type(description="Fetching table...") - yield from page - last_page = page +def _entries_from_page(single_page: Dict[str, Any]) -> List[LeaderboardEntry]: + return list(map(to_leaderboard_entry, single_page.get("entries", []))) diff --git a/src/neptune/internal/backends/hosted_neptune_backend.py b/src/neptune/internal/backends/hosted_neptune_backend.py index 2d50c8d29..a7e174b10 100644 --- a/src/neptune/internal/backends/hosted_neptune_backend.py +++ b/src/neptune/internal/backends/hosted_neptune_backend.py @@ -144,6 +144,7 @@ from neptune.internal.utils.paths import path_to_str from neptune.internal.websockets.websockets_factory import WebsocketsFactory from neptune.management.exceptions import ObjectNotFound +from neptune.typing import ProgressBarCallback from neptune.version import version as neptune_client_version if TYPE_CHECKING: @@ -1052,6 +1053,7 @@ def search_leaderboard_entries( limit: Optional[int] = None, sort_by: str = "sys/creation_time", ascending: bool = False, + progress_bar: Optional[Union[bool, typing.Type[ProgressBarCallback]]] = None, ) -> Generator[LeaderboardEntry, None, None]: default_step_size = int(os.getenv(NEPTUNE_FETCH_TABLE_STEP_SIZE, "100")) @@ -1077,6 +1079,7 @@ def search_leaderboard_entries( sort_by=sort_by, ascending=ascending, sort_by_column_type=sort_by_column_type, + progress_bar=progress_bar, ) except HTTPNotFound: raise ProjectNotFound(project_id) diff --git a/src/neptune/internal/backends/neptune_backend.py b/src/neptune/internal/backends/neptune_backend.py index b69950b8e..c2354ccf6 100644 --- a/src/neptune/internal/backends/neptune_backend.py +++ b/src/neptune/internal/backends/neptune_backend.py @@ -22,6 +22,7 @@ List, Optional, Tuple, + Type, Union, ) @@ -59,6 +60,7 @@ from neptune.internal.operation_processors.operation_storage import OperationStorage from neptune.internal.utils.git import GitInfo from neptune.internal.websockets.websockets_factory import WebsocketsFactory +from neptune.typing import ProgressBarCallback class NeptuneBackend: @@ -307,6 +309,7 @@ def search_leaderboard_entries( limit: Optional[int] = None, sort_by: str = "sys/creation_time", ascending: bool = False, + progress_bar: Optional[Union[bool, Type[ProgressBarCallback]]] = None, ) -> Generator[LeaderboardEntry, None, None]: pass diff --git a/src/neptune/internal/backends/neptune_backend_mock.py b/src/neptune/internal/backends/neptune_backend_mock.py index 75cf2f3cf..cdc870c60 100644 --- a/src/neptune/internal/backends/neptune_backend_mock.py +++ b/src/neptune/internal/backends/neptune_backend_mock.py @@ -132,6 +132,7 @@ from neptune.types.sets.string_set import StringSet from neptune.types.value import Value from neptune.types.value_visitor import ValueVisitor +from neptune.typing import ProgressBarCallback Val = TypeVar("Val", bound=Value) @@ -546,6 +547,7 @@ def search_leaderboard_entries( limit: Optional[int] = None, sort_by: str = "sys/creation_time", ascending: bool = False, + progress_bar: Optional[Union[bool, Type[ProgressBarCallback]]] = None, ) -> Generator[LeaderboardEntry, None, None]: """Non relevant for mock""" diff --git a/src/neptune/internal/backends/utils.py b/src/neptune/internal/backends/utils.py index ea724ced9..a44e98f73 100644 --- a/src/neptune/internal/backends/utils.py +++ b/src/neptune/internal/backends/utils.py @@ -26,6 +26,7 @@ "ssl_verify", "parse_validation_errors", "ExecuteOperationsBatchingManager", + "which_progress_bar", ] import dataclasses @@ -45,6 +46,8 @@ Mapping, Optional, Text, + Type, + Union, ) from urllib.parse import ( urljoin, @@ -64,6 +67,10 @@ ) from neptune.common.backends.utils import with_api_exceptions_handler +from neptune.common.warnings import ( + NeptuneWarning, + warn_once, +) from neptune.envs import NEPTUNE_ALLOW_SELF_SIGNED_CERTIFICATE from neptune.exceptions import ( CannotResolveHostname, @@ -79,6 +86,11 @@ ) from neptune.internal.utils import replace_patch_version from neptune.internal.utils.logger import logger +from neptune.typing import ProgressBarCallback +from neptune.utils import ( + NullProgressBar, + TqdmProgressBar, +) _logger = logging.getLogger(__name__) @@ -276,3 +288,35 @@ def get_batch(self, ops: Iterable[Operation]) -> OperationsBatch: result.operations.append(op) return result + + +def _check_if_tqdm_installed() -> bool: + try: + import tqdm # noqa: F401 + + return True + except ImportError: # tqdm not installed + return False + + +def which_progress_bar(progress_bar: Optional[Union[bool, Type[ProgressBarCallback]]]) -> Type[ProgressBarCallback]: + if isinstance(progress_bar, type) and issubclass( + progress_bar, ProgressBarCallback + ): # return whatever the user gave us + return progress_bar + + if not isinstance(progress_bar, bool) and progress_bar is not None: + raise TypeError(f"progress_bar should be None, bool or ProgressBarCallback, got {type(progress_bar).__name__}") + + if progress_bar or progress_bar is None: + tqdm_available = _check_if_tqdm_installed() + + if not tqdm_available: + warn_once( + "To use the default progress bar, please install tqdm: pip install tqdm", + exception=NeptuneWarning, + ) + return NullProgressBar + return TqdmProgressBar + + return NullProgressBar diff --git a/src/neptune/internal/utils/runningmode.py b/src/neptune/internal/utils/runningmode.py index 9d6ad20b4..0359a06fc 100644 --- a/src/neptune/internal/utils/runningmode.py +++ b/src/neptune/internal/utils/runningmode.py @@ -18,25 +18,19 @@ import sys -def in_interactive(): +def in_interactive() -> bool: """Based on: https://stackoverflow.com/a/2356427/1565454""" return hasattr(sys, "ps1") -def in_notebook(): +def in_notebook() -> bool: """Based on: https://stackoverflow.com/a/22424821/1565454""" try: from IPython import get_ipython ipy = get_ipython() - if ( - ipy is None - or not hasattr(ipy, "config") - or not isinstance(ipy.config, dict) - or "IPKernelApp" not in ipy.config - ): - return False + return ( + ipy is not None and hasattr(ipy, "config") and isinstance(ipy.config, dict) and "IPKernelApp" in ipy.config + ) except ImportError: return False - - return True diff --git a/src/neptune/metadata_containers/metadata_container.py b/src/neptune/metadata_containers/metadata_container.py index 32beb4853..f771d2367 100644 --- a/src/neptune/metadata_containers/metadata_container.py +++ b/src/neptune/metadata_containers/metadata_container.py @@ -32,6 +32,7 @@ Iterable, List, Optional, + Type, Union, ) @@ -95,6 +96,7 @@ from neptune.metadata_containers.metadata_containers_table import Table from neptune.types.mode import Mode from neptune.types.type_casting import cast_value +from neptune.typing import ProgressBarCallback from neptune.utils import stop_synchronization_callback if TYPE_CHECKING: @@ -660,6 +662,7 @@ def _fetch_entries( limit: Optional[int], sort_by: str, ascending: bool, + progress_bar: Optional[Union[bool, Type[ProgressBarCallback]]], ) -> Table: if columns is not None: # always return entries with 'sys/id' and the column chosen for sorting when filter applied @@ -675,6 +678,7 @@ def _fetch_entries( limit=limit, sort_by=sort_by, ascending=ascending, + progress_bar=progress_bar, ) leaderboard_entries = itertools.islice(leaderboard_entries, limit) if limit else leaderboard_entries diff --git a/src/neptune/metadata_containers/model.py b/src/neptune/metadata_containers/model.py index d99ec39dc..c0a4afa53 100644 --- a/src/neptune/metadata_containers/model.py +++ b/src/neptune/metadata_containers/model.py @@ -21,6 +21,8 @@ Iterable, List, Optional, + Type, + Union, ) from typing_extensions import Literal @@ -59,6 +61,7 @@ from neptune.metadata_containers.abstract import NeptuneObjectCallback from neptune.metadata_containers.metadata_containers_table import Table from neptune.types.mode import Mode +from neptune.typing import ProgressBarCallback if TYPE_CHECKING: from neptune.internal.background_job import BackgroundJob @@ -264,6 +267,7 @@ def fetch_model_versions_table( limit: Optional[int] = None, sort_by: str = "sys/creation_time", ascending: bool = False, + progress_bar: Optional[Union[bool, Type[ProgressBarCallback]]] = None, ) -> Table: """Retrieve all versions of the given model. @@ -315,6 +319,7 @@ def fetch_model_versions_table( verify_type("limit", limit, (int, type(None))) verify_type("sort_by", sort_by, str) verify_type("ascending", ascending, bool) + verify_type("progress_bar", progress_bar, (type(None), bool, type(ProgressBarCallback))) if isinstance(limit, int) and limit <= 0: raise ValueError(f"Parameter 'limit' must be a positive integer or None. Got {limit}.") @@ -342,4 +347,5 @@ def fetch_model_versions_table( limit=limit, sort_by=sort_by, ascending=ascending, + progress_bar=progress_bar, ) diff --git a/src/neptune/metadata_containers/project.py b/src/neptune/metadata_containers/project.py index 2e10640d7..92c2bdc5f 100644 --- a/src/neptune/metadata_containers/project.py +++ b/src/neptune/metadata_containers/project.py @@ -19,6 +19,7 @@ from typing import ( Iterable, Optional, + Type, Union, ) @@ -50,6 +51,7 @@ from neptune.metadata_containers.metadata_containers_table import Table from neptune.metadata_containers.utils import prepare_nql_query from neptune.types.mode import Mode +from neptune.typing import ProgressBarCallback class Project(MetadataContainer): @@ -199,6 +201,7 @@ def fetch_runs_table( limit: Optional[int] = None, sort_by: str = "sys/creation_time", ascending: bool = False, + progress_bar: Optional[Union[bool, Type[ProgressBarCallback]]] = None, ) -> Table: """Retrieve runs matching the specified criteria. @@ -287,6 +290,7 @@ def fetch_runs_table( verify_type("limit", limit, (int, type(None))) verify_type("sort_by", sort_by, str) verify_type("ascending", ascending, bool) + verify_type("progress_bar", progress_bar, (type(None), bool, type(ProgressBarCallback))) if isinstance(limit, int) and limit <= 0: raise ValueError(f"Parameter 'limit' must be a positive integer or None. Got {limit}.") @@ -301,6 +305,7 @@ def fetch_runs_table( limit=limit, sort_by=sort_by, ascending=ascending, + progress_bar=progress_bar, ) def fetch_models_table( @@ -311,6 +316,7 @@ def fetch_models_table( limit: Optional[int] = None, sort_by: str = "sys/creation_time", ascending: bool = False, + progress_bar: Optional[Union[bool, Type[ProgressBarCallback]]] = None, ) -> Table: """Retrieve models stored in the project. @@ -365,6 +371,7 @@ def fetch_models_table( verify_type("limit", limit, (int, type(None))) verify_type("sort_by", sort_by, str) verify_type("ascending", ascending, bool) + verify_type("progress_bar", progress_bar, (type(None), bool, type(ProgressBarCallback))) if isinstance(limit, int) and limit <= 0: raise ValueError(f"Parameter 'limit' must be a positive integer or None. Got {limit}.") @@ -384,4 +391,5 @@ def fetch_models_table( limit=limit, sort_by=sort_by, ascending=ascending, + progress_bar=progress_bar, ) diff --git a/src/neptune/typing.py b/src/neptune/typing.py index d319d4d28..a55f33981 100644 --- a/src/neptune/typing.py +++ b/src/neptune/typing.py @@ -13,10 +13,26 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__all__ = ["SupportsNamespaces", "NeptuneObject", "NeptuneObjectCallback"] +__all__ = ["SupportsNamespaces", "NeptuneObject", "NeptuneObjectCallback", "ProgressBarCallback"] + +import abc +import contextlib +from typing import ( + Any, + Optional, +) from neptune.metadata_containers.abstract import ( NeptuneObject, NeptuneObjectCallback, SupportsNamespaces, ) + + +class ProgressBarCallback(contextlib.AbstractContextManager): + def __init__(self, *args: Any, **kwargs: Any) -> None: + ... + + @abc.abstractmethod + def update(self, *, by: int, total: Optional[int] = None) -> None: + ... diff --git a/src/neptune/utils.py b/src/neptune/utils.py index 298a5b41e..1136a6c96 100644 --- a/src/neptune/utils.py +++ b/src/neptune/utils.py @@ -14,19 +14,34 @@ # limitations under the License. # """Utility functions to support ML metadata logging with neptune.ai.""" -__all__ = ["stringify_unsupported", "stop_synchronization_callback"] +__all__ = [ + "stringify_unsupported", + "stop_synchronization_callback", + "TqdmProgressBar", + "NullProgressBar", +] +from types import TracebackType from typing import ( Any, Mapping, MutableMapping, + Optional, + Type, Union, ) from neptune.internal.init.parameters import DEFAULT_STOP_TIMEOUT from neptune.internal.types.stringify_value import StringifyValue from neptune.internal.utils.logger import logger -from neptune.typing import NeptuneObject +from neptune.internal.utils.runningmode import ( + in_interactive, + in_notebook, +) +from neptune.typing import ( + NeptuneObject, + ProgressBarCallback, +) def stringify_unsupported(value: Any) -> Union[StringifyValue, Mapping]: @@ -73,3 +88,58 @@ def stop_synchronization_callback(neptune_object: NeptuneObject) -> None: "Threshold for disrupted synchronization exceeded. Stopping the synchronization using the default callback." ) neptune_object.stop(seconds=DEFAULT_STOP_TIMEOUT) + + +class NullProgressBar(ProgressBarCallback): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + pass + + def update(self, *, by: int, total: Optional[int] = None) -> None: + pass + + +class TqdmProgressBar(ProgressBarCallback): + def __init__( + self, + *args: Any, + description: Optional[str] = None, + unit: Optional[str] = None, + unit_scale: bool = False, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + interactive = in_interactive() or in_notebook() + + if interactive: + from tqdm.notebook import tqdm + else: + from tqdm import tqdm # type: ignore + + unit = unit if unit else "" + + self._progress_bar = tqdm(desc=description, unit=unit, unit_scale=unit_scale, **kwargs) + + def __enter__(self) -> "TqdmProgressBar": + self._progress_bar.__enter__() + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + self._progress_bar.__exit__(exc_type, exc_val, exc_tb) + + def update(self, *, by: int, total: Optional[int] = None) -> None: + if total: + self._progress_bar.total = total + self._progress_bar.update(by) diff --git a/tests/e2e/standard/test_fetch_tables.py b/tests/e2e/standard/test_fetch_tables.py index 8fcf974a6..9154c8092 100644 --- a/tests/e2e/standard/test_fetch_tables.py +++ b/tests/e2e/standard/test_fetch_tables.py @@ -43,7 +43,7 @@ def test_fetch_runs_by_tag(self, environment, project): # wait for the cache to fill time.sleep(5) - runs_table = project.fetch_runs_table(tag=[tag1, tag2]).to_rows() + runs_table = project.fetch_runs_table(tag=[tag1, tag2], progress_bar=False).to_rows() assert len(runs_table) == 1 assert runs_table[0].get_attribute_value("sys/id") == run_id1 @@ -60,14 +60,14 @@ def test_fetch_model_versions_with_correct_ids(self, container: Model, environme time.sleep(5) versions_table = sorted( - container.fetch_model_versions_table().to_rows(), + container.fetch_model_versions_table(progress_bar=False).to_rows(), key=lambda r: r.get_attribute_value("sys/id"), ) assert len(versions_table) == versions_to_initialize for index in range(versions_to_initialize): assert versions_table[index].get_attribute_value("sys/id") == f"{model_sys_id}-{index + 1}" - versions_table_gen = container.fetch_model_versions_table(ascending=True) + versions_table_gen = container.fetch_model_versions_table(ascending=True, progress_bar=False) for te1, te2 in zip(list(versions_table_gen), versions_table): assert te1._id == te2._id assert te1._container_type == te2._container_type @@ -136,7 +136,7 @@ def init_run(): return neptune.init_run(project=environment.project) def get_runs_as_rows(**kwargs): - return project.fetch_runs_table(**kwargs).to_rows() + return project.fetch_runs_table(**kwargs, progress_bar=False).to_rows() self._test_fetch_from_container(init_run, get_runs_as_rows) @@ -145,7 +145,7 @@ def init_run(): return neptune.init_model(project=environment.project, key=a_key()) def get_models_as_rows(**kwargs): - return project.fetch_models_table(**kwargs).to_rows() + return project.fetch_models_table(**kwargs, progress_bar=False).to_rows() self._test_fetch_from_container(init_run, get_models_as_rows) @@ -157,7 +157,7 @@ def init_run(): return neptune.init_model_version(model=model_sys_id, project=environment.project) def get_model_versions_as_rows(**kwargs): - return container.fetch_model_versions_table(**kwargs).to_rows() + return container.fetch_model_versions_table(**kwargs, progress_bar=False).to_rows() self._test_fetch_from_container(init_run, get_model_versions_as_rows) @@ -168,14 +168,14 @@ def test_fetch_runs_table_by_state(self, environment, project): run["some_random_val"] = random_val time.sleep(30) - runs = project.fetch_runs_table(state="active").to_pandas() + runs = project.fetch_runs_table(state="active", progress_bar=False).to_pandas() assert not runs.empty assert tag in runs["sys/tags"].values assert random_val in runs["some_random_val"].values time.sleep(30) - runs = project.fetch_runs_table(state="inactive").to_pandas() + runs = project.fetch_runs_table(state="inactive", progress_bar=False).to_pandas() assert not runs.empty assert tag in runs["sys/tags"].values assert random_val in runs["some_random_val"].values @@ -194,7 +194,9 @@ def test_fetch_runs_table_sorting(self, environment, project, ascending): time.sleep(30) # when - runs = project.fetch_runs_table(sort_by="sys/creation_time", ascending=ascending).to_pandas() + runs = project.fetch_runs_table( + sort_by="sys/creation_time", ascending=ascending, progress_bar=False + ).to_pandas() # then # runs are correctly sorted by creation time -> run1 was first @@ -206,7 +208,7 @@ def test_fetch_runs_table_sorting(self, environment, project, ascending): assert run_list == ["run2", "run1"] # when - runs = project.fetch_runs_table(sort_by="metrics/accuracy", ascending=ascending).to_pandas() + runs = project.fetch_runs_table(sort_by="metrics/accuracy", ascending=ascending, progress_bar=False).to_pandas() # then assert not runs.empty @@ -218,7 +220,7 @@ def test_fetch_runs_table_sorting(self, environment, project, ascending): assert run_list == ["run1", "run2"] # when - runs = project.fetch_runs_table(sort_by="some_val", ascending=ascending).to_pandas() + runs = project.fetch_runs_table(sort_by="some_val", ascending=ascending, progress_bar=False).to_pandas() # then assert not runs.empty @@ -244,4 +246,4 @@ def test_fetch_runs_table_non_atomic_type(self, environment, project): # then with pytest.raises(ValueError): - project.fetch_runs_table(sort_by="metrics/accuracy") + project.fetch_runs_table(sort_by="metrics/accuracy", progress_bar=False) diff --git a/tests/unit/neptune/new/api/test_searching_entries.py b/tests/unit/neptune/new/api/test_searching_entries.py index e2d912ace..d743e7161 100644 --- a/tests/unit/neptune/new/api/test_searching_entries.py +++ b/tests/unit/neptune/new/api/test_searching_entries.py @@ -79,10 +79,11 @@ def test__to_leaderboard_entry(): ] -@patch("neptune.api.searching_entries.get_single_page") -def test__iter_over_pages__single_pagination(get_single_page): +@patch("neptune.api.searching_entries._entries_from_page") +@patch("neptune.api.searching_entries.get_single_page", return_value={"matchingItemCount": 9}) +def test__iter_over_pages__single_pagination(get_single_page, entries_from_page): # given - get_single_page.side_effect = [ + entries_from_page.side_effect = [ generate_leaderboard_entries(values=["a", "b", "c"]), generate_leaderboard_entries(values=["d", "e", "f"]), generate_leaderboard_entries(values=["g", "h", "j"]), @@ -95,6 +96,7 @@ def test__iter_over_pages__single_pagination(get_single_page): # then assert result == generate_leaderboard_entries(values=["a", "b", "c", "d", "e", "f", "g", "h", "j"]) assert get_single_page.mock_calls == [ + call(limit=0, offset=0), # total checking call(limit=3, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), call(limit=3, offset=3, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), call(limit=3, offset=6, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), @@ -102,10 +104,11 @@ def test__iter_over_pages__single_pagination(get_single_page): ] -@patch("neptune.api.searching_entries.get_single_page") -def test__iter_over_pages__multiple_search_after(get_single_page): +@patch("neptune.api.searching_entries._entries_from_page") +@patch("neptune.api.searching_entries.get_single_page", return_value={"matchingItemCount": 9}) +def test__iter_over_pages__multiple_search_after(get_single_page, entries_from_page): # given - get_single_page.side_effect = [ + entries_from_page.side_effect = [ generate_leaderboard_entries(values=["a", "b", "c"]), generate_leaderboard_entries(values=["d", "e", "f"]), generate_leaderboard_entries(values=["g", "h", "j"]), @@ -118,6 +121,7 @@ def test__iter_over_pages__multiple_search_after(get_single_page): # then assert result == generate_leaderboard_entries(values=["a", "b", "c", "d", "e", "f", "g", "h", "j"]) assert get_single_page.mock_calls == [ + call(limit=0, offset=0), # total checking call(limit=3, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), call(limit=3, offset=3, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), call(limit=3, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after="f"), @@ -125,10 +129,11 @@ def test__iter_over_pages__multiple_search_after(get_single_page): ] -@patch("neptune.api.searching_entries.get_single_page") -def test__iter_over_pages__empty(get_single_page): +@patch("neptune.api.searching_entries._entries_from_page") +@patch("neptune.api.searching_entries.get_single_page", return_value={"matchingItemCount": 1}) +def test__iter_over_pages__empty(get_single_page, entries_from_page): # given - get_single_page.side_effect = [[]] + entries_from_page.side_effect = [[]] # when result = list(iter_over_pages(step_size=3)) @@ -136,14 +141,16 @@ def test__iter_over_pages__empty(get_single_page): # then assert result == [] assert get_single_page.mock_calls == [ - call(limit=3, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None) + call(limit=0, offset=0), # total checking + call(limit=3, offset=0, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), ] -@patch("neptune.api.searching_entries.get_single_page") -def test__iter_over_pages__max_server_offset(get_single_page): +@patch("neptune.api.searching_entries._entries_from_page") +@patch("neptune.api.searching_entries.get_single_page", return_value={"matchingItemCount": 1}) +def test__iter_over_pages__max_server_offset(get_single_page, entries_from_page): # given - get_single_page.side_effect = [ + entries_from_page.side_effect = [ generate_leaderboard_entries(values=["a", "b", "c"]), generate_leaderboard_entries(values=["d", "e"]), None, @@ -155,6 +162,7 @@ def test__iter_over_pages__max_server_offset(get_single_page): # then assert result == generate_leaderboard_entries(values=["a", "b", "c", "d", "e"]) assert get_single_page.mock_calls == [ + call(limit=0, offset=0), # total checking call(offset=0, limit=3, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), call(offset=3, limit=2, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after=None), call(offset=0, limit=3, sort_by="sys/id", ascending=False, sort_by_column_type=None, searching_after="e"), diff --git a/tests/unit/neptune/new/internal/backends/test_utils.py b/tests/unit/neptune/new/internal/backends/test_utils.py index cd47214e6..b402e92f9 100644 --- a/tests/unit/neptune/new/internal/backends/test_utils.py +++ b/tests/unit/neptune/new/internal/backends/test_utils.py @@ -15,7 +15,13 @@ # import unittest import uuid -from unittest.mock import Mock +from typing import Optional +from unittest.mock import ( + Mock, + patch, +) + +import pytest from neptune.attributes import ( Integer, @@ -26,9 +32,27 @@ from neptune.internal.backends.neptune_backend import NeptuneBackend from neptune.internal.backends.utils import ( ExecuteOperationsBatchingManager, + _check_if_tqdm_installed, build_operation_url, + which_progress_bar, ) from neptune.internal.container_type import ContainerType +from neptune.typing import ProgressBarCallback +from neptune.utils import ( + NullProgressBar, + TqdmProgressBar, +) + + +class CustomProgressBar(ProgressBarCallback): + def __enter__(self): + ... + + def __exit__(self, exc_type, exc_val, exc_tb): + ... + + def update(self, *, by: int, total: Optional[int] = None) -> None: + pass class TestNeptuneBackendMock(unittest.TestCase): @@ -154,3 +178,34 @@ def test_handle_failed_copy(self): self.assertEqual(operations[1:], batch.operations) self.assertEqual([backend.get_int_attribute.side_effect], batch.errors) self.assertEqual(1, batch.dropped_operations_count) + + +@patch("neptune.internal.backends.utils._check_if_tqdm_installed") +def test_which_progress_bar(mock_tqdm_installed): + mock_tqdm_installed.return_value = True + + assert which_progress_bar(None) == TqdmProgressBar + assert which_progress_bar(True) == TqdmProgressBar + assert which_progress_bar(False) == NullProgressBar + assert which_progress_bar(CustomProgressBar) == CustomProgressBar + + mock_tqdm_installed.return_value = False + assert which_progress_bar(None) == NullProgressBar + assert which_progress_bar(True) == NullProgressBar + assert which_progress_bar(False) == NullProgressBar + assert which_progress_bar(CustomProgressBar) == CustomProgressBar + + assert mock_tqdm_installed.call_count == 4 # 2 x 'None' + 2 x 'True' + + with pytest.raises(TypeError): + which_progress_bar(1) + + +@patch.dict("sys.modules", {"tqdm": None}) +def test_check_if_tqdm_installed_not_installed(): + assert not _check_if_tqdm_installed() + + +@patch.dict("sys.modules", {"tqdm": {}}) +def test_check_if_tqdm_installed_installed(): + assert _check_if_tqdm_installed() From 626f02927434b5d86b3c111000a34ea1fdf82061 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Jankowski?= Date: Fri, 19 Jan 2024 11:00:38 +0100 Subject: [PATCH 07/22] Better value validation for state parameter of `fetch_*_table` (#1616) --- CHANGELOG.md | 2 ++ src/neptune/internal/utils/__init__.py | 13 ++++++++++++- src/neptune/metadata_containers/project.py | 8 +++++++- tests/unit/neptune/new/client/test_run_tables.py | 4 +--- 4 files changed, 22 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 35f9da729..20ed895ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,11 +11,13 @@ ### Fixes - Add direct requirement of `typing-extensions` ([#1586](https://github.com/neptune-ai/neptune-client/pull/1586)) - Handle `None` values in distribution sorting in `InferDependeciesStrategy` ([#1612](https://github.com/neptune-ai/neptune-client/pull/1612)) +- Better value validation for `state` parameter of `fetch_*_table()` methods ([#1616](https://github.com/neptune-ai/neptune-client/pull/1616)) ### Changes - Use literals instead of str for Mode typing ([#1586](https://github.com/neptune-ai/neptune-client/pull/1586)) - Flag added for cleaning internal data ([#1589](https://github.com/neptune-ai/neptune-client/pull/1589)) + ## 1.8.6 ### Fixes diff --git a/src/neptune/internal/utils/__init__.py b/src/neptune/internal/utils/__init__.py index cc8845f93..f9118fd20 100644 --- a/src/neptune/internal/utils/__init__.py +++ b/src/neptune/internal/utils/__init__.py @@ -16,6 +16,7 @@ __all__ = [ "replace_patch_version", "verify_type", + "verify_value", "is_stream", "is_bool", "is_int", @@ -43,6 +44,7 @@ from glob import glob from io import IOBase from typing import ( + Any, Iterable, List, Mapping, @@ -80,6 +82,11 @@ def verify_type(var_name: str, var, expected_type: Union[type, tuple]): raise TypeError("{} is a stream, which does not implement read method".format(var_name)) +def verify_value(var_name: str, var: Any, expected_values: Iterable[T]) -> None: + if var not in expected_values: + raise ValueError(f"{var_name} must be one of {expected_values} (was `{var}`)") + + def is_stream(var): return isinstance(var, IOBase) and hasattr(var, "read") @@ -184,11 +191,15 @@ def is_ipython() -> bool: return False -def as_list(name: str, value: Optional[Union[str, Iterable[str]]]) -> Optional[Iterable[str]]: +def as_list(name: str, value: Optional[Union[str, Iterable[str]]]) -> Iterable[str]: verify_type(name, value, (type(None), str, Iterable)) + if value is None: return [] + if isinstance(value, str): return [value] + verify_collection_type(name, value, str) + return value diff --git a/src/neptune/metadata_containers/project.py b/src/neptune/metadata_containers/project.py index 92c2bdc5f..07e783b73 100644 --- a/src/neptune/metadata_containers/project.py +++ b/src/neptune/metadata_containers/project.py @@ -44,7 +44,9 @@ from neptune.internal.state import ContainerState from neptune.internal.utils import ( as_list, + verify_collection_type, verify_type, + verify_value, ) from neptune.metadata_containers import MetadataContainer from neptune.metadata_containers.abstract import NeptuneObjectCallback @@ -193,7 +195,7 @@ def fetch_runs_table( self, *, id: Optional[Union[str, Iterable[str]]] = None, - state: Optional[Union[str, Iterable[str]]] = None, + state: Optional[Union[Literal["inactive", "active"], Iterable[Literal["inactive", "active"]]]] = None, owner: Optional[Union[str, Iterable[str]]] = None, tag: Optional[Union[str, Iterable[str]]] = None, columns: Optional[Iterable[str]] = None, @@ -291,6 +293,10 @@ def fetch_runs_table( verify_type("sort_by", sort_by, str) verify_type("ascending", ascending, bool) verify_type("progress_bar", progress_bar, (type(None), bool, type(ProgressBarCallback))) + verify_collection_type("state", states, str) + + for state in states: + verify_value("state", state.lower(), ("inactive", "active")) if isinstance(limit, int) and limit <= 0: raise ValueError(f"Parameter 'limit' must be a positive integer or None. Got {limit}.") diff --git a/tests/unit/neptune/new/client/test_run_tables.py b/tests/unit/neptune/new/client/test_run_tables.py index a90f77a22..ed36615b9 100644 --- a/tests/unit/neptune/new/client/test_run_tables.py +++ b/tests/unit/neptune/new/client/test_run_tables.py @@ -20,7 +20,6 @@ from mock import patch from neptune import init_project -from neptune.exceptions import NeptuneException from neptune.internal.backends.neptune_backend_mock import NeptuneBackendMock from neptune.internal.container_type import ContainerType from neptune.metadata_containers.metadata_containers_table import ( @@ -53,6 +52,5 @@ def test_fetch_runs_table_is_case_insensitive(self): def test_fetch_runs_table_raises_correct_exception_for_incorrect_states(self): for incorrect_state in ["idle", "running", "some_arbitrary_state"]: with self.subTest(incorrect_state): - with self.assertRaises(NeptuneException) as context: + with self.assertRaises(ValueError): self.get_table(state=incorrect_state) - self.assertEquals(f"Can't map RunState to API: {incorrect_state}", str(context.exception)) From 9faf34ca13ee4d5d28c873d38f2f507fd599821f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Jankowski?= Date: Mon, 22 Jan 2024 09:31:59 +0100 Subject: [PATCH 08/22] Update src/neptune/metadata_containers/model.py Co-authored-by: Sabine --- src/neptune/metadata_containers/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/neptune/metadata_containers/model.py b/src/neptune/metadata_containers/model.py index c0a4afa53..dd0b3ef72 100644 --- a/src/neptune/metadata_containers/model.py +++ b/src/neptune/metadata_containers/model.py @@ -278,7 +278,7 @@ def fetch_model_versions_table( Fields: `["params/lr", "params/batch", "val/acc"]` - these fields are included as columns. Namespaces: `["params", "val"]` - all the fields inside the namespaces are included as columns. If `None` (default), all the columns of the model versions table are included. - limit: How many entries to return at most (default: None - return all entries). + limit: How many entries to return at most. If `None`, all entries are returned. sort_by: Name of the column to sort the results by. Must be an atomic column (string, float, datetime, integer, boolean), otherwise raises `ValueError`. Default: 'sys/creation_time. From 6974b0d9ad13f332550c5474ecd911a767131958 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Jankowski?= Date: Mon, 22 Jan 2024 09:32:08 +0100 Subject: [PATCH 09/22] Update src/neptune/metadata_containers/model.py Co-authored-by: Sabine --- src/neptune/metadata_containers/model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/neptune/metadata_containers/model.py b/src/neptune/metadata_containers/model.py index dd0b3ef72..5c57cdf52 100644 --- a/src/neptune/metadata_containers/model.py +++ b/src/neptune/metadata_containers/model.py @@ -282,8 +282,7 @@ def fetch_model_versions_table( sort_by: Name of the column to sort the results by. Must be an atomic column (string, float, datetime, integer, boolean), otherwise raises `ValueError`. Default: 'sys/creation_time. - ascending: Whether to sort model versions in the ascending order of the sorting column values. - Default: False - descending order. + ascending: Whether to sort model versions in ascending order of the sorting column values. Returns: `Table` object containing `ModelVersion` objects that match the specified criteria. From 6ab7e7bfc50a79d3f4ea8df8c447c176de8779e4 Mon Sep 17 00:00:00 2001 From: Rafal Jankowski Date: Mon, 22 Jan 2024 09:56:29 +0100 Subject: [PATCH 10/22] Use simple instead of atomic phrasing --- src/neptune/internal/backends/hosted_neptune_backend.py | 7 +++++-- src/neptune/metadata_containers/model.py | 2 +- src/neptune/metadata_containers/project.py | 4 ++-- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/neptune/internal/backends/hosted_neptune_backend.py b/src/neptune/internal/backends/hosted_neptune_backend.py index a7e174b10..d30b69448 100644 --- a/src/neptune/internal/backends/hosted_neptune_backend.py +++ b/src/neptune/internal/backends/hosted_neptune_backend.py @@ -1120,14 +1120,17 @@ def _get_column_type_from_entries(entries: List[Any], column: str) -> str: if entry.name != column: # caught by regex, but it's not this column continue if entry.type not in ATOMIC_ATTRIBUTE_TYPES: # non-atomic type - no need to look further - raise ValueError(f"Column {column} used for sorting is not of atomic type.") + raise ValueError( + f"Column {column} used for sorting is not of simple type. For more, " + f"see https://docs.neptune.ai/api/field_types/#simple-types" + ) types.add(entry.type) if types == {AttributeType.INT.value, AttributeType.FLOAT.value}: return AttributeType.FLOAT.value warn_once( - f"Column {column} contains more than one atomic data type. Sorting result might be inaccurate.", + f"Column {column} contains more than one simple data type. Sorting result might be inaccurate.", exception=NeptuneWarning, ) return AttributeType.STRING.value diff --git a/src/neptune/metadata_containers/model.py b/src/neptune/metadata_containers/model.py index 5c57cdf52..e1dd523b0 100644 --- a/src/neptune/metadata_containers/model.py +++ b/src/neptune/metadata_containers/model.py @@ -280,7 +280,7 @@ def fetch_model_versions_table( If `None` (default), all the columns of the model versions table are included. limit: How many entries to return at most. If `None`, all entries are returned. sort_by: Name of the column to sort the results by. - Must be an atomic column (string, float, datetime, integer, boolean), otherwise raises `ValueError`. + Must be a simple column (string, float, datetime, integer, boolean), otherwise raises `ValueError`. Default: 'sys/creation_time. ascending: Whether to sort model versions in ascending order of the sorting column values. diff --git a/src/neptune/metadata_containers/project.py b/src/neptune/metadata_containers/project.py index 07e783b73..883ec57e3 100644 --- a/src/neptune/metadata_containers/project.py +++ b/src/neptune/metadata_containers/project.py @@ -237,7 +237,7 @@ def fetch_runs_table( If `None`, both trashed and not-trashed runs are retrieved. limit: How many entries to return at most (default: None - return all entries). sort_by: Name of the column to sort the results by. - Must be an atomic column (string, float, datetime, integer, boolean), otherwise raises `ValueError`. + Must be a simple column (string, float, datetime, integer, boolean), otherwise raises `ValueError`. Default: 'sys/creation_time. ascending: Whether to sort runs in the ascending order of the sorting column values. Default: False - descending order. @@ -339,7 +339,7 @@ def fetch_models_table( If `None` (default), all the columns of the models table are included. limit: How many entries to return at most (default: None - return all entries). sort_by: Name of the column to sort the results by. - Must be an atomic column (string, float, datetime, integer, boolean), otherwise raises `ValueError`. + Must be a simple column (string, float, datetime, integer, boolean), otherwise raises `ValueError`. Default: 'sys/creation_time. ascending: Whether to sort models in the ascending order of the sorting column values. Default: False - descending order. From b55f2234f1e77d9c46c86d01ab2837e5a43ea97a Mon Sep 17 00:00:00 2001 From: Rafal Jankowski Date: Mon, 22 Jan 2024 10:01:01 +0100 Subject: [PATCH 11/22] docstring review --- src/neptune/metadata_containers/model.py | 3 +-- src/neptune/metadata_containers/project.py | 6 ++---- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/neptune/metadata_containers/model.py b/src/neptune/metadata_containers/model.py index e1dd523b0..e877d54fa 100644 --- a/src/neptune/metadata_containers/model.py +++ b/src/neptune/metadata_containers/model.py @@ -280,8 +280,7 @@ def fetch_model_versions_table( If `None` (default), all the columns of the model versions table are included. limit: How many entries to return at most. If `None`, all entries are returned. sort_by: Name of the column to sort the results by. - Must be a simple column (string, float, datetime, integer, boolean), otherwise raises `ValueError`. - Default: 'sys/creation_time. + Must be a simple column (string, float, datetime, integer, boolean). ascending: Whether to sort model versions in ascending order of the sorting column values. Returns: diff --git a/src/neptune/metadata_containers/project.py b/src/neptune/metadata_containers/project.py index 883ec57e3..6a4323ff5 100644 --- a/src/neptune/metadata_containers/project.py +++ b/src/neptune/metadata_containers/project.py @@ -237,8 +237,7 @@ def fetch_runs_table( If `None`, both trashed and not-trashed runs are retrieved. limit: How many entries to return at most (default: None - return all entries). sort_by: Name of the column to sort the results by. - Must be a simple column (string, float, datetime, integer, boolean), otherwise raises `ValueError`. - Default: 'sys/creation_time. + Must be a simple column (string, float, datetime, integer, boolean). ascending: Whether to sort runs in the ascending order of the sorting column values. Default: False - descending order. @@ -339,8 +338,7 @@ def fetch_models_table( If `None` (default), all the columns of the models table are included. limit: How many entries to return at most (default: None - return all entries). sort_by: Name of the column to sort the results by. - Must be a simple column (string, float, datetime, integer, boolean), otherwise raises `ValueError`. - Default: 'sys/creation_time. + Must be a simple column (string, float, datetime, integer, boolean). ascending: Whether to sort models in the ascending order of the sorting column values. Default: False - descending order. From 3afb7f5502d2c5e4b63f9296d295a44fa1a4f7c5 Mon Sep 17 00:00:00 2001 From: Rafal Jankowski Date: Mon, 22 Jan 2024 10:04:37 +0100 Subject: [PATCH 12/22] Suggestions reapplied --- src/neptune/metadata_containers/project.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/neptune/metadata_containers/project.py b/src/neptune/metadata_containers/project.py index 6a4323ff5..1e4a608c6 100644 --- a/src/neptune/metadata_containers/project.py +++ b/src/neptune/metadata_containers/project.py @@ -235,11 +235,10 @@ def fetch_runs_table( If `True`, only trashed runs are retrieved. If `False` (default), only not-trashed runs are retrieved. If `None`, both trashed and not-trashed runs are retrieved. - limit: How many entries to return at most (default: None - return all entries). + limit: How many entries to return at most. If `None`, all entries are returned. sort_by: Name of the column to sort the results by. Must be a simple column (string, float, datetime, integer, boolean). - ascending: Whether to sort runs in the ascending order of the sorting column values. - Default: False - descending order. + ascending: Whether to sort model versions in ascending order of the sorting column values. Returns: `Table` object containing `Run` objects matching the specified criteria. @@ -336,11 +335,10 @@ def fetch_models_table( Fields: `["datasets/test", "info/size"]` - these fields are included as columns. Namespaces: `["datasets", "info"]` - all the fields inside the namespaces are included as columns. If `None` (default), all the columns of the models table are included. - limit: How many entries to return at most (default: None - return all entries). + limit: How many entries to return at most. If `None`, all entries are returned. sort_by: Name of the column to sort the results by. Must be a simple column (string, float, datetime, integer, boolean). - ascending: Whether to sort models in the ascending order of the sorting column values. - Default: False - descending order. + ascending: Whether to sort model versions in ascending order of the sorting column values. Returns: `Table` object containing `Model` objects. From eeafb4def098dd14dfb01056b5ebe765dbdf4eac Mon Sep 17 00:00:00 2001 From: Rafal Jankowski Date: Mon, 22 Jan 2024 10:06:36 +0100 Subject: [PATCH 13/22] docstrings review 2 --- src/neptune/internal/backends/hosted_neptune_backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/neptune/internal/backends/hosted_neptune_backend.py b/src/neptune/internal/backends/hosted_neptune_backend.py index d30b69448..38e2da74a 100644 --- a/src/neptune/internal/backends/hosted_neptune_backend.py +++ b/src/neptune/internal/backends/hosted_neptune_backend.py @@ -1121,7 +1121,7 @@ def _get_column_type_from_entries(entries: List[Any], column: str) -> str: continue if entry.type not in ATOMIC_ATTRIBUTE_TYPES: # non-atomic type - no need to look further raise ValueError( - f"Column {column} used for sorting is not of simple type. For more, " + f"Column {column} used for sorting is a complex type. For more, " f"see https://docs.neptune.ai/api/field_types/#simple-types" ) types.add(entry.type) From 95ae406b435d6ab7583e5bf8fe168bfecfaef7c1 Mon Sep 17 00:00:00 2001 From: AleksanderWWW Date: Mon, 22 Jan 2024 11:27:18 +0100 Subject: [PATCH 14/22] docstring proposition for ProgressBarCallback --- src/neptune/typing.py | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/src/neptune/typing.py b/src/neptune/typing.py index a55f33981..3bcded5ae 100644 --- a/src/neptune/typing.py +++ b/src/neptune/typing.py @@ -30,6 +30,47 @@ class ProgressBarCallback(contextlib.AbstractContextManager): + """ + Abstract base class for progress bar callbacks. + + You can use this class to implement your own progress bar callback that will be invoked in `fetch_*_table` methods. + + Example using `click`: + >>> from typing import Any, Optional, Type + >>> from types import TracebackType + >>> from neptune import init_project + >>> from neptune.typing import ProgressBarCallback + >>> class ClickProgressBar(ProgressBarCallback): + ... def __init__(self, *, description: Optional[str] = None, **_: Any) -> None: + ... super().__init__() + ... from click import progressbar + ... + ... self._progress_bar = progressbar(iterable=None, length=1, label=description) + ... + ... def update(self, *, by: int, total: Optional[int] = None) -> None: + ... if total: + ... self._progress_bar.length = total + ... self._progress_bar.update(by) + ... + ... def __enter__(self) -> "ClickProgressBar": + ... self._progress_bar.__enter__() + ... return self + ... + ... def __exit__( + ... self, + ... exc_type: Optional[Type[BaseException]], + ... exc_val: Optional[BaseException], + ... exc_tb: Optional[TracebackType], + ... ) -> None: + ... self._progress_bar.__exit__(exc_type, exc_val, exc_tb) + >>> with init_project() as project: + ... project.fetch_runs_table(progress_bar=ClickProgressBar) + ... project.fetch_models_table(progress_bar=ClickProgressBar) + + IMPORTANT: Remember to pass a type, not an instance to the `progress_bar` argument, + i.e. `ClickProgressBar`, not `ClickProgressBar()`. + """ + def __init__(self, *args: Any, **kwargs: Any) -> None: ... From 54ed8f480c3e7a4db1f5d61d6599810c5d68576f Mon Sep 17 00:00:00 2001 From: Sabine Date: Mon, 22 Jan 2024 20:16:00 +0200 Subject: [PATCH 15/22] docstring adjustments --- src/neptune/metadata_containers/model.py | 34 +++++------ src/neptune/metadata_containers/project.py | 68 +++++++++++----------- src/neptune/typing.py | 15 +++-- 3 files changed, 57 insertions(+), 60 deletions(-) diff --git a/src/neptune/metadata_containers/model.py b/src/neptune/metadata_containers/model.py index e877d54fa..d653c7878 100644 --- a/src/neptune/metadata_containers/model.py +++ b/src/neptune/metadata_containers/model.py @@ -274,14 +274,14 @@ def fetch_model_versions_table( Args: columns: Names of columns to include in the table, as a list of namespace or field names. The Neptune ID ("sys/id") is included automatically. - Examples: - Fields: `["params/lr", "params/batch", "val/acc"]` - these fields are included as columns. - Namespaces: `["params", "val"]` - all the fields inside the namespaces are included as columns. + If you pass the name of a namespace, all the fields inside the namespace are included as columns. If `None` (default), all the columns of the model versions table are included. limit: How many entries to return at most. If `None`, all entries are returned. sort_by: Name of the column to sort the results by. - Must be a simple column (string, float, datetime, integer, boolean). - ascending: Whether to sort model versions in ascending order of the sorting column values. + The column must represent a simple field type (string, float, datetime, integer, or Boolean). + ascending: Whether to sort the entries in ascending order of the sorting column values. + progress_bar: Set to `False` to disable the download progress bar, + or pass a `ProgressBarCallback` class to use your own progress bar callback. Returns: `Table` object containing `ModelVersion` objects that match the specified criteria. @@ -290,26 +290,18 @@ def fetch_model_versions_table( Examples: >>> import neptune - - >>> # Initialize model with the ID "CLS-FOREST" + ... # Initialize model with the ID "CLS-FOREST" ... model = neptune.init_model(with_id="CLS-FOREST") - - >>> # Fetch the metadata of all model versions as a pandas DataFrame + ... # Fetch the metadata of all the model's versions as a pandas DataFrame ... model_versions_df = model.fetch_model_versions_table().to_pandas() - >>> # Fetch the metadata of all model versions as a pandas DataFrame, - ... # including only the fields "params/lr" and "val/loss" as columns: - ... model_versions = model.fetch_model_versions_table(columns=["params/lr", "val/loss"]) - ... model_versions_df = model_versions.to_pandas() - - >>> # Sort model versions by size - ... model_versions_df = model_versions_df.sort_values(by="sys/size") - - >>> # Sort model versions by creation time - ... model_versions_df = model_versions_df.sort_values(by="sys/creation_time", ascending=False) + >>> # Include only the fields "params/lr" and "val/loss" as columns: + ... model_versions_df = model.fetch_model_versions_table(columns=["params/lr", "val/loss"]).to_pandas() - >>> # Extract the last model version ID - ... last_model_version_id = model_versions_df["sys/id"].values[0] + >>> # Sort model versions by size (space they take up in Neptune) + ... model_versions_df = model.fetch_model_versions_table(sort_by="sys/size").to_pandas() + ... # Extract the ID of the largest model version object + ... largest_model_version_id = model_versions_df["sys/id"].values[0] See also the API referene: https://docs.neptune.ai/api/model/#fetch_model_versions_table diff --git a/src/neptune/metadata_containers/project.py b/src/neptune/metadata_containers/project.py index 1e4a608c6..9a588b582 100644 --- a/src/neptune/metadata_containers/project.py +++ b/src/neptune/metadata_containers/project.py @@ -227,9 +227,7 @@ def fetch_runs_table( Only runs that have all specified tags will match this criterion. columns: Names of columns to include in the table, as a list of namespace or field names. The Neptune ID ("sys/id") is included automatically. - Examples: - Fields: `["params/lr", "params/batch", "train/acc"]` - these fields are included as columns. - Namespaces: `["params", "train"]` - all the fields inside the namespaces are included as columns. + If you pass the name of a namespace, all the fields inside the namespace are included as columns. If `None` (default), all the columns of the runs table are included. trashed: Whether to retrieve trashed runs. If `True`, only trashed runs are retrieved. @@ -237,8 +235,10 @@ def fetch_runs_table( If `None`, both trashed and not-trashed runs are retrieved. limit: How many entries to return at most. If `None`, all entries are returned. sort_by: Name of the column to sort the results by. - Must be a simple column (string, float, datetime, integer, boolean). - ascending: Whether to sort model versions in ascending order of the sorting column values. + The column must represent a simple field type (string, float, datetime, integer, or Boolean). + ascending: Whether to sort the entries in ascending order of the sorting column values. + progress_bar: Set to `False` to disable the download progress bar, + or pass a `ProgressBarCallback` class to use your own progress bar callback. Returns: `Table` object containing `Run` objects matching the specified criteria. @@ -247,22 +247,28 @@ def fetch_runs_table( Examples: >>> import neptune - - >>> # Fetch project "jackie/sandbox" + ... # Fetch project "jackie/sandbox" ... project = neptune.init_project(mode="read-only", project="jackie/sandbox") >>> # Fetch the metadata of all runs as a pandas DataFrame ... runs_table_df = project.fetch_runs_table().to_pandas() + ... # Extract the ID of the last run + ... last_run_id = runs_table_df["sys/id"].values[0] - >>> # Fetch the metadata of all runs as a pandas DataFrame, including only the field "train/loss" - ... # and the fields from the "params" namespace as columns: - ... runs_table_df = project.fetch_runs_table(columns=["params", "train/loss"]).to_pandas() + >>> # Fetch the 100 oldest runs + ... runs_table_df = project.fetch_runs_table( + ... sort_by="sys/creation_time", ascending=True, limit=100 + ... ).to_pandas() - >>> # Sort runs by creation time - ... runs_table_df = runs_table_df.sort_values(by="sys/creation_time", ascending=False) + >>> # Fetch the 100 largest runs (space they take up in Neptune) + ... runs_table_df = project.fetch_runs_table(sort_by="sys/size", limit=100).to_pandas() - >>> # Extract the id of the last run - ... last_run_id = runs_table_df["sys/id"].values[0] + >>> # Include only the field "train/loss" and the fields from the "params" namespace as columns: + ... runs_table_df = project.fetch_runs_table(columns=["params", "train/loss"]).to_pandas() + + >>> # Pass a custom progress bar callback + ... runs_table_df = project.fetch_runs_table(progress_bar=MyProgressBar).to_pandas() + ... # The class MyProgressBar(ProgressBarCallback) must be defined You can also filter the runs table by state, owner, tag, or a combination of these: @@ -327,18 +333,18 @@ def fetch_models_table( Args: trashed: Whether to retrieve trashed models. If `True`, only trashed models are retrieved. - If `False` (default), only not-trashed models are retrieved. + If `False`, only not-trashed models are retrieved. If `None`, both trashed and not-trashed models are retrieved. columns: Names of columns to include in the table, as a list of namespace or field names. The Neptune ID ("sys/id") is included automatically. - Examples: - Fields: `["datasets/test", "info/size"]` - these fields are included as columns. - Namespaces: `["datasets", "info"]` - all the fields inside the namespaces are included as columns. - If `None` (default), all the columns of the models table are included. + If you pass the name of a namespace, all the fields inside the namespace are included as columns. + If `None`, all the columns of the models table are included. limit: How many entries to return at most. If `None`, all entries are returned. sort_by: Name of the column to sort the results by. - Must be a simple column (string, float, datetime, integer, boolean). - ascending: Whether to sort model versions in ascending order of the sorting column values. + The column must represent a simple field type (string, float, datetime, integer, or Boolean). + ascending: Whether to sort the entries in ascending order of the sorting column values. + progress_bar: Set to `False` to disable the download progress bar, + or pass a `ProgressBarCallback` class to use your own progress bar callback. Returns: `Table` object containing `Model` objects. @@ -347,27 +353,23 @@ def fetch_models_table( Examples: >>> import neptune - - >>> # Fetch project "jackie/sandbox" + ... # Fetch project "jackie/sandbox" ... project = neptune.init_project(mode="read-only", project="jackie/sandbox") >>> # Fetch the metadata of all models as a pandas DataFrame ... models_table_df = project.fetch_models_table().to_pandas() - >>> # Fetch the metadata of all models as a pandas DataFrame, - ... # including only the "datasets" namespace and "info/size" field as columns: + >>> # Include only the "datasets" namespace and "info/size" field as columns: ... models_table_df = project.fetch_models_table(columns=["datasets", "info/size"]).to_pandas() - >>> # Sort model objects by size - ... models_table_df = models_table_df.sort_values(by="sys/size") - - >>> # Sort models by creation time - ... models_table_df = models_table_df.sort_values(by="sys/creation_time", ascending=False) - - >>> # Extract the last model id + >>> # Fetch 10 oldest model objects + ... models_table_df = project.fetch_models_table( + ... sort_by="sys/creation_time", ascending=True, limit=10 + ... ).to_pandas() + ... # Extract the ID of the first listed (oldest) model object ... last_model_id = models_table_df["sys/id"].values[0] - You may also want to check the API referene in the docs: + See also the API reference in the docs: https://docs.neptune.ai/api/project#fetch_models_table """ verify_type("limit", limit, (int, type(None))) diff --git a/src/neptune/typing.py b/src/neptune/typing.py index 3bcded5ae..42f2a2243 100644 --- a/src/neptune/typing.py +++ b/src/neptune/typing.py @@ -30,15 +30,17 @@ class ProgressBarCallback(contextlib.AbstractContextManager): - """ - Abstract base class for progress bar callbacks. + """Abstract base class for progress bar callbacks. + + You can use this class to implement your own progress bar callback that will be invoked in table fetching methods: - You can use this class to implement your own progress bar callback that will be invoked in `fetch_*_table` methods. + - `fetch_runs_table()` + - `fetch_models_table()` + - `fetch_model_versions_table()` Example using `click`: >>> from typing import Any, Optional, Type >>> from types import TracebackType - >>> from neptune import init_project >>> from neptune.typing import ProgressBarCallback >>> class ClickProgressBar(ProgressBarCallback): ... def __init__(self, *, description: Optional[str] = None, **_: Any) -> None: @@ -63,12 +65,13 @@ class ProgressBarCallback(contextlib.AbstractContextManager): ... exc_tb: Optional[TracebackType], ... ) -> None: ... self._progress_bar.__exit__(exc_type, exc_val, exc_tb) + >>> from neptune import init_project >>> with init_project() as project: ... project.fetch_runs_table(progress_bar=ClickProgressBar) ... project.fetch_models_table(progress_bar=ClickProgressBar) - IMPORTANT: Remember to pass a type, not an instance to the `progress_bar` argument, - i.e. `ClickProgressBar`, not `ClickProgressBar()`. + IMPORTANT: Pass a type, not an instance to the `progress_bar` argument. + That is, `ClickProgressBar`, not `ClickProgressBar()`. """ def __init__(self, *args: Any, **kwargs: Any) -> None: From 5ed4e3243d017143e163c770222bc7a745ac37dd Mon Sep 17 00:00:00 2001 From: Rafal Jankowski Date: Tue, 23 Jan 2024 10:17:55 +0100 Subject: [PATCH 16/22] Docstrings fixes --- src/neptune/metadata_containers/model.py | 4 ++-- src/neptune/metadata_containers/project.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/neptune/metadata_containers/model.py b/src/neptune/metadata_containers/model.py index d653c7878..49f24700d 100644 --- a/src/neptune/metadata_containers/model.py +++ b/src/neptune/metadata_containers/model.py @@ -277,8 +277,8 @@ def fetch_model_versions_table( If you pass the name of a namespace, all the fields inside the namespace are included as columns. If `None` (default), all the columns of the model versions table are included. limit: How many entries to return at most. If `None`, all entries are returned. - sort_by: Name of the column to sort the results by. - The column must represent a simple field type (string, float, datetime, integer, or Boolean). + sort_by: Name of the field to sort the results by. + The field must represent a simple type (string, float, datetime, integer, or Boolean). ascending: Whether to sort the entries in ascending order of the sorting column values. progress_bar: Set to `False` to disable the download progress bar, or pass a `ProgressBarCallback` class to use your own progress bar callback. diff --git a/src/neptune/metadata_containers/project.py b/src/neptune/metadata_containers/project.py index 9a588b582..774440588 100644 --- a/src/neptune/metadata_containers/project.py +++ b/src/neptune/metadata_containers/project.py @@ -234,8 +234,8 @@ def fetch_runs_table( If `False` (default), only not-trashed runs are retrieved. If `None`, both trashed and not-trashed runs are retrieved. limit: How many entries to return at most. If `None`, all entries are returned. - sort_by: Name of the column to sort the results by. - The column must represent a simple field type (string, float, datetime, integer, or Boolean). + sort_by: Name of the field to sort the results by. + The field must represent a simple type (string, float, datetime, integer, or Boolean). ascending: Whether to sort the entries in ascending order of the sorting column values. progress_bar: Set to `False` to disable the download progress bar, or pass a `ProgressBarCallback` class to use your own progress bar callback. @@ -340,8 +340,8 @@ def fetch_models_table( If you pass the name of a namespace, all the fields inside the namespace are included as columns. If `None`, all the columns of the models table are included. limit: How many entries to return at most. If `None`, all entries are returned. - sort_by: Name of the column to sort the results by. - The column must represent a simple field type (string, float, datetime, integer, or Boolean). + sort_by: Name of the field to sort the results by. + The field must represent a simple type (string, float, datetime, integer, or Boolean). ascending: Whether to sort the entries in ascending order of the sorting column values. progress_bar: Set to `False` to disable the download progress bar, or pass a `ProgressBarCallback` class to use your own progress bar callback. From 7ef63a2f06040cf7af1ef16fbdd7013060b9f719 Mon Sep 17 00:00:00 2001 From: AleksanderWWW Date: Mon, 5 Feb 2024 19:38:43 +0100 Subject: [PATCH 17/22] Use progress bar in downloading File and FileSet (#1620) Co-authored-by: Sabine --- CHANGELOG.md | 1 + src/neptune/api/searching_entries.py | 15 ++----- src/neptune/attributes/atoms/artifact.py | 3 +- src/neptune/attributes/atoms/file.py | 9 ++++- src/neptune/attributes/file_set.py | 9 ++++- src/neptune/attributes/series/file_series.py | 6 ++- src/neptune/handler.py | 14 ++++++- .../backends/hosted_file_operations.py | 40 +++++++++---------- .../backends/hosted_neptune_backend.py | 10 ++++- .../internal/backends/neptune_backend.py | 8 ++-- .../internal/backends/neptune_backend_mock.py | 7 +++- .../backends/offline_neptune_backend.py | 7 +++- src/neptune/internal/backends/utils.py | 17 ++++++-- .../metadata_containers/metadata_container.py | 5 +-- .../metadata_containers_table.py | 17 +++++++- src/neptune/metadata_containers/model.py | 9 +++-- src/neptune/metadata_containers/project.py | 10 +++-- src/neptune/typing.py | 9 ++++- .../new/client/abstract_tables_test.py | 2 + .../backends/test_hosted_file_operations.py | 22 +++------- 20 files changed, 139 insertions(+), 81 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 20ed895ce..edc1e9de3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ - Added `sort_by` parameter to `fetch_*_table()` methods ([#1595](https://github.com/neptune-ai/neptune-client/pull/1595)) - Added `ascending` parameter to `fetch_*_table()` methods ([#1602](https://github.com/neptune-ai/neptune-client/pull/1602)) - Added `progress_bar` parameter to `fetch_*_table()` methods ([#1599](https://github.com/neptune-ai/neptune-client/pull/1599)) +- Added `progress_bar` parameter to `download()` method of the `Handler` class ([#1620](https://github.com/neptune-ai/neptune-client/pull/1620)) ### Fixes - Add direct requirement of `typing-extensions` ([#1586](https://github.com/neptune-ai/neptune-client/pull/1586)) diff --git a/src/neptune/api/searching_entries.py b/src/neptune/api/searching_entries.py index 1b1520b53..bb3df690a 100644 --- a/src/neptune/api/searching_entries.py +++ b/src/neptune/api/searching_entries.py @@ -23,8 +23,6 @@ Iterable, List, Optional, - Type, - Union, ) from bravado.client import construct_request # type: ignore @@ -45,9 +43,9 @@ NQLQueryAggregate, NQLQueryAttribute, ) -from neptune.internal.backends.utils import which_progress_bar +from neptune.internal.backends.utils import construct_progress_bar from neptune.internal.init.parameters import MAX_SERVER_OFFSET -from neptune.typing import ProgressBarCallback +from neptune.typing import ProgressBarType if TYPE_CHECKING: from neptune.internal.backends.swagger_client_wrapper import SwaggerClientWrapper @@ -150,7 +148,7 @@ def iter_over_pages( max_offset: int = MAX_SERVER_OFFSET, sort_by_column_type: Optional[str] = None, ascending: bool = False, - progress_bar: Optional[Union[bool, Type[ProgressBarCallback]]] = None, + progress_bar: Optional[ProgressBarType] = None, **kwargs: Any, ) -> Generator[Any, None, None]: searching_after = None @@ -160,7 +158,7 @@ def iter_over_pages( progress_bar = progress_bar if step_size >= total else None - with _construct_progress_bar(progress_bar) as bar: + with construct_progress_bar(progress_bar, "Fetching table...") as bar: # beginning of the first page bar.update( by=0, @@ -207,10 +205,5 @@ def iter_over_pages( last_page = page -def _construct_progress_bar(progress_bar: Optional[Union[bool, Type[ProgressBarCallback]]]) -> ProgressBarCallback: - progress_bar_type = which_progress_bar(progress_bar) - return progress_bar_type(description="Fetching table...") - - def _entries_from_page(single_page: Dict[str, Any]) -> List[LeaderboardEntry]: return list(map(to_leaderboard_entry, single_page.get("entries", []))) diff --git a/src/neptune/attributes/atoms/artifact.py b/src/neptune/attributes/atoms/artifact.py index 5b9069f6d..1d9145ccb 100644 --- a/src/neptune/attributes/atoms/artifact.py +++ b/src/neptune/attributes/atoms/artifact.py @@ -30,6 +30,7 @@ TrackFilesToArtifact, ) from neptune.types.atoms.artifact import Artifact as ArtifactVal +from neptune.typing import ProgressBarType class Artifact(Atom): @@ -62,7 +63,7 @@ def fetch_files_list(self) -> typing.List[ArtifactFileData]: artifact_hash, ) - def download(self, destination: str = None): + def download(self, destination: str = None, progress_bar: typing.Optional[ProgressBarType] = None): self._check_feature() for file_definition in self.fetch_files_list(): driver: typing.Type[ArtifactDriver] = ArtifactDriversMap.match_type(file_definition.type) diff --git a/src/neptune/attributes/atoms/file.py b/src/neptune/attributes/atoms/file.py index 01f385d84..50091cba9 100644 --- a/src/neptune/attributes/atoms/file.py +++ b/src/neptune/attributes/atoms/file.py @@ -21,6 +21,7 @@ from neptune.internal.operation import UploadFile from neptune.internal.utils import verify_type from neptune.types.atoms.file import File as FileVal +from neptune.typing import ProgressBarType class File(Atom): @@ -39,9 +40,13 @@ def assign(self, value: FileVal, *, wait: bool = False) -> None: def upload(self, value, *, wait: bool = False) -> None: self.assign(FileVal.create_from(value), wait=wait) - def download(self, destination: Optional[str] = None) -> None: + def download( + self, + destination: Optional[str] = None, + progress_bar: Optional[ProgressBarType] = None, + ) -> None: verify_type("destination", destination, (str, type(None))) - self._backend.download_file(self._container_id, self._container_type, self._path, destination) + self._backend.download_file(self._container_id, self._container_type, self._path, destination, progress_bar) def fetch_extension(self) -> str: val = self._backend.get_file_attribute(self._container_id, self._container_type, self._path) diff --git a/src/neptune/attributes/file_set.py b/src/neptune/attributes/file_set.py index 94e8ff390..922522a23 100644 --- a/src/neptune/attributes/file_set.py +++ b/src/neptune/attributes/file_set.py @@ -34,6 +34,7 @@ verify_type, ) from neptune.types.file_set import FileSet as FileSetVal +from neptune.typing import ProgressBarType class FileSet(Attribute): @@ -67,9 +68,13 @@ def _enqueue_upload_operation(self, globs: Iterable[str], *, reset: bool, wait: abs_file_globs = list(os.path.abspath(file_glob) for file_glob in globs) self._enqueue_operation(UploadFileSet(self._path, abs_file_globs, reset=reset), wait=wait) - def download(self, destination: Optional[str] = None) -> None: + def download( + self, + destination: Optional[str] = None, + progress_bar: Optional[ProgressBarType] = None, + ) -> None: verify_type("destination", destination, (str, type(None))) - self._backend.download_file_set(self._container_id, self._container_type, self._path, destination) + self._backend.download_file_set(self._container_id, self._container_type, self._path, destination, progress_bar) def list_fileset_files(self, path: Optional[str] = None) -> List[FileEntry]: path = path or "" diff --git a/src/neptune/attributes/series/file_series.py b/src/neptune/attributes/series/file_series.py index 55637d8b8..761bcf8a3 100644 --- a/src/neptune/attributes/series/file_series.py +++ b/src/neptune/attributes/series/file_series.py @@ -45,6 +45,7 @@ from neptune.internal.utils.limits import image_size_exceeds_limit_for_logging from neptune.types import File from neptune.types.series.file_series import FileSeries as FileSeriesVal +from neptune.typing import ProgressBarType Val = FileSeriesVal Data = File @@ -95,14 +96,14 @@ def _get_base64_image_content(file: File) -> str: return base64_encode(file_content) - def download(self, destination: Optional[str]): + def download(self, destination: Optional[str], progress_bar: Optional[ProgressBarType] = None): target_dir = self._get_destination(destination) item_count = self._backend.get_image_series_values( self._container_id, self._container_type, self._path, 0, 1 ).totalItemCount for i in range(0, item_count): self._backend.download_file_series_by_index( - self._container_id, self._container_type, self._path, i, target_dir + self._container_id, self._container_type, self._path, i, target_dir, progress_bar ) def download_last(self, destination: Optional[str]): @@ -117,6 +118,7 @@ def download_last(self, destination: Optional[str]): self._path, item_count - 1, target_dir, + progress_bar=None, ) else: raise ValueError("Unable to download last file - series is empty") diff --git a/src/neptune/handler.py b/src/neptune/handler.py index 331f53001..b52bb0cd3 100644 --- a/src/neptune/handler.py +++ b/src/neptune/handler.py @@ -66,6 +66,7 @@ from neptune.types.atoms.file import File as FileVal from neptune.types.type_casting import cast_value_for_extend from neptune.types.value_copy import ValueCopy +from neptune.typing import ProgressBarType from neptune.utils import stringify_unsupported if TYPE_CHECKING: @@ -595,7 +596,11 @@ def delete_files(self, paths: Union[str, Iterable[str]], *, wait: bool = False) return self._pass_call_to_attr(function_name="delete_files", paths=paths, wait=wait) @check_protected_paths - def download(self, destination: str = None) -> None: + def download( + self, + destination: Optional[str] = None, + progress_bar: Optional[ProgressBarType] = None, + ) -> None: """Downloads the stored files to the working directory or to the specified destination. Available for the following field types: @@ -612,11 +617,16 @@ def download(self, destination: str = None) -> None: composed from field name and extension (if present). If `destination` is a path to a file, the file will be downloaded under the specified name. Defaults to `None`. + progress_bar: (bool or Type of progress bar, optional): progress bar to be used while downloading assets. + If `None` or `True` the default tqdm-based progress bar will be used. + If `False` no progress bar will be used. + If a type of progress bar is passed, it will be used instead of the default one. + Defaults to `None`. For more information, see the docs: https://docs.neptune.ai/api-reference/field-types """ - return self._pass_call_to_attr(function_name="download", destination=destination) + return self._pass_call_to_attr(function_name="download", destination=destination, progress_bar=progress_bar) def download_last(self, destination: str = None) -> None: """Downloads the stored files to the working directory or to the specified destination. diff --git a/src/neptune/internal/backends/hosted_file_operations.py b/src/neptune/internal/backends/hosted_file_operations.py index 1c1034734..3c4b586d5 100644 --- a/src/neptune/internal/backends/hosted_file_operations.py +++ b/src/neptune/internal/backends/hosted_file_operations.py @@ -25,10 +25,10 @@ import json import os import time +from contextlib import ExitStack from io import BytesIO from typing import ( AnyStr, - Callable, Dict, Iterable, List, @@ -79,6 +79,7 @@ ) from neptune.internal.backends.utils import ( build_operation_url, + construct_progress_bar, handle_server_raw_response_messages, ) from neptune.internal.utils import ( @@ -86,6 +87,7 @@ get_common_root, ) from neptune.internal.utils.logger import logger +from neptune.typing import ProgressBarType DEFAULT_CHUNK_SIZE = 5 * BYTES_IN_ONE_MB DEFAULT_UPLOAD_CONFIG = AttributeUploadConfiguration(chunk_size=DEFAULT_CHUNK_SIZE) @@ -392,6 +394,7 @@ def download_image_series_element( attribute: str, index: int, destination: str, + progress_bar: Optional[ProgressBarType], ): url = build_operation_url( swagger_client.swagger_spec.api_url, @@ -413,6 +416,7 @@ def download_image_series_element( destination, "{}.{}".format(index, response.headers["content-type"].split("/")[-1]), ), + progress_bar=progress_bar, ) @@ -421,9 +425,7 @@ def download_file_attribute( container_id: str, attribute: str, destination: Optional[str] = None, - pre_download_hook: Callable[[int], None] = lambda x: None, - download_iter_hook: Callable[[int], None] = lambda x: None, - post_download_hook: Callable[[], None] = lambda: None, + progress_bar: Optional[ProgressBarType] = None, ): url = build_operation_url( swagger_client.swagger_spec.api_url, @@ -435,16 +437,14 @@ def download_file_attribute( headers={"Accept": "application/octet-stream"}, query_params={"experimentId": container_id, "attribute": attribute}, ) - _store_response_as_file(response, destination, pre_download_hook, download_iter_hook, post_download_hook) + _store_response_as_file(response, destination, progress_bar) def download_file_set_attribute( swagger_client: SwaggerClientWrapper, download_id: str, destination: Optional[str] = None, - pre_download_hook: Callable[[int], None] = lambda x: None, - download_iter_hook: Callable[[int], None] = lambda x: None, - post_download_hook: Callable[[], None] = lambda: None, + progress_bar: Optional[ProgressBarType] = None, ): download_url: Optional[str] = _get_download_url(swagger_client, download_id) next_sleep = 0.5 @@ -458,7 +458,7 @@ def download_file_set_attribute( url=download_url, headers={"Accept": "application/zip"}, ) - _store_response_as_file(response, destination, pre_download_hook, download_iter_hook, post_download_hook) + _store_response_as_file(response, destination, progress_bar) def _get_download_url(swagger_client: SwaggerClientWrapper, download_id: str): @@ -470,9 +470,7 @@ def _get_download_url(swagger_client: SwaggerClientWrapper, download_id: str): def _store_response_as_file( response: Response, destination: Optional[str] = None, - pre_download_hook: Callable[[int], None] = lambda x: None, - download_iter_hook: Callable[[int], None] = lambda x: None, - post_download_hook: Callable[[], None] = lambda: None, + progress_bar: Optional[ProgressBarType] = None, ) -> None: if destination is None: target_file = _get_content_disposition_filename(response) @@ -482,14 +480,16 @@ def _store_response_as_file( target_file = destination total_size = int(response.headers.get("content-length", 0)) - pre_download_hook(total_size) - with response: - with open(target_file, "wb") as f: - for chunk in response.iter_content(chunk_size=1024 * 1024): - if chunk: - f.write(chunk) - download_iter_hook(len(chunk) if chunk else 0) - post_download_hook() + # TODO: update syntax once py3.10 becomes min supported version (with (x(), y(), z()): ...) + with ExitStack() as stack: + bar = stack.enter_context(construct_progress_bar(progress_bar, "Fetching file...")) + response = stack.enter_context(response) + file_stream = stack.enter_context(open(target_file, "wb")) + + for chunk in response.iter_content(chunk_size=1024 * 1024): + if chunk: + file_stream.write(chunk) + bar.update(by=len(chunk), total=total_size) def _get_content_disposition_filename(response: Response) -> str: diff --git a/src/neptune/internal/backends/hosted_neptune_backend.py b/src/neptune/internal/backends/hosted_neptune_backend.py index 38e2da74a..30220933f 100644 --- a/src/neptune/internal/backends/hosted_neptune_backend.py +++ b/src/neptune/internal/backends/hosted_neptune_backend.py @@ -144,7 +144,7 @@ from neptune.internal.utils.paths import path_to_str from neptune.internal.websockets.websockets_factory import WebsocketsFactory from neptune.management.exceptions import ObjectNotFound -from neptune.typing import ProgressBarCallback +from neptune.typing import ProgressBarType from neptune.version import version as neptune_client_version if TYPE_CHECKING: @@ -709,6 +709,7 @@ def download_file_series_by_index( path: List[str], index: int, destination: str, + progress_bar: Optional[ProgressBarType], ): try: download_image_series_element( @@ -717,6 +718,7 @@ def download_file_series_by_index( attribute=path_to_str(path), index=index, destination=destination, + progress_bar=progress_bar, ) except ClientHttpError as e: if e.status == HTTPNotFound.status_code: @@ -730,6 +732,7 @@ def download_file( container_type: ContainerType, path: List[str], destination: Optional[str] = None, + progress_bar: Optional[ProgressBarType] = None, ): try: download_file_attribute( @@ -737,6 +740,7 @@ def download_file( container_id=container_id, attribute=path_to_str(path), destination=destination, + progress_bar=progress_bar, ) except ClientHttpError as e: if e.status == HTTPNotFound.status_code: @@ -750,6 +754,7 @@ def download_file_set( container_type: ContainerType, path: List[str], destination: Optional[str] = None, + progress_bar: Optional[ProgressBarType] = None, ): download_request = self._get_file_set_download_request(container_id, container_type, path) try: @@ -757,6 +762,7 @@ def download_file_set( swagger_client=self.leaderboard_client, download_id=download_request.id, destination=destination, + progress_bar=progress_bar, ) except ClientHttpError as e: if e.status == HTTPNotFound.status_code: @@ -1053,7 +1059,7 @@ def search_leaderboard_entries( limit: Optional[int] = None, sort_by: str = "sys/creation_time", ascending: bool = False, - progress_bar: Optional[Union[bool, typing.Type[ProgressBarCallback]]] = None, + progress_bar: Optional[ProgressBarType] = None, ) -> Generator[LeaderboardEntry, None, None]: default_step_size = int(os.getenv(NEPTUNE_FETCH_TABLE_STEP_SIZE, "100")) diff --git a/src/neptune/internal/backends/neptune_backend.py b/src/neptune/internal/backends/neptune_backend.py index c2354ccf6..7954cd5d7 100644 --- a/src/neptune/internal/backends/neptune_backend.py +++ b/src/neptune/internal/backends/neptune_backend.py @@ -22,7 +22,6 @@ List, Optional, Tuple, - Type, Union, ) @@ -60,7 +59,7 @@ from neptune.internal.operation_processors.operation_storage import OperationStorage from neptune.internal.utils.git import GitInfo from neptune.internal.websockets.websockets_factory import WebsocketsFactory -from neptune.typing import ProgressBarCallback +from neptune.typing import ProgressBarType class NeptuneBackend: @@ -157,6 +156,7 @@ def download_file( container_type: ContainerType, path: List[str], destination: Optional[str] = None, + progress_bar: Optional[ProgressBarType] = None, ): pass @@ -167,6 +167,7 @@ def download_file_set( container_type: ContainerType, path: List[str], destination: Optional[str] = None, + progress_bar: Optional[ProgressBarType] = None, ): pass @@ -234,6 +235,7 @@ def download_file_series_by_index( path: List[str], index: int, destination: str, + progress_bar: Optional[ProgressBarType], ): pass @@ -309,7 +311,7 @@ def search_leaderboard_entries( limit: Optional[int] = None, sort_by: str = "sys/creation_time", ascending: bool = False, - progress_bar: Optional[Union[bool, Type[ProgressBarCallback]]] = None, + progress_bar: Optional[ProgressBarType] = None, ) -> Generator[LeaderboardEntry, None, None]: pass diff --git a/src/neptune/internal/backends/neptune_backend_mock.py b/src/neptune/internal/backends/neptune_backend_mock.py index cdc870c60..998d100a5 100644 --- a/src/neptune/internal/backends/neptune_backend_mock.py +++ b/src/neptune/internal/backends/neptune_backend_mock.py @@ -132,7 +132,7 @@ from neptune.types.sets.string_set import StringSet from neptune.types.value import Value from neptune.types.value_visitor import ValueVisitor -from neptune.typing import ProgressBarCallback +from neptune.typing import ProgressBarType Val = TypeVar("Val", bound=Value) @@ -332,6 +332,7 @@ def download_file( container_type: ContainerType, path: List[str], destination: Optional[str] = None, + progress_bar: Optional[ProgressBarType] = None, ): run = self._get_container(container_id, container_type) value: File = run.get(path) @@ -351,6 +352,7 @@ def download_file_set( container_type: ContainerType, path: List[str], destination: Optional[str] = None, + progress_bar: Optional[ProgressBarType] = None, ): run = self._get_container(container_id, container_type) source_file_set_value: FileSet = run.get(path) @@ -488,6 +490,7 @@ def download_file_series_by_index( path: List[str], index: int, destination: str, + progress_bar: Optional[ProgressBarType], ): """Non relevant for backend""" @@ -547,7 +550,7 @@ def search_leaderboard_entries( limit: Optional[int] = None, sort_by: str = "sys/creation_time", ascending: bool = False, - progress_bar: Optional[Union[bool, Type[ProgressBarCallback]]] = None, + progress_bar: Optional[ProgressBarType] = None, ) -> Generator[LeaderboardEntry, None, None]: """Non relevant for mock""" diff --git a/src/neptune/internal/backends/offline_neptune_backend.py b/src/neptune/internal/backends/offline_neptune_backend.py index 759ae4de7..f5d2589bf 100644 --- a/src/neptune/internal/backends/offline_neptune_backend.py +++ b/src/neptune/internal/backends/offline_neptune_backend.py @@ -15,7 +15,10 @@ # __all__ = ["OfflineNeptuneBackend"] -from typing import List +from typing import ( + List, + Optional, +) from neptune.api.dtos import FileEntry from neptune.exceptions import NeptuneOfflineModeFetchException @@ -38,6 +41,7 @@ ) from neptune.internal.backends.neptune_backend_mock import NeptuneBackendMock from neptune.internal.container_type import ContainerType +from neptune.typing import ProgressBarType class OfflineNeptuneBackend(NeptuneBackendMock): @@ -128,6 +132,7 @@ def download_file_series_by_index( path: List[str], index: int, destination: str, + progress_bar: Optional[ProgressBarType], ): raise NeptuneOfflineModeFetchException diff --git a/src/neptune/internal/backends/utils.py b/src/neptune/internal/backends/utils.py index a44e98f73..9eaff77ea 100644 --- a/src/neptune/internal/backends/utils.py +++ b/src/neptune/internal/backends/utils.py @@ -27,6 +27,7 @@ "parse_validation_errors", "ExecuteOperationsBatchingManager", "which_progress_bar", + "construct_progress_bar", ] import dataclasses @@ -47,7 +48,6 @@ Optional, Text, Type, - Union, ) from urllib.parse import ( urljoin, @@ -86,7 +86,10 @@ ) from neptune.internal.utils import replace_patch_version from neptune.internal.utils.logger import logger -from neptune.typing import ProgressBarCallback +from neptune.typing import ( + ProgressBarCallback, + ProgressBarType, +) from neptune.utils import ( NullProgressBar, TqdmProgressBar, @@ -299,7 +302,7 @@ def _check_if_tqdm_installed() -> bool: return False -def which_progress_bar(progress_bar: Optional[Union[bool, Type[ProgressBarCallback]]]) -> Type[ProgressBarCallback]: +def which_progress_bar(progress_bar: Optional[ProgressBarType]) -> Type[ProgressBarCallback]: if isinstance(progress_bar, type) and issubclass( progress_bar, ProgressBarCallback ): # return whatever the user gave us @@ -320,3 +323,11 @@ def which_progress_bar(progress_bar: Optional[Union[bool, Type[ProgressBarCallba return TqdmProgressBar return NullProgressBar + + +def construct_progress_bar( + progress_bar: Optional[ProgressBarType], + description: str, +) -> ProgressBarCallback: + progress_bar_type = which_progress_bar(progress_bar) + return progress_bar_type(description=description) diff --git a/src/neptune/metadata_containers/metadata_container.py b/src/neptune/metadata_containers/metadata_container.py index f771d2367..8b1f83e55 100644 --- a/src/neptune/metadata_containers/metadata_container.py +++ b/src/neptune/metadata_containers/metadata_container.py @@ -32,7 +32,6 @@ Iterable, List, Optional, - Type, Union, ) @@ -96,7 +95,7 @@ from neptune.metadata_containers.metadata_containers_table import Table from neptune.types.mode import Mode from neptune.types.type_casting import cast_value -from neptune.typing import ProgressBarCallback +from neptune.typing import ProgressBarType from neptune.utils import stop_synchronization_callback if TYPE_CHECKING: @@ -662,7 +661,7 @@ def _fetch_entries( limit: Optional[int], sort_by: str, ascending: bool, - progress_bar: Optional[Union[bool, Type[ProgressBarCallback]]], + progress_bar: Optional[ProgressBarType], ) -> Table: if columns is not None: # always return entries with 'sys/id' and the column chosen for sorting when filter applied diff --git a/src/neptune/metadata_containers/metadata_containers_table.py b/src/neptune/metadata_containers/metadata_containers_table.py index 19aafab6e..b8275189a 100644 --- a/src/neptune/metadata_containers/metadata_containers_table.py +++ b/src/neptune/metadata_containers/metadata_containers_table.py @@ -39,6 +39,7 @@ parse_path, ) from neptune.internal.utils.run_state import RunState +from neptune.typing import ProgressBarType logger = logging.getLogger(__name__) @@ -102,7 +103,12 @@ def get_attribute_value(self, path: str) -> Any: return None raise ValueError("Could not find {} attribute".format(path)) - def download_file_attribute(self, path: str, destination: Optional[str]): + def download_file_attribute( + self, + path: str, + destination: Optional[str], + progress_bar: Optional[ProgressBarType] = None, + ): for attr in self._attributes: if attr.path == path: _type = attr.type @@ -112,12 +118,18 @@ def download_file_attribute(self, path: str, destination: Optional[str]): container_type=self._container_type, path=parse_path(path), destination=destination, + progress_bar=progress_bar, ) return raise MetadataInconsistency("Cannot download file from attribute of type {}".format(_type)) raise ValueError("Could not find {} attribute".format(path)) - def download_file_set_attribute(self, path: str, destination: Optional[str]): + def download_file_set_attribute( + self, + path: str, + destination: Optional[str], + progress_bar: Optional[ProgressBarType] = None, + ): for attr in self._attributes: if attr.path == path: _type = attr.type @@ -127,6 +139,7 @@ def download_file_set_attribute(self, path: str, destination: Optional[str]): container_type=self._container_type, path=parse_path(path), destination=destination, + progress_bar=progress_bar, ) return raise MetadataInconsistency("Cannot download ZIP archive from attribute of type {}".format(_type)) diff --git a/src/neptune/metadata_containers/model.py b/src/neptune/metadata_containers/model.py index 49f24700d..30a6a3f00 100644 --- a/src/neptune/metadata_containers/model.py +++ b/src/neptune/metadata_containers/model.py @@ -21,8 +21,6 @@ Iterable, List, Optional, - Type, - Union, ) from typing_extensions import Literal @@ -61,7 +59,10 @@ from neptune.metadata_containers.abstract import NeptuneObjectCallback from neptune.metadata_containers.metadata_containers_table import Table from neptune.types.mode import Mode -from neptune.typing import ProgressBarCallback +from neptune.typing import ( + ProgressBarCallback, + ProgressBarType, +) if TYPE_CHECKING: from neptune.internal.background_job import BackgroundJob @@ -267,7 +268,7 @@ def fetch_model_versions_table( limit: Optional[int] = None, sort_by: str = "sys/creation_time", ascending: bool = False, - progress_bar: Optional[Union[bool, Type[ProgressBarCallback]]] = None, + progress_bar: Optional[ProgressBarType] = None, ) -> Table: """Retrieve all versions of the given model. diff --git a/src/neptune/metadata_containers/project.py b/src/neptune/metadata_containers/project.py index 774440588..d812e6e0c 100644 --- a/src/neptune/metadata_containers/project.py +++ b/src/neptune/metadata_containers/project.py @@ -19,7 +19,6 @@ from typing import ( Iterable, Optional, - Type, Union, ) @@ -53,7 +52,10 @@ from neptune.metadata_containers.metadata_containers_table import Table from neptune.metadata_containers.utils import prepare_nql_query from neptune.types.mode import Mode -from neptune.typing import ProgressBarCallback +from neptune.typing import ( + ProgressBarCallback, + ProgressBarType, +) class Project(MetadataContainer): @@ -203,7 +205,7 @@ def fetch_runs_table( limit: Optional[int] = None, sort_by: str = "sys/creation_time", ascending: bool = False, - progress_bar: Optional[Union[bool, Type[ProgressBarCallback]]] = None, + progress_bar: Optional[ProgressBarType] = None, ) -> Table: """Retrieve runs matching the specified criteria. @@ -326,7 +328,7 @@ def fetch_models_table( limit: Optional[int] = None, sort_by: str = "sys/creation_time", ascending: bool = False, - progress_bar: Optional[Union[bool, Type[ProgressBarCallback]]] = None, + progress_bar: Optional[ProgressBarType] = None, ) -> Table: """Retrieve models stored in the project. diff --git a/src/neptune/typing.py b/src/neptune/typing.py index 42f2a2243..6b8c2f841 100644 --- a/src/neptune/typing.py +++ b/src/neptune/typing.py @@ -13,15 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__all__ = ["SupportsNamespaces", "NeptuneObject", "NeptuneObjectCallback", "ProgressBarCallback"] +__all__ = ["SupportsNamespaces", "NeptuneObject", "NeptuneObjectCallback", "ProgressBarCallback", "ProgressBarType"] import abc import contextlib from typing import ( Any, Optional, + Type, + Union, ) +from typing_extensions import TypeAlias + from neptune.metadata_containers.abstract import ( NeptuneObject, NeptuneObjectCallback, @@ -80,3 +84,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: @abc.abstractmethod def update(self, *, by: int, total: Optional[int] = None) -> None: ... + + +ProgressBarType: TypeAlias = Union[bool, Type[ProgressBarCallback]] diff --git a/tests/unit/neptune/new/client/abstract_tables_test.py b/tests/unit/neptune/new/client/abstract_tables_test.py index 7c128deeb..12132d267 100644 --- a/tests/unit/neptune/new/client/abstract_tables_test.py +++ b/tests/unit/neptune/new/client/abstract_tables_test.py @@ -205,6 +205,7 @@ def test_get_table_as_table_entries( container_type=self.expected_container_type, path=["file"], destination="some_directory", + progress_bar=None, ) table_entry["file/set"].download("some_directory") @@ -213,6 +214,7 @@ def test_get_table_as_table_entries( container_type=self.expected_container_type, path=["file", "set"], destination="some_directory", + progress_bar=None, ) @patch.object(NeptuneBackendMock, "search_leaderboard_entries") diff --git a/tests/unit/neptune/new/internal/backends/test_hosted_file_operations.py b/tests/unit/neptune/new/internal/backends/test_hosted_file_operations.py index 308971ccb..aa055fd7a 100644 --- a/tests/unit/neptune/new/internal/backends/test_hosted_file_operations.py +++ b/tests/unit/neptune/new/internal/backends/test_hosted_file_operations.py @@ -102,19 +102,12 @@ def test_download_file_attribute(self, download_raw, store_response_mock): swagger_mock = self._get_swagger_mock() exp_uuid = str(uuid.uuid4()) - pre_download_hook = MagicMock() - download_iter_hook = MagicMock() - post_download_hook = MagicMock() - # when download_file_attribute( swagger_client=swagger_mock, container_id=exp_uuid, attribute="some/attribute", destination=None, - pre_download_hook=pre_download_hook, - download_iter_hook=download_iter_hook, - post_download_hook=post_download_hook, ) # then @@ -125,7 +118,9 @@ def test_download_file_attribute(self, download_raw, store_response_mock): query_params={"experimentId": str(exp_uuid), "attribute": "some/attribute"}, ) store_response_mock.assert_called_once_with( - download_raw.return_value, None, pre_download_hook, download_iter_hook, post_download_hook + download_raw.return_value, + None, + None, ) @patch("neptune.internal.backends.hosted_file_operations._store_response_as_file") @@ -139,18 +134,11 @@ def test_download_file_set_attribute(self, download_raw, store_response_mock): swagger_mock = self._get_swagger_mock() download_id = str(uuid.uuid4()) - pre_download_hook = MagicMock() - download_iter_hook = MagicMock() - post_download_hook = MagicMock() - # when download_file_set_attribute( swagger_client=swagger_mock, download_id=download_id, destination=None, - pre_download_hook=pre_download_hook, - download_iter_hook=download_iter_hook, - post_download_hook=post_download_hook, ) # then @@ -160,7 +148,9 @@ def test_download_file_set_attribute(self, download_raw, store_response_mock): headers={"Accept": "application/zip"}, ) store_response_mock.assert_called_once_with( - download_raw.return_value, None, pre_download_hook, download_iter_hook, post_download_hook + download_raw.return_value, + None, + None, ) From 8b50b2e0184594472c7a6a2067066e00aebfd5e0 Mon Sep 17 00:00:00 2001 From: AleksanderWWW Date: Mon, 5 Feb 2024 21:02:36 +0100 Subject: [PATCH 18/22] Add progress bar to fetching FloatSeries and StringSeries --- .../attributes/series/fetchable_series.py | 17 ++++++++++++----- src/neptune/handler.py | 13 +++++++++++-- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/src/neptune/attributes/series/fetchable_series.py b/src/neptune/attributes/series/fetchable_series.py index 4fc41d5f5..8802855d6 100644 --- a/src/neptune/attributes/series/fetchable_series.py +++ b/src/neptune/attributes/series/fetchable_series.py @@ -20,6 +20,7 @@ from typing import ( Dict, Generic, + Optional, TypeVar, Union, ) @@ -28,6 +29,9 @@ FloatSeriesValues, StringSeriesValues, ) +from neptune.internal.backends.utils import construct_progress_bar +from neptune.internal.utils.paths import path_to_str +from neptune.typing import ProgressBarType Row = TypeVar("Row", StringSeriesValues, FloatSeriesValues) @@ -37,7 +41,7 @@ class FetchableSeries(Generic[Row]): def _fetch_values_from_backend(self, offset, limit) -> Row: pass - def fetch_values(self, *, include_timestamp=True): + def fetch_values(self, *, include_timestamp: bool = True, progress_bar: Optional[ProgressBarType] = None): import pandas as pd limit = 1000 @@ -53,10 +57,13 @@ def make_row(entry: Row) -> Dict[str, Union[str, float, datetime]]: row["timestamp"] = datetime.fromtimestamp(entry.timestampMillis / 1000) return row - while offset < val.totalItemCount: - batch = self._fetch_values_from_backend(offset, limit) - data.extend(batch.values) - offset += limit + path = path_to_str(self._path) if hasattr(self, "_path") else "" + with construct_progress_bar(progress_bar, f"Fetching {path} values") as bar: + while offset < val.totalItemCount: + batch = self._fetch_values_from_backend(offset, limit) + data.extend(batch.values) + offset += limit + bar.update(by=batch.totalItemCount, total=val.totalItemCount) rows = dict((n, make_row(entry)) for (n, entry) in enumerate(data)) diff --git a/src/neptune/handler.py b/src/neptune/handler.py index b52bb0cd3..2c06d1baa 100644 --- a/src/neptune/handler.py +++ b/src/neptune/handler.py @@ -557,7 +557,7 @@ def fetch_last(self): """ return self._pass_call_to_attr(function_name="fetch_last") - def fetch_values(self, *, include_timestamp: Optional[bool] = True): + def fetch_values(self, *, include_timestamp: Optional[bool] = True, progress_bar: Optional[ProgressBarType] = None): """Fetches all values stored in the series from Neptune. Available for the following field types: @@ -568,6 +568,11 @@ def fetch_values(self, *, include_timestamp: Optional[bool] = True): Args: include_timestamp (bool, optional): Whether the fetched data should include the timestamp field. Defaults to `True`. + progress_bar: (bool or Type of progress bar, optional): progress bar to be used while fetching values. + If `None` or `True` the default tqdm-based progress bar will be used. + If `False` no progress bar will be used. + If a type of progress bar is passed, it will be used instead of the default one. + Defaults to `None`. Returns: ``Pandas.DataFrame``: containing all the values and their indexes stored in the series field. @@ -575,7 +580,11 @@ def fetch_values(self, *, include_timestamp: Optional[bool] = True): For more information on field types, see the docs: https://docs.neptune.ai/api-reference/field-types """ - return self._pass_call_to_attr(function_name="fetch_values", include_timestamp=include_timestamp) + return self._pass_call_to_attr( + function_name="fetch_values", + include_timestamp=include_timestamp, + progress_bar=progress_bar, + ) @check_protected_paths def delete_files(self, paths: Union[str, Iterable[str]], *, wait: bool = False) -> None: From 3542188d090887f080847c4530e4d6324f4d10c5 Mon Sep 17 00:00:00 2001 From: AleksanderWWW Date: Mon, 5 Feb 2024 21:05:10 +0100 Subject: [PATCH 19/22] changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index edc1e9de3..15cec0a0f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ - Added `ascending` parameter to `fetch_*_table()` methods ([#1602](https://github.com/neptune-ai/neptune-client/pull/1602)) - Added `progress_bar` parameter to `fetch_*_table()` methods ([#1599](https://github.com/neptune-ai/neptune-client/pull/1599)) - Added `progress_bar` parameter to `download()` method of the `Handler` class ([#1620](https://github.com/neptune-ai/neptune-client/pull/1620)) +- Added `progrees_bar` parameter to `fetch_values` method of the `Handler` class ([#1633](https://github.com/neptune-ai/neptune-client/pull/1633)) ### Fixes - Add direct requirement of `typing-extensions` ([#1586](https://github.com/neptune-ai/neptune-client/pull/1586)) From 57593d1570ec01477c39537e3b65cf4237ff1c0e Mon Sep 17 00:00:00 2001 From: AleksanderWWW Date: Mon, 5 Feb 2024 21:46:48 +0100 Subject: [PATCH 20/22] take the first fetch into account --- src/neptune/attributes/series/fetchable_series.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/neptune/attributes/series/fetchable_series.py b/src/neptune/attributes/series/fetchable_series.py index 8802855d6..583c84158 100644 --- a/src/neptune/attributes/series/fetchable_series.py +++ b/src/neptune/attributes/series/fetchable_series.py @@ -59,11 +59,12 @@ def make_row(entry: Row) -> Dict[str, Union[str, float, datetime]]: path = path_to_str(self._path) if hasattr(self, "_path") else "" with construct_progress_bar(progress_bar, f"Fetching {path} values") as bar: + bar.update(by=len(data), total=val.totalItemCount) # first fetch before the loop while offset < val.totalItemCount: batch = self._fetch_values_from_backend(offset, limit) data.extend(batch.values) offset += limit - bar.update(by=batch.totalItemCount, total=val.totalItemCount) + bar.update(by=len(batch.values), total=val.totalItemCount) rows = dict((n, make_row(entry)) for (n, entry) in enumerate(data)) From 7c4c7135f317366d292ecc25bf52e4b752c50fcc Mon Sep 17 00:00:00 2001 From: AleksanderWWW Date: Tue, 6 Feb 2024 09:16:23 +0100 Subject: [PATCH 21/22] Update CHANGELOG.md Co-authored-by: Sabine --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 15cec0a0f..d5cf87a13 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,7 @@ - Added `ascending` parameter to `fetch_*_table()` methods ([#1602](https://github.com/neptune-ai/neptune-client/pull/1602)) - Added `progress_bar` parameter to `fetch_*_table()` methods ([#1599](https://github.com/neptune-ai/neptune-client/pull/1599)) - Added `progress_bar` parameter to `download()` method of the `Handler` class ([#1620](https://github.com/neptune-ai/neptune-client/pull/1620)) -- Added `progrees_bar` parameter to `fetch_values` method of the `Handler` class ([#1633](https://github.com/neptune-ai/neptune-client/pull/1633)) +- Added `progress_bar` parameter to `fetch_values()` method of the `Handler` class ([#1633](https://github.com/neptune-ai/neptune-client/pull/1633)) ### Fixes - Add direct requirement of `typing-extensions` ([#1586](https://github.com/neptune-ai/neptune-client/pull/1586)) From 15090d37e582aaddfe26aedf996325ea06b3344e Mon Sep 17 00:00:00 2001 From: AleksanderWWW Date: Tue, 6 Feb 2024 12:01:21 +0100 Subject: [PATCH 22/22] Update CHANGELOG.md --- CHANGELOG.md | 1 - 1 file changed, 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1e9feb9c4..985a90826 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,7 +27,6 @@ - Changed internal directories path structure ([#1606](https://github.com/neptune-ai/neptune-client/pull/1606)) - ## 1.8.6 ### Fixes