Skip to content

Commit

Permalink
fix: Security manager incorrect calls (#29884)
Browse files Browse the repository at this point in the history
(cherry picked from commit d497dca)
  • Loading branch information
michael-s-molina authored and sadpandajoe committed Aug 23, 2024
1 parent 8f93ad7 commit b0a2aea
Show file tree
Hide file tree
Showing 10 changed files with 162 additions and 93 deletions.
55 changes: 53 additions & 2 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@
import dataclasses
import logging
import re
from collections import defaultdict
from collections.abc import Hashable
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Any, Callable, cast
from typing import Any, Callable, cast, Optional, Union

import dateutil.parser
import numpy as np
Expand Down Expand Up @@ -69,7 +70,7 @@
from sqlalchemy.sql.expression import Label, TextAsFrom
from sqlalchemy.sql.selectable import Alias, TableClause

from superset import app, db, security_manager
from superset import app, db, is_feature_enabled, security_manager
from superset.commands.dataset.exceptions import DatasetNotFoundError
from superset.common.db_query_status import QueryStatus
from superset.connectors.sqla.utils import (
Expand Down Expand Up @@ -712,6 +713,56 @@ def get_datasource_by_name(
) -> BaseDatasource | None:
raise NotImplementedError()

def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor:
raise NotImplementedError()

def text(self, clause: str) -> TextClause:
raise NotImplementedError()

def get_sqla_row_level_filters(
self,
template_processor: Optional[BaseTemplateProcessor] = None,
) -> list[TextClause]:
"""
Return the appropriate row level security filters for this table and the
current user. A custom username can be passed when the user is not present in the
Flask global namespace.
:param template_processor: The template processor to apply to the filters.
:returns: A list of SQL clauses to be ANDed together.
"""
template_processor = template_processor or self.get_template_processor()

all_filters: list[TextClause] = []
filter_groups: dict[Union[int, str], list[TextClause]] = defaultdict(list)
try:
for filter_ in security_manager.get_rls_filters(self):
clause = self.text(
f"({template_processor.process_template(filter_.clause)})"
)
if filter_.group_key:
filter_groups[filter_.group_key].append(clause)
else:
all_filters.append(clause)

if is_feature_enabled("EMBEDDED_SUPERSET"):
for rule in security_manager.get_guest_rls_filters(self):
clause = self.text(
f"({template_processor.process_template(rule['clause'])})"
)
all_filters.append(clause)

grouped_filters = [or_(*clauses) for clauses in filter_groups.values()]
all_filters.extend(grouped_filters)
return all_filters
except TemplateError as ex:
raise QueryObjectValidationError(
_(
"Error in jinja expression in RLS filters: %(msg)s",
msg=ex.message,
)
) from ex


class AnnotationDatasource(BaseDatasource):
"""Dummy object so we can query annotations using 'Viz' objects just like
Expand Down
2 changes: 1 addition & 1 deletion superset/databases/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2274,6 +2274,6 @@ def schemas_access_for_file_upload(self, pk: int) -> Response:
# otherwise the database should have been filtered out
# in CsvToDatabaseForm
schemas_allowed_processed = security_manager.get_schemas_accessible_by_user(
database, schemas_allowed, True
database, database.get_default_catalog(), schemas_allowed, True
)
return self.response(200, schemas=schemas_allowed_processed)
9 changes: 5 additions & 4 deletions superset/extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from superset.async_events.async_query_manager_factory import AsyncQueryManagerFactory
from superset.extensions.ssh import SSHManagerFactory
from superset.extensions.stats_logger import BaseStatsLoggerManager
from superset.security.manager import SupersetSecurityManager
from superset.utils.cache_manager import CacheManager
from superset.utils.encrypt import EncryptedFieldFactory
from superset.utils.feature_flag_manager import FeatureFlagManager
Expand Down Expand Up @@ -84,9 +85,9 @@ def get_files(bundle: str, asset_type: str = "js") -> list[str]:
return {
"js_manifest": lambda bundle: get_files(bundle, "js"),
"css_manifest": lambda bundle: get_files(bundle, "css"),
"assets_prefix": self.app.config["STATIC_ASSETS_PREFIX"]
if self.app
else "",
"assets_prefix": (
self.app.config["STATIC_ASSETS_PREFIX"] if self.app else ""
),
}

def parse_manifest_json(self) -> None:
Expand Down Expand Up @@ -132,7 +133,7 @@ def init_app(self, app: Flask) -> None:
migrate = Migrate()
profiling = ProfilingExtension()
results_backend_manager = ResultsBackendManager()
security_manager = LocalProxy(lambda: appbuilder.sm)
security_manager: SupersetSecurityManager = LocalProxy(lambda: appbuilder.sm)
ssh_manager_factory = SSHManagerFactory()
stats_logger_manager = BaseStatsLoggerManager()
talisman = Talisman()
48 changes: 6 additions & 42 deletions superset/models/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import logging
import re
import uuid
from collections import defaultdict
from collections.abc import Hashable
from datetime import datetime, timedelta
from typing import Any, cast, NamedTuple, Optional, TYPE_CHECKING, Union
Expand Down Expand Up @@ -52,7 +51,7 @@
from sqlalchemy.sql.selectable import Alias, TableClause
from sqlalchemy_utils import UUIDType

from superset import app, db, is_feature_enabled, security_manager
from superset import app, db, is_feature_enabled
from superset.advanced_data_type.types import AdvancedDataTypeResponse
from superset.common.db_query_status import QueryStatus
from superset.common.utils.time_range_utils import get_since_until_from_time_range
Expand Down Expand Up @@ -806,47 +805,12 @@ def get_fetch_values_predicate(

def get_sqla_row_level_filters(
self,
template_processor: Optional[BaseTemplateProcessor] = None,
template_processor: Optional[BaseTemplateProcessor] = None, # pylint: disable=unused-argument
) -> list[TextClause]:
"""
Return the appropriate row level security filters for this table and the
current user. A custom username can be passed when the user is not present in the
Flask global namespace.
:param template_processor: The template processor to apply to the filters.
:returns: A list of SQL clauses to be ANDed together.
"""
template_processor = template_processor or self.get_template_processor()

all_filters: list[TextClause] = []
filter_groups: dict[Union[int, str], list[TextClause]] = defaultdict(list)
try:
for filter_ in security_manager.get_rls_filters(self):
clause = self.text(
f"({template_processor.process_template(filter_.clause)})"
)
if filter_.group_key:
filter_groups[filter_.group_key].append(clause)
else:
all_filters.append(clause)

if is_feature_enabled("EMBEDDED_SUPERSET"):
for rule in security_manager.get_guest_rls_filters(self):
clause = self.text(
f"({template_processor.process_template(rule['clause'])})"
)
all_filters.append(clause)

grouped_filters = [or_(*clauses) for clauses in filter_groups.values()]
all_filters.extend(grouped_filters)
return all_filters
except TemplateError as ex:
raise QueryObjectValidationError(
_(
"Error in jinja expression in RLS filters: %(msg)s",
msg=ex.message,
)
) from ex
# TODO: We should refactor this mixin and remove this method
# as it exists in the BaseDatasource and is not applicable
# for datasources of type query
return []

def _process_sql_expression(
self,
Expand Down
4 changes: 1 addition & 3 deletions superset/security/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2646,9 +2646,7 @@ def has_guest_access(self, dashboard: "Dashboard") -> bool:
return False

dashboards = [
r
for r in user.resources
if r["type"] == GuestTokenResourceType.DASHBOARD.value
r for r in user.resources if r["type"] == GuestTokenResourceType.DASHBOARD
]

# TODO (embedded): remove this check once uuids are rolled out
Expand Down
36 changes: 32 additions & 4 deletions tests/integration_tests/security/guest_token_security_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,29 @@ def setUp(self) -> None:
self.authorized_guest = security_manager.get_guest_user_from_token(
{
"user": {},
"resources": [{"type": "dashboard", "id": str(self.embedded.uuid)}],
"resources": [
{
"type": GuestTokenResourceType.DASHBOARD,
"id": str(self.embedded.uuid),
}
],
"iat": 10,
"exp": 20,
"rls_rules": [],
}
)
self.unauthorized_guest = security_manager.get_guest_user_from_token(
{
"user": {},
"resources": [
{"type": "dashboard", "id": "06383667-3e02-4e5e-843f-44e9c5896b6c"}
{
"type": GuestTokenResourceType.DASHBOARD,
"id": "06383667-3e02-4e5e-843f-44e9c5896b6c",
}
],
"iat": 10,
"exp": 20,
"rls_rules": [],
}
)

Expand Down Expand Up @@ -247,15 +261,29 @@ def setUp(self) -> None:
self.authorized_guest = security_manager.get_guest_user_from_token(
{
"user": {},
"resources": [{"type": "dashboard", "id": str(self.embedded.uuid)}],
"resources": [
{
"type": GuestTokenResourceType.DASHBOARD,
"id": str(self.embedded.uuid),
}
],
"iat": 10,
"exp": 20,
"rls_rules": [],
}
)
self.unauthorized_guest = security_manager.get_guest_user_from_token(
{
"user": {},
"resources": [
{"type": "dashboard", "id": "06383667-3e02-4e5e-843f-44e9c5896b6c"}
{
"type": GuestTokenResourceType.DASHBOARD,
"id": "06383667-3e02-4e5e-843f-44e9c5896b6c",
}
],
"iat": 10,
"exp": 20,
"rls_rules": [],
}
)
self.chart = self.get_slice("Girls")
Expand Down
9 changes: 8 additions & 1 deletion tests/integration_tests/security/row_level_security_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,8 +646,15 @@ def guest_user_with_rls(self, rules: Optional[list[Any]] = None) -> GuestUser:
return security_manager.get_guest_user_from_token(
{
"user": {},
"resources": [{"type": GuestTokenResourceType.DASHBOARD.value}],
"resources": [
{
"type": GuestTokenResourceType.DASHBOARD,
"id": "06383667-3e02-4e5e-843f-44e9c5896b6c",
}
],
"rls_rules": rules,
"iat": 10,
"exp": 20,
}
)

Expand Down
42 changes: 28 additions & 14 deletions tests/unit_tests/charts/commands/importers/v1/import_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def test_import_chart(mocker: MockerFixture, session_with_schema: Session) -> No
Test importing a chart.
"""

mocker.patch.object(security_manager, "can_access", return_value=True)
mock_can_access = mocker.patch.object(
security_manager, "can_access", return_value=True
)

config = copy.deepcopy(chart_config)
config["datasource_id"] = 1
Expand All @@ -89,7 +91,7 @@ def test_import_chart(mocker: MockerFixture, session_with_schema: Session) -> No
assert chart.external_url is None

# Assert that the can write to chart was checked
security_manager.can_access.assert_called_once_with("can_write", "Chart")
mock_can_access.assert_called_once_with("can_write", "Chart")


def test_import_chart_managed_externally(
Expand All @@ -98,7 +100,9 @@ def test_import_chart_managed_externally(
"""
Test importing a chart that is managed externally.
"""
mocker.patch.object(security_manager, "can_access", return_value=True)
mock_can_access = mocker.patch.object(
security_manager, "can_access", return_value=True
)

config = copy.deepcopy(chart_config)
config["datasource_id"] = 1
Expand All @@ -111,7 +115,7 @@ def test_import_chart_managed_externally(
assert chart.external_url == "https://example.org/my_chart"

# Assert that the can write to chart was checked
security_manager.can_access.assert_called_once_with("can_write", "Chart")
mock_can_access.assert_called_once_with("can_write", "Chart")


def test_import_chart_without_permission(
Expand All @@ -121,7 +125,9 @@ def test_import_chart_without_permission(
"""
Test importing a chart when a user doesn't have permissions to create.
"""
mocker.patch.object(security_manager, "can_access", return_value=False)
mock_can_access = mocker.patch.object(
security_manager, "can_access", return_value=False
)

config = copy.deepcopy(chart_config)
config["datasource_id"] = 1
Expand All @@ -134,7 +140,7 @@ def test_import_chart_without_permission(
== "Chart doesn't exist and user doesn't have permission to create charts"
)
# Assert that the can write to chart was checked
security_manager.can_access.assert_called_once_with("can_write", "Chart")
mock_can_access.assert_called_once_with("can_write", "Chart")


def test_filter_chart_annotations(session: Session) -> None:
Expand Down Expand Up @@ -162,8 +168,12 @@ def test_import_existing_chart_without_permission(
"""
Test importing a chart when a user doesn't have permissions to modify.
"""
mocker.patch.object(security_manager, "can_access", return_value=True)
mocker.patch.object(security_manager, "can_access_chart", return_value=False)
mock_can_access = mocker.patch.object(
security_manager, "can_access", return_value=True
)
mock_can_access_chart = mocker.patch.object(
security_manager, "can_access_chart", return_value=False
)

slice = (
session_with_data.query(Slice)
Expand All @@ -180,8 +190,8 @@ def test_import_existing_chart_without_permission(
)

# Assert that the can write to chart was checked
security_manager.can_access.assert_called_once_with("can_write", "Chart")
security_manager.can_access_chart.assert_called_once_with(slice)
mock_can_access.assert_called_once_with("can_write", "Chart")
mock_can_access_chart.assert_called_once_with(slice)


def test_import_existing_chart_with_permission(
Expand All @@ -191,8 +201,12 @@ def test_import_existing_chart_with_permission(
"""
Test importing a chart that exists when a user has access permission to that chart.
"""
mocker.patch.object(security_manager, "can_access", return_value=True)
mocker.patch.object(security_manager, "can_access_chart", return_value=True)
mock_can_access = mocker.patch.object(
security_manager, "can_access", return_value=True
)
mock_can_access_chart = mocker.patch.object(
security_manager, "can_access_chart", return_value=True
)

admin = User(
first_name="Alice",
Expand All @@ -215,5 +229,5 @@ def test_import_existing_chart_with_permission(
with override_user(admin):
import_chart(config, overwrite=True)
# Assert that the can write to chart was checked
security_manager.can_access.assert_called_once_with("can_write", "Chart")
security_manager.can_access_chart.assert_called_once_with(slice)
mock_can_access.assert_called_once_with("can_write", "Chart")
mock_can_access_chart.assert_called_once_with(slice)
Loading

0 comments on commit b0a2aea

Please sign in to comment.