Skip to content

Commit

Permalink
refactor: Cleanup user get_id/get_user_id (#20492)
Browse files Browse the repository at this point in the history
Co-authored-by: John Bodley <john.bodley@airbnb.com>
  • Loading branch information
john-bodley and John Bodley authored Jun 25, 2022
1 parent c56e37c commit 3483446
Show file tree
Hide file tree
Showing 27 changed files with 182 additions and 137 deletions.
2 changes: 1 addition & 1 deletion superset/charts/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,7 +805,7 @@ def favorite_status(self, **kwargs: Any) -> Response:
charts = ChartDAO.find_by_ids(requested_ids)
if not charts:
return self.response_404()
favorited_chart_ids = ChartDAO.favorited_ids(charts, g.user.get_id())
favorited_chart_ids = ChartDAO.favorited_ids(charts)
res = [
{"id": request_id, "value": request_id in favorited_chart_ids}
for request_id in requested_ids
Expand Down
5 changes: 3 additions & 2 deletions superset/charts/dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from superset.extensions import db
from superset.models.core import FavStar, FavStarClassName
from superset.models.slice import Slice
from superset.utils.core import get_user_id

if TYPE_CHECKING:
from superset.connectors.base.models import BaseDatasource
Expand Down Expand Up @@ -70,15 +71,15 @@ def overwrite(slc: Slice, commit: bool = True) -> None:
db.session.commit()

@staticmethod
def favorited_ids(charts: List[Slice], current_user_id: int) -> List[FavStar]:
def favorited_ids(charts: List[Slice]) -> List[FavStar]:
ids = [chart.id for chart in charts]
return [
star.obj_id
for star in db.session.query(FavStar.obj_id)
.filter(
FavStar.class_name == FavStarClassName.CHART,
FavStar.obj_id.in_(ids),
FavStar.user_id == current_user_id,
FavStar.user_id == get_user_id(),
)
.all()
]
6 changes: 3 additions & 3 deletions superset/charts/data/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from typing import Any, Dict, Optional, TYPE_CHECKING

import simplejson
from flask import current_app, g, make_response, request, Response
from flask import current_app, make_response, request, Response
from flask_appbuilder.api import expose, protect
from flask_babel import gettext as _
from marshmallow import ValidationError
Expand All @@ -44,7 +44,7 @@
from superset.exceptions import QueryObjectValidationError
from superset.extensions import event_logger
from superset.utils.async_query_manager import AsyncQueryTokenException
from superset.utils.core import create_zip, json_int_dttm_ser
from superset.utils.core import create_zip, get_user_id, json_int_dttm_ser
from superset.views.base import CsvResponse, generate_download_headers
from superset.views.base_api import statsd_metrics

Expand Down Expand Up @@ -324,7 +324,7 @@ def _run_async(
except AsyncQueryTokenException:
return self.response_401()

result = async_command.run(form_data, g.user.get_id())
result = async_command.run(form_data, get_user_id())
return self.response(202, **result)

def _send_chart_response(
Expand Down
2 changes: 1 addition & 1 deletion superset/charts/data/commands/create_async_job_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def validate(self, request: Request) -> None:
jwt_data = async_query_manager.parse_jwt_from_request(request)
self._async_channel_id = jwt_data["channel"]

def run(self, form_data: Dict[str, Any], user_id: Optional[str]) -> Dict[str, Any]:
def run(self, form_data: Dict[str, Any], user_id: Optional[int]) -> Dict[str, Any]:
job_metadata = async_query_manager.init_job(self._async_channel_id, user_id)
load_chart_data_into_cache.delay(job_metadata, form_data)
return job_metadata
4 changes: 1 addition & 3 deletions superset/dashboards/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,9 +942,7 @@ def favorite_status(self, **kwargs: Any) -> Response:
dashboards = DashboardDAO.find_by_ids(requested_ids)
if not dashboards:
return self.response_404()
favorited_dashboard_ids = DashboardDAO.favorited_ids(
dashboards, g.user.get_id()
)
favorited_dashboard_ids = DashboardDAO.favorited_ids(dashboards)
res = [
{"id": request_id, "value": request_id in favorited_dashboard_ids}
for request_id in requested_ids
Expand Down
7 changes: 3 additions & 4 deletions superset/dashboards/dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from superset.models.core import FavStar, FavStarClassName
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.utils.core import get_user_id
from superset.utils.dashboard_filter_scopes_converter import copy_filter_scopes

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -274,17 +275,15 @@ def set_dash_metadata( # pylint: disable=too-many-locals
return dashboard

@staticmethod
def favorited_ids(
dashboards: List[Dashboard], current_user_id: int
) -> List[FavStar]:
def favorited_ids(dashboards: List[Dashboard]) -> List[FavStar]:
ids = [dash.id for dash in dashboards]
return [
star.obj_id
for star in db.session.query(FavStar.obj_id)
.filter(
FavStar.class_name == FavStarClassName.DASHBOARD,
FavStar.obj_id.in_(ids),
FavStar.user_id == current_user_id,
FavStar.user_id == get_user_id(),
)
.all()
]
12 changes: 5 additions & 7 deletions superset/dashboards/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from superset.models.embedded_dashboard import EmbeddedDashboard
from superset.models.slice import Slice
from superset.security.guest_token import GuestTokenResourceType, GuestUser
from superset.utils.core import get_user_id
from superset.views.base import BaseFilter, is_user_admin
from superset.views.base_api import BaseFavoriteFilter

Expand Down Expand Up @@ -57,9 +58,9 @@ def apply(self, query: Query, value: Any) -> Query:
return query.filter(
or_(
Dashboard.created_by_fk # pylint: disable=comparison-with-callable
== g.user.get_user_id(),
== get_user_id(),
Dashboard.changed_by_fk # pylint: disable=comparison-with-callable
== g.user.get_user_id(),
== get_user_id(),
)
)

Expand Down Expand Up @@ -126,17 +127,14 @@ def apply(self, query: Query, value: Any) -> Query:

users_favorite_dash_query = db.session.query(FavStar.obj_id).filter(
and_(
FavStar.user_id == security_manager.user_model.get_user_id(),
FavStar.user_id == get_user_id(),
FavStar.class_name == "Dashboard",
)
)
owner_ids_query = (
db.session.query(Dashboard.id)
.join(Dashboard.owners)
.filter(
security_manager.user_model.id
== security_manager.user_model.get_user_id()
)
.filter(security_manager.user_model.id == get_user_id())
)

feature_flagged_filters = []
Expand Down
11 changes: 8 additions & 3 deletions superset/jinja_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,11 @@
from superset.datasets.commands.exceptions import DatasetNotFoundError
from superset.exceptions import SupersetTemplateException
from superset.extensions import feature_flag_manager
from superset.utils.core import convert_legacy_filters_into_adhoc, merge_extra_filters
from superset.utils.core import (
convert_legacy_filters_into_adhoc,
get_user_id,
merge_extra_filters,
)
from superset.utils.memoized import memoized

if TYPE_CHECKING:
Expand Down Expand Up @@ -115,9 +119,10 @@ def current_user_id(self, add_to_cache_keys: bool = True) -> Optional[int]:
"""

if hasattr(g, "user") and g.user:
id_ = get_user_id()
if add_to_cache_keys:
self.cache_key_wrapper(g.user.get_id())
return g.user.get_id()
self.cache_key_wrapper(id_)
return id_
return None

def current_username(self, add_to_cache_keys: bool = True) -> Optional[str]:
Expand Down
2 changes: 1 addition & 1 deletion superset/key_value/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,4 @@ def get_uuid_namespace(seed: str) -> UUID:


def get_owner(user: User) -> Optional[int]:
return user.get_user_id() if not user.is_anonymous else None
return user.id if not user.is_anonymous else None
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from sqlalchemy.orm import relationship

from superset import db
from superset.utils.core import get_user_id

Base = declarative_base()

Expand Down Expand Up @@ -63,17 +64,10 @@ class User(Base):


class AuditMixin:
@classmethod
def get_user_id(cls):
try:
return g.user.id
except Exception:
return None

@declared_attr
def created_by_fk(cls):
return Column(
Integer, ForeignKey("ab_user.id"), default=cls.get_user_id, nullable=False
Integer, ForeignKey("ab_user.id"), default=get_user_id, nullable=False
)

@declared_attr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from sqlalchemy.ext.declarative import declarative_base, declared_attr

from superset.models.tags import ObjectTypes, TagTypes
from superset.utils.core import get_user_id

Base = declarative_base()

Expand All @@ -54,7 +55,7 @@ def created_by_fk(self) -> Column:
return Column(
Integer,
ForeignKey("ab_user.id"),
default=self.get_user_id,
default=get_user_id,
nullable=True,
)

Expand All @@ -63,8 +64,8 @@ def changed_by_fk(self) -> Column:
return Column(
Integer,
ForeignKey("ab_user.id"),
default=self.get_user_id,
onupdate=self.get_user_id,
default=get_user_id,
onupdate=get_user_id,
nullable=True,
)

Expand Down
7 changes: 4 additions & 3 deletions superset/models/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from sqlalchemy_utils import UUIDType

from superset.common.db_query_status import QueryStatus
from superset.utils.core import get_user_id

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -384,7 +385,7 @@ def created_by_fk(self) -> sa.Column:
return sa.Column(
sa.Integer,
sa.ForeignKey("ab_user.id"),
default=self.get_user_id,
default=get_user_id,
nullable=True,
)

Expand All @@ -393,8 +394,8 @@ def changed_by_fk(self) -> sa.Column:
return sa.Column(
sa.Integer,
sa.ForeignKey("ab_user.id"),
default=self.get_user_id,
onupdate=self.get_user_id,
default=get_user_id,
onupdate=get_user_id,
nullable=True,
)

Expand Down
4 changes: 2 additions & 2 deletions superset/queries/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
# under the License.
from typing import Any

from flask import g
from flask_sqlalchemy import BaseQuery

from superset import security_manager
from superset.models.sql_lab import Query
from superset.utils.core import get_user_id
from superset.views.base import BaseFilter


Expand All @@ -33,5 +33,5 @@ def apply(self, query: BaseQuery, value: Any) -> BaseQuery:
:returns: query
"""
if not security_manager.can_access_all_queries():
query = query.filter(Query.user_id == g.user.get_user_id())
query = query.filter(Query.user_id == get_user_id())
return query
9 changes: 4 additions & 5 deletions superset/security/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
GuestTokenUser,
GuestUser,
)
from superset.utils.core import DatasourceName, RowLevelSecurityFilterType
from superset.utils.core import DatasourceName, get_user_id, RowLevelSecurityFilterType
from superset.utils.urls import get_url_host

if TYPE_CHECKING:
Expand Down Expand Up @@ -529,7 +529,7 @@ def user_view_menu_names(self, permission_name: str) -> Set[str]:
view_menu_names = (
base_query.join(assoc_user_role)
.join(self.user_model)
.filter(self.user_model.id == g.user.get_id())
.filter(self.user_model.id == get_user_id())
.filter(self.permission_model.name == permission_name)
).all()
return {s.name for s in view_menu_names}
Expand Down Expand Up @@ -1252,10 +1252,9 @@ def get_rls_cache_key(self, datasource: "BaseDatasource") -> List[str]:

@staticmethod
def raise_for_user_activity_access(user_id: int) -> None:
user = g.user if g.user and g.user.get_id() else None
if not user or (
if not get_user_id() or (
not current_app.config["ENABLE_BROAD_ACTIVITY_ACCESS"]
and user_id != user.id
and user_id != get_user_id()
):
raise SupersetSecurityException(
SupersetError(
Expand Down
10 changes: 2 additions & 8 deletions superset/sqllab/sqllab_execution_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from superset.models.sql_lab import Query
from superset.sql_parse import CtasMethod
from superset.utils import core as utils
from superset.utils.core import apply_max_row_limit
from superset.utils.core import apply_max_row_limit, get_user_id
from superset.utils.dates import now_as_float
from superset.views.utils import get_cta_schema_name

Expand Down Expand Up @@ -64,7 +64,7 @@ def __init__(self, query_params: Dict[str, Any]):
self.create_table_as_select = None
self.database = None
self._init_from_query_params(query_params)
self.user_id = self._get_user_id()
self.user_id = get_user_id()
self.client_id_or_short_id = cast(str, self.client_id or utils.shortid()[:10])

def set_query(self, query: Query) -> None:
Expand Down Expand Up @@ -111,12 +111,6 @@ def _get_limit_param(query_params: Dict[str, Any]) -> int:
limit = 0
return limit

def _get_user_id(self) -> Optional[int]: # pylint: disable=no-self-use
try:
return g.user.get_id() if g.user else None
except RuntimeError:
return None

def is_run_asynchronous(self) -> bool:
return self.async_flag

Expand Down
18 changes: 7 additions & 11 deletions superset/utils/async_query_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@

import jwt
import redis
from flask import Flask, g, request, Request, Response, session
from flask import Flask, request, Request, Response, session

from superset.utils.core import get_user_id

logger = logging.getLogger(__name__)

Expand All @@ -35,12 +37,12 @@ class AsyncQueryJobException(Exception):


def build_job_metadata(
channel_id: str, job_id: str, user_id: Optional[str], **kwargs: Any
channel_id: str, job_id: str, user_id: Optional[int], **kwargs: Any
) -> Dict[str, Any]:
return {
"channel_id": channel_id,
"job_id": job_id,
"user_id": int(user_id) if user_id else None,
"user_id": user_id,
"status": kwargs.get("status"),
"errors": kwargs.get("errors", []),
"result_url": kwargs.get("result_url"),
Expand Down Expand Up @@ -113,13 +115,7 @@ def init_app(self, app: Flask) -> None:

@app.after_request
def validate_session(response: Response) -> Response:
user_id = None

try:
user_id = g.user.get_id()
user_id = int(user_id)
except Exception: # pylint: disable=broad-except
pass
user_id = get_user_id()

reset_token = (
not request.cookies.get(self._jwt_cookie_name)
Expand Down Expand Up @@ -161,7 +157,7 @@ def parse_jwt_from_request(self, req: Request) -> Dict[str, Any]:
logger.warning("Parse jwt failed", exc_info=True)
raise AsyncQueryTokenException("Failed to parse token") from ex

def init_job(self, channel_id: str, user_id: Optional[str]) -> Dict[str, Any]:
def init_job(self, channel_id: str, user_id: Optional[int]) -> Dict[str, Any]:
job_id = str(uuid.uuid4())
return build_job_metadata(
channel_id, job_id, user_id, status=self.STATUS_PENDING
Expand Down
Loading

0 comments on commit 3483446

Please sign in to comment.