Skip to content

Commit

Permalink
Merge pull request #1594 from CartoDB/dgaubert/ch61421/integrate-do-c…
Browse files Browse the repository at this point in the history
…lient-in-to-dataframe-and-to

Use DODataset to download a dataset from DO
  • Loading branch information
dgaubert authored Apr 2, 2020
2 parents e93e955 + b75d937 commit 9d37c17
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 60 deletions.
36 changes: 9 additions & 27 deletions cartoframes/data/observatory/catalog/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from abc import ABC

from ...clients.bigquery_client import BigQueryClient
from carto.do_dataset import DODataset
from ....utils.logger import log
from ....exceptions import DOError

Expand Down Expand Up @@ -126,35 +126,21 @@ def _download(self, credentials, file_path=None, limit=None, order_by=None):
if not self._is_available_in('bq'):
raise DOError('{} is not ready for Download. Please, contact us for more information.'.format(self))

bq_client = _get_bigquery_client(credentials)

full_remote_table_name = self._get_remote_full_table_name(
bq_client.bq_project,
bq_client.bq_dataset,
bq_client.bq_public_project
)

project, dataset, table = full_remote_table_name.split('.')

column_names = bq_client.get_table_column_names(project, dataset, table)

query = 'SELECT * FROM `{}`'.format(full_remote_table_name)
if order_by:
query = '{} ORDER BY {}'.format(query, order_by)
if limit:
query = '{} LIMIT {}'.format(query, limit)

job = bq_client.query(query)

auth_client = credentials.get_api_key_auth_client()
rows = DODataset(auth_client=auth_client).name(self.id).download_stream(limit=limit, order_by=order_by)
if file_path:
bq_client.download_to_file(job, file_path, column_names=column_names)
with open(file_path, 'w') as csvfile:
for row in rows:
csvfile.write(row.decode('utf-8'))

log.info('Data saved: {}'.format(file_path))
if self.__class__.__name__ == 'Dataset':
log.info(_DATASET_READ_MSG.format(file_path))
elif self.__class__.__name__ == 'Geography':
log.info(_GEOGRAPHY_READ_MSG.format(file_path))
else:
return bq_client.download_to_dataframe(job)
dataframe = pd.read_csv(rows)
return dataframe

def _is_available_in(self, platform=_PLATFORM_BQ):
return self.data['available_in'] and platform in self.data['available_in']
Expand All @@ -172,10 +158,6 @@ def _get_remote_full_table_name(self, user_project, user_dataset, public_project
return self.id


def _get_bigquery_client(credentials):
return BigQueryClient(credentials)


def is_slug_value(id_value):
return len(id_value.split('.')) == 1

Expand Down
29 changes: 14 additions & 15 deletions tests/unit/data/observatory/catalog/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
test_dataset1, test_datasets, test_variables, test_variables_groups, db_dataset1, test_dataset2,
db_dataset2, test_subscription_info
)
from .mocks import BigQueryClientMock
from carto.do_dataset import DODataset


class TestDataset(object):
Expand Down Expand Up @@ -295,27 +295,25 @@ def test_datasets_are_exported_as_dataframe(self):

@patch.object(DatasetRepository, 'get_all')
@patch.object(DatasetRepository, 'get_by_id')
@patch('cartoframes.data.observatory.catalog.entity._get_bigquery_client')
def test_dataset_download(self, mocked_bq_client, get_by_id_mock, get_all_mock):
@patch.object(DODataset, 'download_stream')
def test_dataset_download(self, download_stream_mock, get_by_id_mock, get_all_mock):
# Given
get_by_id_mock.return_value = test_dataset1
dataset = Dataset.get(test_dataset1.id)
get_all_mock.return_value = [dataset]
mocked_bq_client.return_value = BigQueryClientMock()
download_stream_mock.return_value = []
credentials = Credentials('fake_user', '1234')

