diff --git a/superset/dataframe.py b/superset/dataframe.py index 30ba4c776bfcc..2fecad9264147 100644 --- a/superset/dataframe.py +++ b/superset/dataframe.py @@ -27,23 +27,26 @@ INFER_COL_TYPES_SAMPLE_SIZE = 100 -def dedup(l, suffix='__'): +def dedup(l, suffix='__', case_sensitive=True): """De-duplicates a list of string by suffixing a counter Always returns the same number of entries as provided, and always returns - unique values. + unique values. Case sensitive comparison by default. - >>> print(','.join(dedup(['foo', 'bar', 'bar', 'bar']))) - foo,bar,bar__1,bar__2 + >>> print(','.join(dedup(['foo', 'bar', 'bar', 'bar', 'Bar']))) + foo,bar,bar__1,bar__2,Bar + >>> print(','.join(dedup(['foo', 'bar', 'bar', 'bar', 'Bar'], case_sensitive=False))) + foo,bar,bar__1,bar__2,Bar__3 """ new_l = [] seen = {} for s in l: - if s in seen: - seen[s] += 1 - s += suffix + str(seen[s]) + s_fixed_case = s if case_sensitive else s.lower() + if s_fixed_case in seen: + seen[s_fixed_case] += 1 + s += suffix + str(seen[s_fixed_case]) else: - seen[s] = 0 + seen[s_fixed_case] = 0 new_l.append(s) return new_l @@ -70,7 +73,9 @@ def __init__(self, data, cursor_description, db_engine_spec): if cursor_description: column_names = [col[0] for col in cursor_description] - self.column_names = dedup(column_names) + case_sensitive = db_engine_spec.consistent_case_sensitivity + self.column_names = dedup(column_names, + case_sensitive=case_sensitive) data = data or [] self.df = ( diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index 4967d30647cad..c07910af4d2fd 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -101,6 +101,7 @@ class BaseEngineSpec(object): time_secondary_columns = False inner_joins = True allows_subquery = True + consistent_case_sensitivity = True # do results have same case as qry for col names? @classmethod def get_time_grains(cls): @@ -318,7 +319,6 @@ def select_star(cls, my_db, table_name, engine, schema=None, limit=100, if show_cols: fields = [sqla.column(c.get('name')) for c in cols] - full_table_name = table_name quote = engine.dialect.identifier_preparer.quote if schema: full_table_name = quote(schema) + '.' + quote(table_name) @@ -366,6 +366,57 @@ def get_configuration_for_impersonation(cls, uri, impersonate_user, username): def execute(cursor, query, async=False): cursor.execute(query) + @classmethod + def adjust_df_column_names(cls, df, fd): + """Based of fields in form_data, return dataframe with new column names + + Usually sqla engines return column names whose case matches that of the + original query. For example: + SELECT 1 as col1, 2 as COL2, 3 as Col_3 + will usually result in the following df.columns: + ['col1', 'COL2', 'Col_3']. + For these engines there is no need to adjust the dataframe column names + (default behavior). However, some engines (at least Snowflake, Oracle and + Redshift) return column names with different case than in the original query, + usually all uppercase. For these the column names need to be adjusted to + correspond to the case of the fields specified in the form data for Viz + to work properly. This adjustment can be done here. + """ + if cls.consistent_case_sensitivity: + return df + else: + return cls.align_df_col_names_with_form_data(df, fd) + + @staticmethod + def align_df_col_names_with_form_data(df, fd): + """Helper function to rename columns that have changed case during query. + + Returns a dataframe where column names have been adjusted to correspond with + column names in form data (case insensitive). Examples: + dataframe: 'col1', form_data: 'col1' -> no change + dataframe: 'COL1', form_data: 'col1' -> dataframe column renamed: 'col1' + dataframe: 'col1', form_data: 'Col1' -> dataframe column renamed: 'Col1' + """ + + columns = set() + lowercase_mapping = {} + + metrics = utils.get_metric_names(fd.get('metrics', [])) + groupby = fd.get('groupby', []) + other_cols = [utils.DTTM_ALIAS] + for col in metrics + groupby + other_cols: + columns.add(col) + lowercase_mapping[col.lower()] = col + + rename_cols = {} + for col in df.columns: + if col not in columns: + orig_col = lowercase_mapping.get(col.lower()) + if orig_col: + rename_cols[col] = orig_col + + return df.rename(index=str, columns=rename_cols) + class PostgresBaseEngineSpec(BaseEngineSpec): """ Abstract class for Postgres 'like' databases """ @@ -414,6 +465,7 @@ def get_table_names(cls, schema, inspector): class SnowflakeEngineSpec(PostgresBaseEngineSpec): engine = 'snowflake' + consistent_case_sensitivity = False time_grain_functions = { None: '{col}', 'PT1S': "DATE_TRUNC('SECOND', {col})", @@ -434,6 +486,15 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec): 'P1Y': "DATE_TRUNC('YEAR', {col})", } + @classmethod + def adjust_database_uri(cls, uri, selected_schema=None): + database = uri.database + if '/' in uri.database: + database = uri.database.split('/')[0] + if selected_schema: + uri.database = database + '/' + selected_schema + return uri + class VerticaEngineSpec(PostgresBaseEngineSpec): engine = 'vertica' @@ -441,11 +502,13 @@ class VerticaEngineSpec(PostgresBaseEngineSpec): class RedshiftEngineSpec(PostgresBaseEngineSpec): engine = 'redshift' + consistent_case_sensitivity = False class OracleEngineSpec(PostgresBaseEngineSpec): engine = 'oracle' limit_method = LimitMethod.WRAP_SQL + consistent_case_sensitivity = False time_grain_functions = { None: '{col}', diff --git a/superset/viz.py b/superset/viz.py index 770d7eaff09d8..7ba8645bbcfa2 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -153,7 +153,7 @@ def run_extra_queries(self): def handle_nulls(self, df): fillna = self.get_fillna_for_columns(df.columns) - df = df.fillna(fillna) + return df.fillna(fillna) def get_fillna_for_col(self, col): """Returns the value to use as filler for a specific Column.type""" @@ -217,7 +217,7 @@ def get_df(self, query_obj=None): self.df_metrics_to_num(df, query_obj.get('metrics') or []) df.replace([np.inf, -np.inf], np.nan) - self.handle_nulls(df) + df = self.handle_nulls(df) return df @staticmethod @@ -382,6 +382,9 @@ def get_df_payload(self, query_obj=None): if query_obj and not is_loaded: try: df = self.get_df(query_obj) + if hasattr(self.datasource.database, 'db_engine_spec'): + db_engine_spec = self.datasource.database.db_engine_spec + df = db_engine_spec.adjust_df_column_names(df, self.form_data) if self.status != utils.QueryStatus.FAILED: stats_logger.incr('loaded_from_source') is_loaded = True diff --git a/tests/dataframe_test.py b/tests/dataframe_test.py index fdba431491a1d..a773f08c4e376 100644 --- a/tests/dataframe_test.py +++ b/tests/dataframe_test.py @@ -16,12 +16,16 @@ def test_dedup(self): ['foo', 'bar'], ) self.assertEquals( - dedup(['foo', 'bar', 'foo', 'bar']), - ['foo', 'bar', 'foo__1', 'bar__1'], + dedup(['foo', 'bar', 'foo', 'bar', 'Foo']), + ['foo', 'bar', 'foo__1', 'bar__1', 'Foo'], ) self.assertEquals( - dedup(['foo', 'bar', 'bar', 'bar']), - ['foo', 'bar', 'bar__1', 'bar__2'], + dedup(['foo', 'bar', 'bar', 'bar', 'Bar']), + ['foo', 'bar', 'bar__1', 'bar__2', 'Bar'], + ) + self.assertEquals( + dedup(['foo', 'bar', 'bar', 'bar', 'Bar'], case_sensitive=False), + ['foo', 'bar', 'bar__1', 'bar__2', 'Bar__3'], ) def test_get_columns_basic(self):