Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: remove unneeded complexity in migration #19022

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 7 additions & 89 deletions superset/migrations/versions/b8d3a24d9131_new_dataset_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,22 @@
"""

import json
from typing import Any, Dict, List, Optional, Type
from typing import List
from uuid import uuid4

import sqlalchemy as sa
from alembic import op
from sqlalchemy import and_, inspect, or_
from sqlalchemy.engine import create_engine, Engine
from sqlalchemy.engine.url import make_url, URL
from sqlalchemy.exc import ArgumentError
from sqlalchemy.engine.url import make_url
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import backref, relationship, Session
from sqlalchemy.schema import UniqueConstraint
from sqlalchemy_utils import UUIDType

from superset import app, db, db_engine_specs
from superset import app, db
from superset.connectors.sqla.models import ADDITIVE_METRIC_TYPES
from superset.extensions import encrypted_field_factory, security_manager
from superset.extensions import encrypted_field_factory
from superset.sql_parse import ParsedQuery
from superset.utils.memoized import memoized

# revision identifiers, used by Alembic.
revision = "b8d3a24d9131"
Expand Down Expand Up @@ -78,86 +75,6 @@ class Database(Base):
)
server_cert = sa.Column(encrypted_field_factory.create(sa.Text), nullable=True)

@property
def sqlalchemy_uri_decrypted(self) -> str:
try:
url = make_url(self.sqlalchemy_uri)
except (ArgumentError, ValueError):
return "dialect://invalid_uri"
if custom_password_store:
url.password = custom_password_store(url)
else:
url.password = self.password
return str(url)

@property
def backend(self) -> str:
sqlalchemy_url = make_url(self.sqlalchemy_uri_decrypted)
return sqlalchemy_url.get_backend_name() # pylint: disable=no-member

@classmethod
@memoized
def get_db_engine_spec_for_backend(
cls, backend: str
) -> Type[db_engine_specs.BaseEngineSpec]:
engines = db_engine_specs.get_engine_specs()
return engines.get(backend, db_engine_specs.BaseEngineSpec)

@property
def db_engine_spec(self) -> Type[db_engine_specs.BaseEngineSpec]:
return self.get_db_engine_spec_for_backend(self.backend)

def get_extra(self) -> Dict[str, Any]:
return self.db_engine_spec.get_extra_params(self)

def get_effective_user(
self, object_url: URL, user_name: Optional[str] = None,
) -> Optional[str]:
effective_username = None
if self.impersonate_user:
effective_username = object_url.username
if user_name:
effective_username = user_name

return effective_username

def get_encrypted_extra(self) -> Dict[str, Any]:
return json.loads(self.encrypted_extra) if self.encrypted_extra else {}

@memoized(watch=("impersonate_user", "sqlalchemy_uri_decrypted", "extra"))
def get_sqla_engine(self, schema: Optional[str] = None) -> Engine:
extra = self.get_extra()
sqlalchemy_url = make_url(self.sqlalchemy_uri_decrypted)
self.db_engine_spec.adjust_database_uri(sqlalchemy_url, schema)
effective_username = self.get_effective_user(sqlalchemy_url, "admin")
# If using MySQL or Presto for example, will set url.username
self.db_engine_spec.modify_url_for_impersonation(
sqlalchemy_url, self.impersonate_user, effective_username
)

params = extra.get("engine_params", {})
connect_args = params.get("connect_args", {})
if self.impersonate_user:
self.db_engine_spec.update_impersonation_config(
connect_args, str(sqlalchemy_url), effective_username
)

if connect_args:
params["connect_args"] = connect_args

params.update(self.get_encrypted_extra())

if DB_CONNECTION_MUTATOR:
sqlalchemy_url, params = DB_CONNECTION_MUTATOR(
sqlalchemy_url,
params,
effective_username,
security_manager,
"migration",
)

return create_engine(sqlalchemy_url, **params)


class TableColumn(Base):

Expand Down Expand Up @@ -325,8 +242,9 @@ def after_insert(target: SqlaTable) -> None: # pylint: disable=too-many-locals
)
if not database:
return
engine = database.get_sqla_engine(schema=target.schema)
conditional_quote = engine.dialect.identifier_preparer.quote
url = make_url(database.sqlalchemy_uri)
dialect_class = url.get_dialect()
conditional_quote = dialect_class().identifier_preparer.quote

# create columns
columns = []
Expand Down