Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Cleanup user get_id/get_user_id #20492

Merged
merged 1 commit into from
Jun 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion superset/charts/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,7 +805,7 @@ def favorite_status(self, **kwargs: Any) -> Response:
charts = ChartDAO.find_by_ids(requested_ids)
if not charts:
return self.response_404()
favorited_chart_ids = ChartDAO.favorited_ids(charts, g.user.get_id())
favorited_chart_ids = ChartDAO.favorited_ids(charts)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoiding passing around the global user. More on this in a later PR.

res = [
{"id": request_id, "value": request_id in favorited_chart_ids}
for request_id in requested_ids
Expand Down
5 changes: 3 additions & 2 deletions superset/charts/dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from superset.extensions import db
from superset.models.core import FavStar, FavStarClassName
from superset.models.slice import Slice
from superset.utils.core import get_user_id

if TYPE_CHECKING:
from superset.connectors.base.models import BaseDatasource
Expand Down Expand Up @@ -70,15 +71,15 @@ def overwrite(slc: Slice, commit: bool = True) -> None:
db.session.commit()

@staticmethod
def favorited_ids(charts: List[Slice], current_user_id: int) -> List[FavStar]:
def favorited_ids(charts: List[Slice]) -> List[FavStar]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will we ever need to get a list of favorited charts for other users?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ktmud no. Also this logic is going to refactored in another PR. In general Superset does not handle users other than the current logged in user. There have been recent mechanisms added to override the user globally if needed.

ids = [chart.id for chart in charts]
return [
star.obj_id
for star in db.session.query(FavStar.obj_id)
.filter(
FavStar.class_name == FavStarClassName.CHART,
FavStar.obj_id.in_(ids),
FavStar.user_id == current_user_id,
FavStar.user_id == get_user_id(),
)
.all()
]
6 changes: 3 additions & 3 deletions superset/charts/data/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from typing import Any, Dict, Optional, TYPE_CHECKING

import simplejson
from flask import current_app, g, make_response, request, Response
from flask import current_app, make_response, request, Response
from flask_appbuilder.api import expose, protect
from flask_babel import gettext as _
from marshmallow import ValidationError
Expand All @@ -44,7 +44,7 @@
from superset.exceptions import QueryObjectValidationError
from superset.extensions import event_logger
from superset.utils.async_query_manager import AsyncQueryTokenException
from superset.utils.core import create_zip, json_int_dttm_ser
from superset.utils.core import create_zip, get_user_id, json_int_dttm_ser
from superset.views.base import CsvResponse, generate_download_headers
from superset.views.base_api import statsd_metrics

Expand Down Expand Up @@ -324,7 +324,7 @@ def _run_async(
except AsyncQueryTokenException:
return self.response_401()

result = async_command.run(form_data, g.user.get_id())
result = async_command.run(form_data, get_user_id())
return self.response(202, **result)

def _send_chart_response(
Expand Down
2 changes: 1 addition & 1 deletion superset/charts/data/commands/create_async_job_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def validate(self, request: Request) -> None:
jwt_data = async_query_manager.parse_jwt_from_request(request)
self._async_channel_id = jwt_data["channel"]

def run(self, form_data: Dict[str, Any], user_id: Optional[str]) -> Dict[str, Any]:
def run(self, form_data: Dict[str, Any], user_id: Optional[int]) -> Dict[str, Any]:
job_metadata = async_query_manager.init_job(self._async_channel_id, user_id)
load_chart_data_into_cache.delay(job_metadata, form_data)
return job_metadata
4 changes: 1 addition & 3 deletions superset/dashboards/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,9 +942,7 @@ def favorite_status(self, **kwargs: Any) -> Response:
dashboards = DashboardDAO.find_by_ids(requested_ids)
if not dashboards:
return self.response_404()
favorited_dashboard_ids = DashboardDAO.favorited_ids(
dashboards, g.user.get_id()
)
favorited_dashboard_ids = DashboardDAO.favorited_ids(dashboards)
res = [
{"id": request_id, "value": request_id in favorited_dashboard_ids}
for request_id in requested_ids
Expand Down
7 changes: 3 additions & 4 deletions superset/dashboards/dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from superset.models.core import FavStar, FavStarClassName
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.utils.core import get_user_id
from superset.utils.dashboard_filter_scopes_converter import copy_filter_scopes

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -274,17 +275,15 @@ def set_dash_metadata( # pylint: disable=too-many-locals
return dashboard

@staticmethod
def favorited_ids(
dashboards: List[Dashboard], current_user_id: int
) -> List[FavStar]:
def favorited_ids(dashboards: List[Dashboard]) -> List[FavStar]:
ids = [dash.id for dash in dashboards]
return [
star.obj_id
for star in db.session.query(FavStar.obj_id)
.filter(
FavStar.class_name == FavStarClassName.DASHBOARD,
FavStar.obj_id.in_(ids),
FavStar.user_id == current_user_id,
FavStar.user_id == get_user_id(),
)
.all()
]
12 changes: 5 additions & 7 deletions superset/dashboards/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from superset.models.embedded_dashboard import EmbeddedDashboard
from superset.models.slice import Slice
from superset.security.guest_token import GuestTokenResourceType, GuestUser
from superset.utils.core import get_user_id
from superset.views.base import BaseFilter, is_user_admin
from superset.views.base_api import BaseFavoriteFilter

Expand Down Expand Up @@ -57,9 +58,9 @@ def apply(self, query: Query, value: Any) -> Query:
return query.filter(
or_(
Dashboard.created_by_fk # pylint: disable=comparison-with-callable
== g.user.get_user_id(),
Copy link
Member Author

@john-bodley john-bodley Jun 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The get_user_id method is a classmethod and thus doesn't return the ID associated with the user object but rather the g.user.id. In this case they're the same but it's very misleading.

== get_user_id(),
Dashboard.changed_by_fk # pylint: disable=comparison-with-callable
== g.user.get_user_id(),
== get_user_id(),
)
)

Expand Down Expand Up @@ -126,17 +127,14 @@ def apply(self, query: Query, value: Any) -> Query:

users_favorite_dash_query = db.session.query(FavStar.obj_id).filter(
and_(
FavStar.user_id == security_manager.user_model.get_user_id(),
FavStar.user_id == get_user_id(),
FavStar.class_name == "Dashboard",
)
)
owner_ids_query = (
db.session.query(Dashboard.id)
.join(Dashboard.owners)
.filter(
security_manager.user_model.id
== security_manager.user_model.get_user_id()
)
.filter(security_manager.user_model.id == get_user_id())
)

feature_flagged_filters = []
Expand Down
11 changes: 8 additions & 3 deletions superset/jinja_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,11 @@
from superset.datasets.commands.exceptions import DatasetNotFoundError
from superset.exceptions import SupersetTemplateException
from superset.extensions import feature_flag_manager
from superset.utils.core import convert_legacy_filters_into_adhoc, merge_extra_filters
from superset.utils.core import (
convert_legacy_filters_into_adhoc,
get_user_id,
merge_extra_filters,
)
from superset.utils.memoized import memoized

if TYPE_CHECKING:
Expand Down Expand Up @@ -115,9 +119,10 @@ def current_user_id(self, add_to_cache_keys: bool = True) -> Optional[int]:
"""

if hasattr(g, "user") and g.user:
id_ = get_user_id()
if add_to_cache_keys:
self.cache_key_wrapper(g.user.get_id())
return g.user.get_id()
self.cache_key_wrapper(id_)
return id_
return None

def current_username(self, add_to_cache_keys: bool = True) -> Optional[str]:
Expand Down
2 changes: 1 addition & 1 deletion superset/key_value/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,4 @@ def get_uuid_namespace(seed: str) -> UUID:


def get_owner(user: User) -> Optional[int]:
return user.get_user_id() if not user.is_anonymous else None
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above. This is wrong, user.get_user_id does not return user.id—for non-anonymous users—but rather g.user.id. The reason this isn't an issue is that user is always g.user. I'm working on another PR which will replace get_owner with get_user_id.

return user.id if not user.is_anonymous else None
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from sqlalchemy.orm import relationship

from superset import db
from superset.utils.core import get_user_id

Base = declarative_base()

Expand Down Expand Up @@ -63,17 +64,10 @@ class User(Base):


class AuditMixin:
@classmethod
def get_user_id(cls):
try:
return g.user.id
except Exception:
return None

@declared_attr
def created_by_fk(cls):
return Column(
Integer, ForeignKey("ab_user.id"), default=cls.get_user_id, nullable=False
Integer, ForeignKey("ab_user.id"), default=get_user_id, nullable=False
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same logic but ensuring we use the same get_user_id method.

)

@declared_attr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from sqlalchemy.ext.declarative import declarative_base, declared_attr

from superset.models.tags import ObjectTypes, TagTypes
from superset.utils.core import get_user_id

Base = declarative_base()

Expand All @@ -54,7 +55,7 @@ def created_by_fk(self) -> Column:
return Column(
Integer,
ForeignKey("ab_user.id"),
default=self.get_user_id,
default=get_user_id,
nullable=True,
)

Expand All @@ -63,8 +64,8 @@ def changed_by_fk(self) -> Column:
return Column(
Integer,
ForeignKey("ab_user.id"),
default=self.get_user_id,
onupdate=self.get_user_id,
default=get_user_id,
onupdate=get_user_id,
nullable=True,
)

Expand Down
7 changes: 4 additions & 3 deletions superset/models/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from sqlalchemy_utils import UUIDType

from superset.common.db_query_status import QueryStatus
from superset.utils.core import get_user_id

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -384,7 +385,7 @@ def created_by_fk(self) -> sa.Column:
return sa.Column(
sa.Integer,
sa.ForeignKey("ab_user.id"),
default=self.get_user_id,
default=get_user_id,
nullable=True,
)

Expand All @@ -393,8 +394,8 @@ def changed_by_fk(self) -> sa.Column:
return sa.Column(
sa.Integer,
sa.ForeignKey("ab_user.id"),
default=self.get_user_id,
onupdate=self.get_user_id,
default=get_user_id,
onupdate=get_user_id,
nullable=True,
)

Expand Down
4 changes: 2 additions & 2 deletions superset/queries/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
# under the License.
from typing import Any

from flask import g
from flask_sqlalchemy import BaseQuery

from superset import security_manager
from superset.models.sql_lab import Query
from superset.utils.core import get_user_id
from superset.views.base import BaseFilter


Expand All @@ -33,5 +33,5 @@ def apply(self, query: BaseQuery, value: Any) -> BaseQuery:
:returns: query
"""
if not security_manager.can_access_all_queries():
query = query.filter(Query.user_id == g.user.get_user_id())
query = query.filter(Query.user_id == get_user_id())
return query
9 changes: 4 additions & 5 deletions superset/security/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
GuestTokenUser,
GuestUser,
)
from superset.utils.core import DatasourceName, RowLevelSecurityFilterType
from superset.utils.core import DatasourceName, get_user_id, RowLevelSecurityFilterType
from superset.utils.urls import get_url_host

if TYPE_CHECKING:
Expand Down Expand Up @@ -529,7 +529,7 @@ def user_view_menu_names(self, permission_name: str) -> Set[str]:
view_menu_names = (
base_query.join(assoc_user_role)
.join(self.user_model)
.filter(self.user_model.id == g.user.get_id())
.filter(self.user_model.id == get_user_id())
.filter(self.permission_model.name == permission_name)
).all()
return {s.name for s in view_menu_names}
Expand Down Expand Up @@ -1252,10 +1252,9 @@ def get_rls_cache_key(self, datasource: "BaseDatasource") -> List[str]:

@staticmethod
def raise_for_user_activity_access(user_id: int) -> None:
user = g.user if g.user and g.user.get_id() else None
if not user or (
if not get_user_id() or (
not current_app.config["ENABLE_BROAD_ACTIVITY_ACCESS"]
and user_id != user.id
and user_id != get_user_id()
):
raise SupersetSecurityException(
SupersetError(
Expand Down
10 changes: 2 additions & 8 deletions superset/sqllab/sqllab_execution_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from superset.models.sql_lab import Query
from superset.sql_parse import CtasMethod
from superset.utils import core as utils
from superset.utils.core import apply_max_row_limit
from superset.utils.core import apply_max_row_limit, get_user_id
from superset.utils.dates import now_as_float
from superset.views.utils import get_cta_schema_name

Expand Down Expand Up @@ -64,7 +64,7 @@ def __init__(self, query_params: Dict[str, Any]):
self.create_table_as_select = None
self.database = None
self._init_from_query_params(query_params)
self.user_id = self._get_user_id()
self.user_id = get_user_id()
self.client_id_or_short_id = cast(str, self.client_id or utils.shortid()[:10])

def set_query(self, query: Query) -> None:
Expand Down Expand Up @@ -111,12 +111,6 @@ def _get_limit_param(query_params: Dict[str, Any]) -> int:
limit = 0
return limit

def _get_user_id(self) -> Optional[int]: # pylint: disable=no-self-use
try:
return g.user.get_id() if g.user else None
except RuntimeError:
return None

def is_run_asynchronous(self) -> bool:
return self.async_flag

Expand Down
18 changes: 7 additions & 11 deletions superset/utils/async_query_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@

import jwt
import redis
from flask import Flask, g, request, Request, Response, session
from flask import Flask, request, Request, Response, session

from superset.utils.core import get_user_id

logger = logging.getLogger(__name__)

Expand All @@ -35,12 +37,12 @@ class AsyncQueryJobException(Exception):


def build_job_metadata(
channel_id: str, job_id: str, user_id: Optional[str], **kwargs: Any
channel_id: str, job_id: str, user_id: Optional[int], **kwargs: Any
) -> Dict[str, Any]:
return {
"channel_id": channel_id,
"job_id": job_id,
"user_id": int(user_id) if user_id else None,
"user_id": user_id,
"status": kwargs.get("status"),
"errors": kwargs.get("errors", []),
"result_url": kwargs.get("result_url"),
Expand Down Expand Up @@ -113,13 +115,7 @@ def init_app(self, app: Flask) -> None:

@app.after_request
def validate_session(response: Response) -> Response:
user_id = None

try:
user_id = g.user.get_id()
user_id = int(user_id)
except Exception: # pylint: disable=broad-except
pass
user_id = get_user_id()

reset_token = (
not request.cookies.get(self._jwt_cookie_name)
Expand Down Expand Up @@ -161,7 +157,7 @@ def parse_jwt_from_request(self, req: Request) -> Dict[str, Any]:
logger.warning("Parse jwt failed", exc_info=True)
raise AsyncQueryTokenException("Failed to parse token") from ex

def init_job(self, channel_id: str, user_id: Optional[str]) -> Dict[str, Any]:
def init_job(self, channel_id: str, user_id: Optional[int]) -> Dict[str, Any]:
job_id = str(uuid.uuid4())
return build_job_metadata(
channel_id, job_id, user_id, status=self.STATUS_PENDING
Expand Down
Loading