Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Read using copy #570

Merged
merged 20 commits into from
Apr 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Release 2019-xx-xx
Updates

- max line length from 80 to 120
- Rewrite context.read method using COPY TO (#570)

0.9.2
-----
Expand Down
15 changes: 3 additions & 12 deletions cartoframes/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,15 +206,13 @@ def _is_org_user(self):
# is an org user if first item is not `public`
return res['rows'][0]['unnest'] != 'public'

def read(self, table_name, limit=None, index='cartodb_id',
decode_geom=False, shared_user=None):
def read(self, table_name, limit=None, decode_geom=False, shared_user=None):
"""Read a table from CARTO into a pandas DataFrames.

Args:
table_name (str): Name of table in user's CARTO account.
limit (int, optional): Read only `limit` lines from
`table_name`. Defaults to ``None``, which reads the full table.
index (str, optional): Not currently in use.
decode_geom (bool, optional): Decodes CARTO's geometries into a
`Shapely <https://github.com/Toblerity/Shapely>`__
object that can be used, for example, in `GeoPandas
Expand All @@ -236,16 +234,9 @@ def read(self, table_name, limit=None, index='cartodb_id',
# choose schema (default user - org or standalone - or shared)
schema = 'public' if not self.is_org else (
shared_user or self.creds.username())
query = 'SELECT * FROM "{schema}"."{table_name}"'.format(
table_name=table_name,
schema=schema)
if limit is not None:
if isinstance(limit, int) and (limit >= 0):
query += ' LIMIT {limit}'.format(limit=limit)
else:
raise ValueError("`limit` parameter must an integer >= 0")

return self.query(query, decode_geom=decode_geom)
dataset = Dataset(self, table_name, schema)
return dataset.download(limit, decode_geom)

@utils.temp_ignore_warnings
def tables(self):
Expand Down
69 changes: 68 additions & 1 deletion cartoframes/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ class Dataset(object):
PUBLIC = 'public'
LINK = 'link'

def __init__(self, carto_context, table_name, df=None):
def __init__(self, carto_context, table_name, schema='public', df=None):
self.cc = carto_context
self.table_name = _norm_colname(table_name)
self.schema = schema
self.df = df
warn('Table will be named `{}`'.format(table_name))

Expand All @@ -41,6 +42,15 @@ def upload(self, with_lonlat=None, if_exists='fail'):

return self

def download(self, limit=None, decode_geom=False):
table_columns = self._get_table_columns()

query = self._get_read_query(table_columns, limit)
result = self.cc.copy_client.copyto_stream(query)
df = pd.read_csv(result)

return _clean_dataframe_from_carto(df, table_columns, decode_geom)

def exists(self):
"""Checks to see if table exists"""
try:
Expand Down Expand Up @@ -130,6 +140,33 @@ def _create_table_query(self, with_lonlat=None):
create_query = '''CREATE TABLE {table_name} ({cols})'''.format(table_name=self.table_name, cols=cols)
return create_query

def _get_read_query(self, table_columns, limit=None):
"""Create the read (COPY TO) query"""
query_columns = list(table_columns.keys())
query_columns.remove('the_geom_webmercator')

query = 'SELECT {columns} FROM "{schema}"."{table_name}"'.format(
table_name=self.table_name,
schema=self.schema,
columns=', '.join(query_columns))

if limit is not None:
if isinstance(limit, int) and (limit >= 0):
query += ' LIMIT {limit}'.format(limit=limit)
else:
raise ValueError("`limit` parameter must an integer >= 0")

return 'COPY ({query}) TO stdout WITH (FORMAT csv, HEADER true)'.format(query=query)

def _get_table_columns(self):
"""Get column names and types from a table"""
query = 'SELECT * FROM "{schema}"."{table}" limit 0'.format(table=self.table_name, schema=self.schema)
table_info = self.cc.sql_client.send(query)
if 'fields' in table_info:
return table_info['fields']

return None


def _norm_colname(colname):
"""Given an arbitrary column name, translate to a SQL-normalized column
Expand Down Expand Up @@ -251,3 +288,33 @@ def _decode_geom(ewkb):
except Exception:
pass
return None


def _clean_dataframe_from_carto(df, table_columns, decode_geom=False):
simon-contreras-deel marked this conversation as resolved.
Show resolved Hide resolved
"""Clean a DataFrame with a dataset from CARTO:
- use cartodb_id as DataFrame index
- process date columns
- decode geom

Args:
df (pandas.DataFrame): DataFrame with a dataset from CARTO.
table_columns (dict): column names and types from a table.
decode_geom (bool, optional): Decodes CARTO's geometries into a
`Shapely <https://github.com/Toblerity/Shapely>`__
object that can be used, for example, in `GeoPandas
<http://geopandas.org/>`__.

Returns:
pandas.DataFrame
"""
if 'cartodb_id' in df.columns:
df.set_index('cartodb_id', inplace=True)

for column_name in table_columns:
if table_columns[column_name]['type'] == 'date':
df[column_name] = pd.to_datetime(df[column_name], errors='ignore')

if decode_geom:
df['geometry'] = df.the_geom.apply(_decode_geom)

return df
6 changes: 3 additions & 3 deletions test/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,12 @@ def test_cartocontext_read(self):
api_key=self.apikey)
# fails if limit is smaller than zero
with self.assertRaises(ValueError):
df = cc.read('sea_horses', limit=-10)
df = cc.read(self.test_read_table, limit=-10)
# fails if not an int
with self.assertRaises(ValueError):
df = cc.read('sea_horses', limit=3.14159)
df = cc.read(self.test_read_table, limit=3.14159)
with self.assertRaises(ValueError):
df = cc.read('sea_horses', limit='acadia')
df = cc.read(self.test_read_table, limit='acadia')

# fails on non-existent table
with self.assertRaises(CartoException):
Expand Down