Skip to content

Commit

Permalink
feat: add support for catalogs
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida committed May 13, 2024
1 parent 4f51f05 commit 90bbf7f
Show file tree
Hide file tree
Showing 14 changed files with 504 additions and 62 deletions.
22 changes: 2 additions & 20 deletions superset/db_engine_specs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -706,29 +706,11 @@ Hive and Trino:
4. Table
5. Column

If the database supports catalogs, then the DB engine spec should have the `supports_catalog` class attribute set to true.
If the database supports catalogs, then the DB engine spec should have the `supports_catalog` class attribute set to true. It should also implement the `get_default_catalog` method, so that the proper permissions can be created when datasets are added.

### Dynamic catalog

Superset has no support for multiple catalogs. A given SQLAlchemy URI connects to a single catalog, and it's impossible to browse other catalogs, or change the catalog. This means that datasets can only be added for the main catalog of the database. For example, with this Postgres SQLAlchemy URI:

```
postgresql://admin:password123@db.example.org:5432/db
```

Here, datasets can only be added to the `db` catalog (which Postgres calls a "database").

One confusing problem is that many databases allow querying across catalogs in SQL Lab. For example, with BigQuery one can write:

```sql
SELECT * FROM project.schema.table
```

This means that **even though the database is configured for a given catalog (project), users can query other projects**. This is a common workaround for creating datasets in catalogs other than the catalog configured in the database: just create a virtual dataset.

Ideally we would want users to be able to choose the catalog when using SQL Lab and when creating datasets. In order to do that, DB engine specs need to implement a method that rewrites the SQLAlchemy URI depending on the desired catalog. This method already exists, and is the same method used for dynamic schemas, `adjust_engine_params`, but currently there are no UI affordances for choosing a catalog.

Before the UI is implemented Superset still needs to implement support for catalogs in its security manager. But in the meantime, it's possible for DB engine spec developers to support dynamic catalogs, by setting `supports_dynamic_catalog` to true and implementing `adjust_engine_params` to handle a catalog.
Superset support for multiple catalogs. Since, in general, a given SQLAlchemy URI connects only to a single catalog, it requires DB engine specs to implement the `adjust_engine_params` method to rewrite the URL to connect to a different catalog, similar to how dynamic schemas work. Additionally, DB engine specs should also implement the `get_catalog_names` method, so that users can browse the available catalogs.

### SSH tunneling

Expand Down
34 changes: 33 additions & 1 deletion superset/db_engine_specs/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from sqlalchemy import column, types
from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
from sqlalchemy.sql import sqltypes

from superset import sql_parse
Expand Down Expand Up @@ -127,7 +128,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met

allows_hidden_cc_in_orderby = True

supports_catalog = False
supports_catalog = supports_dynamic_catalog = True

"""
https://www.python.org/dev/peps/pep-0249/#arraysize
Expand Down Expand Up @@ -459,6 +460,24 @@ def estimate_query_cost( # pylint: disable=too-many-arguments
for statement in statements
]

@classmethod
def get_default_catalog(cls, database: Database) -> str | None:
"""
Get the default catalog.
"""
url = database.url_object

# The SQLAlchemy driver accepts both `bigquery://project` (where the project is
# technically a host) and `bigquery:///project` (where it's a database). But
# both can be missing, and the project is inferred from the authentication
# credentials.
if project := url.host or url.database:
return project

with database.get_sqla_engine() as engine:
client = cls._get_client(engine)
return client.project

@classmethod
def get_catalog_names(
cls,
Expand All @@ -477,6 +496,19 @@ def get_catalog_names(

return {project.project_id for project in projects}

@classmethod
def adjust_engine_params(
cls,
uri: URL,
connect_args: dict[str, Any],
catalog: str | None = None,
schema: str | None = None,
) -> tuple[URL, dict[str, Any]]:
if catalog:
uri = uri.set(host=catalog, database="")

return uri, connect_args

@classmethod
def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool:
return True
Expand Down
57 changes: 36 additions & 21 deletions superset/db_engine_specs/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# pylint: disable=too-many-lines

from __future__ import annotations

import contextlib
Expand Down Expand Up @@ -165,6 +167,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
"""

supports_dynamic_schema = True
supports_catalog = supports_dynamic_catalog = True