# Then
dataset.to_csv('fake_path', credentials)

@patch.object(DatasetRepository, 'get_all')
@patch.object(DatasetRepository, 'get_by_id')
@patch('cartoframes.data.observatory.catalog.entity._get_bigquery_client')
def test_dataset_not_subscribed_download_fails(self, mocked_bq_client, get_by_id_mock, get_all_mock):
def test_dataset_not_subscribed_download_fails(self, get_by_id_mock, get_all_mock):
# mock dataset
get_by_id_mock.return_value = test_dataset2 # is private
dataset = Dataset.get(test_dataset2.id)
get_all_mock.return_value = []
mocked_bq_client.return_value = BigQueryClientMock()
credentials = Credentials('fake_user', '1234')

# When
Expand All @@ -329,28 +327,29 @@ def test_dataset_not_subscribed_download_fails(self, mocked_bq_client, get_by_id

@patch.object(DatasetRepository, 'get_all')
@patch.object(DatasetRepository, 'get_by_id')
@patch('cartoframes.data.observatory.catalog.entity._get_bigquery_client')
def test_dataset_download_not_subscribed_but_public(self, mocked_bq_client, get_by_id_mock, get_all_mock):
@patch.object(DODataset, 'download_stream')
def test_dataset_download_not_subscribed_but_public(self, download_stream_mock, get_by_id_mock, get_all_mock):
# Given
get_by_id_mock.return_value = test_dataset1 # is public
dataset = Dataset.get(test_dataset1.id)
get_all_mock.return_value = []
mocked_bq_client.return_value = BigQueryClientMock()
download_stream_mock.return_value = []
credentials = Credentials('fake_user', '1234')

dataset.to_csv('fake_path', credentials)

@patch.object(DatasetRepository, 'get_all')
@patch.object(DatasetRepository, 'get_by_id')
@patch('cartoframes.data.observatory.catalog.entity._get_bigquery_client')
def test_dataset_download_without_do_enabled(self, mocked_bq_client, get_by_id_mock, get_all_mock):
@patch.object(DODataset, 'download_stream')
def test_dataset_download_without_do_enabled(self, download_stream_mock, get_by_id_mock, get_all_mock):
# Given
get_by_id_mock.return_value = test_dataset1
dataset = Dataset.get(test_dataset1.id)
get_all_mock.return_value = []
mocked_bq_client.return_value = BigQueryClientMock(
ServerErrorException(['The user does not have Data Observatory enabled'])
)

def raise_exception(limit=None, order_by=None):
raise ServerErrorException(['The user does not have Data Observatory enabled'])
download_stream_mock.side_effect = raise_exception
credentials = Credentials('fake_user', '1234')

# When
Expand Down
37 changes: 19 additions & 18 deletions tests/unit/data/observatory/catalog/test_geography.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
test_geography1, test_geographies, test_datasets, db_geography1,
test_geography2, db_geography2, test_subscription_info
)
from .mocks import BigQueryClientMock
from carto.do_dataset import DODataset


class TestGeography(object):
Expand Down Expand Up @@ -221,8 +221,8 @@ def test_geographies_are_exported_as_dataframe(self):

