Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: extract and narrow orm.object_session() #17582

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions tests/unit/utils/db/test_orm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from sqlalchemy.orm import object_session

from warehouse.db import Model
from warehouse.utils.db.orm import NoSessionError, orm_session_from_obj


def test_orm_session_from_obj_raises_with_no_session():

class FakeObject(Model):
__tablename__ = "fake_object"

obj = FakeObject()
# Confirm that the object does not have a session with the built-in
assert object_session(obj) is None

with pytest.raises(NoSessionError):
orm_session_from_obj(obj)
3 changes: 2 additions & 1 deletion warehouse/accounts/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from warehouse.observations.models import HasObservations, HasObservers, ObservationKind
from warehouse.sitemap.models import SitemapMixin
from warehouse.utils.attrs import make_repr
from warehouse.utils.db import orm_session_from_obj
from warehouse.utils.db.types import TZDateTime, bool_false, datetime_now

if TYPE_CHECKING:
Expand Down Expand Up @@ -236,7 +237,7 @@ def has_primary_verified_email(self):

@property
def recent_events(self):
session = orm.object_session(self)
session = orm_session_from_obj(self)
last_ninety = datetime.datetime.now() - datetime.timedelta(days=90)
return (
session.query(User.Event)
Expand Down
5 changes: 2 additions & 3 deletions warehouse/cache/origin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@

from itertools import chain

from sqlalchemy.orm.session import Session

from warehouse import db
from warehouse.cache.origin.derivers import html_cache_deriver
from warehouse.cache.origin.interfaces import IOriginCache
from warehouse.utils.db import orm_session_from_obj


@db.listens_for(db.Session, "after_flush")
Expand Down Expand Up @@ -139,7 +138,7 @@ def register_origin_cache_keys(config, klass, cache_keys=None, purge_keys=None):

def receive_set(attribute, config, target):
cache_keys = config.registry["cache_keys"]
session = Session.object_session(target)
session = orm_session_from_obj(target)
purges = session.info.setdefault("warehouse.cache.origin.purges", set())
key_maker = cache_keys[attribute]
keys = key_maker(target).purge
Expand Down
6 changes: 3 additions & 3 deletions warehouse/email/ses/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
from sqlalchemy.dialects.postgresql import JSONB, UUID as PG_UUID
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.orm.session import object_session

from warehouse import db
from warehouse.accounts.models import Email as EmailAddress, UnverifyReasons
from warehouse.utils.db import orm_session_from_obj
from warehouse.utils.db.types import bool_false, datetime_now

MAX_TRANSIENT_BOUNCES = 5
Expand Down Expand Up @@ -217,9 +217,9 @@ def _get_email(self):
if self._email_message.missing:
return

db = object_session(self._email_message)
session = orm_session_from_obj(self._email_message)
email = (
db.query(EmailAddress)
session.query(EmailAddress)
.filter(EmailAddress.email == self._email_message.to)
.first()
)
Expand Down
4 changes: 2 additions & 2 deletions warehouse/legacy/api/xmlrpc/cache/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from pyramid.exceptions import ConfigurationError
from sqlalchemy.orm.base import NO_VALUE
from sqlalchemy.orm.session import Session
from urllib3.util import parse_url

from warehouse import db
Expand All @@ -23,6 +22,7 @@
from warehouse.legacy.api.xmlrpc.cache.fncache import RedisLru
from warehouse.legacy.api.xmlrpc.cache.interfaces import IXMLRPCCache
from warehouse.legacy.api.xmlrpc.cache.services import NullXMLRPCCache, RedisXMLRPCCache
from warehouse.utils.db import orm_session_from_obj

__all__ = ["RedisLru"]

Expand All @@ -32,7 +32,7 @@

def receive_set(attribute, config, target):
cache_keys = config.registry["cache_keys"]
session = Session.object_session(target)
session = orm_session_from_obj(target)
purges = session.info.setdefault("warehouse.legacy.api.xmlrpc.cache.purges", set())
key_maker = cache_keys[attribute]
keys = key_maker(target).purge
Expand Down
14 changes: 5 additions & 9 deletions warehouse/organizations/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from warehouse.authnz import Permissions
from warehouse.events.models import HasEvents
from warehouse.utils.attrs import make_repr
from warehouse.utils.db import orm_session_from_obj
from warehouse.utils.db.types import TZDateTime, bool_false, datetime_now

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -332,22 +333,17 @@ class Organization(OrganizationMixin, HasEvents, db.Model):
@property
def owners(self):
"""Return all users who are owners of the organization."""
session = orm_session_from_obj(self)
owner_roles = (
orm.object_session(self)
.query(User.id)
session.query(User.id)
.join(OrganizationRole.user)
.filter(
OrganizationRole.role_name == OrganizationRoleType.Owner,
OrganizationRole.organization == self,
)
.subquery()
)
return (
orm.object_session(self)
.query(User)
.join(owner_roles, User.id == owner_roles.c.id)
.all()
)
return session.query(User).join(owner_roles, User.id == owner_roles.c.id).all()

def record_event(self, *, tag, request: Request = None, additional=None):
"""Record organization name in events in case organization is ever deleted."""
Expand All @@ -358,7 +354,7 @@ def record_event(self, *, tag, request: Request = None, additional=None):
)

