diff --git a/.gitignore b/.gitignore index 9b74f2d70..e332fc6b7 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,7 @@ docs/guides/*.csv __pycache__ .*.sw[nop] .vscode +_debug # OS .DS_Store @@ -63,6 +64,7 @@ htmlcov test_*.json .pytest_cache tmp_file.csv +my_dataset.csv fake_path # Sphinx documentation diff --git a/CHANGELOG.md b/CHANGELOG.md index ad0e6f3de..646eeec20 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Remove unused BigQueryClient code (#1602) - Repo clean-up. Refactor docs (#1682) - Add tests for notebook execution (#1696) +- Use regenerate table in replace strategy (#1707) ### Fixed - Remove the batch_size parameter in the call to bulk_geocode (#1666) diff --git a/cartoframes/io/managers/context_manager.py b/cartoframes/io/managers/context_manager.py index dbb85d6eb..7343f03ed 100644 --- a/cartoframes/io/managers/context_manager.py +++ b/cartoframes/io/managers/context_manager.py @@ -83,7 +83,8 @@ def copy_to(self, source, schema=None, limit=None, retry_times=DEFAULT_RETRY_TIM copy_query = self._get_copy_query(query, columns, limit) return self._copy_to(copy_query, columns, retry_times) - def copy_from(self, gdf, table_name, if_exists='fail', cartodbfy=True, retry_times=DEFAULT_RETRY_TIMES): + def copy_from(self, gdf, table_name, if_exists='fail', cartodbfy=True, + retry_times=DEFAULT_RETRY_TIMES): schema = self.get_schema() table_name = self.normalize_table_name(table_name) df_columns = get_dataframe_columns_info(gdf) @@ -98,7 +99,8 @@ def copy_from(self, gdf, table_name, if_exists='fail', cartodbfy=True, retry_tim self._truncate_table(table_name, schema, cartodbfy) else: # Diff columns: truncate table and drop + add columns - self._truncate_and_drop_add_columns(table_name, schema, df_columns, table_columns, cartodbfy) + self._truncate_and_drop_add_columns( + table_name, schema, df_columns, table_columns, cartodbfy) elif if_exists == 'fail': raise Exception('Table "{schema}.{table_name}" already exists in your CARTO account. ' @@ -190,7 +192,9 @@ def get_schema(self): """Get user schema from current credentials""" query = 'SELECT current_schema()' result = self.execute_query(query, do_post=False) - return result['rows'][0]['current_schema'] + schema = result['rows'][0]['current_schema'] + log.debug('schema: {}'.format(schema)) + return schema def get_geom_type(self, query): """Fetch geom type of a remote table or query""" @@ -290,7 +294,8 @@ def _truncate_table(self, table_name, schema, cartodbfy): def _truncate_and_drop_add_columns(self, table_name, schema, df_columns, table_columns, cartodbfy): log.debug('TRUNCATE AND DROP + ADD columns table "{}"'.format(table_name)) - query = 'BEGIN; {truncate}; {drop_columns}; {add_columns}; {cartodbfy}; COMMIT;'.format( + query = '{regenerate}; BEGIN; {truncate}; {drop_columns}; {add_columns}; {cartodbfy}; COMMIT;'.format( + regenerate=_regenerate_table_query(table_name, schema) if self._check_regenerate_table_exists() else '', truncate=_truncate_table_query(table_name), drop_columns=_drop_columns_query(table_name, table_columns), add_columns=_add_columns_query(table_name, df_columns), @@ -317,6 +322,16 @@ def _check_exists(self, query): except CartoException: return False + def _check_regenerate_table_exists(self): + query = ''' + SELECT 1 + FROM pg_catalog.pg_proc p + LEFT JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace + WHERE p.proname = 'cdb_regeneratetable' AND n.nspname = 'cartodb'; + ''' + result = self.execute_query(query) + return len(result['rows']) > 0 + def _get_query_columns_info(self, query): query = 'SELECT * FROM ({}) _q LIMIT 0'.format(query) table_info = self.execute_query(query) @@ -343,7 +358,7 @@ def _get_copy_query(self, query, columns, limit): @retry_copy def _copy_to(self, query, columns, retry_times=DEFAULT_RETRY_TIMES): log.debug('COPY TO') - copy_query = 'COPY ({0}) TO stdout WITH (FORMAT csv, HEADER true, NULL \'{1}\')'.format(query, PG_NULL) + copy_query = "COPY ({0}) TO stdout WITH (FORMAT csv, HEADER true, NULL '{1}')".format(query, PG_NULL) raw_result = self.copy_client.copyto_stream(copy_query) @@ -424,7 +439,12 @@ def _create_table_from_query_query(table_name, query): def _cartodbfy_query(table_name, schema): - return 'SELECT CDB_CartodbfyTable(\'{schema}\', \'{table_name}\')'.format( + return "SELECT CDB_CartodbfyTable('{schema}', '{table_name}')".format( + schema=schema, table_name=table_name) + + +def _regenerate_table_query(table_name, schema): + return "SELECT CDB_RegenerateTable('{schema}.{table_name}'::regclass)".format( schema=schema, table_name=table_name) diff --git a/tests/unit/data/observatory/catalog/test_dataset.py b/tests/unit/data/observatory/catalog/test_dataset.py index 6fb980269..9d3b519e6 100644 --- a/tests/unit/data/observatory/catalog/test_dataset.py +++ b/tests/unit/data/observatory/catalog/test_dataset.py @@ -294,28 +294,28 @@ def test_datasets_are_exported_as_dataframe(self): assert isinstance(sliced_dataset, pd.Series) assert sliced_dataset.equals(expected_dataset_df) - @patch.object(DatasetRepository, 'get_all') + @patch('cartoframes.data.observatory.catalog.subscriptions.get_subscription_ids') @patch.object(DatasetRepository, 'get_by_id') @patch.object(DODataset, 'download_stream') - def test_dataset_download(self, download_stream_mock, get_by_id_mock, get_all_mock): + def test_dataset_download(self, mock_download_stream, mock_get_by_id, mock_subscription_ids): # Given - get_by_id_mock.return_value = test_dataset1 + mock_get_by_id.return_value = test_dataset1 dataset = Dataset.get(test_dataset1.id) - get_all_mock.return_value = [dataset] - download_stream_mock.return_value = [] + mock_download_stream.return_value = [] + mock_subscription_ids.return_value = [test_dataset1.id] credentials = Credentials('fake_user', '1234') # Then dataset.to_csv('fake_path', credentials) os.remove('fake_path') - @patch.object(DatasetRepository, 'get_all') + @patch('cartoframes.data.observatory.catalog.subscriptions.get_subscription_ids') @patch.object(DatasetRepository, 'get_by_id') - def test_dataset_not_subscribed_download_fails(self, get_by_id_mock, get_all_mock): - # mock dataset - get_by_id_mock.return_value = test_dataset2 # is private + def test_dataset_not_subscribed_download_not_subscribed(self, mock_get_by_id, mock_subscription_ids): + # Given + mock_get_by_id.return_value = test_dataset2 # is private dataset = Dataset.get(test_dataset2.id) - get_all_mock.return_value = [] + mock_subscription_ids.return_value = [] credentials = Credentials('fake_user', '1234') # When @@ -327,32 +327,28 @@ def test_dataset_not_subscribed_download_fails(self, get_by_id_mock, get_all_moc 'You are not subscribed to this Dataset yet. ' 'Please, use the subscribe method first.') - @patch.object(DatasetRepository, 'get_all') @patch.object(DatasetRepository, 'get_by_id') @patch.object(DODataset, 'download_stream') - def test_dataset_download_not_subscribed_but_public(self, download_stream_mock, get_by_id_mock, get_all_mock): + def test_dataset_download_not_subscribed_but_public(self, mock_download_stream, mock_get_by_id): # Given - get_by_id_mock.return_value = test_dataset1 # is public + mock_get_by_id.return_value = test_dataset1 # is public dataset = Dataset.get(test_dataset1.id) - get_all_mock.return_value = [] - download_stream_mock.return_value = [] + mock_download_stream.return_value = [] credentials = Credentials('fake_user', '1234') dataset.to_csv('fake_path', credentials) os.remove('fake_path') - @patch.object(DatasetRepository, 'get_all') @patch.object(DatasetRepository, 'get_by_id') @patch.object(DODataset, 'download_stream') - def test_dataset_download_without_do_enabled(self, download_stream_mock, get_by_id_mock, get_all_mock): + def test_dataset_download_without_do_enabled(self, mock_download_stream, mock_get_by_id): # Given - get_by_id_mock.return_value = test_dataset1 + mock_get_by_id.return_value = test_dataset1 dataset = Dataset.get(test_dataset1.id) - get_all_mock.return_value = [] def raise_exception(limit=None, order_by=None, sql_query=None, add_geom=None, is_geography=None): raise ServerErrorException(['The user does not have Data Observatory enabled']) - download_stream_mock.side_effect = raise_exception + mock_download_stream.side_effect = raise_exception credentials = Credentials('fake_user', '1234') # When @@ -406,10 +402,10 @@ def test_dataset_subscribe_existing(self, mock_display_message, mock_display_for @patch('cartoframes.data.observatory.catalog.subscriptions.get_subscription_ids') @patch('cartoframes.data.observatory.catalog.utils.display_subscription_form') @patch('cartoframes.auth.defaults.get_default_credentials') - def test_dataset_subscribe_default_credentials(self, mocked_credentials, mock_display_form, mock_subscription_ids): + def test_dataset_subscribe_default_credentials(self, mock_credentials, mock_display_form, mock_subscription_ids): # Given expected_credentials = Credentials('fake_user', '1234') - mocked_credentials.return_value = expected_credentials + mock_credentials.return_value = expected_credentials dataset = Dataset(db_dataset1) # When @@ -480,10 +476,10 @@ def test_dataset_subscription_info(self, mock_fetch): @patch('cartoframes.data.observatory.catalog.subscription_info.fetch_subscription_info') @patch('cartoframes.auth.defaults.get_default_credentials') - def test_dataset_subscription_info_default_credentials(self, mocked_credentials, mock_fetch): + def test_dataset_subscription_info_default_credentials(self, mock_credentials, mock_fetch): # Given expected_credentials = Credentials('fake_user', '1234') - mocked_credentials.return_value = expected_credentials + mock_credentials.return_value = expected_credentials dataset = Dataset(db_dataset1) # When diff --git a/tests/unit/data/observatory/catalog/test_geography.py b/tests/unit/data/observatory/catalog/test_geography.py index 9a121ee0b..18447283c 100644 --- a/tests/unit/data/observatory/catalog/test_geography.py +++ b/tests/unit/data/observatory/catalog/test_geography.py @@ -220,31 +220,28 @@ def test_geographies_are_exported_as_dataframe(self): assert isinstance(sliced_geography, pd.Series) assert sliced_geography.equals(expected_geography_df) - @patch.object(GeographyRepository, 'get_all') + @patch('cartoframes.data.observatory.catalog.subscriptions.get_subscription_ids') @patch.object(GeographyRepository, 'get_by_id') @patch.object(DODataset, 'download_stream') - def test_geography_download(self, download_stream_mock, get_by_id_mock, get_all_mock): + def test_geography_download(self, mock_download_stream, mock_get_by_id, mock_subscription_ids): # Given - get_by_id_mock.return_value = test_geography1 + mock_get_by_id.return_value = test_geography1 geography = Geography.get(test_geography1.id) - get_all_mock.return_value = [geography] - download_stream_mock.return_value = [] + mock_download_stream.return_value = [] + mock_subscription_ids.return_value = [test_geography1.id] credentials = Credentials('fake_user', '1234') # Then geography.to_csv('fake_path', credentials) os.remove('fake_path') - @patch.object(GeographyRepository, 'get_all') + @patch('cartoframes.data.observatory.catalog.subscriptions.get_subscription_ids') @patch.object(GeographyRepository, 'get_by_id') - @patch.object(DODataset, 'download_stream') - def test_geography_download_not_subscribed(self, download_stream_mock, get_by_id_mock, get_all_mock): + def test_geography_download_not_subscribed(self, mock_get_by_id, mock_subscription_ids): # Given - get_by_id_mock.return_value = test_geography2 # is private - get_by_id_mock.return_value = test_geography2 + mock_get_by_id.return_value = test_geography2 # is private geography = Geography.get(test_geography2.id) - get_all_mock.return_value = [] - download_stream_mock.return_value = [] + mock_subscription_ids.return_value = [] credentials = Credentials('fake_user', '1234') with pytest.raises(Exception) as e: @@ -255,32 +252,28 @@ def test_geography_download_not_subscribed(self, download_stream_mock, get_by_id 'You are not subscribed to this Geography yet. ' 'Please, use the subscribe method first.') - @patch.object(GeographyRepository, 'get_all') @patch.object(GeographyRepository, 'get_by_id') @patch.object(DODataset, 'download_stream') - def test_geography_download_not_subscribed_but_public(self, download_stream_mock, get_by_id_mock, get_all_mock): + def test_geography_download_not_subscribed_but_public(self, mock_download_stream, mock_get_by_id): # Given - get_by_id_mock.return_value = test_geography1 # is public + mock_get_by_id.return_value = test_geography1 # is public geography = Geography.get(test_geography1.id) - get_all_mock.return_value = [] - download_stream_mock.return_value = [] + mock_download_stream.return_value = [] credentials = Credentials('fake_user', '1234') geography.to_csv('fake_path', credentials) os.remove('fake_path') - @patch.object(GeographyRepository, 'get_all') @patch.object(GeographyRepository, 'get_by_id') @patch.object(DODataset, 'download_stream') - def test_geography_download_without_do_enabled(self, download_stream_mock, get_by_id_mock, get_all_mock): + def test_geography_download_without_do_enabled(self, mock_download_stream, mock_get_by_id): # Given - get_by_id_mock.return_value = test_geography1 + mock_get_by_id.return_value = test_geography1 geography = Geography.get(test_geography1.id) - get_all_mock.return_value = [] def raise_exception(limit=None, order_by=None, sql_query=None, add_geom=None, is_geography=None): raise ServerErrorException(['The user does not have Data Observatory enabled']) - download_stream_mock.side_effect = raise_exception + mock_download_stream.side_effect = raise_exception credentials = Credentials('fake_user', '1234') # When diff --git a/tests/unit/io/test_carto.py b/tests/unit/io/test_carto.py index 515eef149..4acb40458 100644 --- a/tests/unit/io/test_carto.py +++ b/tests/unit/io/test_carto.py @@ -325,7 +325,7 @@ def test_to_carto_two_geom_columns(mocker): 'the_geom': '010100000000000000000000000000000000000000'}) # When - norm_table_name = to_carto(df, table_name, CREDENTIALS) + norm_table_name = to_carto(df, table_name, CREDENTIALS, skip_quota_warning=True) # Then assert cm_mock.call_args[0][1] == table_name @@ -343,7 +343,7 @@ def test_to_carto_two_geom_columns_and_geom_col(mocker): 'the_geom': '010100000000000000000000000000000000000000'}) # When - norm_table_name = to_carto(df, table_name, CREDENTIALS, geom_col='geometry') + norm_table_name = to_carto(df, table_name, CREDENTIALS, geom_col='geometry', skip_quota_warning=True) # Then assert cm_mock.call_args[0][1] == table_name