Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(row-level-security): add base filter type and filter grouping #10946

Merged
merged 4 commits into from
Sep 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions superset/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,8 @@ def init_views(self) -> None:
if self.config["ENABLE_ROW_LEVEL_SECURITY"]:
appbuilder.add_view(
RowLevelSecurityFiltersModelView,
"Row Level Security Filters",
label=__("Row level security filters"),
"Row Level Security",
label=__("Row level security"),
villebro marked this conversation as resolved.
Show resolved Hide resolved
category="Security",
category_label=__("Security"),
icon="fa-lock",
Expand Down
28 changes: 18 additions & 10 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
import json
import logging
from collections import OrderedDict
from collections import defaultdict, OrderedDict
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Any, Dict, Hashable, List, NamedTuple, Optional, Tuple, Union
Expand All @@ -35,6 +35,7 @@
Column,
DateTime,
desc,
Enum,
ForeignKey,
Integer,
or_,
Expand Down Expand Up @@ -92,8 +93,8 @@ class MetadataResult:


class AnnotationDatasource(BaseDatasource):
""" Dummy object so we can query annotations using 'Viz' objects just like
regular datasources.
"""Dummy object so we can query annotations using 'Viz' objects just like
regular datasources.
"""

cache_timeout = 0
Expand Down Expand Up @@ -798,11 +799,14 @@ def _get_sqla_row_level_filters(
:returns: A list of SQL clauses to be ANDed together.
:rtype: List[str]
"""
filters_grouped: Dict[Union[int, str], List[str]] = defaultdict(list)
try:
return [
text("({})".format(template_processor.process_template(f.clause)))
for f in security_manager.get_rls_filters(self)
]
for filter_ in security_manager.get_rls_filters(self):
clause = text(
f"({template_processor.process_template(filter_.clause)})"
)
filters_grouped[filter_.group_key or filter_.id].append(clause)
villebro marked this conversation as resolved.
Show resolved Hide resolved
return [or_(*clauses) for clauses in filters_grouped.values()]
except TemplateError as ex:
raise QueryObjectValidationError(
_("Error in jinja expression in RLS filters: %(msg)s", msg=ex.message,)
Expand Down Expand Up @@ -1371,9 +1375,9 @@ def import_obj(
) -> int:
"""Imports the datasource from the object to the database.

Metrics and columns and datasource will be overrided if exists.
This function can be used to import/export dashboards between multiple
superset instances. Audit metadata isn't copies over.
Metrics and columns and datasource will be overrided if exists.
This function can be used to import/export dashboards between multiple
superset instances. Audit metadata isn't copies over.
"""

def lookup_sqlatable(table_: "SqlaTable") -> "SqlaTable":
Expand Down Expand Up @@ -1506,6 +1510,10 @@ class RowLevelSecurityFilter(Model, AuditMixinNullable):

__tablename__ = "row_level_security_filters"
id = Column(Integer, primary_key=True)
filter_type = Column(
Enum(*[filter_type.value for filter_type in utils.RowLevelSecurityFilterType])
)
group_key = Column(String(255), nullable=True)
roles = relationship(
security_manager.role_model,
secondary=RLSFilterRoles,
Expand Down
60 changes: 52 additions & 8 deletions superset/connectors/sqla/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
import logging
import re
from dataclasses import dataclass, field
from typing import Dict, List, Union
from typing import Any, cast, Dict, List, Union

from flask import flash, Markup, redirect
from flask import current_app, flash, Markup, redirect
from flask_appbuilder import CompactCRUDMixin, expose
from flask_appbuilder.actions import action
from flask_appbuilder.fieldwidgets import Select2Widget
Expand All @@ -41,6 +41,7 @@
DatasourceFilter,
DeleteMixin,
ListWidgetWithCheckboxes,
SupersetListWidget,
SupersetModelView,
validate_sqlatable,
YamlExportMixin,
Expand Down Expand Up @@ -241,30 +242,73 @@ class SqlMetricInlineView( # pylint: disable=too-many-ancestors
edit_form_extra_fields = add_form_extra_fields


class RowLevelSecurityListWidget(
SupersetListWidget
): # pylint: disable=too-few-public-methods
template = "superset/models/rls/list.html"

def __init__(self, **kwargs: Any):
kwargs["appbuilder"] = current_app.appbuilder
super().__init__(**kwargs)


class RowLevelSecurityFiltersModelView( # pylint: disable=too-many-ancestors
SupersetModelView, DeleteMixin
):
datamodel = SQLAInterface(models.RowLevelSecurityFilter)

list_widget = cast(SupersetListWidget, RowLevelSecurityListWidget)
Copy link
Member Author

@villebro villebro Sep 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some reason mypy flags Type[RowLevelSecurityListWidget] as being incompatible with Type[SupersetListWidget] which is required for list_widget, despite RowLevelSecurityListWidget extending from SupersetListWidget. Either I'm missing something fundamental here, or mypy is just flagging a false positive.


list_title = _("Row level security filter")
show_title = _("Show Row level security filter")
add_title = _("Add Row level security filter")
edit_title = _("Edit Row level security filter")

list_columns = ["tables", "roles", "clause", "creator", "modified"]
order_columns = ["tables", "clause", "modified"]
edit_columns = ["tables", "roles", "clause"]
list_columns = [
"filter_type",
"tables",
"roles",
"group_key",
"clause",
"creator",
"modified",
]
order_columns = ["filter_type", "group_key", "clause", "modified"]
edit_columns = ["filter_type", "tables", "roles", "group_key", "clause"]
show_columns = edit_columns
search_columns = ("tables", "roles", "clause")
search_columns = ("filter_type", "tables", "roles", "group_key", "clause")
add_columns = edit_columns
base_order = ("changed_on", "desc")
description_columns = {
"filter_type": _(
"Regular filters add where clauses to queries if a user belongs to a "
"role referenced in the filter. Base filters apply filters to all queries "
"except the roles defined in the filter, and can be used to define what "
"users can see if no RLS filters within a filter group apply to them."
),
"tables": _("These are the tables this filter will be applied to."),
"roles": _("These are the roles this filter will be applied to."),
"roles": _(
"For regular filters, these are the roles this filter will be "
"applied to. For base filters, these are the roles that the "
"filter DOES NOT apply to, e.g. Admin if admin should see all "
"data."
),
"group_key": _(
"Filters with the same group key will be ORed together within the group, "
"while different filter groups will be ANDed together. Undefined group "
"keys are treated as unique groups, i.e. are not grouped together. "
"For example, if a table has three filters, of which two are for "
"departments Finance and Marketing (group key = 'department'), and one "
"refers to the region Europe (group key = 'region'), the filter clause "
"would apply the filter (department = 'Finance' OR department = "
"'Marketing') AND (region = 'Europe')."
),
"clause": _(
"This is the condition that will be added to the WHERE clause. "
"For example, to only return rows for a particular client, "
"you might put in: client_id = 9"
"you might define a regular filter with the clause `client_id = 9`. To "
"display no rows unless a user belongs to a RLS filter role, a base "
"filter can be created with the clause `1 = 0` (always false)."
),
}
label_columns = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""add rls filter type and grouping key

Revision ID: e5ef6828ac4e
Revises: ae19b4ee3692
Create Date: 2020-09-15 18:22:40.130985

"""

# revision identifiers, used by Alembic.
revision = "e5ef6828ac4e"
down_revision = "ae19b4ee3692"

import sqlalchemy as sa
from alembic import op

from superset.utils import core as utils


def upgrade():
with op.batch_alter_table("row_level_security_filters") as batch_op:
batch_op.add_column(sa.Column("filter_type", sa.VARCHAR(255), nullable=True)),
villebro marked this conversation as resolved.
Show resolved Hide resolved
batch_op.add_column(sa.Column("group_key", sa.VARCHAR(255), nullable=True)),
batch_op.create_index(
op.f("ix_row_level_security_filters_filter_type"),
["filter_type"],
unique=False,
)

bind = op.get_bind()
metadata = sa.MetaData(bind=bind)
filters = sa.Table("row_level_security_filters", metadata, autoload=True)
statement = filters.update().values(
filter_type=utils.RowLevelSecurityFilterType.REGULAR.value
)
bind.execute(statement)


def downgrade():
with op.batch_alter_table("row_level_security_filters") as batch_op:
batch_op.drop_index(op.f("ix_row_level_security_filters_filter_type"),)
batch_op.drop_column("filter_type")
batch_op.drop_column("group_key")
46 changes: 38 additions & 8 deletions superset/security/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
ViewMenuModelView,
)
from flask_appbuilder.widgets import ListWidget
from sqlalchemy import or_
from sqlalchemy import and_, or_
from sqlalchemy.engine.base import Connection
from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.orm.query import Query as SqlaQuery
Expand All @@ -46,7 +46,7 @@
from superset.constants import RouteMethod
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetSecurityException
from superset.utils.core import DatasourceName
from superset.utils.core import DatasourceName, RowLevelSecurityFilterType

if TYPE_CHECKING:
from superset.common.query_context import QueryContext
Expand All @@ -62,16 +62,16 @@

class SupersetSecurityListWidget(ListWidget):
"""
Redeclaring to avoid circular imports
Redeclaring to avoid circular imports
"""

template = "superset/fab_overrides/list.html"


class SupersetRoleListWidget(ListWidget):
"""
Role model view from FAB already uses a custom list widget override
So we override the override
Role model view from FAB already uses a custom list widget override
So we override the override
"""

template = "superset/fab_overrides/list_role.html"
Expand Down Expand Up @@ -1012,8 +1012,23 @@ def get_rls_filters( # pylint: disable=no-self-use
.filter(assoc_user_role.c.user_id == g.user.id)
.subquery()
)
filter_roles = (
regular_filter_roles = (
self.get_session.query(RLSFilterRoles.c.rls_filter_id)
.join(RowLevelSecurityFilter)
.filter(
RowLevelSecurityFilter.filter_type
== RowLevelSecurityFilterType.REGULAR
)
.filter(RLSFilterRoles.c.role_id.in_(user_roles))
.subquery()
)
base_filter_roles = (
self.get_session.query(RLSFilterRoles.c.rls_filter_id)
.join(RowLevelSecurityFilter)
.filter(
RowLevelSecurityFilter.filter_type
== RowLevelSecurityFilterType.BASE
)
.filter(RLSFilterRoles.c.role_id.in_(user_roles))
.subquery()
)
Expand All @@ -1024,10 +1039,25 @@ def get_rls_filters( # pylint: disable=no-self-use
)
query = (
self.get_session.query(
RowLevelSecurityFilter.id, RowLevelSecurityFilter.clause
RowLevelSecurityFilter.id,
RowLevelSecurityFilter.group_key,
RowLevelSecurityFilter.clause,
)
.filter(RowLevelSecurityFilter.id.in_(filter_tables))
.filter(RowLevelSecurityFilter.id.in_(filter_roles))
.filter(
or_(
and_(
RowLevelSecurityFilter.filter_type
== RowLevelSecurityFilterType.REGULAR,
RowLevelSecurityFilter.id.in_(regular_filter_roles),
),
and_(
RowLevelSecurityFilter.filter_type
== RowLevelSecurityFilterType.BASE,
RowLevelSecurityFilter.id.notin_(base_filter_roles),
),
)
)
)
return query.all()
return []
Expand Down
Loading