Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Match viz dataframe column case to form_data fields for Snowflake, Oracle and Redshift #5487

Merged
merged 17 commits into from
Aug 3, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions superset/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,26 @@
INFER_COL_TYPES_SAMPLE_SIZE = 100


def dedup(l, suffix='__'):
def dedup(l, suffix='__', case_sensitive=True):
"""De-duplicates a list of string by suffixing a counter

Always returns the same number of entries as provided, and always returns
unique values.
unique values. Case sensitive comparison by default.

>>> print(','.join(dedup(['foo', 'bar', 'bar', 'bar'])))
foo,bar,bar__1,bar__2
>>> print(','.join(dedup(['foo', 'bar', 'bar', 'bar', 'Bar'])))
foo,bar,bar__1,bar__2,Bar
>>> print(','.join(dedup(['foo', 'bar', 'bar', 'bar', 'Bar'], case_sensitive=False)))
foo,bar,bar__1,bar__2,Bar__3
"""
new_l = []
seen = {}
for s in l:
if s in seen:
seen[s] += 1
s += suffix + str(seen[s])
s_fixed_case = s if case_sensitive else s.lower()
if s_fixed_case in seen:
seen[s_fixed_case] += 1
s += suffix + str(seen[s_fixed_case])
else:
seen[s] = 0
seen[s_fixed_case] = 0
new_l.append(s)
return new_l

Expand All @@ -70,7 +73,9 @@ def __init__(self, data, cursor_description, db_engine_spec):
if cursor_description:
column_names = [col[0] for col in cursor_description]

self.column_names = dedup(column_names)
case_sensitive = db_engine_spec.consistent_case_sensitivity
self.column_names = dedup(column_names,
case_sensitive=case_sensitive)