def __acl__(self):
session = orm.object_session(self)
session = orm_session_from_obj(self)

acls = [
(
Expand Down
37 changes: 16 additions & 21 deletions warehouse/packaging/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
from warehouse.sitemap.models import SitemapMixin
from warehouse.utils import dotted_navigator, wheel
from warehouse.utils.attrs import make_repr
from warehouse.utils.db import orm_session_from_obj
from warehouse.utils.db.types import bool_false, bool_true, datetime_now

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -257,7 +258,7 @@ class Project(SitemapMixin, HasEvents, HasObservations, db.Model):
)

def __getitem__(self, version):
session = orm.object_session(self)
session = orm_session_from_obj(self)
canonical_version = packaging.utils.canonicalize_version(version)

try:
Expand Down Expand Up @@ -288,7 +289,7 @@ def __getitem__(self, version):
raise KeyError from None

def __acl__(self):
session = orm.object_session(self)
session = orm_session_from_obj(self)
acls = [
# TODO: Similar to `warehouse.accounts.models.User.__acl__`, we express the
# permissions here in terms of the permissions that the user has on
Expand Down Expand Up @@ -417,42 +418,36 @@ def documentation_url(self):
@property
def owners(self):
"""Return all users who are owners of the project."""
session = orm_session_from_obj(self)
owner_roles = (
orm.object_session(self)
.query(User.id)
session.query(User.id)
.join(Role.user)
.filter(Role.role_name == "Owner", Role.project == self)
.subquery()
)
return (
orm.object_session(self)
.query(User)
.join(owner_roles, User.id == owner_roles.c.id)
.all()
)
return session.query(User).join(owner_roles, User.id == owner_roles.c.id).all()

@property
def maintainers(self):
"""Return all users who are maintainers of the project."""
session = orm_session_from_obj(self)
maintainer_roles = (
orm.object_session(self)
.query(User.id)
session.query(User.id)
.join(Role.user)
.filter(Role.role_name == "Maintainer", Role.project == self)
.subquery()
)
return (
orm.object_session(self)
.query(User)
session.query(User)
.join(maintainer_roles, User.id == maintainer_roles.c.id)
.all()
)

@property
def all_versions(self):
session = orm_session_from_obj(self)
return (
orm.object_session(self)
.query(
session.query(
Release.version,
Release.created,
Release.is_prerelease,
Expand All @@ -466,9 +461,9 @@ def all_versions(self):

@property
def latest_version(self):
session = orm_session_from_obj(self)
return (
orm.object_session(self)
.query(Release.version, Release.created, Release.is_prerelease)
session.query(Release.version, Release.created, Release.is_prerelease)
.filter(Release.project == self, Release.yanked.is_(False))
.order_by(Release.is_prerelease.nullslast(), Release._pypi_ordering.desc())
.first()
Expand All @@ -477,7 +472,7 @@ def latest_version(self):
@property
def active_releases(self):
return (
orm.object_session(self)
orm_session_from_obj(self)
.query(Release)
.filter(Release.project == self, Release.yanked.is_(False))
.order_by(Release._pypi_ordering.desc())
Expand All @@ -487,7 +482,7 @@ def active_releases(self):
@property
def yanked_releases(self):
return (
orm.object_session(self)
orm_session_from_obj(self)
.query(Release)
.filter(Release.project == self, Release.yanked.is_(True))
.order_by(Release._pypi_ordering.desc())
Expand Down Expand Up @@ -747,7 +742,7 @@ def __table_args__(cls): # noqa
uploaded_via: Mapped[str | None]

def __getitem__(self, filename: str) -> File:
session: orm.Session = orm.object_session(self) # type: ignore[assignment]
session = orm_session_from_obj(self)

try:
return (
Expand Down
3 changes: 2 additions & 1 deletion warehouse/utils/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from warehouse.utils.db.orm import orm_session_from_obj
from warehouse.utils.db.query_printer import print_query

__all__ = ["print_query"]
__all__ = ["orm_session_from_obj", "print_query"]
33 changes: 33 additions & 0 deletions warehouse/utils/db/orm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""ORM utilities."""

from sqlalchemy.orm import Session, object_session


class NoSessionError(Exception):
"""Raised when there is no active SQLAlchemy session"""


def orm_session_from_obj(obj) -> Session:
"""
Returns the session from the ORM object.

Adds guard, but it should never happen.
The guard helps with type hinting, as the object_session function
returns Optional[Session] type.
"""
session = object_session(obj)
if not session:
raise NoSessionError("Object does not have a session")
return session
Loading