Skip to content

Commit

Permalink
Merge pull request #1628 from CartoDB/jarroyo/ch72815/implement-if-ex…
Browse files Browse the repository at this point in the history
…ists-truncate

Improve replace logic with truncate (+refactor)
  • Loading branch information
Jesus89 authored May 10, 2020
2 parents 0b8ab8e + afd32dc commit ae806d9
Show file tree
Hide file tree
Showing 12 changed files with 235 additions and 193 deletions.
2 changes: 1 addition & 1 deletion cartoframes/data/clients/auth_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def create_api_key(self, sources, apis=['sql', 'maps'], permissions=['select'],
else:
raise e

return api_key.token, tables_names
return api_key.name, api_key.token, tables_names


def _get_table_dict(schema, name, permissions):
Expand Down
1 change: 0 additions & 1 deletion cartoframes/io/carto.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@


GEOM_COLUMN_NAME = 'the_geom'

IF_EXISTS_OPTIONS = ['fail', 'replace', 'append']


Expand Down
83 changes: 59 additions & 24 deletions cartoframes/io/managers/context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ...utils.logger import log
from ...utils.geom_utils import encode_geometry_ewkb
from ...utils.utils import is_sql_query, check_credentials, encode_row, map_geom_type, PG_NULL
from ...utils.columns import Column, get_dataframe_columns_info, obtain_converters, \
from ...utils.columns import get_dataframe_columns_info, get_query_columns_info, obtain_converters, \
date_columns_names, normalize_name

DEFAULT_RETRY_TIMES = 3
Expand Down Expand Up @@ -47,16 +47,21 @@ def copy_from(self, gdf, table_name, if_exists='fail', cartodbfy=True):
table_name = self.normalize_table_name(table_name)
columns = get_dataframe_columns_info(gdf)

if if_exists == 'replace' or not self.has_table(table_name, schema):
log.debug('Creating table "{}"'.format(table_name))
self._create_table_from_columns(table_name, columns, schema, cartodbfy)
elif if_exists == 'fail':
raise Exception('Table "{schema}.{table_name}" already exists in your CARTO account. '
'Please choose a different `table_name` or use '
'if_exists="replace" to overwrite it.'.format(
table_name=table_name, schema=schema))
else: # 'append'
pass
if self.has_table(table_name, schema):
if if_exists == 'replace':
if self._compare_columns(table_name, schema, columns):
self._truncate_table_from_columns(table_name, schema, columns, cartodbfy)
else:
self._drop_create_table_from_columns(table_name, schema, columns, cartodbfy)
elif if_exists == 'fail':
raise Exception('Table "{schema}.{table_name}" already exists in your CARTO account. '
'Please choose a different `table_name` or use '
'if_exists="replace" to overwrite it.'.format(
table_name=table_name, schema=schema))
else: # 'append'
pass
else:
self._drop_create_table_from_columns(table_name, schema, columns, cartodbfy)

self._copy_from(gdf, table_name, columns)
return table_name
Expand All @@ -65,16 +70,18 @@ def create_table_from_query(self, query, table_name, if_exists, cartodbfy=True):
schema = self.get_schema()
table_name = self.normalize_table_name(table_name)

if if_exists == 'replace' or not self.has_table(table_name, schema):
log.debug('Creating table "{}"'.format(table_name))
self._create_table_from_query(query, table_name, schema, cartodbfy)
elif if_exists == 'fail':
raise Exception('Table "{schema}.{table_name}" already exists in your CARTO account. '
'Please choose a different `table_name` or use '
'if_exists="replace" to overwrite it.'.format(
table_name=table_name, schema=schema))
else: # 'append'
pass
if self.has_table(table_name, schema):
if if_exists == 'replace':
self._drop_create_table_from_query(table_name, schema, query, cartodbfy)
elif if_exists == 'fail':
raise Exception('Table "{schema}.{table_name}" already exists in your CARTO account. '
'Please choose a different `table_name` or use '
'if_exists="replace" to overwrite it.'.format(
table_name=table_name, schema=schema))
else: # 'append'
pass
else:
self._drop_create_table_from_query(table_name, schema, query, cartodbfy)

return table_name

Expand Down Expand Up @@ -187,22 +194,43 @@ def get_table_names(self, query):
tables = [table.split('.')[1] if '.' in table else table for table in result['rows'][0]['tables']]
return tables

def _create_table_from_query(self, query, table_name, schema, cartodbfy=True):
def _compare_columns(self, table_name, schema, columns):
GEOM_COL = 'the_geom_webmercator'

query = self._compute_query_from_table(table_name, schema)
existing_columns = [c for c in self._get_query_columns_info(query) if (c.name != GEOM_COL)]

existing_columns.sort()
columns.sort()

return existing_columns == columns

def _drop_create_table_from_query(self, table_name, schema, query, cartodbfy):
log.debug('DROP + CREATE table "{}"'.format(table_name))
query = 'BEGIN; {drop}; {create}; {cartodbfy}; COMMIT;'.format(
drop=_drop_table_query(table_name),
create=_create_table_from_query_query(table_name, query),
cartodbfy=_cartodbfy_query(table_name, schema) if cartodbfy else ''
)
self.execute_long_running_query(query)

def _create_table_from_columns(self, table_name, columns, schema, cartodbfy=True):
def _drop_create_table_from_columns(self, table_name, schema, columns, cartodbfy):
log.debug('DROP + CREATE table "{}"'.format(table_name))
query = 'BEGIN; {drop}; {create}; {cartodbfy}; COMMIT;'.format(
drop=_drop_table_query(table_name),
create=_create_table_from_columns_query(table_name, columns),
cartodbfy=_cartodbfy_query(table_name, schema) if cartodbfy else ''
)
self.execute_long_running_query(query)

def _truncate_table_from_columns(self, table_name, schema, columns, cartodbfy):
log.debug('TRUNCATE table "{}"'.format(table_name))
query = 'BEGIN; {truncate}; {cartodbfy}; COMMIT;'.format(
truncate=_truncate_table_query(table_name),
cartodbfy=_cartodbfy_query(table_name, schema) if cartodbfy else ''
)
self.execute_long_running_query(query)

def compute_query(self, source, schema=None):
if is_sql_query(source):
return source
Expand All @@ -226,7 +254,7 @@ def _check_exists(self, query):
def _get_query_columns_info(self, query):
query = 'SELECT * FROM ({}) _q LIMIT 0'.format(query)
table_info = self.execute_query(query)
return Column.from_sql_api_fields(table_info['fields'])
return get_query_columns_info(table_info['fields'])

def _get_copy_query(self, query, columns, limit):
query_columns = [
Expand All @@ -245,6 +273,7 @@ def _get_copy_query(self, query, columns, limit):
return query

def _copy_to(self, query, columns, retry_times):
log.debug('COPY TO')
copy_query = 'COPY ({0}) TO stdout WITH (FORMAT csv, HEADER true, NULL \'{1}\')'.format(query, PG_NULL)

try:
Expand Down Expand Up @@ -272,6 +301,7 @@ def _copy_to(self, query, columns, retry_times):
return df

def _copy_from(self, dataframe, table_name, columns):
log.debug('COPY FROM')
query = """
COPY {table_name}({columns}) FROM stdin WITH (FORMAT csv, DELIMITER '|', NULL '{null}');
""".format(
Expand All @@ -297,6 +327,11 @@ def _drop_table_query(table_name, if_exists=True):
if_exists='IF EXISTS' if if_exists else '')


def _truncate_table_query(table_name):
return '''TRUNCATE TABLE {table_name}'''.format(
table_name=table_name)


def _create_table_from_columns_query(table_name, columns):
columns = ['{name} {type}'.format(name=column.dbname, type=column.dbtype) for column in columns]

Expand Down
Loading

0 comments on commit ae806d9

Please sign in to comment.