diff --git a/bogdan.todo b/bogdan.todo
new file mode 100644
index 0000000000000..552a259279fad
--- /dev/null
+++ b/bogdan.todo
@@ -0,0 +1,4 @@
+1. [] implement the polling of the query results
+2. [] implement the retrieving of the CTA results
+3. [] implement parsing of the query to retrieve the table names
+
diff --git a/caravel/config.py b/caravel/config.py
index f922b9db02aaa..79c87d0a9215a 100644
--- a/caravel/config.py
+++ b/caravel/config.py
@@ -179,22 +179,7 @@
# Set this API key to enable Mapbox visualizations
MAPBOX_API_KEY = ""
-# Maximum number of rows returned in the SQL editor
-SQL_MAX_ROW = 1000
-# Default celery config is to use SQLA as a broker, in a production setting
-# you'll want to use a proper broker as specified here:
-# http://docs.celeryproject.org/en/latest/getting-started/brokers/index.html
-"""
-# Example:
-class CeleryConfig(object):
- BROKER_URL = 'sqla+sqlite:///celerydb.sqlite'
- CELERY_IMPORTS = ('caravel.tasks', )
- CELERY_RESULT_BACKEND = 'db+sqlite:///celery_results.sqlite'
- CELERY_ANNOTATIONS = {'tasks.add': {'rate_limit': '10/s'}}
-CELERY_CONFIG = CeleryConfig
-"""
-CELERY_CONFIG = None
try:
from caravel_config import * # noqa
@@ -203,4 +188,3 @@ class CeleryConfig(object):
if not CACHE_DEFAULT_TIMEOUT:
CACHE_DEFAULT_TIMEOUT = CACHE_CONFIG.get('CACHE_DEFAULT_TIMEOUT')
-
diff --git a/caravel/extract_table_names.py b/caravel/extract_table_names.py
new file mode 100644
index 0000000000000..4bc57074290a0
--- /dev/null
+++ b/caravel/extract_table_names.py
@@ -0,0 +1,60 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+#
+# Copyright (C) 2016 Andi Albrecht, albrecht.andi@gmail.com
+#
+# This example is part of python-sqlparse and is released under
+# the BSD License: http://www.opensource.org/licenses/bsd-license.php
+#
+# This example illustrates how to extract table names from nested
+# SELECT statements.
+#
+# See:
+# http://groups.google.com/group/sqlparse/browse_thread/thread/b0bd9a022e9d4895
+
+import sqlparse
+from sqlparse.sql import IdentifierList, Identifier
+from sqlparse.tokens import Keyword, DML
+
+
+def is_subselect(parsed):
+ if not parsed.is_group():
+ return False
+ for item in parsed.tokens:
+ if item.ttype is DML and item.value.upper() == 'SELECT':
+ return True
+ return False
+
+
+def extract_from_part(parsed):
+ from_seen = False
+ for item in parsed.tokens:
+ if from_seen:
+ if is_subselect(item):
+ for x in extract_from_part(item):
+ yield x
+ elif item.ttype is Keyword:
+ raise StopIteration
+ else:
+ yield item
+ elif item.ttype is Keyword and item.value.upper() == 'FROM':
+ from_seen = True
+
+
+def extract_table_identifiers(token_stream):
+ for item in token_stream:
+ if isinstance(item, IdentifierList):
+ for identifier in item.get_identifiers():
+ yield identifier.get_name()
+ elif isinstance(item, Identifier):
+ yield item.get_name()
+ # It's a bug to check for Keyword here, but in the example
+ # above some tables names are identified as keywords...
+ elif item.ttype is Keyword:
+ yield item.value
+
+
+# TODO(bkyryliuk): add logic to support joins and unions.
+def extract_tables(sql):
+ stream = extract_from_part(sqlparse.parse(sql)[0])
+ return list(extract_table_identifiers(stream))
diff --git a/caravel/migrations/versions/ad82a75afd82_add_query_model.py b/caravel/migrations/versions/ad82a75afd82_add_query_model.py
index 4794f416de07f..4a53c4309e077 100644
--- a/caravel/migrations/versions/ad82a75afd82_add_query_model.py
+++ b/caravel/migrations/versions/ad82a75afd82_add_query_model.py
@@ -13,17 +13,26 @@
from alembic import op
import sqlalchemy as sa
+
def upgrade():
op.create_table('query',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('database_id', sa.Integer(), nullable=False),
- sa.Column('tmp_table_name', sa.String(length=64), nullable=True),
+ sa.Column('tmp_table_name', sa.String(length=256), nullable=True),
+ sa.Column('tab_name', sa.String(length=256),nullable=True),
sa.Column('user_id', sa.Integer(), nullable=True),
sa.Column('status', sa.String(length=16), nullable=True),
- sa.Column('name', sa.String(length=64), nullable=True),
- sa.Column('sql', sa.Text, nullable=True),
+ sa.Column('name', sa.String(length=256), nullable=True),
+ sa.Column('schema', sa.String(length=256), nullable=True),
+ sa.Column('sql', sa.Text(), nullable=True),
+ sa.Column('select_sql', sa.Text(), nullable=True),
+ sa.Column('executed_sql', sa.Text(), nullable=True),
sa.Column('limit', sa.Integer(), nullable=True),
+ sa.Column('limit_used', sa.Boolean(), nullable=True),
+ sa.Column('select_as_cta', sa.Boolean(), nullable=True),
+ sa.Column('select_as_cta_used', sa.Boolean(), nullable=True),
sa.Column('progress', sa.Integer(), nullable=True),
+ sa.Column('error_message', sa.Text(), nullable=True),
sa.Column('start_time', sa.DateTime(), nullable=True),
sa.Column('end_time', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['database_id'], [u'dbs.id'], ),
diff --git a/caravel/models.py b/caravel/models.py
index 083f0b3c1ea07..feb2f6d13f011 100644
--- a/caravel/models.py
+++ b/caravel/models.py
@@ -379,7 +379,7 @@ class Database(Model, AuditMixinNullable):
sqlalchemy_uri = Column(String(1024))
password = Column(EncryptedType(String(1024), config.get('SECRET_KEY')))
cache_timeout = Column(Integer)
- select_as_create_table_as = Column(Boolean, default=True)
+ select_as_create_table_as = Column(Boolean, default=False)
extra = Column(Text, default=textwrap.dedent("""\
{
"metadata_params": {},
@@ -1711,6 +1711,16 @@ class FavStar(Model):
class QueryStatus:
+ def from_presto_states(self, presto_status):
+ if presto_status.lower() == 'running':
+ return QueryStatus.IN_PROGRESS
+ if presto_status.lower() == 'running':
+ return QueryStatus.IN_PROGRESS
+ if presto_status.lower() == 'running':
+ return QueryStatus.IN_PROGRESS
+ if presto_status.lower() == 'running':
+ return QueryStatus.IN_PROGRESS
+
SCHEDULED = 'SCHEDULED'
CANCELLED = 'CANCELLED'
IN_PROGRESS = 'IN_PROGRESS'
@@ -1729,18 +1739,28 @@ class Query(Model):
database_id = Column(Integer, ForeignKey('dbs.id'), nullable=False)
# Store the tmp table into the DB only if the user asks for it.
- tmp_table_name = Column(String(64))
+ tmp_table_name = Column(String(256))
user_id = Column(Integer, ForeignKey('ab_user.id'), nullable=True)
# models.QueryStatus
status = Column(String(16))
- name = Column(String(64))
+ name = Column(String(256))
+ tab_name = Column(String(256))
+ schema = Column(String(256))
sql = Column(Text)
- # Could be configured in the caravel config
+ # Query to retrieve the results,
+ # used only in case of select_as_cta_used is true.
+ select_sql = Column(Text)
+ executed_sql = Column(Text)
+ # Could be configured in the caravel config.
limit = Column(Integer)
+ limit_used = Column(Boolean)
+ select_as_cta = Column(Boolean)
+ select_as_cta_used = Column(Boolean)
# 1..100
progress = Column(Integer)
+ error_message = Column(Text)
start_time = Column(DateTime)
end_time = Column(DateTime)
diff --git a/caravel/tasks.py b/caravel/tasks.py
index c48e66997456a..38f4b0edd8a72 100644
--- a/caravel/tasks.py
+++ b/caravel/tasks.py
@@ -1,7 +1,7 @@
import celery
from caravel import models, app, utils
from datetime import datetime
-import logging
+
from sqlalchemy import create_engine, select, text
from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy.sql.expression import TextAsFrom
@@ -11,6 +11,173 @@
celery_app = celery.Celery(config_source=app.config.get('CELERY_CONFIG'))
+@celery_app.task
+def get_sql_results(query_id):
+ """Executes the sql query returns the results."""
+ # Create a separate session, reusing the db.session leads to the
+ # concurrency issues.
+ session = get_session()
+ query = session.query(models.Query).filter_by(id=query_id).first()
+ result = None
+ try:
+ db_to_query = (
+ session.query(models.Database).filter_by(id=query.database_id)
+ .first()
+ )
+ except Exception as e:
+ result = fail_query(query, utils.error_msg_from_exception(e))
+
+ if not db_to_query:
+ result = fail_query(query, "Database with id {0} is missing.".format(
+ query.database_id))
+
+ if not result:
+ result = get_sql_results_as_dict(db_to_query, query, session)
+ query.end_time = datetime.now()
+ session.flush()
+ return result
+
+
+# TODO(bkyryliuk): dump results somewhere for the webserver.
+def get_sql_results_as_dict(db_to_query, query, orm_session):
+ """Get the SQL query results from the give session and db connection."""
+ engine = db_to_query.get_sqla_engine(schema=query.schema)
+ query.executed_sql = query.sql.strip().strip(';')
+
+ # Limit enforced only for retrieving the data, not for the CTA queries.
+ query.select_as_cta_used = False
+ query.limit_used = False
+ if is_query_select(query.executed_sql):
+ if query.select_as_cta:
+ if not query.tmp_table_name:
+ query.tmp_table_name = 'tmp_{}_table_{}'.format(
+ query.user_id,
+ query.start_time.strftime('%Y_%m_%d_%H_%M_%S'))
+ query.executed_sql = create_table_as(
+ query.executed_sql, query.tmp_table_name)
+ query.select_as_cta_used = True
+ elif query.limit:
+ query.executed_sql = add_limit_to_the_sql(
+ query.executed_sql, query.limit, engine)
+ query.limit_used = True
+
+ # TODO(bkyryliuk): ensure that tmp table was created.
+ # Do not set tmp table name if table wasn't created.
+ if not query.select_as_cta_used:
+ query.tmp_table_name = None
+
+ backend = engine.url.get_backend_name()
+ if backend in ('presto', 'hive'):
+ result = get_sql_results_async(engine, query, orm_session)
+ else:
+ result = get_sql_results_sync(engine, query)
+
+ orm_session.flush()
+ return result
+
+
+def get_sql_results_async(engine, query, orm_session):
+ try:
+ result_proxy = engine.execute(query.executed_sql, schema=query.schema)
+ except Exception as e:
+ return fail_query(query, utils.error_msg_from_exception(e))
+
+ cursor = result_proxy.cursor
+ query_stats = cursor.poll()
+ query.status = models.QueryStatus.IN_PROGRESS
+ orm_session.flush()
+ # poll returns dict -- JSON status information or ``None``
+ # if the query is done
+ # https://github.com/dropbox/PyHive/blob/
+ # b34bdbf51378b3979eaf5eca9e956f06ddc36ca0/pyhive/presto.py#L178
+ while query_stats:
+ # Update the object and wait for the kill signal.
+ orm_session.refresh(query)
+ completed_splits = int(query_stats['stats']['completedSplits'])
+ total_splits = int(query_stats['stats']['totalSplits'])
+ progress = 100 * completed_splits / total_splits
+ if progress > query.progress:
+ query.progress = progress
+
+ orm_session.flush()
+ query_stats = cursor.poll()
+ # TODO(b.kyryliuk): check for the kill signal.
+
+ if query.select_as_cta_used:
+ select_star = (
+ select('*').select_from(query.tmp_table_name).
+ limit(query.limit)
+ )
+ # SQL code to preview the results
+ query.select_sql = str(select_star.compile(
+ engine, compile_kwargs={"literal_binds": True}))
+ try:
+ # override cursor value to reuse the data extraction down below.
+ result_proxy = engine.execute(
+ query.select_sql, schema=query.schema)
+ cursor = result_proxy.cursor
+ while cursor.poll():
+ # TODO: wait till the data is fetched
+ pass
+ except Exception as e:
+ return fail_query(query, utils.error_msg_from_exception(e))
+
+ response = fetch_response_from_cursor(result_proxy, query)
+ query.status = models.QueryStatus.FINISHED
+ orm_session.flush()
+ return response
+
+
+def get_sql_results_sync(engine, query):
+ # TODO(bkyryliuk): rewrite into eng.execute as queries different from
+ # select should be permitted too.
+ query.select_sql = query.sql
+ if query.select_as_cta_used:
+ try:
+ engine.execute(query.executed_sql, schema=query.schema)
+ except Exception as e:
+ return fail_query(query, utils.error_msg_from_exception(e))
+ select_star = (
+ select('*').select_from(query.tmp_table_name).
+ limit(query.limit)
+ )
+ query.select_sql = str(select_star.compile(
+ engine, compile_kwargs={"literal_binds": True}))
+ try:
+ result_proxy = engine.execute(
+ query.select_sql, schema=query.schema)
+ except Exception as e:
+ return fail_query(query, utils.error_msg_from_exception(e))
+ response = fetch_response_from_cursor(result_proxy, query)
+ query.status = models.QueryStatus.FINISHED
+ return response
+
+
+def fail_query(query, message):
+ query.error_message = message
+ query.status = models.QueryStatus.FAILED
+ return {
+ 'error': query.error_message,
+ 'status': query.status,
+ }
+
+
+# TODO(b.kyryliuk): find better way to pass the data.
+def fetch_response_from_cursor(result_proxy, query):
+ cols = [col[0] for col in result_proxy.cursor.description]
+ data = result_proxy.fetchall()
+ print("DELETEME")
+ print(data)
+ df = pd.DataFrame(data, columns=cols)
+ df = df.fillna(0)
+ return {
+ 'query_id': query.id,
+ 'columns': [c for c in df.columns],
+ 'data': df.to_dict(orient='records'),
+ 'status': models.QueryStatus.FINISHED,
+ }
+
+
def is_query_select(sql):
try:
return sqlparse.parse(sql)[0].get_type() == 'SELECT'
@@ -35,13 +202,12 @@ def get_tables():
pass
-def add_limit_to_the_query(sql, limit, eng):
+def add_limit_to_the_sql(sql, limit, eng):
# Treat as single sql statement in case of failure.
- sql_statements = [sql]
try:
sql_statements = [s for s in sqlparse.split(sql) if s]
except Exception as e:
- logging.info(
+ app.logger.info(
"Statement " + sql + "failed to be transformed to have the limit "
"with the exception" + e.message)
return sql
@@ -56,6 +222,8 @@ def add_limit_to_the_query(sql, limit, eng):
# create table works only for the single statement.
+# TODO(bkyryliuk): enforce that all the columns have names. Presto requires it
+# for the CTA operation.
def create_table_as(sql, table_name, override=False):
"""Reformats the query into the create table as query.
@@ -69,12 +237,11 @@ def create_table_as(sql, table_name, override=False):
# TODO(bkyryliuk): drop table if allowed, check the namespace and
# the permissions.
# Treat as single sql statement in case of failure.
- sql_statements = [sql]
try:
# Filter out empty statements.
sql_statements = [s for s in sqlparse.split(sql) if s]
except Exception as e:
- logging.info(
+ app.logger.info(
"Statement " + sql + "failed to be transformed as create table as "
"with the exception" + e.message)
return sql
@@ -95,125 +262,4 @@ def get_session():
engine = create_engine(
app.config.get('SQLALCHEMY_DATABASE_URI'), convert_unicode=True)
return scoped_session(sessionmaker(
- autocommit=False, autoflush=False, bind=engine))
-
-
-@celery_app.task
-def get_sql_results(database_id, sql, user_id, tmp_table_name="", schema=None):
- """Executes the sql query returns the results.
-
- :param database_id: integer
- :param sql: string, query that will be executed
- :param user_id: integer
- :param tmp_table_name: name of the table for CTA
- :param schema: string, name of the schema (used in presto)
- :return: dataframe, query result
- """
- # Create a separate session, reusing the db.session leads to the
- # concurrency issues.
- session = get_session()
- try:
- db_to_query = (
- session.query(models.Database).filter_by(id=database_id).first()
- )
- except Exception as e:
- return {
- 'error': utils.error_msg_from_exception(e),
- 'success': False,
- }
- if not db_to_query:
- return {
- 'error': "Database with id {0} is missing.".format(database_id),
- 'success': False,
- }
-
- # TODO(bkyryliuk): provide a way for the user to name the query.
- # TODO(bkyryliuk): run explain query to derive the tables and fill in the
- # table_ids
- # TODO(bkyryliuk): check the user permissions
- # TODO(bkyryliuk): store the tab name in the query model
- limit = app.config.get('SQL_MAX_ROW', None)
- start_time = datetime.now()
- if not tmp_table_name:
- tmp_table_name = 'tmp.{}_table_{}'.format(user_id, start_time)
- query = models.Query(
- user_id=user_id,
- database_id=database_id,
- limit=limit,
- name='{}'.format(start_time),
- sql=sql,
- start_time=start_time,
- tmp_table_name=tmp_table_name,
- status=models.QueryStatus.IN_PROGRESS,
- )
- session.add(query)
- session.commit()
- query_result = get_sql_results_as_dict(
- db_to_query, sql, query.tmp_table_name, schema=schema)
- query.end_time = datetime.now()
- if query_result['success']:
- query.status = models.QueryStatus.FINISHED
- else:
- query.status = models.QueryStatus.FAILED
- session.commit()
- # TODO(bkyryliuk): return the tmp table / query_id
- return query_result
-
-
-# TODO(bkyryliuk): merge the changes made in the carapal first
-# before merging this PR.
-def get_sql_results_as_dict(db_to_query, sql, tmp_table_name, schema=None):
- """Get the SQL query results from the give session and db connection.
-
- :param sql: string, query that will be executed
- :param db_to_query: models.Database to query, cannot be None
- :param tmp_table_name: name of the table for CTA
- :param schema: string, name of the schema (used in presto)
- :return: (dataframe, boolean), results and the status
- """
- eng = db_to_query.get_sqla_engine(schema=schema)
- sql = sql.strip().strip(';')
- # TODO(bkyryliuk): fix this case for multiple statements
- if app.config.get('SQL_MAX_ROW'):
- sql = add_limit_to_the_query(
- sql, app.config.get("SQL_MAX_ROW"), eng)
-
- cta_used = False
- if (app.config.get('SQL_SELECT_AS_CTA') and
- db_to_query.select_as_create_table_as and is_query_select(sql)):
- # TODO(bkyryliuk): figure out if the query is select query.
- sql = create_table_as(sql, tmp_table_name)
- cta_used = True
-
- if cta_used:
- try:
- eng.execute(sql)
- return {
- 'tmp_table': tmp_table_name,
- 'success': True,
- }
- except Exception as e:
- return {
- 'error': utils.error_msg_from_exception(e),
- 'success': False,
- }
-
- # otherwise run regular SQL query.
- # TODO(bkyryliuk): rewrite into eng.execute as queries different from
- # select should be permitted too.
- try:
- df = db_to_query.get_df(sql, schema)
- df = df.fillna(0)
- return {
- 'columns': [c for c in df.columns],
- 'data': df.to_dict(orient='records'),
- 'success': True,
- }
-
- except Exception as e:
- return {
- 'error': utils.error_msg_from_exception(e),
- 'success': False,
- }
-
-
+ autocommit=True, autoflush=False, bind=engine))
diff --git a/caravel/utils.py b/caravel/utils.py
index 9b784517c573f..668c80f493674 100644
--- a/caravel/utils.py
+++ b/caravel/utils.py
@@ -339,10 +339,17 @@ def error_msg_from_exception(e):
Database have different ways to handle exception. This function attempts
to make sense of the exception object and construct a human readable
sentence.
+
+ TODO(bkyryliuk): parse the Presto error message from the connection
+ created via create_engine.
+ engine = create_engine('presto://localhost:3506/silver') -
+ gives an e.message as the str(dict)
+ presto.connect("localhost", port=3506, catalog='silver') - as a dict.
+ The latter version is parsed correctly by this function.
"""
msg = ''
if hasattr(e, 'message'):
- if (type(e.message) is dict):
+ if type(e.message) is dict:
msg = e.message.get('message')
elif e.message:
msg = "{}".format(e.message)
diff --git a/caravel/views.py b/caravel/views.py
index 66e3697f48502..3cfbf338bd594 100755
--- a/caravel/views.py
+++ b/caravel/views.py
@@ -430,11 +430,13 @@ class DatabaseAsync(DatabaseView):
appbuilder.add_view_no_menu(DatabaseAsync)
+
class DatabaseTablesAsync(DatabaseView):
list_columns = ['id', 'all_table_names', 'all_schema_names']
appbuilder.add_view_no_menu(DatabaseTablesAsync)
+
class TableModelView(CaravelModelView, DeleteMixin): # noqa
datamodel = SQLAInterface(models.SqlaTable)
list_columns = [
@@ -592,7 +594,8 @@ def add(self):
url = "/druiddatasourcemodelview/list/"
msg = _(
"Click on a datasource link to create a Slice, "
- "or click on a table link here "
+ "or click on a table link "
+ "here "
"to create a Slice for a table"
)
else:
@@ -866,7 +869,8 @@ def explore(self, datasource_type, datasource_id):
datasource_access = self.can_access(
'datasource_access', datasource.perm)
if not (all_datasource_access or datasource_access):
- flash(__("You don't seem to have access to this datasource"), "danger")
+ flash(__("You don't seem to have access to this datasource"),
+ "danger")
return redirect(error_redirect)
action = request.args.get('action')
@@ -943,7 +947,8 @@ def save_or_overwrite_slice(
del d['action']
del d['previous_viz_type']
- as_list = ('metrics', 'groupby', 'columns', 'all_columns', 'mapbox_label', 'order_by_cols')
+ as_list = ('metrics', 'groupby', 'columns', 'all_columns',
+ 'mapbox_label', 'order_by_cols')
for k in d:
v = d.get(k)
if k in as_list and not isinstance(v, list):
@@ -1054,7 +1059,8 @@ def activity_per_day(self):
.group_by(Log.dt)
.all()
)
- payload = {str(time.mktime(dt.timetuple())): ccount for dt, ccount in qry if dt}
+ payload = {str(time.mktime(dt.timetuple())):
+ ccount for dt, ccount in qry if dt}
return Response(json.dumps(payload), mimetype="application/json")
@api
@@ -1110,9 +1116,11 @@ def add_slices(self, dashboard_id):
data = json.loads(request.form.get('data'))
session = db.session()
Slice = models.Slice # noqa
- dash = session.query(models.Dashboard).filter_by(id=dashboard_id).first()
+ dash = (
+ session.query(models.Dashboard).filter_by(id=dashboard_id).first())
check_ownership(dash, raise_if_false=True)
- new_slices = session.query(Slice).filter(Slice.id.in_(data['slice_ids']))
+ new_slices = session.query(Slice).filter(
+ Slice.id.in_(data['slice_ids']))
dash.slices += new_slices
session.merge(dash)
session.commit()
@@ -1146,13 +1154,18 @@ def favstar(self, class_name, obj_id, action):
FavStar = models.FavStar # noqa
count = 0
favs = session.query(FavStar).filter_by(
- class_name=class_name, obj_id=obj_id, user_id=g.user.get_id()).all()
+ class_name=class_name, obj_id=obj_id,
+ user_id=g.user.get_id()).all()
if action == 'select':
if not favs:
session.add(
FavStar(
- class_name=class_name, obj_id=obj_id, user_id=g.user.get_id(),
- dttm=datetime.now()))
+ class_name=class_name,
+ obj_id=obj_id,
+ user_id=g.user.get_id(),
+ dttm=datetime.now()
+ )
+ )
count = 1
elif action == 'unselect':
for fav in favs:
@@ -1358,9 +1371,22 @@ def sql_json(self):
sql = request.form.get('sql')
database_id = request.form.get('database_id')
schema = request.form.get('schema')
+ tab_name = request.form.get('tab_name')
+ tmp_table_name = request.form.get('tmp_table_name')
+ select_as_cta = request.form.get('select_as_cta') == 'True'
+
session = db.session()
mydb = session.query(models.Database).filter_by(id=database_id).first()
+ if not mydb:
+ return Response(
+ json.dumps({
+ 'error': 'Database with id 0 is missing.',
+ 'status': models.QueryStatus.FAILED,
+ }),
+ status=500,
+ mimetype="application/json")
+
if not (self.can_access(
'all_datasource_access', 'all_datasource_access') or
self.can_access('database_access', mydb.perm)):
@@ -1368,19 +1394,78 @@ def sql_json(self):
"SQL Lab requires the `all_datasource_access` or "
"specific DB permission"))
- data = tasks.get_sql_results(database_id, sql, g.user.get_id(),
- schema=schema)
- if 'error' in data:
+ # DB select_as_create_table_as forces all queries to be
+ # select_as_cta.
+ if select_as_cta or mydb.select_as_create_table_as:
+ select_as_cta = True
+ start_time = datetime.now()
+ query_name = '{}_{}_{}'.format(
+ g.user.get_id(), tab_name, start_time.strftime('%M:%S:%f'))
+
+ query = models.Query(
+ database_id=database_id,
+ limit=app.config.get('SQL_MAX_ROW', None),
+ name=query_name,
+ sql=sql,
+ schema=schema,
+ # TODO(bkyryliuk): consider it being DB property.
+ select_as_cta=select_as_cta,
+ start_time=start_time,
+ status=models.QueryStatus.SCHEDULED,
+ tab_name=tab_name,
+ tmp_table_name=tmp_table_name,
+ user_id=g.user.get_id(),
+ )
+ session.add(query)
+ session.commit()
+
+ data = tasks.get_sql_results(query.id)
+ if data['status'] == models.QueryStatus.FAILED:
return Response(
- json.dumps(data),
+ json.dumps(
+ data, default=utils.json_int_dttm_ser, allow_nan=False),
status=500,
mimetype="application/json")
- if 'tmp_table' in data:
- # TODO(bkyryliuk): add query id to the response and implement the
- # endpoint to poll the status and results.
- return None
- return json.dumps(
- data, default=utils.json_int_dttm_ser, allow_nan=False)
+ print("DELETEME")
+ print(data)
+ return Response(
+ json.dumps(
+ data, default=utils.json_int_dttm_ser, allow_nan=False),
+ status=200,
+ mimetype="application/json")
+
+ @has_access
+ @expose("/query_progress/", methods=['GET'])
+ @log_this
+ def query_progress(self):
+ """Runs arbitrary sql and returns and json"""
+ query_id = request.form.get('query_id')
+ s = db.session()
+ query = s.query(models.Query).filter_by(id=query_id).first()
+ mydb = s.query(models.Database).filter_by(id=query.database_id).first()
+
+ if not (self.can_access(
+ 'all_datasource_access', 'all_datasource_access') or
+ self.can_access('database_access', mydb.perm)):
+ raise utils.CaravelSecurityException(_(
+ "SQL Lab requires the `all_datasource_access` or "
+ "specific DB permission"))
+
+ if query:
+ return Response(
+ json.dumps({
+ 'status': query.status,
+ 'progress': query.progress
+ }),
+ status=200,
+ mimetype="application/json")
+
+ return Response(
+ json.dumps({
+ 'error': "Query with id {} wasn't found".format(query_id),
+ }),
+ status=404,
+ mimetype="application/json")
@has_access
@expose("/refresh_datasources/")
diff --git a/setup.py b/setup.py
index ceb266d9fef18..07ae7b1750173 100644
--- a/setup.py
+++ b/setup.py
@@ -30,6 +30,7 @@
'pandas==0.18.1',
'parsedatetime==2.0.0',
'pydruid==0.3.0',
+ 'PyHive>=0.2.1',
'python-dateutil==2.5.3',
'requests==2.10.0',
'simplejson==3.8.2',
@@ -37,6 +38,8 @@
'sqlalchemy==1.0.13',
'sqlalchemy-utils==0.32.7',
'sqlparse==0.1.19',
+ 'thrift>=0.9.3',
+ 'thrift-sasl>=0.2.1',
'werkzeug==0.11.10',
],
extras_require={
diff --git a/tests/celery_tests.py b/tests/celery_tests.py
index e88ae0fca1c5b..89208bb458bc3 100644
--- a/tests/celery_tests.py
+++ b/tests/celery_tests.py
@@ -1,10 +1,9 @@
"""Unit tests for Caravel Celery worker"""
-import datetime
import imp
+import json
import subprocess
import os
import pandas as pd
-import time
import unittest
import caravel
@@ -116,36 +115,39 @@ class CeleryTestCase(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(CeleryTestCase, self).__init__(*args, **kwargs)
self.client = app.test_client()
- utils.init(caravel)
- admin = appbuilder.sm.find_user('admin')
- if not admin:
- appbuilder.sm.add_user(
- 'admin', 'admin', ' user', 'admin@fab.org',
- appbuilder.sm.find_role('Admin'),
- password='general')
- utils.init(caravel)
@classmethod
def setUpClass(cls):
try:
os.remove(app.config.get('SQL_CELERY_DB_FILE_PATH'))
- except OSError:
- pass
+ except OSError as e:
+ app.logger.warn(str(e))
try:
os.remove(app.config.get('SQL_CELERY_RESULTS_DB_FILE_PATH'))
- except OSError:
- pass
+ except OSError as e:
+ app.logger.warn(str(e))
+
+ utils.init(caravel)
+ admin = appbuilder.sm.find_user('admin')
+ if not admin:
+ appbuilder.sm.add_user(
+ 'admin', 'admin', ' user', 'admin@fab.org',
+ appbuilder.sm.find_role('Admin'),
+ password='general')
+ cli.load_examples(load_test_data=True)
worker_command = BASE_DIR + '/bin/caravel worker'
subprocess.Popen(
worker_command, shell=True, stdout=subprocess.PIPE)
- cli.load_examples(load_test_data=True)
@classmethod
def tearDownClass(cls):
+ main_db = db.session.query(models.Database).filter_by(
+ database_name="main").first()
+ main_db.get_sqla_engine().execute("DELETE FROM query;")
+
subprocess.call(
- "ps auxww | grep 'celeryd' | awk '{print $2}' | "
- "xargs kill -9",
+ "ps auxww | grep 'celeryd' | awk '{print $2}' | xargs kill -9",
shell=True
)
subprocess.call(
@@ -160,6 +162,30 @@ def setUp(self):
def tearDown(self):
pass
+ def login(self, username='admin', password='general'):
+ resp = self.client.post(
+ '/login/',
+ data=dict(username=username, password=password),
+ follow_redirects=True)
+ assert 'Welcome' in resp.data.decode('utf-8')
+
+ def logout(self):
+ self.client.get('/logout/', follow_redirects=True)
+
+ def run_sql(self, dbid, sql, select_as_cta='False', tmp_table_name='tmp'):
+ self.login()
+ resp = self.client.post(
+ '/caravel/sql_json/',
+ data=dict(
+ database_id=dbid,
+ sql=sql,
+ select_as_cta=select_as_cta,
+ tmp_table_name=tmp_table_name,
+ ),
+ )
+ self.logout()
+ return json.loads(resp.data.decode('utf-8'))
+
def test_add_limit_to_the_query(self):
query_session = tasks.get_session()
db_to_query = query_session.query(models.Database).filter_by(
@@ -167,7 +193,7 @@ def test_add_limit_to_the_query(self):
eng = db_to_query.get_sqla_engine()
select_query = "SELECT * FROM outer_space;"
- updated_select_query = tasks.add_limit_to_the_query(
+ updated_select_query = tasks.add_limit_to_the_sql(
select_query, 100, eng)
# Different DB engines have their own spacing while compiling
# the queries, that's why ' '.join(query.split()) is used.
@@ -178,7 +204,7 @@ def test_add_limit_to_the_query(self):
)
select_query_no_semicolon = "SELECT * FROM outer_space"
- updated_select_query_no_semicolon = tasks.add_limit_to_the_query(
+ updated_select_query_no_semicolon = tasks.add_limit_to_the_sql(
select_query_no_semicolon, 100, eng)
self.assertTrue(
"SELECT * FROM (SELECT * FROM outer_space) AS inner_qry "
@@ -187,19 +213,19 @@ def test_add_limit_to_the_query(self):
)
incorrect_query = "SMTH WRONG SELECT * FROM outer_space"
- updated_incorrect_query = tasks.add_limit_to_the_query(
+ updated_incorrect_query = tasks.add_limit_to_the_sql(
incorrect_query, 100, eng)
self.assertEqual(incorrect_query, updated_incorrect_query)
insert_query = "INSERT INTO stomach VALUES (beer, chips);"
- updated_insert_query = tasks.add_limit_to_the_query(
+ updated_insert_query = tasks.add_limit_to_the_sql(
insert_query, 100, eng)
self.assertEqual(insert_query, updated_insert_query)
multi_line_query = (
"SELECT * FROM planets WHERE\n Luke_Father = 'Darth Vader';"
)
- updated_multi_line_query = tasks.add_limit_to_the_query(
+ updated_multi_line_query = tasks.add_limit_to_the_sql(
multi_line_query, 100, eng)
self.assertTrue(
"SELECT * FROM (SELECT * FROM planets WHERE "
@@ -208,13 +234,13 @@ def test_add_limit_to_the_query(self):
)
delete_query = "DELETE FROM planet WHERE name = 'Earth'"
- updated_delete_query = tasks.add_limit_to_the_query(
+ updated_delete_query = tasks.add_limit_to_the_sql(
delete_query, 100, eng)
self.assertEqual(delete_query, updated_delete_query)
create_table_as = (
"CREATE TABLE pleasure AS SELECT chocolate FROM lindt_store;\n")
- updated_create_table_as = tasks.add_limit_to_the_query(
+ updated_create_table_as = tasks.add_limit_to_the_sql(
create_table_as, 100, eng)
self.assertEqual(create_table_as, updated_create_table_as)
@@ -231,7 +257,7 @@ def test_add_limit_to_the_query(self):
"(B.TECH ,BE ,Degree ,MCA ,MiBA)\n "
"AND Having Brothers= Null AND Sisters = Null"
)
- updated_sql_procedure = tasks.add_limit_to_the_query(
+ updated_sql_procedure = tasks.add_limit_to_the_sql(
sql_procedure, 100, eng)
self.assertEqual(sql_procedure, updated_sql_procedure)
@@ -242,11 +268,15 @@ def test_run_async_query_delay_get(self):
# Case 1.
# DB #0 doesn't exist.
- result1 = tasks.get_sql_results.delay(
- 0, 'SELECT * FROM dontexist', 1, tmp_table_name='tmp_1_1').get()
+ result1 = self.run_sql(
+ 0,
+ 'SELECT * FROM dontexist',
+ tmp_table_name='tmp_table_1_a',
+ select_as_cta='True',
+ )
expected_result1 = {
'error': 'Database with id 0 is missing.',
- 'success': False
+ 'status': models.QueryStatus.FAILED,
}
self.assertEqual(
sorted(expected_result1.items()),
@@ -255,18 +285,17 @@ def test_run_async_query_delay_get(self):
session1 = db.create_scoped_session()
query1 = session1.query(models.Query).filter_by(
sql='SELECT * FROM dontexist').first()
- session1.close()
self.assertIsNone(query1)
+ session1.close()
# Case 2.
- session2 = db.create_scoped_session()
- query2 = session2.query(models.Query).filter_by(
- sql='SELECT * FROM dontexist1').first()
- self.assertEqual(models.QueryStatus.FAILED, query2.status)
- session2.close()
-
- result2 = tasks.get_sql_results.delay(
- 1, 'SELECT * FROM dontexist1', 1, tmp_table_name='tmp_2_1').get()
+ # Table doesn't exist.
+ result2 = self.run_sql(
+ 1,
+ 'SELECT * FROM dontexist1',
+ tmp_table_name='tmp_table_2_a',
+ select_as_cta='True',
+ )
self.assertTrue('error' in result2)
session2 = db.create_scoped_session()
query2 = session2.query(models.Query).filter_by(
@@ -275,13 +304,20 @@ def test_run_async_query_delay_get(self):
session2.close()
# Case 3.
+ # Table and DB exists, CTA call to the backend.
where_query = (
"SELECT name FROM ab_permission WHERE name='can_select_star'")
- result3 = tasks.get_sql_results.delay(
- 1, where_query, 1, tmp_table_name='tmp_3_1').get()
+ result3 = self.run_sql(
+ 1,
+ where_query,
+ tmp_table_name='tmp_table_3_a',
+ select_as_cta='True',
+ )
expected_result3 = {
- 'tmp_table': 'tmp_3_1',
- 'success': True
+ u'query_id': 2,
+ u'status': models.QueryStatus.FINISHED,
+ u'columns': [u'name'],
+ u'data': [{u'name': u'can_select_star'}],
}
self.assertEqual(
sorted(expected_result3.items()),
@@ -291,18 +327,24 @@ def test_run_async_query_delay_get(self):
query3 = session3.query(models.Query).filter_by(
sql=where_query).first()
session3.close()
- df3 = pd.read_sql_query(sql="SELECT * FROM tmp_3_1", con=eng)
+ df3 = pd.read_sql_query(sql="SELECT * FROM tmp_table_3_a", con=eng)
data3 = df3.to_dict(orient='records')
self.assertEqual(models.QueryStatus.FINISHED, query3.status)
self.assertEqual([{'name': 'can_select_star'}], data3)
# Case 4.
- result4 = tasks.get_sql_results.delay(
- 1, 'SELECT * FROM ab_permission WHERE id=666', 1,
- tmp_table_name='tmp_4_1').get()
+ # Table and DB exists, CTA call to the backend, no data.
+ result4 = self.run_sql(
+ 1,
+ 'SELECT * FROM ab_permission WHERE id=666',
+ tmp_table_name='tmp_table_4_a',
+ select_as_cta='True',
+ )
expected_result4 = {
- 'tmp_table': 'tmp_4_1',
- 'success': True
+ u'query_id': 3,
+ u'status': models.QueryStatus.FINISHED,
+ u'columns': [u'id', u'name'],
+ u'data': [],
}
self.assertEqual(
sorted(expected_result4.items()),
@@ -312,88 +354,30 @@ def test_run_async_query_delay_get(self):
query4 = session4.query(models.Query).filter_by(
sql='SELECT * FROM ab_permission WHERE id=666').first()
session4.close()
- df4 = pd.read_sql_query(sql="SELECT * FROM tmp_4_1", con=eng)
+ df4 = pd.read_sql_query(sql="SELECT * FROM tmp_table_4_a", con=eng)
data4 = df4.to_dict(orient='records')
self.assertEqual(models.QueryStatus.FINISHED, query4.status)
self.assertEqual([], data4)
# Case 5.
- # Return the data directly if DB select_as_create_table_as is False.
- main_db.select_as_create_table_as = False
- db.session.commit()
- result5 = tasks.get_sql_results.delay(
- 1, where_query, 1, tmp_table_name='tmp_5_1').get()
+ # Table and DB exists, select without CTA.
+ result5 = self.run_sql(
+ 1,
+ where_query,
+ tmp_table_name='tmp_table_5_a',
+ select_as_cta='False',
+ )
expected_result5 = {
- 'columns': ['name'],
- 'data': [{'name': 'can_select_star'}],
- 'success': True
+ u'query_id': 4,
+ u'columns': [u'name'],
+ u'data': [{u'name': u'can_select_star'}],
+ u'status': models.QueryStatus.FINISHED,
}
self.assertEqual(
sorted(expected_result5.items()),
sorted(result5.items())
)
- def test_run_async_query_delay(self):
- celery_task1 = tasks.get_sql_results.delay(
- 0, 'SELECT * FROM dontexist', 1, tmp_table_name='tmp_1_2')
- celery_task2 = tasks.get_sql_results.delay(
- 1, 'SELECT * FROM dontexist1', 1, tmp_table_name='tmp_2_2')
- where_query = (
- "SELECT name FROM ab_permission WHERE name='can_select_star'")
- celery_task3 = tasks.get_sql_results.delay(
- 1, where_query, 1, tmp_table_name='tmp_3_2')
- celery_task4 = tasks.get_sql_results.delay(
- 1, 'SELECT * FROM ab_permission WHERE id=666', 1,
- tmp_table_name='tmp_4_2')
-
- time.sleep(1)
-
- # DB #0 doesn't exist.
- expected_result1 = {
- 'error': 'Database with id 0 is missing.',
- 'success': False
- }
- self.assertEqual(
- sorted(expected_result1.items()),
- sorted(celery_task1.get().items())
- )
- session2 = db.create_scoped_session()
- query2 = session2.query(models.Query).filter_by(
- sql='SELECT * FROM dontexist1').first()
- self.assertEqual(models.QueryStatus.FAILED, query2.status)
- self.assertTrue('error' in celery_task2.get())
- expected_result3 = {
- 'tmp_table': 'tmp_3_2',
- 'success': True
- }
- self.assertEqual(
- sorted(expected_result3.items()),
- sorted(celery_task3.get().items())
- )
- expected_result4 = {
- 'tmp_table': 'tmp_4_2',
- 'success': True
- }
- self.assertEqual(
- sorted(expected_result4.items()),
- sorted(celery_task4.get().items())
- )
-
- session = db.create_scoped_session()
- query1 = session.query(models.Query).filter_by(
- sql='SELECT * FROM dontexist').first()
- self.assertIsNone(query1)
- query2 = session.query(models.Query).filter_by(
- sql='SELECT * FROM dontexist1').first()
- self.assertEqual(models.QueryStatus.FAILED, query2.status)
- query3 = session.query(models.Query).filter_by(
- sql=where_query).first()
- self.assertEqual(models.QueryStatus.FINISHED, query3.status)
- query4 = session.query(models.Query).filter_by(
- sql='SELECT * FROM ab_permission WHERE id=666').first()
- self.assertEqual(models.QueryStatus.FINISHED, query4.status)
- session.close()
-
if __name__ == '__main__':
unittest.main()
diff --git a/tests/core_tests.py b/tests/core_tests.py
index 48d26c16e9606..87bce398b9e7c 100644
--- a/tests/core_tests.py
+++ b/tests/core_tests.py
@@ -321,7 +321,7 @@ def run_sql(self, sql, user_name):
)
resp = self.client.post(
'/caravel/sql_json/',
- data=dict(database_id=dbid, sql=sql),
+ data=dict(database_id=dbid, sql=sql, select_as_create_as=False),
)
self.logout()
return json.loads(resp.data.decode('utf-8'))
@@ -340,9 +340,9 @@ def test_sql_json_has_access(self):
db.session.commit()
main_db_permission_view = (
db.session.query(ab_models.PermissionView)
- .join(ab_models.ViewMenu)
- .filter(ab_models.ViewMenu.name == '[main].(id:1)')
- .first()
+ .join(ab_models.ViewMenu)
+ .filter(ab_models.ViewMenu.name == '[main].(id:1)')
+ .first()
)
astronaut = sm.add_role("Astronaut")
sm.add_permission_role(astronaut, main_db_permission_view)
@@ -361,6 +361,8 @@ def test_sql_json_has_access(self):
def test_sql_json(self):
data = self.run_sql("SELECT * FROM ab_user", 'admin')
+ print("self.run_sql")
+ print(str(data))
assert len(data['data']) > 0
data = self.run_sql("SELECT * FROM unexistant_table", 'admin')