Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
shnela committed Aug 30, 2022
1 parent 164c26d commit 6bfcc3b
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 11 deletions.
16 changes: 14 additions & 2 deletions tests/neptune/new/client/abstract_tables_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import uuid
from abc import abstractmethod
from datetime import datetime
from typing import List

from mock import Mock, patch

Expand All @@ -32,6 +33,7 @@
LeaderboardEntry,
)
from neptune.new.internal.backends.neptune_backend_mock import NeptuneBackendMock
from neptune.new.metadata_containers.metadata_containers_table import Table, TableEntry


@patch(
Expand All @@ -43,11 +45,11 @@ class AbstractTablesTestMixin:
expected_container_type = None

@abstractmethod
def get_table(self):
def get_table(self, **kwargs) -> Table:
pass

@abstractmethod
def get_table_entries(self, table):
def get_table_entries(self, table) -> List[TableEntry]:
pass

@classmethod
Expand Down Expand Up @@ -95,6 +97,16 @@ def build_attributes_leaderboard(now: datetime):
attributes.append(AttributeWithProperties("image/series", AttributeType.IMAGE_SERIES, None))
return attributes

@patch.object(NeptuneBackendMock, "search_leaderboard_entries")
def test_get_table_with_columns_filter(self, search_leaderboard_entries):
# when
self.get_table(columns=["datetime"])

# then
self.assertEqual(1, search_leaderboard_entries.call_count)
parameters = search_leaderboard_entries.call_args[1]
self.assertEqual({"sys/id", "datetime"}, parameters.get("columns"))

@patch.object(NeptuneBackendMock, "search_leaderboard_entries")
def test_get_table_as_pandas(self, search_leaderboard_entries):
# given
Expand Down
8 changes: 5 additions & 3 deletions tests/neptune/new/client/test_model_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,19 @@
#

import unittest
from typing import List

from neptune.new import get_project
from neptune.new.internal.container_type import ContainerType
from neptune.new.metadata_containers.metadata_containers_table import Table, TableEntry
from tests.neptune.new.client.abstract_tables_test import AbstractTablesTestMixin


class TestModelTables(AbstractTablesTestMixin, unittest.TestCase):
expected_container_type = ContainerType.MODEL

def get_table(self):
return get_project("organization/project").fetch_models_table()
def get_table(self, **kwargs) -> Table:
return get_project("organization/project").fetch_models_table(**kwargs)

def get_table_entries(self, table):
def get_table_entries(self, table) -> List[TableEntry]:
return table.to_rows()
8 changes: 5 additions & 3 deletions tests/neptune/new/client/test_model_version_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,23 @@
#

import unittest
from typing import List

from neptune.new import init_model
from neptune.new.internal.container_type import ContainerType
from neptune.new.metadata_containers.metadata_containers_table import Table, TableEntry
from tests.neptune.new.client.abstract_tables_test import AbstractTablesTestMixin


class TestModelVersionTables(AbstractTablesTestMixin, unittest.TestCase):
expected_container_type = ContainerType.MODEL_VERSION

def get_table(self):
def get_table(self, **kwargs) -> Table:
return init_model(
model="organization/project",
project="PRO-MOD",
mode="read-only",
).fetch_model_versions_table()
).fetch_model_versions_table(**kwargs)

def get_table_entries(self, table):
def get_table_entries(self, table) -> List[TableEntry]:
return table.to_rows()
8 changes: 5 additions & 3 deletions tests/neptune/new/client/test_run_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,19 @@
#

import unittest
from typing import List

from neptune.new import get_project
from neptune.new.internal.container_type import ContainerType
from neptune.new.metadata_containers.metadata_containers_table import Table, TableEntry
from tests.neptune.new.client.abstract_tables_test import AbstractTablesTestMixin


class TestRunTables(AbstractTablesTestMixin, unittest.TestCase):
expected_container_type = ContainerType.RUN

def get_table(self):
return get_project("organization/project").fetch_runs_table()
def get_table(self, **kwargs) -> Table:
return get_project("organization/project").fetch_runs_table(**kwargs)

def get_table_entries(self, table):
def get_table_entries(self, table) -> List[TableEntry]:
return table.to_rows()

0 comments on commit 6bfcc3b

Please sign in to comment.