Skip to content

Commit

Permalink
fix: Ensure SQLAlchemy sessions are closed (#25031)
Browse files Browse the repository at this point in the history
(cherry picked from commit adaab35)
  • Loading branch information
john-bodley authored and michael-s-molina committed Aug 24, 2023
1 parent 931e1b2 commit ad89ea5
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 127 deletions.
48 changes: 25 additions & 23 deletions superset/models/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,28 +74,31 @@ def copy_dashboard(_mapper: Mapper, connection: Connection, target: Dashboard) -

session_class = sessionmaker(autoflush=False)
session = session_class(bind=connection)
new_user = session.query(User).filter_by(id=target.id).first()

# copy template dashboard to user
template = session.query(Dashboard).filter_by(id=int(dashboard_id)).first()
dashboard = Dashboard(
dashboard_title=template.dashboard_title,
position_json=template.position_json,
description=template.description,
css=template.css,
json_metadata=template.json_metadata,
slices=template.slices,
owners=[new_user],
)
session.add(dashboard)
session.commit()

# set dashboard as the welcome dashboard
extra_attributes = UserAttribute(
user_id=target.id, welcome_dashboard_id=dashboard.id
)
session.add(extra_attributes)
session.commit()
try:
new_user = session.query(User).filter_by(id=target.id).first()

# copy template dashboard to user
template = session.query(Dashboard).filter_by(id=int(dashboard_id)).first()
dashboard = Dashboard(
dashboard_title=template.dashboard_title,
position_json=template.position_json,
description=template.description,
css=template.css,
json_metadata=template.json_metadata,
slices=template.slices,
owners=[new_user],
)
session.add(dashboard)

# set dashboard as the welcome dashboard
extra_attributes = UserAttribute(
user_id=target.id, welcome_dashboard_id=dashboard.id
)
session.add(extra_attributes)
session.commit()
finally:
session.close()


sqla.event.listen(User, "after_insert", copy_dashboard)
Expand Down Expand Up @@ -411,13 +414,12 @@ def export_dashboards( # pylint: disable=too-many-locals
"native_filter_configuration", []
)
for native_filter in native_filter_configuration:
session = db.session()
for target in native_filter.get("targets", []):
id_ = target.get("datasetId")
if id_ is None:
continue
datasource = DatasourceDAO.get_datasource(
session, utils.DatasourceType.TABLE, id_
db.session, utils.DatasourceType.TABLE, id_
)
datasource_ids.add((datasource.id, datasource.type))

Expand Down
126 changes: 69 additions & 57 deletions superset/tags/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,17 +156,19 @@ def after_insert(
) -> None:
session = Session(bind=connection)

# add `owner:` tags
cls._add_owners(session, target)
try:
# add `owner:` tags
cls._add_owners(session, target)

# add `type:` tags
tag = get_tag(f"type:{cls.object_type}", session, TagTypes.type)
tagged_object = TaggedObject(
tag_id=tag.id, object_id=target.id, object_type=cls.object_type
)
session.add(tagged_object)

session.commit()
# add `type:` tags
tag = get_tag(f"type:{cls.object_type}", session, TagTypes.type)
tagged_object = TaggedObject(
tag_id=tag.id, object_id=target.id, object_type=cls.object_type
)
session.add(tagged_object)
session.commit()
finally:
session.close()

@classmethod
def after_update(
Expand All @@ -177,25 +179,27 @@ def after_update(
) -> None:
session = Session(bind=connection)

# delete current `owner:` tags
query = (
session.query(TaggedObject.id)
.join(Tag)
.filter(
TaggedObject.object_type == cls.object_type,
TaggedObject.object_id == target.id,
Tag.type == TagTypes.owner,
try:
# delete current `owner:` tags
query = (
session.query(TaggedObject.id)
.join(Tag)
.filter(
TaggedObject.object_type == cls.object_type,
TaggedObject.object_id == target.id,
Tag.type == TagTypes.owner,
)
)
ids = [row[0] for row in query]
session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete(
synchronize_session=False
)
)
ids = [row[0] for row in query]
session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete(
synchronize_session=False
)

# add `owner:` tags
cls._add_owners(session, target)

session.commit()
# add `owner:` tags
cls._add_owners(session, target)
session.commit()
finally:
session.close()

@classmethod
def after_delete(
Expand All @@ -206,13 +210,16 @@ def after_delete(
) -> None:
session = Session(bind=connection)

# delete row from `tagged_objects`
session.query(TaggedObject).filter(
TaggedObject.object_type == cls.object_type,
TaggedObject.object_id == target.id,
).delete()
try:
# delete row from `tagged_objects`
session.query(TaggedObject).filter(
TaggedObject.object_type == cls.object_type,
TaggedObject.object_id == target.id,
).delete()

session.commit()
session.commit()
finally:
session.close()


class ChartUpdater(ObjectUpdater):
Expand Down Expand Up @@ -253,35 +260,40 @@ def after_insert(
cls, _mapper: Mapper, connection: Connection, target: FavStar
) -> None:
session = Session(bind=connection)
name = f"favorited_by:{target.user_id}"
tag = get_tag(name, session, TagTypes.favorited_by)
tagged_object = TaggedObject(
tag_id=tag.id,
object_id=target.obj_id,
object_type=get_object_type(target.class_name),
)
session.add(tagged_object)

session.commit()
try:
name = f"favorited_by:{target.user_id}"
tag = get_tag(name, session, TagTypes.favorited_by)
tagged_object = TaggedObject(
tag_id=tag.id,
object_id=target.obj_id,
object_type=get_object_type(target.class_name),
)
session.add(tagged_object)
session.commit()
finally:
session.close()

@classmethod
def after_delete(
cls, _mapper: Mapper, connection: Connection, target: FavStar
) -> None:
session = Session(bind=connection)
name = f"favorited_by:{target.user_id}"
query = (
session.query(TaggedObject.id)
.join(Tag)
.filter(
TaggedObject.object_id == target.obj_id,
Tag.type == TagTypes.favorited_by,
Tag.name == name,
try:
name = f"favorited_by:{target.user_id}"
query = (
session.query(TaggedObject.id)
.join(Tag)
.filter(
TaggedObject.object_id == target.obj_id,
Tag.type == TagTypes.favorited_by,
Tag.name == name,
)
)
ids = [row[0] for row in query]
session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete(
synchronize_session=False
)
)
ids = [row[0] for row in query]
session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete(
synchronize_session=False
)

session.commit()
session.commit()
finally:
session.close()
106 changes: 59 additions & 47 deletions superset/tasks/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,11 @@ class DummyStrategy(Strategy): # pylint: disable=too-few-public-methods

def get_payloads(self) -> list[dict[str, int]]:
session = db.create_scoped_session()
charts = session.query(Slice).all()

try:
charts = session.query(Slice).all()
finally:
session.close()

return [get_payload(chart) for chart in charts]

Expand Down Expand Up @@ -129,20 +133,24 @@ def get_payloads(self) -> list[dict[str, int]]:
payloads = []
session = db.create_scoped_session()

records = (
session.query(Log.dashboard_id, func.count(Log.dashboard_id))
.filter(and_(Log.dashboard_id.isnot(None), Log.dttm >= self.since))
.group_by(Log.dashboard_id)
.order_by(func.count(Log.dashboard_id).desc())
.limit(self.top_n)
.all()
)
dash_ids = [record.dashboard_id for record in records]
dashboards = session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all()
for dashboard in dashboards:
for chart in dashboard.slices:
payloads.append(get_payload(chart, dashboard))

try:
records = (
session.query(Log.dashboard_id, func.count(Log.dashboard_id))
.filter(and_(Log.dashboard_id.isnot(None), Log.dttm >= self.since))
.group_by(Log.dashboard_id)
.order_by(func.count(Log.dashboard_id).desc())
.limit(self.top_n)
.all()
)
dash_ids = [record.dashboard_id for record in records]
dashboards = (
session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all()
)
for dashboard in dashboards:
for chart in dashboard.slices:
payloads.append(get_payload(chart, dashboard))
finally:
session.close()
return payloads


Expand Down Expand Up @@ -172,42 +180,46 @@ def get_payloads(self) -> list[dict[str, int]]:
payloads = []
session = db.create_scoped_session()

tags = session.query(Tag).filter(Tag.name.in_(self.tags)).all()
tag_ids = [tag.id for tag in tags]

# add dashboards that are tagged
tagged_objects = (
session.query(TaggedObject)
.filter(
and_(
TaggedObject.object_type == "dashboard",
TaggedObject.tag_id.in_(tag_ids),
try:
tags = session.query(Tag).filter(Tag.name.in_(self.tags)).all()
tag_ids = [tag.id for tag in tags]

# add dashboards that are tagged
tagged_objects = (
session.query(TaggedObject)
.filter(
and_(
TaggedObject.object_type == "dashboard",
TaggedObject.tag_id.in_(tag_ids),
)
)
.all()
)
.all()
)
dash_ids = [tagged_object.object_id for tagged_object in tagged_objects]
tagged_dashboards = session.query(Dashboard).filter(Dashboard.id.in_(dash_ids))
for dashboard in tagged_dashboards:
for chart in dashboard.slices:
payloads.append(get_payload(chart))

# add charts that are tagged
tagged_objects = (
session.query(TaggedObject)
.filter(
and_(
TaggedObject.object_type == "chart",
TaggedObject.tag_id.in_(tag_ids),
dash_ids = [tagged_object.object_id for tagged_object in tagged_objects]
tagged_dashboards = session.query(Dashboard).filter(
Dashboard.id.in_(dash_ids)
)
for dashboard in tagged_dashboards:
for chart in dashboard.slices:
payloads.append(get_payload(chart))

# add charts that are tagged
tagged_objects = (
session.query(TaggedObject)
.filter(
and_(
TaggedObject.object_type == "chart",
TaggedObject.tag_id.in_(tag_ids),
)
)
.all()
)
.all()
)
chart_ids = [tagged_object.object_id for tagged_object in tagged_objects]
tagged_charts = session.query(Slice).filter(Slice.id.in_(chart_ids))
for chart in tagged_charts:
payloads.append(get_payload(chart))

chart_ids = [tagged_object.object_id for tagged_object in tagged_objects]
tagged_charts = session.query(Slice).filter(Slice.id.in_(chart_ids))
for chart in tagged_charts:
payloads.append(get_payload(chart))
finally:
session.close()
return payloads


Expand Down

0 comments on commit ad89ea5

Please sign in to comment.