@patch.object(GeographyRepository, 'get_all')
@patch.object(GeographyRepository, 'get_by_id')
@patch('cartoframes.data.observatory.catalog.entity._get_bigquery_client')
def test_geography_not_available_in_bq_download_fails(self, mocked_bq_client, get_by_id_mock, get_all_mock):
@patch.object(DODataset, 'download_stream')
def test_geography_not_available_in_bq_download_fails(self, download_stream_mock, get_by_id_mock, get_all_mock):
# mock geography
get_by_id_mock.return_value = test_geography2
geography = Geography.get(test_geography2.id)
Expand All @@ -231,7 +231,7 @@ def test_geography_not_available_in_bq_download_fails(self, mocked_bq_client, ge
get_all_mock.return_value = [geography]

# mock big query client
mocked_bq_client.return_value = BigQueryClientMock()
download_stream_mock.return_value = []

# test
credentials = Credentials('fake_user', '1234')
Expand All @@ -244,28 +244,28 @@ def test_geography_not_available_in_bq_download_fails(self, mocked_bq_client, ge

@patch.object(GeographyRepository, 'get_all')
@patch.object(GeographyRepository, 'get_by_id')
@patch('cartoframes.data.observatory.catalog.entity._get_bigquery_client')
def test_geography_download(self, mocked_bq_client, get_by_id_mock, get_all_mock):
@patch.object(DODataset, 'download_stream')
def test_geography_download(self, download_stream_mock, get_by_id_mock, get_all_mock):
# Given
get_by_id_mock.return_value = test_geography1
geography = Geography.get(test_geography1.id)
get_all_mock.return_value = [geography]
mocked_bq_client.return_value = BigQueryClientMock()
download_stream_mock.return_value = []
credentials = Credentials('fake_user', '1234')

# Then
geography.to_csv('fake_path', credentials)

@patch.object(GeographyRepository, 'get_all')
@patch.object(GeographyRepository, 'get_by_id')
@patch('cartoframes.data.observatory.catalog.entity._get_bigquery_client')
def test_geography_download_not_subscribed(self, mocked_bq_client, get_by_id_mock, get_all_mock):
@patch.object(DODataset, 'download_stream')
def test_geography_download_not_subscribed(self, download_stream_mock, get_by_id_mock, get_all_mock):
# Given
get_by_id_mock.return_value = test_geography2 # is private
get_by_id_mock.return_value = test_geography2
geography = Geography.get(test_geography2.id)
get_all_mock.return_value = []
mocked_bq_client.return_value = BigQueryClientMock()
download_stream_mock.return_value = []
credentials = Credentials('fake_user', '1234')

with pytest.raises(Exception) as e:
Expand All @@ -278,28 +278,29 @@ def test_geography_download_not_subscribed(self, mocked_bq_client, get_by_id_moc

@patch.object(GeographyRepository, 'get_all')
@patch.object(GeographyRepository, 'get_by_id')
@patch('cartoframes.data.observatory.catalog.entity._get_bigquery_client')
def test_geography_download_not_subscribed_but_public(self, mocked_bq_client, get_by_id_mock, get_all_mock):
@patch.object(DODataset, 'download_stream')
def test_geography_download_not_subscribed_but_public(self, download_stream_mock, get_by_id_mock, get_all_mock):
# Given
get_by_id_mock.return_value = test_geography1 # is public
geography = Geography.get(test_geography1.id)
get_all_mock.return_value = []
mocked_bq_client.return_value = BigQueryClientMock()
download_stream_mock.return_value = []
credentials = Credentials('fake_user', '1234')

geography.to_csv('fake_path', credentials)

@patch.object(GeographyRepository, 'get_all')
@patch.object(GeographyRepository, 'get_by_id')
@patch('cartoframes.data.observatory.catalog.entity._get_bigquery_client')
def test_geography_download_without_do_enabled(self, mocked_bq_client, get_by_id_mock, get_all_mock):
@patch.object(DODataset, 'download_stream')
def test_geography_download_without_do_enabled(self, download_stream_mock, get_by_id_mock, get_all_mock):
# Given
get_by_id_mock.return_value = test_geography1
geography = Geography.get(test_geography1.id)
get_all_mock.return_value = []
mocked_bq_client.return_value = BigQueryClientMock(
ServerErrorException(['The user does not have Data Observatory enabled'])
)

def raise_exception(limit=None, order_by=None):
raise ServerErrorException(['The user does not have Data Observatory enabled'])
download_stream_mock.side_effect = raise_exception
credentials = Credentials('fake_user', '1234')

# When
Expand Down

0 comments on commit 9d37c17

Please sign in to comment.