column_type_mappings = (
(
Expand Down Expand Up @@ -295,6 +298,24 @@ def convert_dttm(
def epoch_to_dttm(cls) -> str:
return "from_unixtime({col})"

@classmethod
def get_default_catalog(cls, database: "Database") -> str | None:
"""
Return the default catalog.
"""
return database.url_object.database.split("/")[0]

@classmethod
def get_catalog_names(
cls,
database: Database,
inspector: Inspector,
) -> set[str]:
"""
Get all catalogs.
"""
return {catalog for (catalog,) in inspector.bind.execute("SHOW CATALOGS")}

@classmethod
def adjust_engine_params(
cls,
Expand All @@ -303,14 +324,22 @@ def adjust_engine_params(
catalog: str | None = None,
schema: str | None = None,
) -> tuple[URL, dict[str, Any]]:
database = uri.database
if schema and database:
if uri.database and "/" in uri.database:
current_catalog, current_schema = uri.database.split("/", 1)
else:
current_catalog, current_schema = uri.database, None

if schema:
schema = parse.quote(schema, safe="")
if "/" in database:
database = database.split("/")[0] + "/" + schema
else:
database += "/" + schema
uri = uri.set(database=database)

adjusted_database = "/".join(
[
catalog or current_catalog or "",
schema or current_schema or "",
]
).rstrip("/")

uri = uri.set(database=adjusted_database)

return uri, connect_args

Expand Down Expand Up @@ -648,8 +677,6 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
engine_name = "Presto"
allows_alias_to_source_column = False

supports_catalog = False

custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = {
COLUMN_DOES_NOT_EXIST_REGEX: (
__(
Expand Down Expand Up @@ -812,17 +839,6 @@ def get_view_names(
results = cursor.fetchall()
return {row[0] for row in results}

@classmethod
def get_catalog_names(
cls,
database: Database,
inspector: Inspector,
) -> set[str]:
"""
Get all catalogs.
"""
return {catalog for (catalog,) in inspector.bind.execute("SHOW CATALOGS")}

@classmethod
def _create_column_info(
cls, name: str, data_type: types.TypeEngine
Expand Down Expand Up @@ -1248,7 +1264,6 @@ def get_extra_table_metadata(
),
}

# flake8 is not matching `Optional[str]` to `Any` for some reason...
metadata["view"] = cast(
Any,
cls.get_create_view(database, table.schema, table.table),
Expand Down
28 changes: 21 additions & 7 deletions superset/db_engine_specs/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
sqlalchemy_uri_placeholder = "snowflake://"

supports_dynamic_schema = True
supports_catalog = False
supports_catalog = supports_dynamic_catalog = True

_time_grain_expressions = {
None: "{col}",
Expand Down Expand Up @@ -144,12 +144,19 @@ def adjust_engine_params(
catalog: Optional[str] = None,
schema: Optional[str] = None,
) -> tuple[URL, dict[str, Any]]:
database = uri.database
if "/" in database:
database = database.split("/")[0]
if schema:
schema = parse.quote(schema, safe="")
uri = uri.set(database=f"{database}/{schema}")
if "/" in uri.database:
current_catalog, current_schema = uri.database.split("/", 1)
else:
current_catalog, current_schema = uri.database, None

adjusted_database = "/".join(
[
catalog or current_catalog,
schema or current_schema or "",
]
).rstrip("/")

uri = uri.set(database=adjusted_database)

return uri, connect_args

Expand All @@ -169,6 +176,13 @@ def get_schema_from_engine_params(

return parse.unquote(database.split("/")[1])

@classmethod
def get_default_catalog(cls, database: "Database") -> Optional[str]:
"""
Return the default catalog.
"""
return database.url_object.database.split("/")[0]

@classmethod
def get_catalog_names(
cls,
Expand Down
8 changes: 4 additions & 4 deletions superset/migrations/shared/catalogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class Slice(Base):
schema_perm = sa.Column(sa.String(1000))


def upgrade_catalog_perms(engine: str | None = None) -> None:
def upgrade_catalog_perms(engines: set[str] | None = None) -> None:
"""
Update models when catalogs are introduced in a DB engine spec.
Expand All @@ -102,7 +102,7 @@ def upgrade_catalog_perms(engine: str | None = None) -> None:
for database in session.query(Database).all():
db_engine_spec = database.db_engine_spec
if (
engine and db_engine_spec.engine != engine
engines and db_engine_spec.engine not in engines
) or not db_engine_spec.supports_catalog:
continue

Expand Down Expand Up @@ -166,7 +166,7 @@ def upgrade_catalog_perms(engine: str | None = None) -> None:
session.commit()


def downgrade_catalog_perms(engine: str | None = None) -> None:
def downgrade_catalog_perms(engines: set[str] | None = None) -> None:
"""
Reverse the process of `upgrade_catalog_perms`.
"""
Expand All @@ -175,7 +175,7 @@ def downgrade_catalog_perms(engine: str | None = None) -> None:
for database in session.query(Database).all():
db_engine_spec = database.db_engine_spec
if (
engine and db_engine_spec.engine != engine
engines and db_engine_spec.engine not in engines
) or not db_engine_spec.supports_catalog:
continue

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ def upgrade():
"slices",
sa.Column("catalog_perm", sa.String(length=1000), nullable=True),
)
upgrade_catalog_perms(engine="postgresql")
upgrade_catalog_perms(engines={"postgresql"})


def downgrade():
op.drop_column("slices", "catalog_perm")
op.drop_column("tables", "catalog_perm")
downgrade_catalog_perms(engine="postgresql")
downgrade_catalog_perms(engines={"postgresql"})
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@


def upgrade():
upgrade_catalog_perms(engine="databricks")
upgrade_catalog_perms(engines={"databricks"})


def downgrade():
downgrade_catalog_perms(engine="databricks")
downgrade_catalog_perms(engines={"databricks"})
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Enable catalog in BigQuery/Presto/Trino/Snowflake
Revision ID: 87ffc36f9842
Revises: 4081be5b6b74
Create Date: 2024-05-09 18:44:43.289445
"""

from superset.migrations.shared.catalogs import (
downgrade_catalog_perms,
upgrade_catalog_perms,
)

# revision identifiers, used by Alembic.
revision = "87ffc36f9842"
down_revision = "4081be5b6b74"


def upgrade():
upgrade_catalog_perms(engines={"trino", "presto", "bigquery", "snowflake"})


def downgrade():
downgrade_catalog_perms(engines={"trino", "presto", "bigquery", "snowflake"})
2 changes: 1 addition & 1 deletion tests/integration_tests/databases/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3281,7 +3281,7 @@ def test_available(self, app, get_available_engine_specs):
"sqlalchemy_uri_placeholder": "bigquery://{project_id}",
"engine_information": {
"supports_file_upload": True,
"supports_dynamic_catalog": False,
"supports_dynamic_catalog": True,
"disable_ssh_tunneling": True,
},
},
Expand Down
8 changes: 4 additions & 4 deletions tests/integration_tests/model_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def test_impersonate_user_presto(self, mocked_create_engine):
model._get_sqla_engine()
call_args = mocked_create_engine.call_args

assert str(call_args[0][0]) == "presto://gamma@localhost"
assert str(call_args[0][0]) == "presto://gamma@localhost/"

assert call_args[1]["connect_args"] == {
"protocol": "https",
Expand All @@ -180,7 +180,7 @@ def test_impersonate_user_presto(self, mocked_create_engine):
model._get_sqla_engine()
call_args = mocked_create_engine.call_args

assert str(call_args[0][0]) == "presto://localhost"
assert str(call_args[0][0]) == "presto://localhost/"

assert call_args[1]["connect_args"] == {
"protocol": "https",
Expand Down Expand Up @@ -225,7 +225,7 @@ def test_impersonate_user_trino(self, mocked_create_engine):
model._get_sqla_engine()
call_args = mocked_create_engine.call_args

assert str(call_args[0][0]) == "trino://localhost"
assert str(call_args[0][0]) == "trino://localhost/"
assert call_args[1]["connect_args"]["user"] == "gamma"

model = Database(
Expand All @@ -239,7 +239,7 @@ def test_impersonate_user_trino(self, mocked_create_engine):

assert (
str(call_args[0][0])
== "trino://original_user:original_user_password@localhost"
== "trino://original_user:original_user_password@localhost/"
)
assert call_args[1]["connect_args"]["user"] == "gamma"

Expand Down
Loading

0 comments on commit 90bbf7f

Please sign in to comment.