Skip to content

Commit

Permalink
Show Presto views as views, not tables (#8243)
Browse files Browse the repository at this point in the history
* WIP

* Implement views in Presto

* Clean up

* Fix CSS

* Fix unit tests

* Add types to database

* Fix circular import
  • Loading branch information
betodealmeida authored Sep 18, 2019
1 parent 4088a84 commit 12fb8e7
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 17 deletions.
3 changes: 3 additions & 0 deletions superset/assets/src/components/TableSelector.css
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,6 @@
border-bottom: 1px solid #f2f2f2;
margin: 15px 0;
}
.TableLabel {
white-space: nowrap;
}
2 changes: 1 addition & 1 deletion superset/assets/src/components/TableSelector.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ export default class TableSelector extends React.PureComponent {
onMouseEnter={() => focusOption(option)}
style={style}
>
<span>
<span className="TableLabel">
<span className="m-r-5">
<small className="text-muted">
<i className={`fa fa-${option.type === 'view' ? 'eye' : 'table'}`} />
Expand Down
14 changes: 11 additions & 3 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import hashlib
import os
import re
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, TYPE_CHECKING, Union

from flask import g
from flask_babel import lazy_gettext as _
Expand All @@ -40,6 +40,10 @@
from superset import app, db, sql_parse
from superset.utils import core as utils

if TYPE_CHECKING:
# prevent circular imports
from superset.models.core import Database


class TimeGrain(NamedTuple):
name: str # TODO: redundant field, remove
Expand Down Expand Up @@ -538,7 +542,9 @@ def get_schema_names(cls, inspector: Inspector) -> List[str]:
return sorted(inspector.get_schema_names())

@classmethod
def get_table_names(cls, inspector: Inspector, schema: Optional[str]) -> List[str]:
def get_table_names(
cls, database: "Database", inspector: Inspector, schema: Optional[str]
) -> List[str]:
"""
Get all tables from schema
Expand All @@ -552,7 +558,9 @@ def get_table_names(cls, inspector: Inspector, schema: Optional[str]) -> List[st
return sorted(tables)

@classmethod
def get_view_names(cls, inspector: Inspector, schema: Optional[str]) -> List[str]:
def get_view_names(
cls, database: "Database", inspector: Inspector, schema: Optional[str]
) -> List[str]:
"""
Get all views from schema
Expand Down
8 changes: 6 additions & 2 deletions superset/db_engine_specs/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@
# under the License.
# pylint: disable=C,R,W
from datetime import datetime
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, TYPE_CHECKING

from sqlalchemy.dialects.postgresql.base import PGInspector

from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod

if TYPE_CHECKING:
# prevent circular imports
from superset.models.core import Database


class PostgresBaseEngineSpec(BaseEngineSpec):
""" Abstract class for Postgres 'like' databases """
Expand Down Expand Up @@ -64,7 +68,7 @@ class PostgresEngineSpec(PostgresBaseEngineSpec):

@classmethod
def get_table_names(
cls, inspector: PGInspector, schema: Optional[str]
cls, database: "Database", inspector: PGInspector, schema: Optional[str]
) -> List[str]:
"""Need to consider foreign tables for PostgreSQL"""
tables = inspector.get_table_names(schema)
Expand Down
40 changes: 37 additions & 3 deletions superset/db_engine_specs/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import re
import textwrap
import time
from typing import Any, cast, Dict, List, Optional, Set, Tuple
from typing import Any, cast, Dict, List, Optional, Set, Tuple, TYPE_CHECKING
from urllib import parse

import simplejson as json
Expand All @@ -40,6 +40,10 @@
from superset.sql_parse import ParsedQuery
from superset.utils import core as utils

if TYPE_CHECKING:
# prevent circular imports
from superset.models.core import Database

QueryStatus = utils.QueryStatus
config = app.config

Expand Down Expand Up @@ -128,14 +132,44 @@ def get_allow_cost_estimate(cls, version: str = None) -> bool:
return version is not None and StrictVersion(version) >= StrictVersion("0.319")

@classmethod
def get_view_names(cls, inspector: Inspector, schema: Optional[str]) -> List[str]:
def get_table_names(
cls, database: "Database", inspector: Inspector, schema: Optional[str]
) -> List[str]:
tables = super().get_table_names(database, inspector, schema)
if not is_feature_enabled("PRESTO_SPLIT_VIEWS_FROM_TABLES"):
return tables

views = set(cls.get_view_names(database, inspector, schema))
actual_tables = set(tables) - views
return list(actual_tables)

@classmethod
def get_view_names(
cls, database: "Database", inspector: Inspector, schema: Optional[str]
) -> List[str]:
"""Returns an empty list
get_table_names() function returns all table names and view names,
and get_view_names() is not implemented in sqlalchemy_presto.py
https://github.com/dropbox/PyHive/blob/e25fc8440a0686bbb7a5db5de7cb1a77bdb4167a/pyhive/sqlalchemy_presto.py
"""
return []
if not is_feature_enabled("PRESTO_SPLIT_VIEWS_FROM_TABLES"):
return []

if schema:
sql = "SELECT table_name FROM information_schema.views WHERE table_schema=%(schema)s"
params = {"schema": schema}
else:
sql = "SELECT table_name FROM information_schema.views"
params = {}

engine = cls.get_engine(database, schema=schema)
with closing(engine.raw_connection()) as conn:
with closing(conn.cursor()) as cursor:
cursor.execute(sql, params)
results = cursor.fetchall()

return [row[0] for row in results]

@classmethod
def _create_column_info(cls, name: str, data_type: str) -> dict:
Expand Down
10 changes: 8 additions & 2 deletions superset/db_engine_specs/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,17 @@
# under the License.
# pylint: disable=C,R,W
from datetime import datetime
from typing import List
from typing import List, TYPE_CHECKING

from sqlalchemy.engine.reflection import Inspector

from superset.db_engine_specs.base import BaseEngineSpec
from superset.utils import core as utils

if TYPE_CHECKING:
# prevent circular imports
from superset.models.core import Database


class SqliteEngineSpec(BaseEngineSpec):
engine = "sqlite"
Expand Down Expand Up @@ -79,6 +83,8 @@ def convert_dttm(cls, target_type: str, dttm: datetime) -> str:
return "'{}'".format(iso)

@classmethod
def get_table_names(cls, inspector: Inspector, schema: str) -> List[str]:
def get_table_names(
cls, database: "Database", inspector: Inspector, schema: str
) -> List[str]:
"""Need to disregard the schema for Sqlite"""
return sorted(inspector.get_table_names())
4 changes: 2 additions & 2 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,7 +1063,7 @@ def get_all_table_names_in_schema(
"""
try:
tables = self.db_engine_spec.get_table_names(
inspector=self.inspector, schema=schema
database=self, inspector=self.inspector, schema=schema
)
return [
utils.DatasourceName(table=table, schema=schema) for table in tables
Expand Down Expand Up @@ -1097,7 +1097,7 @@ def get_all_view_names_in_schema(
"""
try:
views = self.db_engine_spec.get_view_names(
inspector=self.inspector, schema=schema
database=self, inspector=self.inspector, schema=schema
)
return [utils.DatasourceName(table=view, schema=schema) for view in views]
except Exception as e:
Expand Down
1 change: 1 addition & 0 deletions superset/views/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1572,6 +1572,7 @@ def get_datasource_label(ds_name: utils.DatasourceName) -> str:
for vn in views[:max_views]
]
)
table_options.sort(key=lambda value: value["label"])
payload = {"tableLength": len(tables) + len(views), "options": table_options}
return json_success(json.dumps(payload))

