Skip to content

Commit

Permalink
Add support for period character in table names (#7453)
Browse files Browse the repository at this point in the history
* Move schema name handling in table names from frontend to backend

* Rename all_schema_names to get_all_schema_names

* Fix js errors

* Fix additional js linting errors

* Refactor datasource getters and fix linting errors

* Update js unit tests

* Add python unit test for get_table_names method

* Add python unit test for get_table_names method

* Fix js linting error
  • Loading branch information
villebro authored May 26, 2019
1 parent 47ba2ad commit f7d3413
Show file tree
Hide file tree
Showing 11 changed files with 148 additions and 137 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -208,19 +208,20 @@ describe('TableSelector', () => {

it('test 1', () => {
wrapper.instance().changeTable({
value: 'birth_names',
value: { schema: 'main', table: 'birth_names' },
label: 'birth_names',
});
expect(wrapper.state().tableName).toBe('birth_names');
});

it('test 2', () => {
it('should call onTableChange with schema from table object', () => {
wrapper.setProps({ schema: null });
wrapper.instance().changeTable({
value: 'main.my_table',
label: 'my_table',
value: { schema: 'other_schema', table: 'my_table' },
label: 'other_schema.my_table',
});
expect(mockedProps.onTableChange.getCall(0).args[0]).toBe('my_table');
expect(mockedProps.onTableChange.getCall(0).args[1]).toBe('main');
expect(mockedProps.onTableChange.getCall(0).args[1]).toBe('other_schema');
});
});

Expand Down
6 changes: 3 additions & 3 deletions superset/assets/spec/javascripts/sqllab/fixtures.js
Original file line number Diff line number Diff line change
Expand Up @@ -329,15 +329,15 @@ export const databases = {
export const tables = {
options: [
{
value: 'birth_names',
value: { schema: 'main', table: 'birth_names' },
label: 'birth_names',
},
{
value: 'energy_usage',
value: { schema: 'main', table: 'energy_usage' },
label: 'energy_usage',
},
{
value: 'wb_health_population',
value: { schema: 'main', table: 'wb_health_population' },
label: 'wb_health_population',
},
],
Expand Down
15 changes: 4 additions & 11 deletions superset/assets/src/SqlLab/components/SqlEditorLeftBar.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,10 @@ export default class SqlEditorLeftBar extends React.PureComponent {
this.setState({ tableName: '' });
return;
}
const namePieces = tableOpt.value.split('.');
let tableName = namePieces[0];
let schemaName = this.props.queryEditor.schema;
if (namePieces.length === 1) {
this.setState({ tableName });
} else {
schemaName = namePieces[0];
tableName = namePieces[1];
this.setState({ tableName });
this.props.actions.queryEditorSetSchema(this.props.queryEditor, schemaName);
}
const schemaName = tableOpt.value.schema;
const tableName = tableOpt.value.table;
this.setState({ tableName });
this.props.actions.queryEditorSetSchema(this.props.queryEditor, schemaName);
this.props.actions.addTable(this.props.queryEditor, tableName, schemaName);
}

Expand Down
9 changes: 2 additions & 7 deletions superset/assets/src/components/TableSelector.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,8 @@ export default class TableSelector extends React.PureComponent {
this.setState({ tableName: '' });
return;
}
const namePieces = tableOpt.value.split('.');
let tableName = namePieces[0];
let schemaName = this.props.schema;
if (namePieces.length > 1) {
schemaName = namePieces[0];
tableName = namePieces[1];
}
const schemaName = tableOpt.value.schema;
const tableName = tableOpt.value.table;
if (this.props.tableNameSticky) {
this.setState({ tableName }, this.onChange);
}
Expand Down
4 changes: 2 additions & 2 deletions superset/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,9 @@ def update_datasources_cache():
if database.allow_multi_schema_metadata_fetch:
print('Fetching {} datasources ...'.format(database.name))
try:
database.all_table_names_in_database(
database.get_all_table_names_in_database(
force=True, cache=True, cache_timeout=24 * 60 * 60)
database.all_view_names_in_database(
database.get_all_view_names_in_database(
force=True, cache=True, cache_timeout=24 * 60 * 60)
except Exception as e:
print('{}'.format(str(e)))
Expand Down
84 changes: 41 additions & 43 deletions superset/db_engine_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ class BaseEngineSpec(object):
force_column_alias_quotes = False
arraysize = 0
max_column_name_length = 0
try_remove_schema_from_table_name = True

@classmethod
def get_time_expr(cls, expr, pdf, time_grain, grain):
Expand Down Expand Up @@ -279,33 +280,32 @@ def convert_dttm(cls, target_type, dttm):
return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))

@classmethod
def fetch_result_sets(cls, db, datasource_type):
"""Returns a list of tables [schema1.table1, schema2.table2, ...]
def get_all_datasource_names(cls, db, datasource_type: str) \
-> List[utils.DatasourceName]:
"""Returns a list of all tables or views in database.
Datasource_type can be 'table' or 'view'.
Empty schema corresponds to the list of full names of the all
tables or views: <schema>.<result_set_name>.
:param db: Database instance
:param datasource_type: Datasource_type can be 'table' or 'view'
:return: List of all datasources in database or schema
"""
schemas = db.all_schema_names(cache=db.schema_cache_enabled,
cache_timeout=db.schema_cache_timeout,
force=True)
all_result_sets = []
schemas = db.get_all_schema_names(cache=db.schema_cache_enabled,
cache_timeout=db.schema_cache_timeout,
force=True)
all_datasources: List[utils.DatasourceName] = []
for schema in schemas:
if datasource_type == 'table':
all_datasource_names = db.all_table_names_in_schema(
all_datasources += db.get_all_table_names_in_schema(
schema=schema, force=True,
cache=db.table_cache_enabled,
cache_timeout=db.table_cache_timeout)
elif datasource_type == 'view':
all_datasource_names = db.all_view_names_in_schema(
all_datasources += db.get_all_view_names_in_schema(
schema=schema, force=True,
cache=db.table_cache_enabled,
cache_timeout=db.table_cache_timeout)
else:
raise Exception(f'Unsupported datasource_type: {datasource_type}')
all_result_sets += [
'{}.{}'.format(schema, t) for t in all_datasource_names]
return all_result_sets
return all_datasources

@classmethod
def handle_cursor(cls, cursor, query, session):
Expand Down Expand Up @@ -352,11 +352,17 @@ def get_schema_names(cls, inspector):

@classmethod
def get_table_names(cls, inspector, schema):
return sorted(inspector.get_table_names(schema))
tables = inspector.get_table_names(schema)
if schema and cls.try_remove_schema_from_table_name:
tables = [re.sub(f'^{schema}\\.', '', table) for table in tables]
return sorted(tables)

@classmethod
def get_view_names(cls, inspector, schema):
return sorted(inspector.get_view_names(schema))
views = inspector.get_view_names(schema)
if schema and cls.try_remove_schema_from_table_name:
views = [re.sub(f'^{schema}\\.', '', view) for view in views]
return sorted(views)

@classmethod
def get_columns(cls, inspector: Inspector, table_name: str, schema: str) -> list:
Expand Down Expand Up @@ -528,6 +534,7 @@ def convert_dttm(cls, target_type, dttm):
class PostgresEngineSpec(PostgresBaseEngineSpec):
engine = 'postgresql'
max_column_name_length = 63
try_remove_schema_from_table_name = False

@classmethod
def get_table_names(cls, inspector, schema):
Expand Down Expand Up @@ -685,29 +692,25 @@ def epoch_to_dttm(cls):
return "datetime({col}, 'unixepoch')"

@classmethod
def fetch_result_sets(cls, db, datasource_type):
schemas = db.all_schema_names(cache=db.schema_cache_enabled,
cache_timeout=db.schema_cache_timeout,
force=True)
all_result_sets = []
def get_all_datasource_names(cls, db, datasource_type: str) \
-> List[utils.DatasourceName]:
schemas = db.get_all_schema_names(cache=db.schema_cache_enabled,
cache_timeout=db.schema_cache_timeout,
force=True)
schema = schemas[0]
if datasource_type == 'table':
all_datasource_names = db.all_table_names_in_schema(
return db.get_all_table_names_in_schema(
schema=schema, force=True,
cache=db.table_cache_enabled,
cache_timeout=db.table_cache_timeout)
elif datasource_type == 'view':
all_datasource_names = db.all_view_names_in_schema(
return db.get_all_view_names_in_schema(
schema=schema, force=True,
cache=db.table_cache_enabled,
cache_timeout=db.table_cache_timeout)
else:
raise Exception(f'Unsupported datasource_type: {datasource_type}')

all_result_sets += [
'{}.{}'.format(schema, t) for t in all_datasource_names]
return all_result_sets

@classmethod
def convert_dttm(cls, target_type, dttm):
iso = dttm.isoformat().replace('T', ' ')
Expand Down Expand Up @@ -1107,24 +1110,19 @@ def epoch_to_dttm(cls):
return 'from_unixtime({col})'

@classmethod
def fetch_result_sets(cls, db, datasource_type):
"""Returns a list of tables [schema1.table1, schema2.table2, ...]
Datasource_type can be 'table' or 'view'.
Empty schema corresponds to the list of full names of the all
tables or views: <schema>.<result_set_name>.
"""
result_set_df = db.get_df(
def get_all_datasource_names(cls, db, datasource_type: str) \
-> List[utils.DatasourceName]:
datasource_df = db.get_df(
"""SELECT table_schema, table_name FROM INFORMATION_SCHEMA.{}S
ORDER BY concat(table_schema, '.', table_name)""".format(
datasource_type.upper(),
),
None)
result_sets = []
for unused, row in result_set_df.iterrows():
result_sets.append('{}.{}'.format(
row['table_schema'], row['table_name']))
return result_sets
datasource_names: List[utils.DatasourceName] = []
for unused, row in datasource_df.iterrows():
datasource_names.append(utils.DatasourceName(
schema=row['table_schema'], table=row['table_name']))
return datasource_names

@classmethod
def extra_table_metadata(cls, database, table_name, schema_name):
Expand Down Expand Up @@ -1385,9 +1383,9 @@ def patch(cls):
hive.Cursor.fetch_logs = patched_hive.fetch_logs

@classmethod
def fetch_result_sets(cls, db, datasource_type):
return BaseEngineSpec.fetch_result_sets(
db, datasource_type)
def get_all_datasource_names(cls, db, datasource_type: str) \
-> List[utils.DatasourceName]:
return BaseEngineSpec.get_all_datasource_names(db, datasource_type)

@classmethod
def fetch_data(cls, cursor, limit):
Expand Down
Loading

0 comments on commit f7d3413

Please sign in to comment.