data = data or []
self.df = (
Expand Down
65 changes: 64 additions & 1 deletion superset/db_engine_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ class BaseEngineSpec(object):
time_secondary_columns = False
inner_joins = True
allows_subquery = True
consistent_case_sensitivity = True # do results have same case as qry for col names?

@classmethod
def get_time_grains(cls):
Expand Down Expand Up @@ -318,7 +319,6 @@ def select_star(cls, my_db, table_name, engine, schema=None, limit=100,

if show_cols:
fields = [sqla.column(c.get('name')) for c in cols]
full_table_name = table_name
quote = engine.dialect.identifier_preparer.quote
if schema:
full_table_name = quote(schema) + '.' + quote(table_name)
Expand Down Expand Up @@ -366,6 +366,57 @@ def get_configuration_for_impersonation(cls, uri, impersonate_user, username):
def execute(cursor, query, async=False):
cursor.execute(query)

@classmethod
def adjust_df_column_names(cls, df, fd):
"""Based of fields in form_data, return dataframe with new column names

Usually sqla engines return column names whose case matches that of the
original query. For example:
SELECT 1 as col1, 2 as COL2, 3 as Col_3
will usually result in the following df.columns:
['col1', 'COL2', 'Col_3'].
For these engines there is no need to adjust the dataframe column names
(default behavior). However, some engines (at least Snowflake, Oracle and
Redshift) return column names with different case than in the original query,
usually all uppercase. For these the column names need to be adjusted to
correspond to the case of the fields specified in the form data for Viz
to work properly. This adjustment can be done here.
"""
if cls.consistent_case_sensitivity:
return df
else:
return cls.align_df_col_names_with_form_data(df, fd)

@staticmethod
def align_df_col_names_with_form_data(df, fd):
"""Helper function to rename columns that have changed case during query.

Returns a dataframe where column names have been adjusted to correspond with
column names in form data (case insensitive). Examples:
dataframe: 'col1', form_data: 'col1' -> no change
dataframe: 'COL1', form_data: 'col1' -> dataframe column renamed: 'col1'
dataframe: 'col1', form_data: 'Col1' -> dataframe column renamed: 'Col1'
"""

columns = set()
lowercase_mapping = {}

metrics = utils.get_metric_names(fd.get('metrics', []))
groupby = fd.get('groupby', [])
other_cols = [utils.DTTM_ALIAS]
for col in metrics + groupby + other_cols:
columns.add(col)
lowercase_mapping[col.lower()] = col

rename_cols = {}
for col in df.columns:
if col not in columns:
orig_col = lowercase_mapping.get(col.lower())
if orig_col:
rename_cols[col] = orig_col

return df.rename(index=str, columns=rename_cols)


class PostgresBaseEngineSpec(BaseEngineSpec):
""" Abstract class for Postgres 'like' databases """
Expand Down Expand Up @@ -414,6 +465,7 @@ def get_table_names(cls, schema, inspector):

class SnowflakeEngineSpec(PostgresBaseEngineSpec):
engine = 'snowflake'
consistent_case_sensitivity = False
time_grain_functions = {
None: '{col}',
'PT1S': "DATE_TRUNC('SECOND', {col})",
Expand All @@ -434,18 +486,29 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
'P1Y': "DATE_TRUNC('YEAR', {col})",
}

@classmethod
def adjust_database_uri(cls, uri, selected_schema=None):
database = uri.database
if '/' in uri.database:
database = uri.database.split('/')[0]
if selected_schema:
uri.database = database + '/' + selected_schema
return uri


class VerticaEngineSpec(PostgresBaseEngineSpec):
engine = 'vertica'


class RedshiftEngineSpec(PostgresBaseEngineSpec):
engine = 'redshift'
consistent_case_sensitivity = False


class OracleEngineSpec(PostgresBaseEngineSpec):
engine = 'oracle'
limit_method = LimitMethod.WRAP_SQL
consistent_case_sensitivity = False

time_grain_functions = {
None: '{col}',
Expand Down
7 changes: 5 additions & 2 deletions superset/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def run_extra_queries(self):

def handle_nulls(self, df):
fillna = self.get_fillna_for_columns(df.columns)
df = df.fillna(fillna)
return df.fillna(fillna)

def get_fillna_for_col(self, col):
"""Returns the value to use as filler for a specific Column.type"""
Expand Down Expand Up @@ -217,7 +217,7 @@ def get_df(self, query_obj=None):
self.df_metrics_to_num(df, query_obj.get('metrics') or [])

df.replace([np.inf, -np.inf], np.nan)
self.handle_nulls(df)
df = self.handle_nulls(df)
return df

@staticmethod
Expand Down Expand Up @@ -382,6 +382,9 @@ def get_df_payload(self, query_obj=None):
if query_obj and not is_loaded:
try:
df = self.get_df(query_obj)
if hasattr(self.datasource.database, 'db_engine_spec'):
db_engine_spec = self.datasource.database.db_engine_spec
df = db_engine_spec.adjust_df_column_names(df, self.form_data)
if self.status != utils.QueryStatus.FAILED:
stats_logger.incr('loaded_from_source')
is_loaded = True
Expand Down
12 changes: 8 additions & 4 deletions tests/dataframe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@ def test_dedup(self):
['foo', 'bar'],
)
self.assertEquals(
dedup(['foo', 'bar', 'foo', 'bar']),
['foo', 'bar', 'foo__1', 'bar__1'],
dedup(['foo', 'bar', 'foo', 'bar', 'Foo']),
['foo', 'bar', 'foo__1', 'bar__1', 'Foo'],
)
self.assertEquals(
dedup(['foo', 'bar', 'bar', 'bar']),
['foo', 'bar', 'bar__1', 'bar__2'],
dedup(['foo', 'bar', 'bar', 'bar', 'Bar']),
['foo', 'bar', 'bar__1', 'bar__2', 'Bar'],
)
self.assertEquals(
dedup(['foo', 'bar', 'bar', 'bar', 'Bar'], case_sensitive=False),
['foo', 'bar', 'bar__1', 'bar__2', 'Bar__3'],
)

def test_get_columns_basic(self):
Expand Down