Expand Down
12 changes: 8 additions & 4 deletions tests/db_engine_specs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,9 @@ def test_engine_time_grain_validity(self):
self.assertSetEqual(defined_grains, intersection, engine)

def test_presto_get_view_names_return_empty_list(self):
self.assertEquals([], PrestoEngineSpec.get_view_names(mock.ANY, mock.ANY))
self.assertEquals(
[], PrestoEngineSpec.get_view_names(mock.ANY, mock.ANY, mock.ANY)
)

def verify_presto_column(self, column, expected_results):
inspector = mock.Mock()
Expand Down Expand Up @@ -877,7 +879,9 @@ def test_presto_where_latest_partition(self):
self.assertEqual("SELECT \nWHERE ds = '01-01-19' AND hour = 1", query_result)

def test_hive_get_view_names_return_empty_list(self):
self.assertEquals([], HiveEngineSpec.get_view_names(mock.ANY, mock.ANY))
self.assertEquals(
[], HiveEngineSpec.get_view_names(mock.ANY, mock.ANY, mock.ANY)
)

def test_bigquery_sqla_column_label(self):
label = BigQueryEngineSpec.make_label_compatible(column("Col").name)
Expand Down Expand Up @@ -952,15 +956,15 @@ def test_get_table_names(self):
ie. when try_remove_schema_from_table_name == True. """
base_result_expected = ["table", "table_2"]
base_result = BaseEngineSpec.get_table_names(
schema="schema", inspector=inspector
database=mock.ANY, schema="schema", inspector=inspector
)
self.assertListEqual(base_result_expected, base_result)

""" Make sure postgres doesn't try to remove schema name from table name
ie. when try_remove_schema_from_table_name == False. """
pg_result_expected = ["schema.table", "table_2", "table_3"]
pg_result = PostgresEngineSpec.get_table_names(
schema="schema", inspector=inspector
database=mock.ANY, schema="schema", inspector=inspector
)
self.assertListEqual(pg_result_expected, pg_result)

Expand Down

0 comments on commit 12fb8e7

Please sign in to comment.