diff --git a/superset/assets/spec/javascripts/components/TableSelector_spec.jsx b/superset/assets/spec/javascripts/components/TableSelector_spec.jsx index 70e2cca1f1925..13665927556d3 100644 --- a/superset/assets/spec/javascripts/components/TableSelector_spec.jsx +++ b/superset/assets/spec/javascripts/components/TableSelector_spec.jsx @@ -208,19 +208,20 @@ describe('TableSelector', () => { it('test 1', () => { wrapper.instance().changeTable({ - value: 'birth_names', + value: { schema: 'main', table: 'birth_names' }, label: 'birth_names', }); expect(wrapper.state().tableName).toBe('birth_names'); }); - it('test 2', () => { + it('should call onTableChange with schema from table object', () => { + wrapper.setProps({ schema: null }); wrapper.instance().changeTable({ - value: 'main.my_table', - label: 'my_table', + value: { schema: 'other_schema', table: 'my_table' }, + label: 'other_schema.my_table', }); expect(mockedProps.onTableChange.getCall(0).args[0]).toBe('my_table'); - expect(mockedProps.onTableChange.getCall(0).args[1]).toBe('main'); + expect(mockedProps.onTableChange.getCall(0).args[1]).toBe('other_schema'); }); }); diff --git a/superset/assets/spec/javascripts/sqllab/fixtures.js b/superset/assets/spec/javascripts/sqllab/fixtures.js index 6471be1286556..f43f43f550a68 100644 --- a/superset/assets/spec/javascripts/sqllab/fixtures.js +++ b/superset/assets/spec/javascripts/sqllab/fixtures.js @@ -329,15 +329,15 @@ export const databases = { export const tables = { options: [ { - value: 'birth_names', + value: { schema: 'main', table: 'birth_names' }, label: 'birth_names', }, { - value: 'energy_usage', + value: { schema: 'main', table: 'energy_usage' }, label: 'energy_usage', }, { - value: 'wb_health_population', + value: { schema: 'main', table: 'wb_health_population' }, label: 'wb_health_population', }, ], diff --git a/superset/assets/src/SqlLab/components/SqlEditorLeftBar.jsx b/superset/assets/src/SqlLab/components/SqlEditorLeftBar.jsx index 9d0796cd8ce74..43ea4873a0f73 100644 --- a/superset/assets/src/SqlLab/components/SqlEditorLeftBar.jsx +++ b/superset/assets/src/SqlLab/components/SqlEditorLeftBar.jsx @@ -83,17 +83,10 @@ export default class SqlEditorLeftBar extends React.PureComponent { this.setState({ tableName: '' }); return; } - const namePieces = tableOpt.value.split('.'); - let tableName = namePieces[0]; - let schemaName = this.props.queryEditor.schema; - if (namePieces.length === 1) { - this.setState({ tableName }); - } else { - schemaName = namePieces[0]; - tableName = namePieces[1]; - this.setState({ tableName }); - this.props.actions.queryEditorSetSchema(this.props.queryEditor, schemaName); - } + const schemaName = tableOpt.value.schema; + const tableName = tableOpt.value.table; + this.setState({ tableName }); + this.props.actions.queryEditorSetSchema(this.props.queryEditor, schemaName); this.props.actions.addTable(this.props.queryEditor, tableName, schemaName); } diff --git a/superset/assets/src/components/TableSelector.jsx b/superset/assets/src/components/TableSelector.jsx index ba2cebb2799d8..940e1c274b93a 100644 --- a/superset/assets/src/components/TableSelector.jsx +++ b/superset/assets/src/components/TableSelector.jsx @@ -170,13 +170,8 @@ export default class TableSelector extends React.PureComponent { this.setState({ tableName: '' }); return; } - const namePieces = tableOpt.value.split('.'); - let tableName = namePieces[0]; - let schemaName = this.props.schema; - if (namePieces.length > 1) { - schemaName = namePieces[0]; - tableName = namePieces[1]; - } + const schemaName = tableOpt.value.schema; + const tableName = tableOpt.value.table; if (this.props.tableNameSticky) { this.setState({ tableName }, this.onChange); } diff --git a/superset/cli.py b/superset/cli.py index 7b441b47bdf9e..edb0102400571 100755 --- a/superset/cli.py +++ b/superset/cli.py @@ -288,9 +288,9 @@ def update_datasources_cache(): if database.allow_multi_schema_metadata_fetch: print('Fetching {} datasources ...'.format(database.name)) try: - database.all_table_names_in_database( + database.get_all_table_names_in_database( force=True, cache=True, cache_timeout=24 * 60 * 60) - database.all_view_names_in_database( + database.get_all_view_names_in_database( force=True, cache=True, cache_timeout=24 * 60 * 60) except Exception as e: print('{}'.format(str(e))) diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index 35a591fa10202..67aba1264b9e5 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -122,6 +122,7 @@ class BaseEngineSpec(object): force_column_alias_quotes = False arraysize = 0 max_column_name_length = 0 + try_remove_schema_from_table_name = True @classmethod def get_time_expr(cls, expr, pdf, time_grain, grain): @@ -279,33 +280,32 @@ def convert_dttm(cls, target_type, dttm): return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S')) @classmethod - def fetch_result_sets(cls, db, datasource_type): - """Returns a list of tables [schema1.table1, schema2.table2, ...] + def get_all_datasource_names(cls, db, datasource_type: str) \ + -> List[utils.DatasourceName]: + """Returns a list of all tables or views in database. - Datasource_type can be 'table' or 'view'. - Empty schema corresponds to the list of full names of the all - tables or views: .. + :param db: Database instance + :param datasource_type: Datasource_type can be 'table' or 'view' + :return: List of all datasources in database or schema """ - schemas = db.all_schema_names(cache=db.schema_cache_enabled, - cache_timeout=db.schema_cache_timeout, - force=True) - all_result_sets = [] + schemas = db.get_all_schema_names(cache=db.schema_cache_enabled, + cache_timeout=db.schema_cache_timeout, + force=True) + all_datasources: List[utils.DatasourceName] = [] for schema in schemas: if datasource_type == 'table': - all_datasource_names = db.all_table_names_in_schema( + all_datasources += db.get_all_table_names_in_schema( schema=schema, force=True, cache=db.table_cache_enabled, cache_timeout=db.table_cache_timeout) elif datasource_type == 'view': - all_datasource_names = db.all_view_names_in_schema( + all_datasources += db.get_all_view_names_in_schema( schema=schema, force=True, cache=db.table_cache_enabled, cache_timeout=db.table_cache_timeout) else: raise Exception(f'Unsupported datasource_type: {datasource_type}') - all_result_sets += [ - '{}.{}'.format(schema, t) for t in all_datasource_names] - return all_result_sets + return all_datasources @classmethod def handle_cursor(cls, cursor, query, session): @@ -352,11 +352,17 @@ def get_schema_names(cls, inspector): @classmethod def get_table_names(cls, inspector, schema): - return sorted(inspector.get_table_names(schema)) + tables = inspector.get_table_names(schema) + if schema and cls.try_remove_schema_from_table_name: + tables = [re.sub(f'^{schema}\\.', '', table) for table in tables] + return sorted(tables) @classmethod def get_view_names(cls, inspector, schema): - return sorted(inspector.get_view_names(schema)) + views = inspector.get_view_names(schema) + if schema and cls.try_remove_schema_from_table_name: + views = [re.sub(f'^{schema}\\.', '', view) for view in views] + return sorted(views) @classmethod def get_columns(cls, inspector: Inspector, table_name: str, schema: str) -> list: @@ -528,6 +534,7 @@ def convert_dttm(cls, target_type, dttm): class PostgresEngineSpec(PostgresBaseEngineSpec): engine = 'postgresql' max_column_name_length = 63 + try_remove_schema_from_table_name = False @classmethod def get_table_names(cls, inspector, schema): @@ -685,29 +692,25 @@ def epoch_to_dttm(cls): return "datetime({col}, 'unixepoch')" @classmethod - def fetch_result_sets(cls, db, datasource_type): - schemas = db.all_schema_names(cache=db.schema_cache_enabled, - cache_timeout=db.schema_cache_timeout, - force=True) - all_result_sets = [] + def get_all_datasource_names(cls, db, datasource_type: str) \ + -> List[utils.DatasourceName]: + schemas = db.get_all_schema_names(cache=db.schema_cache_enabled, + cache_timeout=db.schema_cache_timeout, + force=True) schema = schemas[0] if datasource_type == 'table': - all_datasource_names = db.all_table_names_in_schema( + return db.get_all_table_names_in_schema( schema=schema, force=True, cache=db.table_cache_enabled, cache_timeout=db.table_cache_timeout) elif datasource_type == 'view': - all_datasource_names = db.all_view_names_in_schema( + return db.get_all_view_names_in_schema( schema=schema, force=True, cache=db.table_cache_enabled, cache_timeout=db.table_cache_timeout) else: raise Exception(f'Unsupported datasource_type: {datasource_type}') - all_result_sets += [ - '{}.{}'.format(schema, t) for t in all_datasource_names] - return all_result_sets - @classmethod def convert_dttm(cls, target_type, dttm): iso = dttm.isoformat().replace('T', ' ') @@ -1107,24 +1110,19 @@ def epoch_to_dttm(cls): return 'from_unixtime({col})' @classmethod - def fetch_result_sets(cls, db, datasource_type): - """Returns a list of tables [schema1.table1, schema2.table2, ...] - - Datasource_type can be 'table' or 'view'. - Empty schema corresponds to the list of full names of the all - tables or views: .. - """ - result_set_df = db.get_df( + def get_all_datasource_names(cls, db, datasource_type: str) \ + -> List[utils.DatasourceName]: + datasource_df = db.get_df( """SELECT table_schema, table_name FROM INFORMATION_SCHEMA.{}S ORDER BY concat(table_schema, '.', table_name)""".format( datasource_type.upper(), ), None) - result_sets = [] - for unused, row in result_set_df.iterrows(): - result_sets.append('{}.{}'.format( - row['table_schema'], row['table_name'])) - return result_sets + datasource_names: List[utils.DatasourceName] = [] + for unused, row in datasource_df.iterrows(): + datasource_names.append(utils.DatasourceName( + schema=row['table_schema'], table=row['table_name'])) + return datasource_names @classmethod def extra_table_metadata(cls, database, table_name, schema_name): @@ -1385,9 +1383,9 @@ def patch(cls): hive.Cursor.fetch_logs = patched_hive.fetch_logs @classmethod - def fetch_result_sets(cls, db, datasource_type): - return BaseEngineSpec.fetch_result_sets( - db, datasource_type) + def get_all_datasource_names(cls, db, datasource_type: str) \ + -> List[utils.DatasourceName]: + return BaseEngineSpec.get_all_datasource_names(db, datasource_type) @classmethod def fetch_data(cls, cursor, limit): diff --git a/superset/models/core.py b/superset/models/core.py index e16a234bfd723..047a3ddb11b11 100644 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -23,6 +23,7 @@ import json import logging import textwrap +from typing import List from flask import escape, g, Markup, request from flask_appbuilder import Model @@ -65,6 +66,7 @@ PASSWORD_MASK = 'X' * 10 + def set_related_perm(mapper, connection, target): # noqa src_class = target.cls_model id_ = target.datasource_id @@ -184,7 +186,7 @@ def clone(self): description=self.description, cache_timeout=self.cache_timeout) - @datasource.getter + @datasource.getter # type: ignore @utils.memoized def get_datasource(self): return ( @@ -210,7 +212,7 @@ def datasource_edit_url(self): datasource = self.datasource return datasource.url if datasource else None - @property + @property # type: ignore @utils.memoized def viz(self): d = json.loads(self.params) @@ -930,100 +932,87 @@ def inspector(self): @cache_util.memoized_func( key=lambda *args, **kwargs: 'db:{}:schema:None:table_list', attribute_in_key='id') - def all_table_names_in_database(self, cache=False, - cache_timeout=None, force=False): + def get_all_table_names_in_database(self, cache: bool = False, + cache_timeout: bool = None, + force=False) -> List[utils.DatasourceName]: """Parameters need to be passed as keyword arguments.""" if not self.allow_multi_schema_metadata_fetch: return [] - return self.db_engine_spec.fetch_result_sets(self, 'table') + return self.db_engine_spec.get_all_datasource_names(self, 'table') @cache_util.memoized_func( key=lambda *args, **kwargs: 'db:{}:schema:None:view_list', attribute_in_key='id') - def all_view_names_in_database(self, cache=False, - cache_timeout=None, force=False): + def get_all_view_names_in_database(self, cache: bool = False, + cache_timeout: bool = None, + force: bool = False) -> List[utils.DatasourceName]: """Parameters need to be passed as keyword arguments.""" if not self.allow_multi_schema_metadata_fetch: return [] - return self.db_engine_spec.fetch_result_sets(self, 'view') + return self.db_engine_spec.get_all_datasource_names(self, 'view') @cache_util.memoized_func( key=lambda *args, **kwargs: 'db:{{}}:schema:{}:table_list'.format( kwargs.get('schema')), attribute_in_key='id') - def all_table_names_in_schema(self, schema, cache=False, - cache_timeout=None, force=False): + def get_all_table_names_in_schema(self, schema: str, cache: bool = False, + cache_timeout: int = None, force: bool = False): """Parameters need to be passed as keyword arguments. For unused parameters, they are referenced in cache_util.memoized_func decorator. :param schema: schema name - :type schema: str :param cache: whether cache is enabled for the function - :type cache: bool :param cache_timeout: timeout in seconds for the cache - :type cache_timeout: int :param force: whether to force refresh the cache - :type force: bool - :return: table list - :rtype: list + :return: list of tables """ - tables = [] try: tables = self.db_engine_spec.get_table_names( inspector=self.inspector, schema=schema) + return [utils.DatasourceName(table=table, schema=schema) for table in tables] except Exception as e: logging.exception(e) - return tables @cache_util.memoized_func( key=lambda *args, **kwargs: 'db:{{}}:schema:{}:view_list'.format( kwargs.get('schema')), attribute_in_key='id') - def all_view_names_in_schema(self, schema, cache=False, - cache_timeout=None, force=False): + def get_all_view_names_in_schema(self, schema: str, cache: bool = False, + cache_timeout: int = None, force: bool = False): """Parameters need to be passed as keyword arguments. For unused parameters, they are referenced in cache_util.memoized_func decorator. :param schema: schema name - :type schema: str :param cache: whether cache is enabled for the function - :type cache: bool :param cache_timeout: timeout in seconds for the cache - :type cache_timeout: int :param force: whether to force refresh the cache - :type force: bool - :return: view list - :rtype: list + :return: list of views """ - views = [] try: views = self.db_engine_spec.get_view_names( inspector=self.inspector, schema=schema) + return [utils.DatasourceName(table=view, schema=schema) for view in views] except Exception as e: logging.exception(e) - return views @cache_util.memoized_func( key=lambda *args, **kwargs: 'db:{}:schema_list', attribute_in_key='id') - def all_schema_names(self, cache=False, cache_timeout=None, force=False): + def get_all_schema_names(self, cache: bool = False, cache_timeout: int = None, + force: bool = False) -> List[str]: """Parameters need to be passed as keyword arguments. For unused parameters, they are referenced in cache_util.memoized_func decorator. :param cache: whether cache is enabled for the function - :type cache: bool :param cache_timeout: timeout in seconds for the cache - :type cache_timeout: int :param force: whether to force refresh the cache - :type force: bool :return: schema list - :rtype: list """ return self.db_engine_spec.get_schema_names(self.inspector) @@ -1232,7 +1221,7 @@ def username(self): def datasource(self): return self.get_datasource - @datasource.getter + @datasource.getter # type: ignore @utils.memoized def get_datasource(self): # pylint: disable=no-member diff --git a/superset/security.py b/superset/security.py index f8ae0571560ec..89eab5d53bf16 100644 --- a/superset/security.py +++ b/superset/security.py @@ -17,6 +17,7 @@ # pylint: disable=C,R,W """A set of constants and methods to manage permissions and security""" import logging +from typing import List from flask import g from flask_appbuilder.security.sqla import models as ab_models @@ -26,6 +27,7 @@ from superset import sql_parse from superset.connectors.connector_registry import ConnectorRegistry from superset.exceptions import SupersetSecurityException +from superset.utils.core import DatasourceName class SupersetSecurityManager(SecurityManager): @@ -240,7 +242,9 @@ def schemas_accessible_by_user(self, database, schemas, hierarchical=True): subset.add(t.schema) return sorted(list(subset)) - def accessible_by_user(self, database, datasource_names, schema=None): + def get_datasources_accessible_by_user( + self, database, datasource_names: List[DatasourceName], + schema: str = None) -> List[DatasourceName]: from superset import db if self.database_access(database) or self.all_datasource_access(): return datasource_names diff --git a/superset/utils/core.py b/superset/utils/core.py index 3b4145793939a..2defa70dd179e 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -32,7 +32,7 @@ import smtplib import sys from time import struct_time -from typing import List, Optional, Tuple +from typing import List, NamedTuple, Optional, Tuple from urllib.parse import unquote_plus import uuid import zlib @@ -1100,3 +1100,8 @@ def MediumText() -> Variant: def shortid() -> str: return '{}'.format(uuid.uuid4())[-12:] + + +class DatasourceName(NamedTuple): + table: str + schema: str diff --git a/superset/views/core.py b/superset/views/core.py index 883a2d95a92ca..0a6ddefa18e04 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -22,7 +22,7 @@ import os import re import traceback -from typing import List # noqa: F401 +from typing import Dict, List # noqa: F401 from urllib import parse from flask import ( @@ -311,7 +311,7 @@ def pre_add(self, db): db.set_sqlalchemy_uri(db.sqlalchemy_uri) security_manager.add_permission_view_menu('database_access', db.perm) # adding a new database we always want to force refresh schema list - for schema in db.all_schema_names(): + for schema in db.get_all_schema_names(): security_manager.add_permission_view_menu( 'schema_access', security_manager.get_schema_perm(db, schema)) @@ -1545,7 +1545,7 @@ def schemas(self, db_id, force_refresh='false'): .first() ) if database: - schemas = database.all_schema_names( + schemas = database.get_all_schema_names( cache=database.schema_cache_enabled, cache_timeout=database.schema_cache_timeout, force=force_refresh) @@ -1570,50 +1570,57 @@ def tables(self, db_id, schema, substr, force_refresh='false'): database = db.session.query(models.Database).filter_by(id=db_id).one() if schema: - table_names = database.all_table_names_in_schema( + tables = database.get_all_table_names_in_schema( schema=schema, force=force_refresh, cache=database.table_cache_enabled, - cache_timeout=database.table_cache_timeout) - view_names = database.all_view_names_in_schema( + cache_timeout=database.table_cache_timeout) or [] + views = database.get_all_view_names_in_schema( schema=schema, force=force_refresh, cache=database.table_cache_enabled, - cache_timeout=database.table_cache_timeout) + cache_timeout=database.table_cache_timeout) or [] else: - table_names = database.all_table_names_in_database( + tables = database.get_all_table_names_in_database( cache=True, force=False, cache_timeout=24 * 60 * 60) - view_names = database.all_view_names_in_database( + views = database.get_all_view_names_in_database( cache=True, force=False, cache_timeout=24 * 60 * 60) - table_names = security_manager.accessible_by_user(database, table_names, schema) - view_names = security_manager.accessible_by_user(database, view_names, schema) + tables = security_manager.get_datasources_accessible_by_user( + database, tables, schema) + views = security_manager.get_datasources_accessible_by_user( + database, views, schema) + + def get_datasource_label(ds_name: utils.DatasourceName) -> str: + return ds_name.table if schema else f'{ds_name.schema}.{ds_name.table}' if substr: - table_names = [tn for tn in table_names if substr in tn] - view_names = [vn for vn in view_names if substr in vn] + tables = [tn for tn in tables if substr in get_datasource_label(tn)] + views = [vn for vn in views if substr in get_datasource_label(vn)] if not schema and database.default_schemas: - def get_schema(tbl_or_view_name): - return tbl_or_view_name.split('.')[0] if '.' in tbl_or_view_name else None - user_schema = g.user.email.split('@')[0] valid_schemas = set(database.default_schemas + [user_schema]) - table_names = [tn for tn in table_names if get_schema(tn) in valid_schemas] - view_names = [vn for vn in view_names if get_schema(vn) in valid_schemas] + tables = [tn for tn in tables if tn.schema in valid_schemas] + views = [vn for vn in views if vn.schema in valid_schemas] - max_items = config.get('MAX_TABLE_NAMES') or len(table_names) - total_items = len(table_names) + len(view_names) - max_tables = len(table_names) - max_views = len(view_names) + max_items = config.get('MAX_TABLE_NAMES') or len(tables) + total_items = len(tables) + len(views) + max_tables = len(tables) + max_views = len(views) if total_items and substr: - max_tables = max_items * len(table_names) // total_items - max_views = max_items * len(view_names) // total_items - - table_options = [{'value': tn, 'label': tn} - for tn in table_names[:max_tables]] - table_options.extend([{'value': vn, 'label': '[view] {}'.format(vn)} - for vn in view_names[:max_views]]) + max_tables = max_items * len(tables) // total_items + max_views = max_items * len(views) // total_items + + def get_datasource_value(ds_name: utils.DatasourceName) -> Dict[str, str]: + return {'schema': ds_name.schema, 'table': ds_name.table} + + table_options = [{'value': get_datasource_value(tn), + 'label': get_datasource_label(tn)} + for tn in tables[:max_tables]] + table_options.extend([{'value': get_datasource_value(vn), + 'label': f'[view] {get_datasource_label(vn)}'} + for vn in views[:max_views]]) payload = { - 'tableLength': len(table_names) + len(view_names), + 'tableLength': len(tables) + len(views), 'options': table_options, } return json_success(json.dumps(payload)) diff --git a/tests/db_engine_specs_test.py b/tests/db_engine_specs_test.py index e0d914f0b51e4..e190014e1e909 100644 --- a/tests/db_engine_specs_test.py +++ b/tests/db_engine_specs_test.py @@ -464,3 +464,22 @@ def test_mssql_where_clause_n_prefix(self): query = str(sel.compile(dialect=dialect, compile_kwargs={'literal_binds': True})) query_expected = "SELECT col, unicode_col \nFROM tbl \nWHERE col = 'abc' AND unicode_col = N'abc'" # noqa self.assertEqual(query, query_expected) + + def test_get_table_names(self): + inspector = mock.Mock() + inspector.get_table_names = mock.Mock(return_value=['schema.table', 'table_2']) + inspector.get_foreign_table_names = mock.Mock(return_value=['table_3']) + + """ Make sure base engine spec removes schema name from table name + ie. when try_remove_schema_from_table_name == True. """ + base_result_expected = ['table', 'table_2'] + base_result = db_engine_specs.BaseEngineSpec.get_table_names( + schema='schema', inspector=inspector) + self.assertListEqual(base_result_expected, base_result) + + """ Make sure postgres doesn't try to remove schema name from table name + ie. when try_remove_schema_from_table_name == False. """ + pg_result_expected = ['schema.table', 'table_2', 'table_3'] + pg_result = db_engine_specs.PostgresEngineSpec.get_table_names( + schema='schema', inspector=inspector) + self.assertListEqual(pg_result_expected, pg_result)