Skip to content

Commit

Permalink
chore: Cleanup database sessions (#10427)
Browse files Browse the repository at this point in the history
Co-authored-by: John Bodley <john.bodley@airbnb.com>
  • Loading branch information
john-bodley and John Bodley authored Jul 31, 2020
1 parent 7ff1757 commit 7645fc8
Show file tree
Hide file tree
Showing 39 changed files with 496 additions and 645 deletions.
12 changes: 5 additions & 7 deletions superset/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,18 +197,17 @@ def set_database_uri(database_name: str, uri: str) -> None:
)
def refresh_druid(datasource: str, merge: bool) -> None:
"""Refresh druid datasources"""
session = db.session()
from superset.connectors.druid.models import DruidCluster

for cluster in session.query(DruidCluster).all():
for cluster in db.session.query(DruidCluster).all():
try:
cluster.refresh_datasources(datasource_name=datasource, merge_flag=merge)
except Exception as ex: # pylint: disable=broad-except
print("Error while processing cluster '{}'\n{}".format(cluster, str(ex)))
logger.exception(ex)
cluster.metadata_last_refreshed = datetime.now()
print("Refreshed metadata from cluster " "[" + cluster.cluster_name + "]")
session.commit()
db.session.commit()


@superset.command()
Expand Down Expand Up @@ -250,7 +249,7 @@ def import_dashboards(path: str, recursive: bool, username: str) -> None:
logger.info("Importing dashboard from file %s", file_)
try:
with file_.open() as data_stream:
dashboard_import_export.import_dashboards(db.session, data_stream)
dashboard_import_export.import_dashboards(data_stream)
except Exception as ex: # pylint: disable=broad-except
logger.error("Error when importing dashboard from file %s", file_)
logger.error(ex)
Expand All @@ -268,7 +267,7 @@ def export_dashboards(dashboard_file: str, print_stdout: bool) -> None:
"""Export dashboards to JSON"""
from superset.utils import dashboard_import_export

