Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use regenerate table in replace strategy #1707

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ docs/guides/*.csv
__pycache__
.*.sw[nop]
.vscode
_debug

# OS
.DS_Store
Expand Down Expand Up @@ -63,6 +64,7 @@ htmlcov
test_*.json
.pytest_cache
tmp_file.csv
my_dataset.csv
Jesus89 marked this conversation as resolved.
Show resolved Hide resolved
fake_path

# Sphinx documentation
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Remove unused BigQueryClient code (#1602)
- Repo clean-up. Refactor docs (#1682)
- Add tests for notebook execution (#1696)
- Use regenerate table in replace strategy (#1707)

### Fixed
- Remove the batch_size parameter in the call to bulk_geocode (#1666)
Expand Down
32 changes: 26 additions & 6 deletions cartoframes/io/managers/context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def copy_to(self, source, schema=None, limit=None, retry_times=DEFAULT_RETRY_TIM
copy_query = self._get_copy_query(query, columns, limit)
return self._copy_to(copy_query, columns, retry_times)

def copy_from(self, gdf, table_name, if_exists='fail', cartodbfy=True, retry_times=DEFAULT_RETRY_TIMES):
def copy_from(self, gdf, table_name, if_exists='fail', cartodbfy=True,
retry_times=DEFAULT_RETRY_TIMES):
schema = self.get_schema()
table_name = self.normalize_table_name(table_name)
df_columns = get_dataframe_columns_info(gdf)
Expand All @@ -98,7 +99,8 @@ def copy_from(self, gdf, table_name, if_exists='fail', cartodbfy=True, retry_tim
self._truncate_table(table_name, schema, cartodbfy)
else:
# Diff columns: truncate table and drop + add columns
self._truncate_and_drop_add_columns(table_name, schema, df_columns, table_columns, cartodbfy)
self._truncate_and_drop_add_columns(
table_name, schema, df_columns, table_columns, cartodbfy)

elif if_exists == 'fail':
raise Exception('Table "{schema}.{table_name}" already exists in your CARTO account. '
Expand Down Expand Up @@ -190,7 +192,9 @@ def get_schema(self):
"""Get user schema from current credentials"""
query = 'SELECT current_schema()'
result = self.execute_query(query, do_post=False)
return result['rows'][0]['current_schema']
schema = result['rows'][0]['current_schema']
log.debug('schema: {}'.format(schema))
return schema

def get_geom_type(self, query):
"""Fetch geom type of a remote table or query"""
Expand Down Expand Up @@ -290,7 +294,8 @@ def _truncate_table(self, table_name, schema, cartodbfy):

def _truncate_and_drop_add_columns(self, table_name, schema, df_columns, table_columns, cartodbfy):
log.debug('TRUNCATE AND DROP + ADD columns table "{}"'.format(table_name))
query = 'BEGIN; {truncate}; {drop_columns}; {add_columns}; {cartodbfy}; COMMIT;'.format(
query = '{regenerate}; BEGIN; {truncate}; {drop_columns}; {add_columns}; {cartodbfy}; COMMIT;'.format(
regenerate=_regenerate_table_query(table_name, schema) if self._check_regenerate_table_exists() else '',
truncate=_truncate_table_query(table_name),
drop_columns=_drop_columns_query(table_name, table_columns),
add_columns=_add_columns_query(table_name, df_columns),
Expand All @@ -317,6 +322,16 @@ def _check_exists(self, query):
except CartoException:
return False

def _check_regenerate_table_exists(self):
query = '''
SELECT 1
FROM pg_catalog.pg_proc p
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace
WHERE p.proname = 'cdb_regeneratetable' AND n.nspname = 'cartodb';
'''
result = self.execute_query(query)
return len(result['rows']) > 0

def _get_query_columns_info(self, query):
query = 'SELECT * FROM ({}) _q LIMIT 0'.format(query)
table_info = self.execute_query(query)
Expand All @@ -343,7 +358,7 @@ def _get_copy_query(self, query, columns, limit):
@retry_copy
def _copy_to(self, query, columns, retry_times=DEFAULT_RETRY_TIMES):
log.debug('COPY TO')
copy_query = 'COPY ({0}) TO stdout WITH (FORMAT csv, HEADER true, NULL \'{1}\')'.format(query, PG_NULL)
copy_query = "COPY ({0}) TO stdout WITH (FORMAT csv, HEADER true, NULL '{1}')".format(query, PG_NULL)

raw_result = self.copy_client.copyto_stream(copy_query)

Expand Down Expand Up @@ -424,7 +439,12 @@ def _create_table_from_query_query(table_name, query):


def _cartodbfy_query(table_name, schema):
return 'SELECT CDB_CartodbfyTable(\'{schema}\', \'{table_name}\')'.format(
return "SELECT CDB_CartodbfyTable('{schema}', '{table_name}')".format(
schema=schema, table_name=table_name)


def _regenerate_table_query(table_name, schema):
return "SELECT CDB_RegenerateTable('{schema}.{table_name}'::regclass)".format(
schema=schema, table_name=table_name)


Expand Down
44 changes: 20 additions & 24 deletions tests/unit/data/observatory/catalog/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,28 +294,28 @@ def test_datasets_are_exported_as_dataframe(self):
assert isinstance(sliced_dataset, pd.Series)
assert sliced_dataset.equals(expected_dataset_df)

@patch.object(DatasetRepository, 'get_all')
@patch('cartoframes.data.observatory.catalog.subscriptions.get_subscription_ids')
@patch.object(DatasetRepository, 'get_by_id')
@patch.object(DODataset, 'download_stream')
def test_dataset_download(self, download_stream_mock, get_by_id_mock, get_all_mock):
def test_dataset_download(self, mock_download_stream, mock_get_by_id, mock_subscription_ids):
# Given
get_by_id_mock.return_value = test_dataset1
mock_get_by_id.return_value = test_dataset1
dataset = Dataset.get(test_dataset1.id)
get_all_mock.return_value = [dataset]
download_stream_mock.return_value = []
mock_download_stream.return_value = []
mock_subscription_ids.return_value = [test_dataset1.id]
credentials = Credentials('fake_user', '1234')

# Then
dataset.to_csv('fake_path', credentials)
os.remove('fake_path')

@patch.object(DatasetRepository, 'get_all')
@patch('cartoframes.data.observatory.catalog.subscriptions.get_subscription_ids')
@patch.object(DatasetRepository, 'get_by_id')
def test_dataset_not_subscribed_download_fails(self, get_by_id_mock, get_all_mock):
# mock dataset
get_by_id_mock.return_value = test_dataset2 # is private
def test_dataset_not_subscribed_download_not_subscribed(self, mock_get_by_id, mock_subscription_ids):
# Given
mock_get_by_id.return_value = test_dataset2 # is private
dataset = Dataset.get(test_dataset2.id)
get_all_mock.return_value = []
mock_subscription_ids.return_value = []
credentials = Credentials('fake_user', '1234')

# When
Expand All @@ -327,32 +327,28 @@ def test_dataset_not_subscribed_download_fails(self, get_by_id_mock, get_all_moc
'You are not subscribed to this Dataset yet. '
'Please, use the subscribe method first.')

@patch.object(DatasetRepository, 'get_all')
@patch.object(DatasetRepository, 'get_by_id')
@patch.object(DODataset, 'download_stream')
def test_dataset_download_not_subscribed_but_public(self, download_stream_mock, get_by_id_mock, get_all_mock):
def test_dataset_download_not_subscribed_but_public(self, mock_download_stream, mock_get_by_id):
# Given
get_by_id_mock.return_value = test_dataset1 # is public
mock_get_by_id.return_value = test_dataset1 # is public
dataset = Dataset.get(test_dataset1.id)
get_all_mock.return_value = []
download_stream_mock.return_value = []
mock_download_stream.return_value = []
credentials = Credentials('fake_user', '1234')

dataset.to_csv('fake_path', credentials)
os.remove('fake_path')

@patch.object(DatasetRepository, 'get_all')
@patch.object(DatasetRepository, 'get_by_id')
@patch.object(DODataset, 'download_stream')
def test_dataset_download_without_do_enabled(self, download_stream_mock, get_by_id_mock, get_all_mock):
def test_dataset_download_without_do_enabled(self, mock_download_stream, mock_get_by_id):
# Given
get_by_id_mock.return_value = test_dataset1
mock_get_by_id.return_value = test_dataset1
dataset = Dataset.get(test_dataset1.id)
get_all_mock.return_value = []

def raise_exception(limit=None, order_by=None, sql_query=None, add_geom=None, is_geography=None):
raise ServerErrorException(['The user does not have Data Observatory enabled'])
download_stream_mock.side_effect = raise_exception
mock_download_stream.side_effect = raise_exception
credentials = Credentials('fake_user', '1234')

# When
Expand Down Expand Up @@ -406,10 +402,10 @@ def test_dataset_subscribe_existing(self, mock_display_message, mock_display_for
@patch('cartoframes.data.observatory.catalog.subscriptions.get_subscription_ids')
@patch('cartoframes.data.observatory.catalog.utils.display_subscription_form')
@patch('cartoframes.auth.defaults.get_default_credentials')
def test_dataset_subscribe_default_credentials(self, mocked_credentials, mock_display_form, mock_subscription_ids):
def test_dataset_subscribe_default_credentials(self, mock_credentials, mock_display_form, mock_subscription_ids):
# Given
expected_credentials = Credentials('fake_user', '1234')
mocked_credentials.return_value = expected_credentials
mock_credentials.return_value = expected_credentials
dataset = Dataset(db_dataset1)

# When
Expand Down Expand Up @@ -480,10 +476,10 @@ def test_dataset_subscription_info(self, mock_fetch):

@patch('cartoframes.data.observatory.catalog.subscription_info.fetch_subscription_info')
@patch('cartoframes.auth.defaults.get_default_credentials')
def test_dataset_subscription_info_default_credentials(self, mocked_credentials, mock_fetch):
def test_dataset_subscription_info_default_credentials(self, mock_credentials, mock_fetch):
# Given
expected_credentials = Credentials('fake_user', '1234')
mocked_credentials.return_value = expected_credentials
mock_credentials.return_value = expected_credentials
dataset = Dataset(db_dataset1)

# When
Expand Down
37 changes: 15 additions & 22 deletions tests/unit/data/observatory/catalog/test_geography.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,31 +220,28 @@ def test_geographies_are_exported_as_dataframe(self):
assert isinstance(sliced_geography, pd.Series)
assert sliced_geography.equals(expected_geography_df)

@patch.object(GeographyRepository, 'get_all')
@patch('cartoframes.data.observatory.catalog.subscriptions.get_subscription_ids')
@patch.object(GeographyRepository, 'get_by_id')
@patch.object(DODataset, 'download_stream')
def test_geography_download(self, download_stream_mock, get_by_id_mock, get_all_mock):
def test_geography_download(self, mock_download_stream, mock_get_by_id, mock_subscription_ids):
# Given
get_by_id_mock.return_value = test_geography1
mock_get_by_id.return_value = test_geography1
geography = Geography.get(test_geography1.id)
get_all_mock.return_value = [geography]
download_stream_mock.return_value = []
mock_download_stream.return_value = []
mock_subscription_ids.return_value = [test_geography1.id]
credentials = Credentials('fake_user', '1234')

# Then
geography.to_csv('fake_path', credentials)
os.remove('fake_path')

@patch.object(GeographyRepository, 'get_all')
@patch('cartoframes.data.observatory.catalog.subscriptions.get_subscription_ids')
@patch.object(GeographyRepository, 'get_by_id')
@patch.object(DODataset, 'download_stream')
def test_geography_download_not_subscribed(self, download_stream_mock, get_by_id_mock, get_all_mock):
def test_geography_download_not_subscribed(self, mock_get_by_id, mock_subscription_ids):
# Given
get_by_id_mock.return_value = test_geography2 # is private
get_by_id_mock.return_value = test_geography2
mock_get_by_id.return_value = test_geography2 # is private
geography = Geography.get(test_geography2.id)
get_all_mock.return_value = []
download_stream_mock.return_value = []
mock_subscription_ids.return_value = []
credentials = Credentials('fake_user', '1234')

with pytest.raises(Exception) as e:
Expand All @@ -255,32 +252,28 @@ def test_geography_download_not_subscribed(self, download_stream_mock, get_by_id
'You are not subscribed to this Geography yet. '
'Please, use the subscribe method first.')

@patch.object(GeographyRepository, 'get_all')
@patch.object(GeographyRepository, 'get_by_id')
@patch.object(DODataset, 'download_stream')
def test_geography_download_not_subscribed_but_public(self, download_stream_mock, get_by_id_mock, get_all_mock):
def test_geography_download_not_subscribed_but_public(self, mock_download_stream, mock_get_by_id):
# Given
get_by_id_mock.return_value = test_geography1 # is public
mock_get_by_id.return_value = test_geography1 # is public
geography = Geography.get(test_geography1.id)
get_all_mock.return_value = []
download_stream_mock.return_value = []
mock_download_stream.return_value = []
credentials = Credentials('fake_user', '1234')

geography.to_csv('fake_path', credentials)
os.remove('fake_path')

@patch.object(GeographyRepository, 'get_all')
@patch.object(GeographyRepository, 'get_by_id')
@patch.object(DODataset, 'download_stream')
def test_geography_download_without_do_enabled(self, download_stream_mock, get_by_id_mock, get_all_mock):
def test_geography_download_without_do_enabled(self, mock_download_stream, mock_get_by_id):
# Given
get_by_id_mock.return_value = test_geography1
mock_get_by_id.return_value = test_geography1
geography = Geography.get(test_geography1.id)
get_all_mock.return_value = []

def raise_exception(limit=None, order_by=None, sql_query=None, add_geom=None, is_geography=None):
raise ServerErrorException(['The user does not have Data Observatory enabled'])
download_stream_mock.side_effect = raise_exception
mock_download_stream.side_effect = raise_exception
credentials = Credentials('fake_user', '1234')

# When
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/io/test_carto.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def test_to_carto_two_geom_columns(mocker):
'the_geom': '010100000000000000000000000000000000000000'})

# When
norm_table_name = to_carto(df, table_name, CREDENTIALS)
norm_table_name = to_carto(df, table_name, CREDENTIALS, skip_quota_warning=True)

# Then
assert cm_mock.call_args[0][1] == table_name
Expand All @@ -343,7 +343,7 @@ def test_to_carto_two_geom_columns_and_geom_col(mocker):
'the_geom': '010100000000000000000000000000000000000000'})

# When
norm_table_name = to_carto(df, table_name, CREDENTIALS, geom_col='geometry')
norm_table_name = to_carto(df, table_name, CREDENTIALS, geom_col='geometry', skip_quota_warning=True)

# Then
assert cm_mock.call_args[0][1] == table_name
Expand Down