Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Ensure SQLAlchemy sessions are closed #25031

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 25 additions & 23 deletions superset/models/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,28 +74,31 @@ def copy_dashboard(_mapper: Mapper, connection: Connection, target: Dashboard) -

session_class = sessionmaker(autoflush=False)
session = session_class(bind=connection)
new_user = session.query(User).filter_by(id=target.id).first()

# copy template dashboard to user
template = session.query(Dashboard).filter_by(id=int(dashboard_id)).first()
dashboard = Dashboard(
dashboard_title=template.dashboard_title,
position_json=template.position_json,
description=template.description,
css=template.css,
json_metadata=template.json_metadata,
slices=template.slices,
owners=[new_user],
)
session.add(dashboard)
session.commit()

# set dashboard as the welcome dashboard
extra_attributes = UserAttribute(
user_id=target.id, welcome_dashboard_id=dashboard.id
)
session.add(extra_attributes)
session.commit()
try:
new_user = session.query(User).filter_by(id=target.id).first()

# copy template dashboard to user
template = session.query(Dashboard).filter_by(id=int(dashboard_id)).first()
dashboard = Dashboard(
dashboard_title=template.dashboard_title,
position_json=template.position_json,
description=template.description,
css=template.css,
json_metadata=template.json_metadata,
slices=template.slices,
owners=[new_user],
)
session.add(dashboard)

# set dashboard as the welcome dashboard
extra_attributes = UserAttribute(
user_id=target.id, welcome_dashboard_id=dashboard.id
)
session.add(extra_attributes)
session.commit()
finally:
session.close()


sqla.event.listen(User, "after_insert", copy_dashboard)
Expand Down Expand Up @@ -414,13 +417,12 @@ def export_dashboards( # pylint: disable=too-many-locals
"native_filter_configuration", []
)
for native_filter in native_filter_configuration:
session = db.session()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's unclear why a new session was instantiated (and never closed) here. On line #433 the Flask-SQLAlchemy session is used and thus it seems prudent (and hopefully) safe to use db.session.

for target in native_filter.get("targets", []):
id_ = target.get("datasetId")
if id_ is None:
continue
datasource = DatasourceDAO.get_datasource(
session, utils.DatasourceType.TABLE, id_
db.session, utils.DatasourceType.TABLE, id_
)
datasource_ids.add((datasource.id, datasource.type))

Expand Down
126 changes: 69 additions & 57 deletions superset/tags/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,17 +170,19 @@ def after_insert(
) -> None:
session = Session(bind=connection)

# add `owner:` tags
cls._add_owners(session, target)
try:
# add `owner:` tags
cls._add_owners(session, target)

# add `type:` tags
tag = get_tag(f"type:{cls.object_type}", session, TagTypes.type)
tagged_object = TaggedObject(
tag_id=tag.id, object_id=target.id, object_type=cls.object_type
)
session.add(tagged_object)

session.commit()
# add `type:` tags
tag = get_tag(f"type:{cls.object_type}", session, TagTypes.type)
tagged_object = TaggedObject(
tag_id=tag.id, object_id=target.id, object_type=cls.object_type
)
session.add(tagged_object)
session.commit()
finally:
session.close()

