diff --git a/superset/__init__.py b/superset/__init__.py
index 8ab8ded5e80db..b9cc6b0041e69 100644
--- a/superset/__init__.py
+++ b/superset/__init__.py
@@ -29,6 +29,7 @@
# In production mode, add log handler to sys.stderr.
app.logger.addHandler(logging.StreamHandler())
app.logger.setLevel(logging.INFO)
+logging.getLogger('pyhive.presto').setLevel(logging.INFO)
db = SQLA(app)
@@ -36,6 +37,8 @@
utils.pessimistic_connection_handling(db.engine.pool)
cache = Cache(app, config=app.config.get('CACHE_CONFIG'))
+tables_cache = Cache(app, config=app.config.get('TABLE_NAMES_CACHE_CONFIG'))
+
migrate = Migrate(app, db, directory=APP_DIR + "/migrations")
diff --git a/superset/assets/javascripts/SqlLab/actions.js b/superset/assets/javascripts/SqlLab/actions.js
index 0e72c209a3944..d7d20f49d3e1c 100644
--- a/superset/assets/javascripts/SqlLab/actions.js
+++ b/superset/assets/javascripts/SqlLab/actions.js
@@ -213,9 +213,9 @@ export function mergeTable(table, query) {
return { type: MERGE_TABLE, table, query };
}
-export function addTable(query, tableName) {
+export function addTable(query, tableName, schemaName) {
return function (dispatch) {
- let url = `/superset/table/${query.dbId}/${tableName}/${query.schema}/`;
+ let url = `/superset/table/${query.dbId}/${tableName}/${schemaName}/`;
$.get(url, (data) => {
const dataPreviewQuery = {
id: shortid.generate(),
@@ -232,7 +232,7 @@ export function addTable(query, tableName) {
Object.assign(data, {
dbId: query.dbId,
queryEditorId: query.id,
- schema: query.schema,
+ schema: schemaName,
expanded: true,
}), dataPreviewQuery)
);
@@ -248,12 +248,12 @@ export function addTable(query, tableName) {
);
});
- url = `/superset/extra_table_metadata/${query.dbId}/${tableName}/${query.schema}/`;
+ url = `/superset/extra_table_metadata/${query.dbId}/${tableName}/${schemaName}/`;
$.get(url, (data) => {
const table = {
dbId: query.dbId,
queryEditorId: query.id,
- schema: query.schema,
+ schema: schemaName,
name: tableName,
};
Object.assign(table, data);
diff --git a/superset/assets/javascripts/SqlLab/components/SqlEditorLeftBar.jsx b/superset/assets/javascripts/SqlLab/components/SqlEditorLeftBar.jsx
index 21d5f2bfceecc..c36d659ca6744 100644
--- a/superset/assets/javascripts/SqlLab/components/SqlEditorLeftBar.jsx
+++ b/superset/assets/javascripts/SqlLab/components/SqlEditorLeftBar.jsx
@@ -30,8 +30,8 @@ class SqlEditorLeftBar extends React.PureComponent {
};
}
componentWillMount() {
- this.fetchSchemas();
- this.fetchTables();
+ this.fetchSchemas(this.props.queryEditor.dbId);
+ this.fetchTables(this.props.queryEditor.dbId, this.props.queryEditor.schema);
}
onChange(db) {
const val = (db) ? db.value : null;
@@ -58,22 +58,51 @@ class SqlEditorLeftBar extends React.PureComponent {
resetState() {
this.props.actions.resetState();
}
- fetchTables(dbId, schema) {
- const actualDbId = dbId || this.props.queryEditor.dbId;
- if (actualDbId) {
- const actualSchema = schema || this.props.queryEditor.schema;
- this.setState({ tableLoading: true });
- this.setState({ tableOptions: [] });
- const url = `/superset/tables/${actualDbId}/${actualSchema}`;
+ getTableNamesBySubStr(input) {
+ if (!this.props.queryEditor.dbId || !input) {
+ return Promise.resolve({ options: [] });
+ }
+ const url = `/superset/tables/${this.props.queryEditor.dbId}/\
+${this.props.queryEditor.schema}/${input}`;
+ return $.get(url).then((data) => ({ options: data.options }));
+ }
+ // TODO: move fetching methods to the actions.
+ fetchTables(dbId, schema, substr) {
+ if (dbId) {
+ this.setState({ tableLoading: true, tableOptions: [] });
+ const url = `/superset/tables/${dbId}/${schema}/${substr}/`;
$.get(url, (data) => {
- let tableOptions = data.tables.map((s) => ({ value: s, label: s }));
- const views = data.views.map((s) => ({ value: s, label: '[view] ' + s }));
- tableOptions = [...tableOptions, ...views];
- this.setState({ tableOptions });
- this.setState({ tableLoading: false });
+ this.setState({
+ tableLoading: false,
+ tableOptions: data.options,
+ tableLength: data.tableLength,
+ });
});
}
}
+ changeTable(tableOpt) {
+ if (!tableOpt) {
+ this.setState({ tableName: '' });
+ return;
+ }
+ const namePieces = tableOpt.value.split('.');
+ let tableName = namePieces[0];
+ let schemaName = this.props.queryEditor.schema;
+ if (namePieces.length === 1) {
+ this.setState({ tableName });
+ } else {
+ schemaName = namePieces[0];
+ tableName = namePieces[1];
+ this.setState({ tableName });
+ this.props.actions.queryEditorSetSchema(this.props.queryEditor, schemaName);
+ this.fetchTables(this.props.queryEditor.dbId, schemaName);
+ }
+ this.setState({ tableLoading: true });
+ // TODO: handle setting the tableLoading state depending on success or
+ // failure of the addTable async call in the action.
+ this.props.actions.addTable(this.props.queryEditor, tableName, schemaName);
+ this.setState({ tableLoading: false });
+ }
changeSchema(schemaOpt) {
const schema = (schemaOpt) ? schemaOpt.value : null;
this.props.actions.queryEditorSetSchema(this.props.queryEditor, schema);
@@ -95,14 +124,6 @@ class SqlEditorLeftBar extends React.PureComponent {
closePopover(ref) {
this.refs[ref].hide();
}
- changeTable(tableOpt) {
- const tableName = tableOpt.value;
- const qe = this.props.queryEditor;
-
- this.setState({ tableLoading: true });
- this.props.actions.addTable(qe, tableName);
- this.setState({ tableLoading: false });
- }
render() {
let networkAlert = null;
if (!this.props.networkOn) {
@@ -118,6 +139,8 @@ class SqlEditorLeftBar extends React.PureComponent {
dataEndpoint="/databaseasync/api/read?_flt_0_expose_in_sqllab=1"
onChange={this.onChange.bind(this)}
value={this.props.queryEditor.dbId}
+ databaseId={this.props.queryEditor.dbId}
+ actions={this.props.actions}
valueRenderer={(o) => (
Database: {o.label}
@@ -126,8 +149,6 @@ class SqlEditorLeftBar extends React.PureComponent {
mutator={this.dbMutator.bind(this)}
placeholder="Select a database"
/>
-
-
-
+ {this.props.queryEditor.schema &&
+
+ }
+ {!this.props.queryEditor.schema &&
+
+ }
diff --git a/superset/cache_util.py b/superset/cache_util.py
new file mode 100644
index 0000000000000..ef8835c55cce5
--- /dev/null
+++ b/superset/cache_util.py
@@ -0,0 +1,27 @@
+from superset import tables_cache
+from flask import request
+
+
+def view_cache_key(*unused_args, **unused_kwargs):
+ args_hash = hash(frozenset(request.args.items()))
+ return 'view/{}/{}'.format(request.path, args_hash)
+
+
+def memoized_func(timeout=5 * 60, key=view_cache_key):
+ """Use this decorator to cache functions that have predefined first arg.
+
+ memoized_func uses simple_cache and stored the data in memory.
+ Key is a callable function that takes function arguments and
+ returns the caching key.
+ """
+ def wrap(f):
+ def wrapped_f(cls, *args, **kwargs):
+ cache_key = key(*args, **kwargs)
+ o = tables_cache.get(cache_key)
+ if o is not None:
+ return o
+ o = f(cls, *args, **kwargs)
+ tables_cache.set(cache_key, o, timeout=timeout)
+ return o
+ return wrapped_f
+ return wrap
diff --git a/superset/config.py b/superset/config.py
index fa87ff5739c6d..78064bfc1eb8c 100644
--- a/superset/config.py
+++ b/superset/config.py
@@ -153,6 +153,7 @@
CACHE_DEFAULT_TIMEOUT = None
CACHE_CONFIG = {'CACHE_TYPE': 'null'}
+TABLE_NAMES_CACHE_CONFIG = {'CACHE_TYPE': 'null'}
# CORS Options
ENABLE_CORS = False
@@ -209,6 +210,9 @@
SQL_MAX_ROW = 1000000
DISPLAY_SQL_MAX_ROW = 1000
+# Maximum number of tables/views displayed in the dropdown window in SQL Lab.
+MAX_TABLE_NAMES = 3000
+
# If defined, shows this text in an alert-warning box in the navbar
# one example use case may be "STAGING" to make it clear that this is
# not the production version of the site.
diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py
index 8bcb1a304931b..60d43bc8a7820 100644
--- a/superset/db_engine_specs.py
+++ b/superset/db_engine_specs.py
@@ -16,12 +16,13 @@
from __future__ import print_function
from __future__ import unicode_literals
-from collections import namedtuple
+from collections import namedtuple, defaultdict
+from flask_babel import lazy_gettext as _
import inspect
import textwrap
import time
-from flask_babel import lazy_gettext as _
+from superset import cache_util
Grain = namedtuple('Grain', 'name label function')
@@ -54,6 +55,33 @@ def extra_table_metadata(cls, database, table_name, schema_name):
def convert_dttm(cls, target_type, dttm):
return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
+ @classmethod
+ @cache_util.memoized_func(
+ timeout=600,
+ key=lambda *args, **kwargs: 'db:{}:{}'.format(args[0].id, args[1]))
+ def fetch_result_sets(cls, db, datasource_type):
+ """Returns the dictionary {schema : [result_set_name]}.
+
+ Datasource_type can be 'table' or 'view'.
+ Empty schema corresponds to the list of full names of the all
+ tables or views:
..
+ """
+ schemas = db.inspector.get_schema_names()
+ result_sets = {}
+ all_result_sets = []
+ for schema in schemas:
+ if datasource_type == 'table':
+ result_sets[schema] = sorted(
+ db.inspector.get_table_names(schema))
+ elif datasource_type == 'view':
+ result_sets[schema] = sorted(
+ db.inspector.get_view_names(schema))
+ all_result_sets += [
+ '{}.{}'.format(schema, t) for t in result_sets[schema]]
+ if all_result_sets:
+ result_sets[""] = all_result_sets
+ return result_sets
+
@classmethod
def handle_cursor(cls, cursor, query, session):
"""Handle a live cursor between the execute and fetchall calls
@@ -221,6 +249,28 @@ def show_partition_pql(
{limit_clause}
""").format(**locals())
+ @classmethod
+ @cache_util.memoized_func(
+ timeout=600,
+ key=lambda *args, **kwargs: 'db:{}:{}'.format(args[0].id, args[1]))
+ def fetch_result_sets(cls, db, datasource_type):
+ """Returns the dictionary {schema : [result_set_name]}.
+
+ Datasource_type can be 'table' or 'view'.
+ Empty schema corresponds to the list of full names of the all
+ tables or views: ..
+ """
+ result_set_df = db.get_df(
+ """SELECT table_schema, table_name FROM INFORMATION_SCHEMA.{}S
+ ORDER BY concat(table_schema, '.', table_name)""".format(
+ datasource_type.upper()), None)
+ result_sets = defaultdict(list)
+ for unused, row in result_set_df.iterrows():
+ result_sets[row['table_schema']].append(row['table_name'])
+ result_sets[""].append('{}.{}'.format(
+ row['table_schema'], row['table_name']))
+ return result_sets
+
@classmethod
def extra_table_metadata(cls, database, table_name, schema_name):
indexes = database.get_indexes(table_name, schema_name)
diff --git a/superset/models.py b/superset/models.py
index ae0ba525c7844..99af95499dc8f 100644
--- a/superset/models.py
+++ b/superset/models.py
@@ -845,13 +845,19 @@ def inspector(self):
return sqla.inspect(engine)
def all_table_names(self, schema=None):
+ if not schema:
+ tables_dict = self.db_engine_spec.fetch_result_sets(self, 'table')
+ return tables_dict.get("", [])
return sorted(self.inspector.get_table_names(schema))
def all_view_names(self, schema=None):
+ if not schema:
+ views_dict = self.db_engine_spec.fetch_result_sets(self, 'view')
+ return views_dict.get("", [])
views = []
try:
views = self.inspector.get_view_names(schema)
- except Exception as e:
+ except Exception:
pass
return views
diff --git a/superset/source_registry.py b/superset/source_registry.py
index ff64265d4788c..df91762564bbd 100644
--- a/superset/source_registry.py
+++ b/superset/source_registry.py
@@ -44,25 +44,14 @@ def get_datasource_by_name(cls, session, datasource_type, datasource_name,
return db_ds[0]
@classmethod
- def query_datasources_by_name(
- cls, session, database, datasource_name, schema=None):
+ def query_datasources_by_permissions(cls, session, database, permissions):
datasource_class = SourceRegistry.sources[database.type]
- if database.type == 'table':
- query = (
- session.query(datasource_class)
- .filter_by(database_id=database.id)
- .filter_by(table_name=datasource_name))
- if schema:
- query = query.filter_by(schema=schema)
- return query.all()
- if database.type == 'druid':
- return (
- session.query(datasource_class)
- .filter_by(cluster_name=database.id)
- .filter_by(datasource_name=datasource_name)
- .all()
- )
- return None
+ return (
+ session.query(datasource_class)
+ .filter_by(database_id=database.id)
+ .filter(datasource_class.perm.in_(permissions))
+ .all()
+ )
@classmethod
def get_eager_datasource(cls, session, datasource_type, datasource_id):
diff --git a/superset/utils.py b/superset/utils.py
index d6f0e484210e6..7e6da2a1ca8c6 100644
--- a/superset/utils.py
+++ b/superset/utils.py
@@ -8,7 +8,6 @@
import functools
import json
import logging
-import markdown as md
import numpy
import os
import parsedatetime
@@ -33,6 +32,7 @@
)
from flask_appbuilder._compat import as_unicode
from flask_babel import gettext as __
+import markdown as md
from past.builtins import basestring
from pydruid.utils.having import Having
from sqlalchemy import event, exc
@@ -122,6 +122,10 @@ def __get__(self, obj, objtype):
return functools.partial(self.__call__, obj)
+def js_string_to_python(item):
+ return None if item in ('null', 'undefined') else item
+
+
class DimSelector(Having):
def __init__(self, **args):
# Just a hack to prevent any exceptions
diff --git a/superset/views.py b/superset/views.py
index c9b4aea096e38..ce98b967198b6 100755
--- a/superset/views.py
+++ b/superset/views.py
@@ -36,8 +36,8 @@
import superset
from superset import (
- app, appbuilder, cache, db, models, sm, sql_lab, sql_parse,
- results_backend, security, viz, utils,
+ appbuilder, cache, db, models, viz, utils, app,
+ sm, sql_lab, sql_parse, results_backend, security,
)
from superset.utils import has_access
from superset.source_registry import SourceRegistry
@@ -83,12 +83,12 @@ def datasource_access(self, datasource, user=None):
def datasource_access_by_name(
self, database, datasource_name, schema=None):
- if (self.database_access(database) or
- self.all_datasource_access()):
+ if self.database_access(database) or self.all_datasource_access():
return True
schema_perm = utils.get_schema_perm(database, schema)
- if schema and utils.can_access(sm, 'schema_access', schema_perm, g.user):
+ if schema and utils.can_access(
+ sm, 'schema_access', schema_perm, g.user):
return True
datasources = SourceRegistry.query_datasources_by_name(
@@ -116,6 +116,31 @@ def rejected_datasources(self, sql, database, schema):
t for t in superset_query.tables if not
self.datasource_access_by_fullname(database, t, schema)]
+ def accessible_by_user(self, database, datasource_names, schema=None):
+ if self.database_access(database) or self.all_datasource_access():
+ return datasource_names
+
+ schema_perm = utils.get_schema_perm(database, schema)
+ if schema and utils.can_access(
+ sm, 'schema_access', schema_perm, g.user):
+ return datasource_names
+
+ role_ids = set([role.id for role in g.user.roles])
+ # TODO: cache user_perms or user_datasources
+ user_pvms = (
+ db.session.query(ab_models.PermissionView)
+ .join(ab_models.Permission)
+ .filter(ab_models.Permission.name == 'datasource_access')
+ .filter(ab_models.PermissionView.role.any(
+ ab_models.Role.id.in_(role_ids)))
+ .all()
+ )
+ user_perms = set([pvm.view_menu.name for pvm in user_pvms])
+ user_datasources = SourceRegistry.query_datasources_by_permissions(
+ db.session, database, user_perms)
+ full_names = set([d.full_name for d in user_datasources])
+ return [d for d in datasource_names if d in full_names]
+
class ListWidgetWithCheckboxes(ListWidget):
"""An alternative to list view that renders Boolean fields as checkboxes
@@ -165,6 +190,11 @@ def json_error_response(msg, status=None):
json.dumps(data), status=status, mimetype="application/json")
+def json_success(json_msg, status=None):
+ status = status if status else 200
+ return Response(json_msg, status=status, mimetype="application/json")
+
+
def api(f):
"""
A decorator to label an endpoint as an API. Catches uncaught exceptions and
@@ -175,13 +205,7 @@ def wraps(self, *args, **kwargs):
return f(self, *args, **kwargs)
except Exception as e:
logging.exception(e)
- resp = Response(
- json.dumps({
- 'message': get_error_msg()
- }),
- status=500,
- mimetype="application/json")
- return resp
+ return json_error_response(get_error_msg())
return functools.update_wrapper(wraps, f)
@@ -1459,28 +1483,18 @@ def explore_json(self, datasource_type, datasource_id):
return json_error_response(utils.error_msg_from_exception(e))
if not self.datasource_access(viz_obj.datasource):
- return Response(
- json.dumps(
- {'error': DATASOURCE_ACCESS_ERR}),
- status=404,
- mimetype="application/json")
+ return json_error_response(DATASOURCE_ACCESS_ERR, status=404)
payload = {}
- status = 200
try:
payload = viz_obj.get_payload()
except Exception as e:
logging.exception(e)
- status = 500
return json_error_response(utils.error_msg_from_exception(e))
-
if payload.get('status') == QueryStatus.FAILED:
- status = 500
+ return json_error_response(viz_obj.json_dumps(payload))
- return Response(
- viz_obj.json_dumps(payload),
- status=status,
- mimetype="application/json")
+ return json_success(viz_obj.json_dumps(payload))
@expose("/import_dashboards", methods=['GET', 'POST'])
@log_this
@@ -1645,12 +1659,7 @@ def filter(self, datasource_type, datasource_id, column):
except Exception as e:
flash(str(e), "danger")
return redirect(error_redirect)
- status = 200
- payload = obj.get_values_for_column(column)
- return Response(
- payload,
- status=status,
- mimetype="application/json")
+ return json_success(obj.get_values_for_column(column))
def save_or_overwrite_slice(
self, args, slc, slice_add_perm, slice_edit_perm):
@@ -1761,7 +1770,7 @@ def checkbox(self, model_view, id_, attr, value):
if obj:
setattr(obj, attr, value == 'true')
db.session.commit()
- return Response("OK", mimetype="application/json")
+ return json_success("OK")
@api
@has_access_api
@@ -1779,52 +1788,56 @@ def activity_per_day(self):
)
payload = {str(time.mktime(dt.timetuple())):
ccount for dt, ccount in qry if dt}
- return Response(json.dumps(payload), mimetype="application/json")
+ return json_success(json.dumps(payload))
@api
@has_access_api
- @expose("/all_tables/")
- def all_tables(self, db_id):
- """Endpoint that returns all tables and views from the database"""
+ @expose("/schemas/")
+ def schemas(self, db_id):
database = (
db.session
.query(models.Database)
.filter_by(id=db_id)
.one()
)
- all_tables = []
- all_views = []
- schemas = database.all_schema_names()
- for schema in schemas:
- all_tables.extend(database.all_table_names(schema=schema))
- all_views.extend(database.all_view_names(schema=schema))
- if not schemas:
- all_tables.extend(database.all_table_names())
- all_views.extend(database.all_view_names())
-
return Response(
- json.dumps({"tables": all_tables, "views": all_views}),
+ json.dumps({'schemas': database.all_schema_names()}),
mimetype="application/json")
@api
@has_access_api
- @expose("/tables//")
- def tables(self, db_id, schema):
+ @expose("/tables////")
+ def tables(self, db_id, schema, substr):
"""endpoint to power the calendar heatmap on the welcome page"""
- schema = None if schema in ('null', 'undefined') else schema
- database = (
- db.session
- .query(models.Database)
- .filter_by(id=db_id)
- .one()
- )
- tables = [t for t in database.all_table_names(schema) if
- self.datasource_access_by_name(database, t, schema=schema)]
- views = [v for v in database.all_view_names(schema) if
- self.datasource_access_by_name(database, v, schema=schema)]
- payload = {'tables': tables, 'views': views}
- return Response(
- json.dumps(payload), mimetype="application/json")
+ schema = utils.js_string_to_python(schema)
+ substr = utils.js_string_to_python(substr)
+ database = db.session.query(models.Database).filter_by(id=db_id).one()
+ table_names = self.accessible_by_user(
+ database, database.all_table_names(schema), schema)
+ view_names = self.accessible_by_user(
+ database, database.all_view_names(schema), schema)
+
+ if substr:
+ table_names = [tn for tn in table_names if substr in tn]
+ view_names = [vn for vn in view_names if substr in vn]
+
+ max_items = config.get('MAX_TABLE_NAMES') or len(table_names)
+ total_items = len(table_names) + len(view_names)
+ max_tables = len(table_names)
+ max_views = len(view_names)
+ if total_items and substr:
+ max_tables = max_items * len(table_names) // total_items
+ max_views = max_items * len(view_names) // total_items
+
+ table_options = [{'value': tn, 'label': tn}
+ for tn in table_names[:max_tables]]
+ table_options.extend([{'value': vn, 'label': '[view] {}'.format(vn)}
+ for vn in view_names[:max_views]])
+ payload = {
+ 'tableLength': len(table_names) + len(view_names),
+ 'options': table_options,
+ }
+ return json_success(json.dumps(payload))
@api
@has_access_api
@@ -1848,8 +1861,7 @@ def copy_dash(self, dashboard_id):
session.commit()
dash_json = dash.json_data
session.close()
- return Response(
- dash_json, mimetype="application/json")
+ return json_success(dash_json)
@api
@has_access_api
@@ -1932,11 +1944,9 @@ def testconn(self):
engine.connect()
return json.dumps(engine.table_names(), indent=4)
except Exception as e:
- return Response((
+ return json_error_response((
"Connection failed!\n\n"
- "The error message returned was:\n{}").format(e),
- status=500,
- mimetype="application/json")
+ "The error message returned was:\n{}").format(e))
@api
@has_access_api
@@ -1980,9 +1990,8 @@ def recent_activity(self, user_id):
'item_title': item_title,
'time': log.Log.dttm,
})
- return Response(
- json.dumps(payload, default=utils.json_int_dttm_ser),
- mimetype="application/json")
+ return json_success(
+ json.dumps(payload, default=utils.json_int_dttm_ser))
@api
@has_access_api
@@ -2020,9 +2029,8 @@ def fave_dashboards(self, user_id):
d['creator_url'] = '/superset/profile/{}/'.format(
user.username)
payload.append(d)
- return Response(
- json.dumps(payload, default=utils.json_int_dttm_ser),
- mimetype="application/json")
+ return json_success(
+ json.dumps(payload, default=utils.json_int_dttm_ser))
@api
@has_access_api
@@ -2050,9 +2058,8 @@ def created_dashboards(self, user_id):
'url': o.url,
'dttm': o.changed_on,
} for o in qry.all()]
- return Response(
- json.dumps(payload, default=utils.json_int_dttm_ser),
- mimetype="application/json")
+ return json_success(
+ json.dumps(payload, default=utils.json_int_dttm_ser))
@api
@has_access_api
@@ -2076,9 +2083,8 @@ def created_slices(self, user_id):
'url': o.slice_url,
'dttm': o.changed_on,
} for o in qry.all()]
- return Response(
- json.dumps(payload, default=utils.json_int_dttm_ser),
- mimetype="application/json")
+ return json_success(
+ json.dumps(payload, default=utils.json_int_dttm_ser))
@api
@has_access_api
@@ -2116,9 +2122,8 @@ def fave_slices(self, user_id):
d['creator_url'] = '/superset/profile/{}/'.format(
user.username)
payload.append(d)
- return Response(
- json.dumps(payload, default=utils.json_int_dttm_ser),
- mimetype="application/json")
+ return json_success(
+ json.dumps(payload, default=utils.json_int_dttm_ser))
@api
@has_access_api
@@ -2162,12 +2167,9 @@ def warm_up_cache(self):
obj.get_json(force=True)
except Exception as e:
return json_error_response(utils.error_msg_from_exception(e))
- return Response(
- json.dumps(
- [{"slice_id": session.id, "slice_name": session.slice_name}
- for session in slices]),
- status=200,
- mimetype="application/json")
+ return json_success(json.dumps(
+ [{"slice_id": session.id, "slice_name": session.slice_name}
+ for session in slices]))
@expose("/favstar////")
def favstar(self, class_name, obj_id, action):
@@ -2195,9 +2197,7 @@ def favstar(self, class_name, obj_id, action):
else:
count = len(favs)
session.commit()
- return Response(
- json.dumps({'count': count}),
- mimetype="application/json")
+ return json_success(json.dumps({'count': count}))
@has_access
@expose("/dashboard//")
@@ -2362,7 +2362,7 @@ def sqllab_viz(self):
@expose("/table////")
@log_this
def table(self, database_id, table_name, schema):
- schema = None if schema in ('null', 'undefined') else schema
+ schema = utils.js_string_to_python(schema)
mydb = db.session.query(models.Database).filter_by(id=database_id).one()
cols = []
indexes = []
@@ -2373,9 +2373,7 @@ def table(self, database_id, table_name, schema):
primary_key = mydb.get_pk_constraint(table_name, schema)
foreign_keys = mydb.get_foreign_keys(table_name, schema)
except Exception as e:
- return Response(
- json.dumps({'error': utils.error_msg_from_exception(e)}),
- mimetype="application/json")
+ return json_error_response(utils.error_msg_from_exception(e))
keys = []
if primary_key and primary_key.get('constrained_columns'):
primary_key['column_names'] = primary_key.pop('constrained_columns')
@@ -2413,17 +2411,17 @@ def table(self, database_id, table_name, schema):
'foreignKeys': foreign_keys,
'indexes': keys,
}
- return Response(json.dumps(tbl), mimetype="application/json")
+ return json_success(json.dumps(tbl))
@has_access
@expose("/extra_table_metadata////")
@log_this
def extra_table_metadata(self, database_id, table_name, schema):
- schema = None if schema in ('null', 'undefined') else schema
+ schema = utils.js_string_to_python(schema)
mydb = db.session.query(models.Database).filter_by(id=database_id).one()
payload = mydb.db_engine_spec.extra_table_metadata(
mydb, table_name, schema)
- return Response(json.dumps(payload), mimetype="application/json")
+ return json_success(json.dumps(payload))
@has_access
@expose("/select_star///")
@@ -2470,35 +2468,27 @@ def results(self, key):
return json_error_response("Results backend isn't configured")
blob = results_backend.get(key)
- if blob:
- query = (
- db.session.query(models.Query)
- .filter_by(results_key=key)
- .one()
+ if not blob:
+ return json_error_response(
+ "Data could not be retrieved. "
+ "You may want to re-run the query.",
+ status=410
)
- rejected_tables = self.rejected_datasources(
- query.sql, query.database, query.schema)
- if rejected_tables:
- return json_error_response(get_datasource_access_error_msg(
- '{}'.format(rejected_tables)))
- payload = zlib.decompress(blob)
- display_limit = app.config.get('DISPLAY_SQL_MAX_ROW', None)
- if display_limit:
- payload_json = json.loads(payload)
- payload_json['data'] = payload_json['data'][:display_limit]
- return Response(
- json.dumps(payload_json, default=utils.json_iso_dttm_ser),
- status=200, mimetype="application/json")
- else:
- return Response(
- json.dumps({
- 'error': (
- "Data could not be retrieved. You may want to "
- "re-run the query."
- )
- }),
- status=410,
- mimetype="application/json")
+
+ query = db.session.query(models.Query).filter_by(results_key=key).one()
+ rejected_tables = self.rejected_datasources(
+ query.sql, query.database, query.schema)
+ if rejected_tables:
+ return json_error_response(get_datasource_access_error_msg(
+ '{}'.format(rejected_tables)))
+
+ payload = zlib.decompress(blob)
+ display_limit = app.config.get('DISPLAY_SQL_MAX_ROW', None)
+ if display_limit:
+ payload_json = json.loads(payload)
+ payload_json['data'] = payload_json['data'][:display_limit]
+ return json_success(
+ json.dumps(payload_json, default=utils.json_iso_dttm_ser))
@has_access_api
@expose("/sql_json/", methods=['POST', 'GET'])
@@ -2555,12 +2545,9 @@ def sql_json(self):
sql_lab.get_sql_results.delay(
query_id, return_results=False,
store_results=not query.select_as_cta)
- return Response(
- json.dumps({'query': query.to_dict()},
- default=utils.json_int_dttm_ser,
- allow_nan=False),
- status=202, # Accepted
- mimetype="application/json")
+ return json_success(json.dumps(
+ {'query': query.to_dict()}, default=utils.json_int_dttm_ser,
+ allow_nan=False), status=202)
# Sync request.
try:
@@ -2575,14 +2562,8 @@ def sql_json(self):
data = sql_lab.get_sql_results(query_id, return_results=True)
except Exception as e:
logging.exception(e)
- return Response(
- json.dumps({'error': "{}".format(e)}),
- status=500,
- mimetype="application/json")
- return Response(
- data,
- status=200,
- mimetype="application/json")
+ return json_error_response("{}".format(e))
+ return json_success(data)
@has_access
@expose("/csv/")
@@ -2637,21 +2618,15 @@ def fetch_datasource_metadata(self):
# Check permission for datasource
if not self.datasource_access(datasource):
return json_error_response(DATASOURCE_ACCESS_ERR)
-
- return Response(
- json.dumps(datasource.data),
- mimetype="application/json"
- )
+ return json_success(json.dumps(datasource.data))
@has_access
@expose("/queries/")
def queries(self, last_updated_ms):
"""Get the updated queries."""
if not g.user.get_id():
- return Response(
- json.dumps({'error': "Please login to access the queries."}),
- status=403,
- mimetype="application/json")
+ return json_error_response(
+ "Please login to access the queries.", status=403)
# Unix time, milliseconds.
last_updated_ms_int = int(float(last_updated_ms)) if last_updated_ms else 0
@@ -2668,10 +2643,8 @@ def queries(self, last_updated_ms):
.all()
)
dict_queries = {q.client_id: q.to_dict() for q in sql_queries}
- return Response(
- json.dumps(dict_queries, default=utils.json_int_dttm_ser),
- status=200,
- mimetype="application/json")
+ return json_success(
+ json.dumps(dict_queries, default=utils.json_int_dttm_ser))
@has_access
@expose("/search_queries")
diff --git a/tests/core_tests.py b/tests/core_tests.py
index 61b5e4c1f64be..7df53c567ecdc 100644
--- a/tests/core_tests.py
+++ b/tests/core_tests.py
@@ -556,14 +556,6 @@ def test_fetch_datasource_metadata(self):
for k in keys:
self.assertIn(k, resp.keys())
- def test_fetch_all_tables(self):
- self.login(username='admin')
- database = self.get_main_database(db.session)
- url = '/superset/all_tables/{}'.format(database.id)
- resp = json.loads(self.get_resp(url))
- self.assertIn('tables', resp)
- self.assertIn('views', resp)
-
def test_user_profile(self):
self.login(username='admin')
slc = self.get_slice("Girls", db.session)