Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(dao): Replace save/overwrite with create/update respectively #24467

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions superset/annotation_layers/annotations/commands/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,10 @@ def __init__(self, data: dict[str, Any]):
def run(self) -> Model:
self.validate()
try:
annotation = AnnotationDAO.create(self._properties)
return AnnotationDAO.create(attributes=self._properties)
except DAOCreateFailedError as ex:
logger.exception(ex.exception)
raise AnnotationCreateFailedError() from ex
return annotation

def validate(self) -> None:
exceptions: list[ValidationError] = []
Expand Down
3 changes: 1 addition & 2 deletions superset/annotation_layers/commands/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,10 @@ def __init__(self, data: dict[str, Any]):
def run(self) -> Model:
self.validate()
try:
annotation_layer = AnnotationLayerDAO.create(self._properties)
return AnnotationLayerDAO.create(attributes=self._properties)
except DAOCreateFailedError as ex:
logger.exception(ex.exception)
raise AnnotationLayerCreateFailedError() from ex
return annotation_layer

def validate(self) -> None:
exceptions: list[ValidationError] = []
Expand Down
3 changes: 1 addition & 2 deletions superset/charts/commands/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,10 @@ def run(self) -> Model:
try:
self._properties["last_saved_at"] = datetime.now()
self._properties["last_saved_by"] = g.user
chart = ChartDAO.create(self._properties)
return ChartDAO.create(attributes=self._properties)
except DAOCreateFailedError as ex:
logger.exception(ex.exception)
raise ChartCreateFailedError() from ex
return chart

def validate(self) -> None:
exceptions = []
Expand Down
80 changes: 47 additions & 33 deletions superset/daos/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from sqlalchemy.orm import Session