data = dashboard_import_export.export_dashboards(db.session)
data = dashboard_import_export.export_dashboards()
if print_stdout or not dashboard_file:
print(data)
if dashboard_file:
Expand Down Expand Up @@ -321,7 +320,7 @@ def import_datasources(path: str, sync: str, recursive: bool) -> None:
try:
with file_.open() as data_stream:
dict_import_export.import_from_dict(
db.session, yaml.safe_load(data_stream), sync=sync_array
yaml.safe_load(data_stream), sync=sync_array
)
except Exception as ex: # pylint: disable=broad-except
logger.error("Error when importing datasources from file %s", file_)
Expand Down Expand Up @@ -360,7 +359,6 @@ def export_datasources(
from superset.utils import dict_import_export

data = dict_import_export.export_to_dict(
session=db.session,
recursive=True,
back_references=back_references,
include_defaults=include_defaults,
Expand Down
6 changes: 2 additions & 4 deletions superset/commands/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)
from superset.connectors.base.models import BaseDatasource
from superset.connectors.connector_registry import ConnectorRegistry
from superset.extensions import db, security_manager
from superset.extensions import security_manager


def populate_owners(user: User, owners_ids: Optional[List[int]] = None) -> List[User]:
Expand All @@ -50,8 +50,6 @@ def populate_owners(user: User, owners_ids: Optional[List[int]] = None) -> List[

def get_datasource_by_id(datasource_id: int, datasource_type: str) -> BaseDatasource:
try:
return ConnectorRegistry.get_datasource(
datasource_type, datasource_id, db.session
)
return ConnectorRegistry.get_datasource(datasource_type, datasource_id)
except (NoResultFound, KeyError):
raise DatasourceNotFoundValidationError()
4 changes: 2 additions & 2 deletions superset/common/query_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import numpy as np
import pandas as pd

from superset import app, cache, db, security_manager
from superset import app, cache, security_manager
from superset.common.query_object import QueryObject
from superset.connectors.base.models import BaseDatasource
from superset.connectors.connector_registry import ConnectorRegistry
Expand Down Expand Up @@ -64,7 +64,7 @@ def __init__( # pylint: disable=too-many-arguments
result_format: Optional[utils.ChartDataResultFormat] = None,
) -> None:
self.datasource = ConnectorRegistry.get_datasource(
str(datasource["type"]), int(datasource["id"]), db.session
str(datasource["type"]), int(datasource["id"])
)
self.queries = [QueryObject(**query_obj) for query_obj in queries]
self.force = force
Expand Down
35 changes: 14 additions & 21 deletions superset/connectors/connector_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
from typing import Dict, List, Optional, Set, Type, TYPE_CHECKING

from sqlalchemy import or_
from sqlalchemy.orm import Session, subqueryload
from sqlalchemy.orm import subqueryload

from superset.extensions import db

if TYPE_CHECKING:
# pylint: disable=unused-import
Expand All @@ -43,50 +45,45 @@ def register_sources(cls, datasource_config: "OrderedDict[str, List[str]]") -> N

@classmethod
def get_datasource(
cls, datasource_type: str, datasource_id: int, session: Session
cls, datasource_type: str, datasource_id: int
) -> "BaseDatasource":
return (
session.query(cls.sources[datasource_type])
db.session.query(cls.sources[datasource_type])
.filter_by(id=datasource_id)
.one()
)

@classmethod
def get_all_datasources(cls, session: Session) -> List["BaseDatasource"]:
def get_all_datasources(cls) -> List["BaseDatasource"]:
datasources: List["BaseDatasource"] = []
for source_type in ConnectorRegistry.sources:
source_class = ConnectorRegistry.sources[source_type]
qry = session.query(source_class)
qry = db.session.query(source_class)
qry = source_class.default_query(qry)
datasources.extend(qry.all())
return datasources

@classmethod
def get_datasource_by_name( # pylint: disable=too-many-arguments
cls,
session: Session,
datasource_type: str,
datasource_name: str,
schema: str,
database_name: str,
) -> Optional["BaseDatasource"]:
datasource_class = ConnectorRegistry.sources[datasource_type]
return datasource_class.get_datasource_by_name(
session, datasource_name, schema, database_name
datasource_name, schema, database_name
)

@classmethod
def query_datasources_by_permissions( # pylint: disable=invalid-name
cls,
session: Session,
database: "Database",
permissions: Set[str],
schema_perms: Set[str],
cls, database: "Database", permissions: Set[str], schema_perms: Set[str],
) -> List["BaseDatasource"]:
# TODO(bogdan): add unit test
datasource_class = ConnectorRegistry.sources[database.type]
return (
session.query(datasource_class)
db.session.query(datasource_class)
.filter_by(database_id=database.id)
.filter(
or_(
Expand All @@ -99,12 +96,12 @@ def query_datasources_by_permissions( # pylint: disable=invalid-name

@classmethod
def get_eager_datasource(
cls, session: Session, datasource_type: str, datasource_id: int
cls, datasource_type: str, datasource_id: int
) -> "BaseDatasource":
"""Returns datasource with columns and metrics."""
datasource_class = ConnectorRegistry.sources[datasource_type]
return (
session.query(datasource_class)
db.session.query(datasource_class)
.options(
subqueryload(datasource_class.columns),
subqueryload(datasource_class.metrics),
Expand All @@ -115,13 +112,9 @@ def get_eager_datasource(

@classmethod
def query_datasources_by_name(
cls,
session: Session,
database: "Database",
datasource_name: str,
schema: Optional[str] = None,
cls, database: "Database", datasource_name: str, schema: Optional[str] = None,
) -> List["BaseDatasource"]:
datasource_class = ConnectorRegistry.sources[database.type]
return datasource_class.query_datasources_by_name(
session, database, datasource_name, schema=schema
database, datasource_name, schema=schema
)
52 changes: 24 additions & 28 deletions superset/connectors/druid/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
UniqueConstraint,
)
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import backref, relationship, Session
from sqlalchemy.orm import backref, relationship
from sqlalchemy.sql import expression
from sqlalchemy_utils import EncryptedType

Expand Down Expand Up @@ -223,9 +223,8 @@ def refresh(
Fetches metadata for the specified datasources and
merges to the Superset database
"""
session = db.session
ds_list = (
session.query(DruidDatasource)
db.session.query(DruidDatasource)
.filter(DruidDatasource.cluster_id == self.id)
.filter(DruidDatasource.datasource_name.in_(datasource_names))
)
Expand All @@ -234,8 +233,8 @@ def refresh(
datasource = ds_map.get(ds_name, None)
if not datasource:
datasource = DruidDatasource(datasource_name=ds_name)
with session.no_autoflush:
session.add(datasource)
with db.session.no_autoflush:
db.session.add(datasource)
flasher(_("Adding new datasource [{}]").format(ds_name), "success")
ds_map[ds_name] = datasource
elif refresh_all:
Expand All @@ -245,7 +244,7 @@ def refresh(
continue
datasource.cluster = self
datasource.merge_flag = merge_flag
session.flush()
db.session.flush()

# Prepare multithreaded executation
pool = ThreadPool()
Expand All @@ -259,7 +258,7 @@ def refresh(
cols = metadata[i]
if cols:
col_objs_list = (
session.query(DruidColumn)
db.session.query(DruidColumn)
.filter(DruidColumn.datasource_id == datasource.id)
.filter(DruidColumn.column_name.in_(cols.keys()))
)
Expand All @@ -272,15 +271,15 @@ def refresh(
col_obj = DruidColumn(
datasource_id=datasource.id, column_name=col
)
with session.no_autoflush:
session.add(col_obj)
with db.session.no_autoflush:
db.session.add(col_obj)
col_obj.type = cols[col]["type"]
col_obj.datasource = datasource
if col_obj.type == "STRING":
col_obj.groupby = True
col_obj.filterable = True
datasource.refresh_metrics()
session.commit()
db.session.commit()

@hybrid_property
def perm(self) -> str:
Expand Down Expand Up @@ -390,7 +389,7 @@ def lookup_obj(lookup_column: DruidColumn) -> Optional[DruidColumn]:
.first()
)

return import_datasource.import_simple_obj(db.session, i_column, lookup_obj)
return import_datasource.import_simple_obj(i_column, lookup_obj)


class DruidMetric(Model, BaseMetric):
Expand Down Expand Up @@ -459,7 +458,7 @@ def lookup_obj(lookup_metric: DruidMetric) -> Optional[DruidMetric]:
.first()
)

return import_datasource.import_simple_obj(db.session, i_metric, lookup_obj)
return import_datasource.import_simple_obj(i_metric, lookup_obj)


druiddatasource_user = Table(
Expand Down Expand Up @@ -635,7 +634,7 @@ def lookup_cluster(d: DruidDatasource) -> Optional[DruidCluster]:
return db.session.query(DruidCluster).filter_by(id=d.cluster_id).first()

return import_datasource.import_datasource(
db.session, i_datasource, lookup_cluster, lookup_datasource, import_time
i_datasource, lookup_cluster, lookup_datasource, import_time
)

def latest_metadata(self) -> Optional[Dict[str, Any]]:
Expand Down Expand Up @@ -705,9 +704,10 @@ def sync_to_db_from_config(
refresh: bool = True,
) -> None:
"""Merges the ds config from druid_config into one stored in the db."""
session = db.session
datasource = (
session.query(cls).filter_by(datasource_name=druid_config["name"]).first()
db.session.query(cls)
.filter_by(datasource_name=druid_config["name"])
.first()
)
# Create a new datasource.
if not datasource:
Expand All @@ -718,13 +718,13 @@ def sync_to_db_from_config(
changed_by_fk=user.id,
created_by_fk=user.id,
)
session.add(datasource)
db.session.add(datasource)
elif not refresh:
return

dimensions = druid_config["dimensions"]
col_objs = (
session.query(DruidColumn)
db.session.query(DruidColumn)
.filter(DruidColumn.datasource_id == datasource.id)
.filter(DruidColumn.column_name.in_(dimensions))
)
Expand All @@ -741,10 +741,10 @@ def sync_to_db_from_config(
type="STRING",
datasource=datasource,
)
session.add(col_obj)
db.session.add(col_obj)
# Import Druid metrics
metric_objs = (
session.query(DruidMetric)
db.session.query(DruidMetric)
.filter(DruidMetric.datasource_id == datasource.id)
.filter(
DruidMetric.metric_name.in_(
Expand Down Expand Up @@ -777,8 +777,8 @@ def sync_to_db_from_config(
% druid_config["name"]
),
)
session.add(metric_obj)
session.commit()
db.session.add(metric_obj)
db.session.commit()

@staticmethod
def time_offset(granularity: Granularity) -> int:
Expand All @@ -788,10 +788,10 @@ def time_offset(granularity: Granularity) -> int:

@classmethod
def get_datasource_by_name(
cls, session: Session, datasource_name: str, schema: str, database_name: str
cls, datasource_name: str, schema: str, database_name: str
) -> Optional["DruidDatasource"]:
query = (
session.query(cls)
db.session.query(cls)
.join(DruidCluster)
.filter(cls.datasource_name == datasource_name)
.filter(DruidCluster.cluster_name == database_name)
Expand Down Expand Up @@ -1724,11 +1724,7 @@ def get_having_filters(

@classmethod
def query_datasources_by_name(
cls,
session: Session,
database: Database,
datasource_name: str,
schema: Optional[str] = None,
cls, database: Database, datasource_name: str, schema: Optional[str] = None,
) -> List["DruidDatasource"]:
return []

Expand Down
5 changes: 2 additions & 3 deletions superset/connectors/druid/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,11 +365,10 @@ def refresh_datasources( # pylint: disable=no-self-use
self, refresh_all: bool = True
) -> FlaskResponse:
"""endpoint that refreshes druid datasources metadata"""
session = db.session()
DruidCluster = ConnectorRegistry.sources[ # pylint: disable=invalid-name
"druid"
].cluster_class
for cluster in session.query(DruidCluster).all():
for cluster in db.session.query(DruidCluster).all():
cluster_name = cluster.cluster_name
valid_cluster = True
try:
Expand All @@ -391,7 +390,7 @@ def refresh_datasources( # pylint: disable=no-self-use
),
"info",
)
session.commit()
db.session.commit()
return redirect("/druiddatasourcemodelview/list/")

@has_access
Expand Down
Loading

0 comments on commit 7645fc8

Please sign in to comment.