From ca4974ffedb27b92dfb2ed6699e5efb09bd76243 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jes=C3=BAs=20Arroyo=20Torrens?= Date: Thu, 30 Apr 2020 15:25:20 +0200 Subject: [PATCH 1/6] Refactor columns --- cartoframes/utils/columns.py | 202 +++++++++++++++-------------------- 1 file changed, 89 insertions(+), 113 deletions(-) diff --git a/cartoframes/utils/columns.py b/cartoframes/utils/columns.py index f20429e7c..ab53b9616 100644 --- a/cartoframes/utils/columns.py +++ b/cartoframes/utils/columns.py @@ -5,131 +5,67 @@ from unidecode import unidecode from collections import namedtuple -from .utils import dtypes2pg, pg2dtypes, PG_NULL -from .geom_utils import decode_geometry_item, detect_encoding_type - - -class Column(object): - BOOL_DTYPE = 'bool' - OBJECT_DTYPE = 'object' - INT_DTYPES = ['int16', 'int32', 'int64'] - FLOAT_DTYPES = ['float32', 'float64'] - DATETIME_DTYPES = ['datetime64[D]', 'datetime64[ns]', 'datetime64[ns, UTC]'] - INDEX_COLUMN_NAME = 'cartodb_id' - FORBIDDEN_COLUMN_NAMES = ['the_geom_webmercator'] - MAX_LENGTH = 63 - MAX_COLLISION_LENGTH = MAX_LENGTH - 4 - RESERVED_WORDS = ('ALL', 'ANALYSE', 'ANALYZE', 'AND', 'ANY', 'ARRAY', 'AS', 'ASC', 'ASYMMETRIC', 'AUTHORIZATION', - 'BETWEEN', 'BINARY', 'BOTH', 'CASE', 'CAST', 'CHECK', 'COLLATE', 'COLUMN', 'CONSTRAINT', - 'CREATE', 'CROSS', 'CURRENT_DATE', 'CURRENT_ROLE', 'CURRENT_TIME', 'CURRENT_TIMESTAMP', - 'CURRENT_USER', 'DEFAULT', 'DEFERRABLE', 'DESC', 'DISTINCT', 'DO', 'ELSE', 'END', 'EXCEPT', - 'FALSE', 'FOR', 'FOREIGN', 'FREEZE', 'FROM', 'FULL', 'GRANT', 'GROUP', 'HAVING', 'ILIKE', 'IN', - 'INITIALLY', 'INNER', 'INTERSECT', 'INTO', 'IS', 'ISNULL', 'JOIN', 'LEADING', 'LEFT', 'LIKE', - 'LIMIT', 'LOCALTIME', 'LOCALTIMESTAMP', 'NATURAL', 'NEW', 'NOT', 'NOTNULL', 'NULL', 'OFF', - 'OFFSET', 'OLD', 'ON', 'ONLY', 'OR', 'ORDER', 'OUTER', 'OVERLAPS', 'PLACING', 'PRIMARY', - 'REFERENCES', 'RIGHT', 'SELECT', 'SESSION_USER', 'SIMILAR', 'SOME', 'SYMMETRIC', 'TABLE', 'THEN', - 'TO', 'TRAILING', 'TRUE', 'UNION', 'UNIQUE', 'USER', 'USING', 'VERBOSE', 'WHEN', 'WHERE', - 'XMIN', 'XMAX', 'FORMAT', 'CONTROLLER', 'ACTION', ) - NORMALIZED_GEOM_COL_NAME = 'the_geom' - - @staticmethod - def from_sql_api_fields(fields): - return [Column(column, normalize=False, pgtype=_extract_pgtype(fields[column])) for column in fields] - - def __init__(self, name, normalize=True, pgtype=None): - if not name: - raise ValueError('Column name cannot be null or empty') - - self.name = str(name) - self.pgtype = pgtype - self.dtype = pg2dtypes(pgtype) - if normalize: - self.normalize() - - def normalize(self, forbidden_column_names=None): - self._sanitize() - self.name = self._truncate() - - if forbidden_column_names: - i = 1 - while self.name in forbidden_column_names: - self.name = '{}_{}'.format(self._truncate(length=Column.MAX_COLLISION_LENGTH), str(i)) - i += 1 - - return self - - def _sanitize(self): - self.name = self._slugify(self.name) - - if self._is_reserved() or self._is_unsupported(): - self.name = '_{}'.format(self.name) - - def _is_reserved(self): - return self.name.upper() in Column.RESERVED_WORDS - - def _is_unsupported(self): - return not re.match(r'^[a-z_]+[a-z_0-9]*$', self.name) - - def _truncate(self, length=MAX_LENGTH): - return self.name[:length] - - def _slugify(self, value): - value = unidecode(str(value).lower()) - - value = re.sub(r'<[^>]+>', '', value) - value = re.sub(r'&.+?;', '-', value) - value = re.sub(r'[^a-z0-9 _-]', '-', value).strip().lower() - value = re.sub(r'\s+', '-', value) - value = re.sub(r' ', '-', value) - value = re.sub(r'-+', '-', value) - value = re.sub(r'-', '_', value) - - return value +from .utils import dtypes2pg, PG_NULL + +BOOL_DTYPE = 'bool' +OBJECT_DTYPE = 'object' +INT_DTYPES = ['int16', 'int32', 'int64'] +FLOAT_DTYPES = ['float32', 'float64'] +DATETIME_DTYPES = ['datetime64[D]', 'datetime64[ns]', 'datetime64[ns, UTC]'] +FORBIDDEN_COLUMN_NAMES = ['the_geom_webmercator'] +MAX_LENGTH = 63 +MAX_COLLISION_LENGTH = MAX_LENGTH - 4 +RESERVED_WORDS = ('ALL', 'ANALYSE', 'ANALYZE', 'AND', 'ANY', 'ARRAY', 'AS', 'ASC', 'ASYMMETRIC', 'AUTHORIZATION', + 'BETWEEN', 'BINARY', 'BOTH', 'CASE', 'CAST', 'CHECK', 'COLLATE', 'COLUMN', 'CONSTRAINT', + 'CREATE', 'CROSS', 'CURRENT_DATE', 'CURRENT_ROLE', 'CURRENT_TIME', 'CURRENT_TIMESTAMP', + 'CURRENT_USER', 'DEFAULT', 'DEFERRABLE', 'DESC', 'DISTINCT', 'DO', 'ELSE', 'END', 'EXCEPT', + 'FALSE', 'FOR', 'FOREIGN', 'FREEZE', 'FROM', 'FULL', 'GRANT', 'GROUP', 'HAVING', 'ILIKE', 'IN', + 'INITIALLY', 'INNER', 'INTERSECT', 'INTO', 'IS', 'ISNULL', 'JOIN', 'LEADING', 'LEFT', 'LIKE', + 'LIMIT', 'LOCALTIME', 'LOCALTIMESTAMP', 'NATURAL', 'NEW', 'NOT', 'NOTNULL', 'NULL', 'OFF', + 'OFFSET', 'OLD', 'ON', 'ONLY', 'OR', 'ORDER', 'OUTER', 'OVERLAPS', 'PLACING', 'PRIMARY', + 'REFERENCES', 'RIGHT', 'SELECT', 'SESSION_USER', 'SIMILAR', 'SOME', 'SYMMETRIC', 'TABLE', 'THEN', + 'TO', 'TRAILING', 'TRUE', 'UNION', 'UNIQUE', 'USER', 'USING', 'VERBOSE', 'WHEN', 'WHERE', + 'XMIN', 'XMAX', 'FORMAT', 'CONTROLLER', 'ACTION') ColumnInfo = namedtuple('ColumnInfo', ['name', 'dbname', 'dbtype', 'is_geom']) -def _extract_pgtype(fields): - if 'pgtype' in fields: - return fields['pgtype'] - return None - - def get_dataframe_columns_info(df): columns = [] - geom_type = _get_geometry_type(df) df_columns = [(name, df.dtypes[name]) for name in df.columns] for name, dtype in df_columns: if _is_valid_column(name): - columns.append(_compute_column_info(name, dtype, geom_type)) + columns.append(_compute_column_info(name, dtype)) return columns -def _get_geom_col_name(df): - return getattr(df, '_geometry_column_name', None) +def get_query_columns_info(fields): + columns = [] + field_columns = [(name, _dtype_from_field(fields[name])) for name in fields] + + for name, dtype in field_columns: + columns.append(_compute_column_info(name, dtype)) + + return columns -def _get_geometry_type(df): - geom_column = _get_geom_col_name(df) - if geom_column in df: - first_geom = _first_value(df[geom_column]) - if first_geom: - enc_type = detect_encoding_type(first_geom) - return decode_geometry_item(first_geom, enc_type).geom_type +def _dtype_from_field(field): + if field: + return field.get('pgtype', field.get('type')) def _is_valid_column(name): - return name.lower() not in Column.FORBIDDEN_COLUMN_NAMES + return name.lower() not in FORBIDDEN_COLUMN_NAMES -def _compute_column_info(name, dtype=None, geom_type=None): +def _compute_column_info(name, dtype=None): name = name dbname = normalize_name(name) if str(dtype) == 'geometry': - dbtype = 'geometry({}, 4326)'.format(geom_type or 'Point') + dbtype = 'geometry(Geometry, 4326)' is_geom = True else: dbtype = dtypes2pg(dtype) @@ -137,6 +73,13 @@ def _compute_column_info(name, dtype=None, geom_type=None): return ColumnInfo(name, dbname, dbtype, is_geom) +def normalize_name(column_name): + if column_name is None: + return None + + return normalize_names([column_name])[0] + + def normalize_names(column_names): """Given an arbitrary column name, translate to a SQL-normalized column name a la CARTO's Import API will translate to @@ -156,9 +99,9 @@ def normalize_names(column_names): * 'SELECT' -> '_select', * 'à' -> 'a', * 'longcolumnshouldbesplittedsomehowanditellyouwhereitsgonnabesplittedrightnow' -> \ - 'longcolumnshouldbesplittedsomehowanditellyouwhereitsgonnabespli', + 'longcolumnshouldbesplittedsomehowanditellyouwhereitsgonnabespli', * 'longcolumnshouldbesplittedsomehowanditellyouwhereitsgonnabesplittedrightnow' -> \ - 'longcolumnshouldbesplittedsomehowanditellyouwhereitsgonnabe_1', + 'longcolumnshouldbesplittedsomehowanditellyouwhereitsgonnabe_1', * 'all' -> '_all' Args: @@ -167,18 +110,51 @@ def normalize_names(column_names): list: List of SQL-normalized column names """ result = [] + for column_name in column_names: - column = Column(column_name).normalize(forbidden_column_names=result) - result.append(column.name) + result.append(_normalize(column_name, forbidden_column_names=result)) return result -def normalize_name(column_name): - if column_name is None: - return None +def _normalize(column_name, forbidden_column_names=None): + column_name = _truncate(_sanitize(_slugify(column_name))) - return normalize_names([column_name])[0] + if forbidden_column_names: + i = 1 + while column_name in forbidden_column_names: + column_name = '{}_{}'.format(_truncate(column_name, length=MAX_COLLISION_LENGTH), i) + i += 1 + + return column_name + + +def _slugify(value): + value = unidecode(str(value).lower()) + value = re.sub(r'<[^>]+>', '', value) + value = re.sub(r'&.+?;', '-', value) + value = re.sub(r'[^a-z0-9 _-]', '-', value).strip().lower() + value = re.sub(r'\s+', '-', value) + value = re.sub(r' ', '-', value) + value = re.sub(r'-+', '-', value) + value = re.sub(r'-', '_', value) + return value + + +def _sanitize(value): + return '_{}'.format(value) if _is_reserved(value) or _is_unsupported(value) else value + + +def _truncate(value, length=MAX_LENGTH): + return value[:length] + + +def _is_reserved(value): + return value.upper() in RESERVED_WORDS + + +def _is_unsupported(value): + return not re.match(r'^[a-z_]+[a-z_0-9]*$', value) def obtain_converters(columns): @@ -200,23 +176,23 @@ def obtain_converters(columns): def date_columns_names(columns): - return [x.name for x in columns if x.dtype in Column.DATETIME_DTYPES] + return [x.name for x in columns if x.dtype in DATETIME_DTYPES] def int_columns_names(columns): - return [x.name for x in columns if x.dtype in Column.INT_DTYPES] + return [x.name for x in columns if x.dtype in INT_DTYPES] def float_columns_names(columns): - return [x.name for x in columns if x.dtype in Column.FLOAT_DTYPES] + return [x.name for x in columns if x.dtype in FLOAT_DTYPES] def bool_columns_names(columns): - return [x.name for x in columns if x.dtype == Column.BOOL_DTYPE] + return [x.name for x in columns if x.dtype == BOOL_DTYPE] def object_columns_names(columns): - return [x.name for x in columns if x.dtype == Column.OBJECT_DTYPE] + return [x.name for x in columns if x.dtype == OBJECT_DTYPE] def _convert_int(x): From 11e972f34a6deec6ae21cfd4d8304fe34c3360dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jes=C3=BAs=20Arroyo=20Torrens?= Date: Tue, 5 May 2020 18:42:08 +0200 Subject: [PATCH 2/6] Improve replace with truncate --- cartoframes/io/carto.py | 5 +- cartoframes/io/managers/context_manager.py | 83 ++++++++++++----- cartoframes/utils/columns.py | 101 ++++++++++++--------- cartoframes/utils/utils.py | 1 + 4 files changed, 120 insertions(+), 70 deletions(-) diff --git a/cartoframes/io/carto.py b/cartoframes/io/carto.py index f07892546..567275776 100644 --- a/cartoframes/io/carto.py +++ b/cartoframes/io/carto.py @@ -14,8 +14,6 @@ GEOM_COLUMN_NAME = 'the_geom' -IF_EXISTS_OPTIONS = ['fail', 'replace', 'append'] - @send_metrics('data_downloaded') def read_carto(source, credentials=None, limit=None, retry_times=3, schema=None, index_col=None, decode_geom=True): @@ -94,6 +92,7 @@ def to_carto(dataframe, table_name, credentials=None, if_exists='fail', geom_col if not is_valid_str(table_name): raise ValueError('Wrong table name. You should provide a valid table name.') + IF_EXISTS_OPTIONS = ['fail', 'replace', 'append'] if if_exists not in IF_EXISTS_OPTIONS: raise ValueError('Wrong option for the `if_exists` param. You should provide: {}.'.format( ', '.join(IF_EXISTS_OPTIONS))) @@ -228,6 +227,7 @@ def copy_table(table_name, new_table_name, credentials=None, if_exists='fail', l if not is_valid_str(new_table_name): raise ValueError('Wrong new table name. You should provide a valid table name.') + IF_EXISTS_OPTIONS = ['fail', 'replace', 'append'] if if_exists not in IF_EXISTS_OPTIONS: raise ValueError('Wrong option for the `if_exists` param. You should provide: {}.'.format( ', '.join(IF_EXISTS_OPTIONS))) @@ -261,6 +261,7 @@ def create_table_from_query(query, new_table_name, credentials=None, if_exists=' if not is_valid_str(new_table_name): raise ValueError('Wrong new table name. You should provide a valid table name.') + IF_EXISTS_OPTIONS = ['fail', 'replace', 'append'] if if_exists not in IF_EXISTS_OPTIONS: raise ValueError('Wrong option for the `if_exists` param. You should provide: {}.'.format( ', '.join(IF_EXISTS_OPTIONS))) diff --git a/cartoframes/io/managers/context_manager.py b/cartoframes/io/managers/context_manager.py index 67b28679e..5272661d1 100644 --- a/cartoframes/io/managers/context_manager.py +++ b/cartoframes/io/managers/context_manager.py @@ -13,7 +13,7 @@ from ...utils.logger import log from ...utils.geom_utils import encode_geometry_ewkb from ...utils.utils import is_sql_query, check_credentials, encode_row, map_geom_type, PG_NULL -from ...utils.columns import Column, get_dataframe_columns_info, obtain_converters, \ +from ...utils.columns import get_dataframe_columns_info, get_query_columns_info, obtain_converters, \ date_columns_names, normalize_name DEFAULT_RETRY_TIMES = 3 @@ -47,16 +47,21 @@ def copy_from(self, gdf, table_name, if_exists='fail', cartodbfy=True): table_name = self.normalize_table_name(table_name) columns = get_dataframe_columns_info(gdf) - if if_exists == 'replace' or not self.has_table(table_name, schema): - log.debug('Creating table "{}"'.format(table_name)) - self._create_table_from_columns(table_name, columns, schema, cartodbfy) - elif if_exists == 'fail': - raise Exception('Table "{schema}.{table_name}" already exists in your CARTO account. ' - 'Please choose a different `table_name` or use ' - 'if_exists="replace" to overwrite it.'.format( - table_name=table_name, schema=schema)) - else: # 'append' - pass + if self.has_table(table_name, schema): + if if_exists == 'replace': + if self._compare_columns(table_name, schema, columns): + self._truncate_table_from_columns(table_name, schema, columns, cartodbfy) + else: + self._drop_create_table_from_columns(table_name, schema, columns, cartodbfy) + elif if_exists == 'fail': + raise Exception('Table "{schema}.{table_name}" already exists in your CARTO account. ' + 'Please choose a different `table_name` or use ' + 'if_exists="replace" to overwrite it.'.format( + table_name=table_name, schema=schema)) + else: # 'append' + pass + else: + self._drop_create_table_from_columns(table_name, schema, columns, cartodbfy) self._copy_from(gdf, table_name, columns) return table_name @@ -65,16 +70,18 @@ def create_table_from_query(self, query, table_name, if_exists, cartodbfy=True): schema = self.get_schema() table_name = self.normalize_table_name(table_name) - if if_exists == 'replace' or not self.has_table(table_name, schema): - log.debug('Creating table "{}"'.format(table_name)) - self._create_table_from_query(query, table_name, schema, cartodbfy) - elif if_exists == 'fail': - raise Exception('Table "{schema}.{table_name}" already exists in your CARTO account. ' - 'Please choose a different `table_name` or use ' - 'if_exists="replace" to overwrite it.'.format( - table_name=table_name, schema=schema)) - else: # 'append' - pass + if self.has_table(table_name, schema): + if if_exists == 'replace': + self._drop_create_table_from_query(table_name, schema, query, cartodbfy) + elif if_exists == 'fail': + raise Exception('Table "{schema}.{table_name}" already exists in your CARTO account. ' + 'Please choose a different `table_name` or use ' + 'if_exists="replace" to overwrite it.'.format( + table_name=table_name, schema=schema)) + else: # 'append' + pass + else: + self._drop_create_table_from_query(table_name, schema, query, cartodbfy) return table_name @@ -187,7 +194,19 @@ def get_table_names(self, query): tables = [table.split('.')[1] if '.' in table else table for table in result['rows'][0]['tables']] return tables - def _create_table_from_query(self, query, table_name, schema, cartodbfy=True): + def _compare_columns(self, table_name, schema, columns): + GEOM_COL = 'the_geom_webmercator' + + query = self._compute_query_from_table(table_name, schema) + existing_columns = [c for c in self._get_query_columns_info(query) if (c.name != GEOM_COL)] + + existing_columns.sort() + columns.sort() + + return existing_columns == columns + + def _drop_create_table_from_query(self, table_name, schema, query, cartodbfy): + log.debug('DROP + CREATE table "{}"'.format(table_name)) query = 'BEGIN; {drop}; {create}; {cartodbfy}; COMMIT;'.format( drop=_drop_table_query(table_name), create=_create_table_from_query_query(table_name, query), @@ -195,7 +214,8 @@ def _create_table_from_query(self, query, table_name, schema, cartodbfy=True): ) self.execute_long_running_query(query) - def _create_table_from_columns(self, table_name, columns, schema, cartodbfy=True): + def _drop_create_table_from_columns(self, table_name, schema, columns, cartodbfy): + log.debug('DROP + CREATE table "{}"'.format(table_name)) query = 'BEGIN; {drop}; {create}; {cartodbfy}; COMMIT;'.format( drop=_drop_table_query(table_name), create=_create_table_from_columns_query(table_name, columns), @@ -203,6 +223,14 @@ def _create_table_from_columns(self, table_name, columns, schema, cartodbfy=True ) self.execute_long_running_query(query) + def _truncate_table_from_columns(self, table_name, schema, columns, cartodbfy): + log.debug('TRUNCATE table "{}"'.format(table_name)) + query = 'BEGIN; {truncate}; {cartodbfy}; COMMIT;'.format( + truncate=_truncate_table_query(table_name), + cartodbfy=_cartodbfy_query(table_name, schema) if cartodbfy else '' + ) + self.execute_long_running_query(query) + def compute_query(self, source, schema=None): if is_sql_query(source): return source @@ -226,7 +254,7 @@ def _check_exists(self, query): def _get_query_columns_info(self, query): query = 'SELECT * FROM ({}) _q LIMIT 0'.format(query) table_info = self.execute_query(query) - return Column.from_sql_api_fields(table_info['fields']) + return get_query_columns_info(table_info['fields']) def _get_copy_query(self, query, columns, limit): query_columns = [ @@ -245,6 +273,7 @@ def _get_copy_query(self, query, columns, limit): return query def _copy_to(self, query, columns, retry_times): + log.debug('COPY TO') copy_query = 'COPY ({0}) TO stdout WITH (FORMAT csv, HEADER true, NULL \'{1}\')'.format(query, PG_NULL) try: @@ -272,6 +301,7 @@ def _copy_to(self, query, columns, retry_times): return df def _copy_from(self, dataframe, table_name, columns): + log.debug('COPY FROM') query = """ COPY {table_name}({columns}) FROM stdin WITH (FORMAT csv, DELIMITER '|', NULL '{null}'); """.format( @@ -297,6 +327,11 @@ def _drop_table_query(table_name, if_exists=True): if_exists='IF EXISTS' if if_exists else '') +def _truncate_table_query(table_name): + return '''TRUNCATE TABLE {table_name}'''.format( + table_name=table_name) + + def _create_table_from_columns_query(table_name, columns): columns = ['{name} {type}'.format(name=column.dbname, type=column.dbtype) for column in columns] diff --git a/cartoframes/utils/columns.py b/cartoframes/utils/columns.py index ab53b9616..6d6df5189 100644 --- a/cartoframes/utils/columns.py +++ b/cartoframes/utils/columns.py @@ -3,15 +3,14 @@ import re from unidecode import unidecode -from collections import namedtuple -from .utils import dtypes2pg, PG_NULL +from .utils import dtypes2pg, pg2dtypes, PG_NULL -BOOL_DTYPE = 'bool' -OBJECT_DTYPE = 'object' -INT_DTYPES = ['int16', 'int32', 'int64'] -FLOAT_DTYPES = ['float32', 'float64'] -DATETIME_DTYPES = ['datetime64[D]', 'datetime64[ns]', 'datetime64[ns, UTC]'] +BOOL_DBTYPES = ['bool', 'boolean'] +OBJECT_DBTYPES = ['text'] +INT_DBTYPES = ['int2', 'int4', 'int2', 'int', 'int8', 'smallint', 'integer', 'bigint'] +FLOAT_DBTYPES = ['float4', 'float8', 'real', 'double precision', 'numeric', 'decimal'] +DATETIME_DBTYPES = ['date', 'timestamp', 'timestampz'] FORBIDDEN_COLUMN_NAMES = ['the_geom_webmercator'] MAX_LENGTH = 63 MAX_COLLISION_LENGTH = MAX_LENGTH - 4 @@ -28,48 +27,74 @@ 'XMIN', 'XMAX', 'FORMAT', 'CONTROLLER', 'ACTION') -ColumnInfo = namedtuple('ColumnInfo', ['name', 'dbname', 'dbtype', 'is_geom']) +class ColumnInfo: + + def __init__(self, name, dbname, dbtype, is_geom): + self.name = name + self.dbname = dbname + self.dbtype = dbtype + self.is_geom = is_geom + + def __repr__(self): + params = ', '.join([self.name, self.dbname, self.dbtype, str(self.is_geom)]) + return 'ColumnInfo({})'.format(params) + + def __eq__(self, other): + if self.name == 'cartodb_id': + # Skip cartodb_id comparison because cartodbfy converts bigint to integer + return True + else: + return self.name == other.name and \ + self.dbname == other.dbname and \ + self.dbtype == other.dbtype and \ + self.is_geom == other.is_geom + + def __gt__(self, other): + return self.name > other.name + + def __ge__(self, other): + return self.name >= other.name + + def __lt__(self, other): + return self.name < other.name + + def __le__(self, other): + return self.name <= other.name def get_dataframe_columns_info(df): columns = [] - df_columns = [(name, df.dtypes[name]) for name in df.columns] - for name, dtype in df_columns: + for name in df.columns: if _is_valid_column(name): - columns.append(_compute_column_info(name, dtype)) + dbtype = dtypes2pg(str(df.dtypes[name])) + columns.append(_create_column_info(name, dbtype)) return columns def get_query_columns_info(fields): columns = [] - field_columns = [(name, _dtype_from_field(fields[name])) for name in fields] - for name, dtype in field_columns: - columns.append(_compute_column_info(name, dtype)) + for name in fields: + field = fields[name] + pgtype = field.get('pgtype') + dbtype = dtypes2pg(pg2dtypes(pgtype)) if pgtype else field.get('type') + columns.append(_create_column_info(name, dbtype)) return columns -def _dtype_from_field(field): - if field: - return field.get('pgtype', field.get('type')) - - def _is_valid_column(name): return name.lower() not in FORBIDDEN_COLUMN_NAMES -def _compute_column_info(name, dtype=None): - name = name +def _create_column_info(name, dbtype=None): + is_geom = False dbname = normalize_name(name) - if str(dtype) == 'geometry': + if dbtype == 'geometry': dbtype = 'geometry(Geometry, 4326)' is_geom = True - else: - dbtype = dtypes2pg(dtype) - is_geom = False return ColumnInfo(name, dbname, dbtype, is_geom) @@ -160,39 +185,27 @@ def _is_unsupported(value): def obtain_converters(columns): converters = {} - for int_column_name in int_columns_names(columns): + for int_column_name in type_columns_names(columns, INT_DBTYPES): converters[int_column_name] = _convert_int - for float_column_name in float_columns_names(columns): + for float_column_name in type_columns_names(columns, FLOAT_DBTYPES): converters[float_column_name] = _convert_float - for bool_column_name in bool_columns_names(columns): + for bool_column_name in type_columns_names(columns, BOOL_DBTYPES): converters[bool_column_name] = _convert_bool - for object_column_name in object_columns_names(columns): + for object_column_name in type_columns_names(columns, OBJECT_DBTYPES): converters[object_column_name] = _convert_object return converters def date_columns_names(columns): - return [x.name for x in columns if x.dtype in DATETIME_DTYPES] - - -def int_columns_names(columns): - return [x.name for x in columns if x.dtype in INT_DTYPES] - - -def float_columns_names(columns): - return [x.name for x in columns if x.dtype in FLOAT_DTYPES] - - -def bool_columns_names(columns): - return [x.name for x in columns if x.dtype == BOOL_DTYPE] + return type_columns_names(columns, DATETIME_DBTYPES) -def object_columns_names(columns): - return [x.name for x in columns if x.dtype == OBJECT_DTYPE] +def type_columns_names(columns, dbtypes): + return [x.name for x in columns if x.dbtype in dbtypes] def _convert_int(x): diff --git a/cartoframes/utils/utils.py b/cartoframes/utils/utils.py index b65188a7c..1506c500b 100644 --- a/cartoframes/utils/utils.py +++ b/cartoframes/utils/utils.py @@ -121,6 +121,7 @@ def dtypes2pg(dtype): 'bool': 'boolean', 'datetime64[ns]': 'timestamp', 'datetime64[ns, UTC]': 'timestamp', + 'geometry': 'geometry' } return mapping.get(str(dtype), 'text') From 114c4a4294fcf6957464e117634dcc8b71386bb4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jes=C3=BAs=20Arroyo=20Torrens?= Date: Tue, 5 May 2020 19:01:19 +0200 Subject: [PATCH 3/6] Update tests --- .../utils/managers/test_context_manager.py | 23 ++++++++++++++++--- tests/unit/utils/test_columns.py | 21 ++++------------- 2 files changed, 24 insertions(+), 20 deletions(-) diff --git a/tests/unit/utils/managers/test_context_manager.py b/tests/unit/utils/managers/test_context_manager.py index 159b23217..477f19ece 100644 --- a/tests/unit/utils/managers/test_context_manager.py +++ b/tests/unit/utils/managers/test_context_manager.py @@ -70,12 +70,12 @@ def test_copy_from_exists_fail(self, mocker): 'Please choose a different `table_name` or use ' 'if_exists="replace" to overwrite it.') - def test_copy_from_exists_replace(self, mocker): + def test_copy_from_exists_replace_drop_create(self, mocker): # Given mocker.patch('cartoframes.io.managers.context_manager._create_auth_client') mocker.patch.object(ContextManager, 'has_table', return_value=True) mocker.patch.object(ContextManager, 'get_schema', return_value='schema') - mock = mocker.patch.object(ContextManager, '_create_table_from_columns') + mock = mocker.patch.object(ContextManager, '_drop_create_table_from_columns') df = DataFrame({'A': [1]}) columns = [ColumnInfo('A', 'a', 'bigint', False)] @@ -84,7 +84,24 @@ def test_copy_from_exists_replace(self, mocker): cm.copy_from(df, 'TABLE NAME', 'replace') # Then - mock.assert_called_once_with('table_name', columns, 'schema', True) + mock.assert_called_once_with('table_name', 'schema', columns, True) + + def test_copy_from_exists_replace_truncate(self, mocker): + # Given + mocker.patch('cartoframes.io.managers.context_manager._create_auth_client') + mocker.patch.object(ContextManager, 'has_table', return_value=True) + mocker.patch.object(ContextManager, 'get_schema', return_value='schema') + mocker.patch.object(ContextManager, '_compare_columns', return_value=True) + mock = mocker.patch.object(ContextManager, '_truncate_table_from_columns') + df = DataFrame({'A': [1]}) + columns = [ColumnInfo('A', 'a', 'bigint', False)] + + # When + cm = ContextManager(self.credentials) + cm.copy_from(df, 'TABLE NAME', 'replace') + + # Then + mock.assert_called_once_with('table_name', 'schema', columns, True) def test_internal_copy_from(self, mocker): # Given diff --git a/tests/unit/utils/test_columns.py b/tests/unit/utils/test_columns.py index 183d3fbad..879fa719a 100644 --- a/tests/unit/utils/test_columns.py +++ b/tests/unit/utils/test_columns.py @@ -6,7 +6,7 @@ from geopandas import GeoDataFrame from cartoframes.utils.geom_utils import set_geometry -from cartoframes.utils.columns import Column, ColumnInfo, get_dataframe_columns_info, normalize_names +from cartoframes.utils.columns import ColumnInfo, get_dataframe_columns_info, normalize_names class TestColumns(object): @@ -54,19 +54,6 @@ def setup_method(self): 'longcolumnshouldbesplittedsomehowanditellyouwhereitsgonnabe_1', '_all'] - def test_normalize(self): - other_cols = [] - for c, a in zip(self.cols, self.cols_ans): - # changed cols should match answers - column = Column(c) - a_column = Column(a) - column.normalize(other_cols) - a_column.normalize(other_cols) - assert column.name == a - # already sql-normed cols should match themselves - assert a_column.name == a - other_cols.append(column.name) - def test_normalize_names(self): assert normalize_names(self.cols) == self.cols_ans @@ -84,7 +71,7 @@ def test_column_info_with_geom(self): assert dataframe_columns_info == [ ColumnInfo('Address', 'address', 'text', False), ColumnInfo('City', 'city', 'text', False), - ColumnInfo('the_geom', 'the_geom', 'geometry(Point, 4326)', True) + ColumnInfo('the_geom', 'the_geom', 'geometry(Geometry, 4326)', True) ] def test_column_info_without_geom(self): @@ -110,7 +97,7 @@ def test_column_info_basic_troubled_names(self): assert dataframe_columns_info == [ ColumnInfo('cartodb_id', 'cartodb_id', 'bigint', False), - ColumnInfo('the_geom', 'the_geom', 'geometry(Point, 4326)', True) + ColumnInfo('the_geom', 'the_geom', 'geometry(Geometry, 4326)', True) ] def test_column_info_geometry_troubled_names(self): @@ -123,6 +110,6 @@ def test_column_info_geometry_troubled_names(self): assert dataframe_columns_info == [ ColumnInfo('Geom', 'geom', 'text', False), - ColumnInfo('the_geom', 'the_geom', 'geometry(Point, 4326)', True), + ColumnInfo('the_geom', 'the_geom', 'geometry(Geometry, 4326)', True), ColumnInfo('g-e-o-m-e-t-r-y', 'g_e_o_m_e_t_r_y', 'text', False) ] From a024dcc9b01f1f1941b0871ec51a70dcf724d73c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jes=C3=BAs=20Arroyo=20Torrens?= Date: Thu, 7 May 2020 22:52:02 +0200 Subject: [PATCH 4/6] Update api key creation message --- cartoframes/data/clients/auth_api_client.py | 2 +- cartoframes/viz/kuviz.py | 10 ++++++---- tests/unit/auth/test_auth_api_client.py | 6 ++++-- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/cartoframes/data/clients/auth_api_client.py b/cartoframes/data/clients/auth_api_client.py index 21cca1965..0c563ebf4 100644 --- a/cartoframes/data/clients/auth_api_client.py +++ b/cartoframes/data/clients/auth_api_client.py @@ -41,7 +41,7 @@ def create_api_key(self, sources, apis=['sql', 'maps'], permissions=['select'], else: raise e - return api_key.token, tables_names + return api_key.name, api_key.token, tables_names def _get_table_dict(schema, name, permissions): diff --git a/cartoframes/viz/kuviz.py b/cartoframes/viz/kuviz.py index bd2c0bb9c..ec627693a 100644 --- a/cartoframes/viz/kuviz.py +++ b/cartoframes/viz/kuviz.py @@ -78,12 +78,14 @@ def _create_maps_api_keys(self, layers): private_sources = [layer.source for layer in layers if not layer.source.is_public()] if len(private_sources) > 0: - maps_api_key, private_tables_names = self._auth_api_client.create_api_key(private_sources, ['maps']) + key_name, key_value, private_tables_names = self._auth_api_client.create_api_key( + private_sources, ['maps'] + ) log.info( 'The map has been published. ' - 'The "{0}" Maps API key is being used for these datasets {1}. ' - 'You can manage your API keys on your account'.format(maps_api_key, private_tables_names)) - return maps_api_key + 'The "{0}" Maps API key with value "{1}" is being used for these datasets {2}. ' + 'You can manage your API keys on your account.'.format(key_name, key_value, private_tables_names)) + return key_value return DEFAULT_PUBLIC diff --git a/tests/unit/auth/test_auth_api_client.py b/tests/unit/auth/test_auth_api_client.py index 557ebac8c..3e245bcb2 100644 --- a/tests/unit/auth/test_auth_api_client.py +++ b/tests/unit/auth/test_auth_api_client.py @@ -33,8 +33,9 @@ def test_create_api_key(self, mocker): api_key_name = 'fake_name' auth_api_client = AuthAPIClient() - token, tables = auth_api_client.create_api_key([source], name=api_key_name) + name, token, tables = auth_api_client.create_api_key([source], name=api_key_name) + assert name == api_key_name assert token == TOKEN_MOCK def test_create_api_key_several_sources(self, mocker): @@ -44,6 +45,7 @@ def test_create_api_key_several_sources(self, mocker): api_key_name = 'fake_name' auth_api_client = AuthAPIClient() - token, tables = auth_api_client.create_api_key([source, source, source], name=api_key_name) + name, token, tables = auth_api_client.create_api_key([source, source, source], name=api_key_name) + assert name == api_key_name assert token == TOKEN_MOCK From b0926d9a9bcb5fe122316183c85bf93529687ddf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jes=C3=BAs=20Arroyo=20Torrens?= Date: Thu, 7 May 2020 23:02:23 +0200 Subject: [PATCH 5/6] Update publish viz examples --- .../publish_visualization_gdf.ipynb | 12 +++++++++++- .../publish_visualization_layout.ipynb | 4 ++-- .../publish_visualization_private_table.ipynb | 2 +- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/examples/publish_and_share/publish_visualization_gdf.ipynb b/examples/publish_and_share/publish_visualization_gdf.ipynb index 2b5e382e2..0e896f46e 100644 --- a/examples/publish_and_share/publish_visualization_gdf.ipynb +++ b/examples/publish_and_share/publish_visualization_gdf.ipynb @@ -92,6 +92,16 @@ "text": [ "Success! Data uploaded to table \"table_name\" correctly\n" ] + }, + { + "data": { + "text/plain": [ + "'table_name'" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ @@ -118,7 +128,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "The map has been published. The \"YTAEiju-Utii8mgtERiEkA\" Maps API key is being used for these datasets ['table_name']. You can manage your API keys on your account\n" + "The map has been published. The \"cartoframes_0a8a065b2cc026c5c33ee3dc269afcf1\" Maps API key with value \"YTAEiju-Utii8mgtERiEkA\" is being used for these datasets ['table_name']. You can manage your API keys on your account.\n" ] }, { diff --git a/examples/publish_and_share/publish_visualization_layout.ipynb b/examples/publish_and_share/publish_visualization_layout.ipynb index aaf619e3c..385d7490c 100644 --- a/examples/publish_and_share/publish_visualization_layout.ipynb +++ b/examples/publish_and_share/publish_visualization_layout.ipynb @@ -34,7 +34,7 @@ "layout_viz = Layout([\n", " Map(Layer('public_table')),\n", " Map(Layer('private_table'))\n", - "], is_static=True)" + "])" ] }, { @@ -46,7 +46,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "The map has been published. The \"7rj9ftFsOKUjSotnygh2jg\" Maps API key is being used for these datasets ['private_table']. You can manage your API keys on your account\n" + "The map has been published. The \"cartoframes_997c05771fd4e0916de49826722e51cd\" Maps API key with value \"7rj9ftFsOKUjSotnygh2jg\" is being used for these datasets ['private_table']. You can manage your API keys on your account.\n" ] }, { diff --git a/examples/publish_and_share/publish_visualization_private_table.ipynb b/examples/publish_and_share/publish_visualization_private_table.ipynb index 47097f006..84c5ae9bc 100644 --- a/examples/publish_and_share/publish_visualization_private_table.ipynb +++ b/examples/publish_and_share/publish_visualization_private_table.ipynb @@ -44,7 +44,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "The map has been published. The \"7rj9ftFsOKUjSotnygh2jg\" Maps API key is being used for these datasets ['private_table']. You can manage your API keys on your account\n" + "The map has been published. The \"cartoframes_997c05771fd4e0916de49826722e51cd\" Maps API key with value \"7rj9ftFsOKUjSotnygh2jg\" is being used for these datasets ['private_table']. You can manage your API keys on your account.\n" ] }, { From afd32dc27d1a907534665fc9f21e9b060c267a1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jes=C3=BAs=20Arroyo=20Torrens?= Date: Fri, 8 May 2020 15:58:09 +0200 Subject: [PATCH 6/6] Refactor duplicated options --- cartoframes/io/carto.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/cartoframes/io/carto.py b/cartoframes/io/carto.py index 567275776..0c9a62e1a 100644 --- a/cartoframes/io/carto.py +++ b/cartoframes/io/carto.py @@ -13,6 +13,7 @@ GEOM_COLUMN_NAME = 'the_geom' +IF_EXISTS_OPTIONS = ['fail', 'replace', 'append'] @send_metrics('data_downloaded') @@ -92,7 +93,6 @@ def to_carto(dataframe, table_name, credentials=None, if_exists='fail', geom_col if not is_valid_str(table_name): raise ValueError('Wrong table name. You should provide a valid table name.') - IF_EXISTS_OPTIONS = ['fail', 'replace', 'append'] if if_exists not in IF_EXISTS_OPTIONS: raise ValueError('Wrong option for the `if_exists` param. You should provide: {}.'.format( ', '.join(IF_EXISTS_OPTIONS))) @@ -227,7 +227,6 @@ def copy_table(table_name, new_table_name, credentials=None, if_exists='fail', l if not is_valid_str(new_table_name): raise ValueError('Wrong new table name. You should provide a valid table name.') - IF_EXISTS_OPTIONS = ['fail', 'replace', 'append'] if if_exists not in IF_EXISTS_OPTIONS: raise ValueError('Wrong option for the `if_exists` param. You should provide: {}.'.format( ', '.join(IF_EXISTS_OPTIONS))) @@ -261,7 +260,6 @@ def create_table_from_query(query, new_table_name, credentials=None, if_exists=' if not is_valid_str(new_table_name): raise ValueError('Wrong new table name. You should provide a valid table name.') - IF_EXISTS_OPTIONS = ['fail', 'replace', 'append'] if if_exists not in IF_EXISTS_OPTIONS: raise ValueError('Wrong option for the `if_exists` param. You should provide: {}.'.format( ', '.join(IF_EXISTS_OPTIONS)))