Skip to content

Commit

Permalink
Merge pull request #1625 from CartoDB/jarroyo/ch72147/fix-providers-i…
Browse files Browse the repository at this point in the history
…n-catalog

Add providers to Catalog
  • Loading branch information
Jesus89 authored May 7, 2020
2 parents d53a347 + 013152b commit 7dca384
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 10 deletions.
18 changes: 16 additions & 2 deletions cartoframes/data/observatory/catalog/catalog.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .dataset import Dataset
from .entity import is_slug_value
from .category import Category
from .country import Country
from .category import Category
from .provider import Provider
from .dataset import Dataset
from .geography import Geography
from .subscriptions import Subscriptions
from .repository.constants import COUNTRY_FILTER, CATEGORY_FILTER, GEOGRAPHY_FILTER, PROVIDER_FILTER, PUBLIC_FILTER
Expand Down Expand Up @@ -143,6 +144,19 @@ def categories(self):
"""
return Category.get_all(self.filters)

@property
def providers(self):
"""Get all the providers in the Catalog.
Returns:
:py:class:`CatalogList <cartoframes.data.observatory.entity.CatalogList>`
Raises:
CatalogError: if there's a problem when connecting to the catalog or no datasets are found.
"""
return Provider.get_all(self.filters)

@property
def datasets(self):
"""Get all the datasets in the Catalog.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from .constants import COUNTRY_FILTER
from .constants import COUNTRY_FILTER, PROVIDER_FILTER
from .entity_repo import EntityRepository


_CATEGORY_ID_FIELD = 'id'
_ALLOWED_FILTERS = [COUNTRY_FILTER]
_ALLOWED_FILTERS = [COUNTRY_FILTER, PROVIDER_FILTER]


def get_category_repo():
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from .constants import CATEGORY_FILTER
from .constants import CATEGORY_FILTER, PROVIDER_FILTER
from .entity_repo import EntityRepository


_COUNTRY_ID_FIELD = 'id'
_ALLOWED_FILTERS = [CATEGORY_FILTER]
_ALLOWED_FILTERS = [CATEGORY_FILTER, PROVIDER_FILTER]


def get_country_repo():
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from .constants import CATEGORY_FILTER, COUNTRY_FILTER
from .entity_repo import EntityRepository


_PROVIDER_ID_FIELD = 'id'
_ALLOWED_FILTERS = [CATEGORY_FILTER, COUNTRY_FILTER]


def get_provider_repo():
Expand All @@ -11,7 +13,7 @@ def get_provider_repo():
class ProviderRepository(EntityRepository):

def __init__(self):
super(ProviderRepository, self).__init__(_PROVIDER_ID_FIELD, [])
super(ProviderRepository, self).__init__(_PROVIDER_ID_FIELD, _ALLOWED_FILTERS)

@classmethod
def _get_entity_class(cls):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def test_get_all_only_uses_allowed_filters(self, mocked_repo):

# Then
mocked_repo.assert_called_once_with({
COUNTRY_FILTER: 'usa'
COUNTRY_FILTER: 'usa',
PROVIDER_FILTER: 'open_data'
})
assert categories == test_categories

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def test_get_all_only_uses_allowed_filters(self, mocked_repo):

# Then
mocked_repo.assert_called_once_with({
CATEGORY_FILTER: 'demographics'
CATEGORY_FILTER: 'demographics',
PROVIDER_FILTER: 'open_data'
})
assert countries == test_countries

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
from cartoframes.data.observatory.catalog.provider import Provider
from cartoframes.data.observatory.catalog.repository.provider_repo import ProviderRepository
from cartoframes.data.observatory.catalog.repository.repo_client import RepoClient
from cartoframes.data.observatory.catalog.repository.constants import (
CATEGORY_FILTER, COUNTRY_FILTER, DATASET_FILTER, GEOGRAPHY_FILTER, VARIABLE_FILTER,
VARIABLE_GROUP_FILTER
)
from ..examples import test_provider1, test_providers, db_provider1, db_provider2


Expand Down Expand Up @@ -39,6 +43,31 @@ def test_get_all_when_empty(self, mocked_repo):
mocked_repo.assert_called_once_with(None)
assert providers is None

@patch.object(RepoClient, 'get_providers')
def test_get_all_only_uses_allowed_filters(self, mocked_repo):
# Given
mocked_repo.return_value = [db_provider1, db_provider2]
repo = ProviderRepository()
filters = {
COUNTRY_FILTER: 'usa',
DATASET_FILTER: 'carto-do.project.census2011',
CATEGORY_FILTER: 'demographics',
VARIABLE_FILTER: 'population',
GEOGRAPHY_FILTER: 'census-geo',
VARIABLE_GROUP_FILTER: 'var-group',
'fake_field_id': 'fake_value'
}

# When
providers = repo.get_all(filters)

# Then
mocked_repo.assert_called_once_with({
COUNTRY_FILTER: 'usa',
CATEGORY_FILTER: 'demographics'
})
assert providers == test_providers

@patch.object(RepoClient, 'get_providers')
def test_get_by_id(self, mocked_repo):
# Given
Expand Down
17 changes: 16 additions & 1 deletion tests/unit/data/observatory/catalog/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from cartoframes.data.observatory.catalog.geography import Geography
from cartoframes.data.observatory.catalog.country import Country
from cartoframes.data.observatory.catalog.category import Category
from cartoframes.data.observatory.catalog.provider import Provider
from cartoframes.data.observatory.catalog.catalog import Catalog
from cartoframes.data.observatory.catalog.subscriptions import Subscriptions
from cartoframes.data.observatory.catalog.repository.geography_repo import GeographyRepository
Expand All @@ -15,7 +16,8 @@
)
from .examples import (
test_country2, test_country1, test_category1, test_category2, test_dataset1, test_dataset2,
test_geographies, test_datasets, test_categories, test_countries, test_geography1, test_geography2
test_geographies, test_datasets, test_categories, test_countries, test_geography1, test_geography2,
test_provider1, test_provider2
)


Expand Down Expand Up @@ -47,6 +49,19 @@ def test_categories(self, mocked_categories):
# Then
assert categories == expected_categories

@patch.object(Provider, 'get_all')
def test_providers(self, mocked_providers):
# Given
expected_providers = [test_provider1, test_provider2]
mocked_providers.return_value = expected_providers
catalog = Catalog()

# When
providers = catalog.providers

# Then
assert providers == expected_providers

@patch.object(Dataset, 'get_all')
def test_datasets(self, mocked_datasets):
# Given
Expand Down

0 comments on commit 7dca384

Please sign in to comment.