@classmethod
def after_update(
Expand All @@ -191,25 +193,27 @@ def after_update(
) -> None:
session = Session(bind=connection)

# delete current `owner:` tags
query = (
session.query(TaggedObject.id)
.join(Tag)
.filter(
TaggedObject.object_type == cls.object_type,
TaggedObject.object_id == target.id,
Tag.type == TagTypes.owner,
try:
# delete current `owner:` tags
query = (
session.query(TaggedObject.id)
.join(Tag)
.filter(
TaggedObject.object_type == cls.object_type,
TaggedObject.object_id == target.id,
Tag.type == TagTypes.owner,
)
)
ids = [row[0] for row in query]
session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete(
synchronize_session=False
)
)
ids = [row[0] for row in query]
session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete(
synchronize_session=False
)

# add `owner:` tags
cls._add_owners(session, target)

session.commit()
# add `owner:` tags
cls._add_owners(session, target)
session.commit()
finally:
session.close()

@classmethod
def after_delete(
Expand All @@ -220,13 +224,16 @@ def after_delete(
) -> None:
session = Session(bind=connection)

# delete row from `tagged_objects`
session.query(TaggedObject).filter(
TaggedObject.object_type == cls.object_type,
TaggedObject.object_id == target.id,
).delete()
try:
# delete row from `tagged_objects`
session.query(TaggedObject).filter(
TaggedObject.object_type == cls.object_type,
TaggedObject.object_id == target.id,
).delete()

session.commit()
session.commit()
finally:
session.close()


class ChartUpdater(ObjectUpdater):
Expand Down Expand Up @@ -267,35 +274,40 @@ def after_insert(
cls, _mapper: Mapper, connection: Connection, target: FavStar
) -> None:
session = Session(bind=connection)
name = f"favorited_by:{target.user_id}"
tag = get_tag(name, session, TagTypes.favorited_by)
tagged_object = TaggedObject(
tag_id=tag.id,
object_id=target.obj_id,
object_type=get_object_type(target.class_name),
)
session.add(tagged_object)

session.commit()
try:
name = f"favorited_by:{target.user_id}"
tag = get_tag(name, session, TagTypes.favorited_by)
tagged_object = TaggedObject(
tag_id=tag.id,
object_id=target.obj_id,
object_type=get_object_type(target.class_name),
)
session.add(tagged_object)
session.commit()
finally:
session.close()

@classmethod
def after_delete(
cls, _mapper: Mapper, connection: Connection, target: FavStar
) -> None:
session = Session(bind=connection)
name = f"favorited_by:{target.user_id}"
query = (
session.query(TaggedObject.id)
.join(Tag)
.filter(
TaggedObject.object_id == target.obj_id,
Tag.type == TagTypes.favorited_by,
Tag.name == name,
try:
name = f"favorited_by:{target.user_id}"
query = (
session.query(TaggedObject.id)
.join(Tag)
.filter(
TaggedObject.object_id == target.obj_id,
Tag.type == TagTypes.favorited_by,
Tag.name == name,
)
)
ids = [row[0] for row in query]
session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete(
synchronize_session=False
)
)
ids = [row[0] for row in query]
session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete(
synchronize_session=False
)

session.commit()
session.commit()
finally:
session.close()
106 changes: 59 additions & 47 deletions superset/tasks/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,11 @@ class DummyStrategy(Strategy): # pylint: disable=too-few-public-methods

def get_payloads(self) -> list[dict[str, int]]:
session = db.create_scoped_session()
charts = session.query(Slice).all()

try:
charts = session.query(Slice).all()
finally:
session.close()

return [get_payload(chart) for chart in charts]

Expand Down Expand Up @@ -129,20 +133,24 @@ def get_payloads(self) -> list[dict[str, int]]:
payloads = []
session = db.create_scoped_session()

records = (
session.query(Log.dashboard_id, func.count(Log.dashboard_id))
.filter(and_(Log.dashboard_id.isnot(None), Log.dttm >= self.since))
.group_by(Log.dashboard_id)
.order_by(func.count(Log.dashboard_id).desc())
.limit(self.top_n)
.all()
)
dash_ids = [record.dashboard_id for record in records]
dashboards = session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all()
for dashboard in dashboards:
for chart in dashboard.slices:
payloads.append(get_payload(chart, dashboard))

try:
records = (
session.query(Log.dashboard_id, func.count(Log.dashboard_id))
.filter(and_(Log.dashboard_id.isnot(None), Log.dttm >= self.since))
.group_by(Log.dashboard_id)
.order_by(func.count(Log.dashboard_id).desc())
.limit(self.top_n)
.all()
)
dash_ids = [record.dashboard_id for record in records]
dashboards = (
session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all()
)
for dashboard in dashboards:
for chart in dashboard.slices:
payloads.append(get_payload(chart, dashboard))
finally:
session.close()
return payloads


Expand Down Expand Up @@ -172,42 +180,46 @@ def get_payloads(self) -> list[dict[str, int]]:
payloads = []
session = db.create_scoped_session()

tags = session.query(Tag).filter(Tag.name.in_(self.tags)).all()
tag_ids = [tag.id for tag in tags]

# add dashboards that are tagged
tagged_objects = (
session.query(TaggedObject)
.filter(
and_(
TaggedObject.object_type == "dashboard",
TaggedObject.tag_id.in_(tag_ids),
try:
tags = session.query(Tag).filter(Tag.name.in_(self.tags)).all()
tag_ids = [tag.id for tag in tags]

# add dashboards that are tagged
tagged_objects = (
session.query(TaggedObject)
.filter(
and_(
TaggedObject.object_type == "dashboard",
TaggedObject.tag_id.in_(tag_ids),
)
)
.all()
)
.all()
)
dash_ids = [tagged_object.object_id for tagged_object in tagged_objects]
tagged_dashboards = session.query(Dashboard).filter(Dashboard.id.in_(dash_ids))
for dashboard in tagged_dashboards:
for chart in dashboard.slices:
payloads.append(get_payload(chart))

# add charts that are tagged
tagged_objects = (
session.query(TaggedObject)
.filter(
and_(
TaggedObject.object_type == "chart",
TaggedObject.tag_id.in_(tag_ids),
dash_ids = [tagged_object.object_id for tagged_object in tagged_objects]
tagged_dashboards = session.query(Dashboard).filter(
Dashboard.id.in_(dash_ids)
)
for dashboard in tagged_dashboards:
for chart in dashboard.slices:
payloads.append(get_payload(chart))

# add charts that are tagged
tagged_objects = (
session.query(TaggedObject)
.filter(
and_(
TaggedObject.object_type == "chart",
TaggedObject.tag_id.in_(tag_ids),
)
)
.all()
)
.all()
)
chart_ids = [tagged_object.object_id for tagged_object in tagged_objects]
tagged_charts = session.query(Slice).filter(Slice.id.in_(chart_ids))
for chart in tagged_charts:
payloads.append(get_payload(chart))

chart_ids = [tagged_object.object_id for tagged_object in tagged_objects]
tagged_charts = session.query(Slice).filter(Slice.id.in_(chart_ids))
for chart in tagged_charts:
payloads.append(get_payload(chart))
finally:
session.close()
return payloads


Expand Down
Loading