from superset.daos.exceptions import (
DAOConfigError,
DAOCreateFailedError,
DAODeleteFailedError,
DAOUpdateFailedError,
Expand Down Expand Up @@ -130,57 +129,72 @@ def find_one_or_none(cls, **filter_by: Any) -> T | None:
return query.filter_by(**filter_by).one_or_none()

@classmethod
def create(cls, properties: dict[str, Any], commit: bool = True) -> T:
"""
Generic for creating models
:raises: DAOCreateFailedError
def create(
cls,
item: T | None = None,
Copy link
Member Author

@john-bodley john-bodley Jun 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To accommodate the save functionality I extended the create method to accept a predefined object.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be possible to set item as T | dict[str, Any] | None and remove the attributes parameter?

Copy link
Member Author

@john-bodley john-bodley Jul 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@michael-s-molina that's an interesting idea, though I wonder if we're overloading item here. I was likely just trying to replicate the same behavior (familiarity) as the update method which needs to have either/or both the item and attributes—well technically you're just updating an existing item, but it seems cleaner (from a DRY perspective) to have the base update method perform said logic.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically I can call create with both item and attributes set as None. It feels really weird for me.

Copy link
Member Author

@john-bodley john-bodley Aug 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@michael-s-molina sorry for the delay in getting back to you on this. You mention you could do something of the form,

chart = ChartDAO.create()

which would try to create a new chart and persist it to the database. This operation would succeed if the model attributes can be nullable. I realize this isn't a typical workflow, but I'm not entirely sure it's wrong per se.

BTW I totally agree there's room for improvement here—I was even toiling with the idea that maybe create and update should be handled by a single upsert method—however I do sense this PR helps ensure that the code is more consistent, i.e., the save and override methods have been removed and both the create and update methods have the same function signature.

attributes: dict[str, Any] | None = None,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used the term "attributes" rather than "properties" as this is the lingo that SQLAlchemy uses—combined with the fact you use the setattr method for setting them.

commit: bool = True,
) -> T:
"""
if cls.model_cls is None:
raise DAOConfigError()
model = cls.model_cls() # pylint: disable=not-callable
for key, value in properties.items():
setattr(model, key, value)
try:
db.session.add(model)
if commit:
db.session.commit()
except SQLAlchemyError as ex: # pragma: no cover
db.session.rollback()
raise DAOCreateFailedError(exception=ex) from ex
return model
Create an object from the specified item and/or attributes.

@classmethod
def save(cls, instance_model: T, commit: bool = True) -> None:
"""
Generic for saving models
:raises: DAOCreateFailedError
:param item: The object to create
:param attributes: The attributes associated with the object to create
:param commit: Whether to commit the transaction
:raises DAOCreateFailedError: If the creation failed
"""
if cls.model_cls is None:
raise DAOConfigError()

if not item:
item = cls.model_cls() # type: ignore # pylint: disable=not-callable

if attributes:
for key, value in attributes.items():
setattr(item, key, value)

try:
db.session.add(instance_model)
db.session.add(item)

if commit:
db.session.commit()
except SQLAlchemyError as ex: # pragma: no cover
db.session.rollback()
raise DAOCreateFailedError(exception=ex) from ex

return item # type: ignore

@classmethod
def update(cls, model: T, properties: dict[str, Any], commit: bool = True) -> T:
def update(
cls,
item: T | None = None,
Copy link
Member Author

@john-bodley john-bodley Jun 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The item and/or attributes are now optional which provides more flexibility.

attributes: dict[str, Any] | None = None,
commit: bool = True,
) -> T:
"""
Generic update a model
:raises: DAOCreateFailedError
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously this comment was wrong.

Update an object from the specified item and/or attributes.

:param item: The object to update
:param attributes: The attributes associated with the object to update
:param commit: Whether to commit the transaction
:raises DAOUpdateFailedError: If the updating failed
"""
for key, value in properties.items():
setattr(model, key, value)

if not item:
item = cls.model_cls() # type: ignore # pylint: disable=not-callable

if attributes:
for key, value in attributes.items():
setattr(item, key, value)

try:
db.session.merge(model)
db.session.merge(item)

if commit:
db.session.commit()
except SQLAlchemyError as ex: # pragma: no cover
db.session.rollback()
raise DAOUpdateFailedError(exception=ex) from ex
return model

return item # type: ignore

@classmethod
def delete(cls, items: T | list[T], commit: bool = True) -> None:
Expand Down
13 changes: 0 additions & 13 deletions superset/daos/chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=arguments-renamed
from __future__ import annotations

import logging
Expand Down Expand Up @@ -54,18 +53,6 @@ def delete(cls, items: Slice | list[Slice], commit: bool = True) -> None:
db.session.rollback()
raise ex

@staticmethod
def save(slc: Slice, commit: bool = True) -> None:
db.session.add(slc)
if commit:
db.session.commit()

@staticmethod
def overwrite(slc: Slice, commit: bool = True) -> None:
db.session.merge(slc)
if commit:
db.session.commit()

@staticmethod
def favorited_ids(charts: list[Slice]) -> list[FavStar]:
ids = [chart.id for chart in charts]
Expand Down
55 changes: 29 additions & 26 deletions superset/daos/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,11 @@
from typing import Any

from flask import g
from flask_appbuilder.models.sqla import Model
from flask_appbuilder.models.sqla.interface import SQLAInterface
from sqlalchemy.exc import SQLAlchemyError

from superset import is_feature_enabled, security_manager
from superset.daos.base import BaseDAO
from superset.daos.exceptions import DAOConfigError, DAOCreateFailedError
from superset.dashboards.commands.exceptions import (
DashboardAccessDeniedError,
DashboardForbiddenError,
Expand Down Expand Up @@ -403,35 +401,40 @@ def upsert(dashboard: Dashboard, allowed_domains: list[str]) -> EmbeddedDashboar
return embedded

@classmethod
def create(cls, properties: dict[str, Any], commit: bool = True) -> Any:
def create(
cls,
item: EmbeddedDashboardDAO | None = None,
attributes: dict[str, Any] | None = None,
commit: bool = True,
) -> Any:
"""
Use EmbeddedDashboardDAO.upsert() instead.
At least, until we are ok with more than one embedded instance per dashboard.
At least, until we are ok with more than one embedded item per dashboard.
"""
raise NotImplementedError("Use EmbeddedDashboardDAO.upsert() instead.")


class FilterSetDAO(BaseDAO[FilterSet]):
@classmethod
def create(cls, properties: dict[str, Any], commit: bool = True) -> Model:
if cls.model_cls is None:
raise DAOConfigError()
model = FilterSet()
setattr(model, NAME_FIELD, properties[NAME_FIELD])
setattr(model, JSON_METADATA_FIELD, properties[JSON_METADATA_FIELD])
setattr(model, DESCRIPTION_FIELD, properties.get(DESCRIPTION_FIELD, None))
setattr(
model,
OWNER_ID_FIELD,
properties.get(OWNER_ID_FIELD, properties[DASHBOARD_ID_FIELD]),
)
setattr(model, OWNER_TYPE_FIELD, properties[OWNER_TYPE_FIELD])
setattr(model, DASHBOARD_ID_FIELD, properties[DASHBOARD_ID_FIELD])
try:
db.session.add(model)
if commit:
db.session.commit()
except SQLAlchemyError as ex: # pragma: no cover
db.session.rollback()
raise DAOCreateFailedError() from ex
return model
def create(
cls,
item: FilterSet | None = None,
attributes: dict[str, Any] | None = None,
commit: bool = True,
) -> FilterSet:
if not item:
item = FilterSet()

if attributes:
setattr(item, NAME_FIELD, attributes[NAME_FIELD])
setattr(item, JSON_METADATA_FIELD, attributes[JSON_METADATA_FIELD])
setattr(item, DESCRIPTION_FIELD, attributes.get(DESCRIPTION_FIELD, None))
setattr(
item,
OWNER_ID_FIELD,
attributes.get(OWNER_ID_FIELD, attributes[DASHBOARD_ID_FIELD]),
)
setattr(item, OWNER_TYPE_FIELD, attributes[OWNER_TYPE_FIELD])
setattr(item, DASHBOARD_ID_FIELD, attributes[DASHBOARD_ID_FIELD])

return super().create(item, commit=commit)
35 changes: 20 additions & 15 deletions superset/daos/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import logging
from typing import Any, Optional
from typing import Any

from superset.daos.base import BaseDAO
from superset.databases.filters import DatabaseFilter
Expand All @@ -37,8 +39,8 @@ class DatabaseDAO(BaseDAO[Database]):
@classmethod
def update(
cls,
model: Database,
properties: dict[str, Any],
item: Database | None = None,
attributes: dict[str, Any] | None = None,
commit: bool = True,
) -> Database:
"""
Expand All @@ -50,13 +52,14 @@ def update(

The masked values should be unmasked before the database is updated.
"""
if "encrypted_extra" in properties:
properties["encrypted_extra"] = model.db_engine_spec.unmask_encrypted_extra(
model.encrypted_extra,
properties["encrypted_extra"],

if item and attributes and "encrypted_extra" in attributes:
attributes["encrypted_extra"] = item.db_engine_spec.unmask_encrypted_extra(
item.encrypted_extra,
attributes["encrypted_extra"],
)

return super().update(model, properties, commit)
return super().update(item, attributes, commit)

@staticmethod
def validate_uniqueness(database_name: str) -> bool:
Expand All @@ -74,7 +77,7 @@ def validate_update_uniqueness(database_id: int, database_name: str) -> bool:
return not db.session.query(database_query.exists()).scalar()

@staticmethod
def get_database_by_name(database_name: str) -> Optional[Database]:
def get_database_by_name(database_name: str) -> Database | None:
return (
db.session.query(Database)
.filter(Database.database_name == database_name)
Expand Down Expand Up @@ -129,7 +132,7 @@ def get_related_objects(cls, database_id: int) -> dict[str, Any]:
}

@classmethod
def get_ssh_tunnel(cls, database_id: int) -> Optional[SSHTunnel]:
def get_ssh_tunnel(cls, database_id: int) -> SSHTunnel | None:
ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == database_id)
Expand All @@ -143,8 +146,8 @@ class SSHTunnelDAO(BaseDAO[SSHTunnel]):
@classmethod
def update(
cls,
model: SSHTunnel,
properties: dict[str, Any],
item: SSHTunnel | None = None,
attributes: dict[str, Any] | None = None,
commit: bool = True,
) -> SSHTunnel:
"""
Expand All @@ -156,7 +159,9 @@ def update(
The masked values should be unmasked before the ssh tunnel is updated.
"""
# ID cannot be updated so we remove it if present in the payload
properties.pop("id", None)
properties = unmask_password_info(properties, model)

return super().update(model, properties, commit)
if item and attributes:
attributes.pop("id", None)
attributes = unmask_password_info(attributes, item)

return super().update(item, attributes, commit)
Loading