Skip to content

Commit

Permalink
⚰️ refactor(sqla): remove utils.methodcaller
Browse files Browse the repository at this point in the history
  • Loading branch information
ProgramRipper committed Dec 6, 2023
1 parent b732577 commit cc73a0f
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 34 deletions.
6 changes: 3 additions & 3 deletions nonebot_plugin_orm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def _create_engine(engine: str | URL | dict[str, Any] | AsyncEngine) -> AsyncEng
else:
url = engine

return create_async_engine(make_url(url), **options)
return create_async_engine(url, **options)


def _init_engines():
Expand Down Expand Up @@ -204,7 +204,7 @@ def _init_table():
_plugins = {}

_get_plugin_by_module_name = lru_cache(None)(get_plugin_by_module_name)
for model in get_subclasses(Model):
for model in set(get_subclasses(Model)):
table: Table | None = getattr(model, "__table__", None)

if table is None or (bind_key := table.info.get("bind_key")) is None:
Expand All @@ -229,7 +229,7 @@ def _init_logger():
"sqlalchemy": log_level,
**{
_qual_logger_name_for_cls(cls): echo_log_level
for cls in get_subclasses(Identified)
for cls in set(get_subclasses(Identified))
},
}

Expand Down
3 changes: 2 additions & 1 deletion nonebot_plugin_orm/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from itertools import repeat
from typing import Any, cast
from dataclasses import dataclass
from operator import methodcaller
from inspect import Parameter, isclass

from pydantic.fields import FieldInfo
Expand All @@ -15,7 +16,7 @@
from sqlalchemy.ext.asyncio import AsyncResult, AsyncScalarResult

from .model import Model
from .utils import Option, methodcaller, compile_dependency, generic_issubclass
from .utils import Option, compile_dependency, generic_issubclass

if sys.version_info >= (3, 9):
from typing import Annotated
Expand Down
40 changes: 10 additions & 30 deletions nonebot_plugin_orm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from itertools import repeat
from contextlib import suppress
from typing import Any, TypeVar
from operator import methodcaller
from typing_extensions import Annotated
from dataclasses import field, dataclass
from inspect import Parameter, Signature, isclass
Expand Down Expand Up @@ -95,29 +96,10 @@ class Option:
calls: list[methodcaller] = field(default_factory=list)


class methodcaller:
__slots__ = ("_name", "_args", "_kwargs")

def __init__(self, name, /, *args, **kwargs):
self._name = name
if not isinstance(self._name, str):
raise TypeError("method name must be a string")
self._args = args
self._kwargs = kwargs

def __call__(self, obj):
return getattr(obj, self._name)(*self._args, **self._kwargs)

def __eq__(self, value: object, /) -> bool:
return isinstance(value, methodcaller) and all(
getattr(self, attr) == getattr(value, attr) for attr in self.__slots__
)


def compile_dependency(statement: ExecutableReturnsRows, option: Option) -> Any:
from . import async_scoped_session

async def dependency(*, __session: async_scoped_session, **params: Any):
async def __dependency(*, __session: async_scoped_session, **params: Any):
if option.stream:
result = await __session.stream(statement, params)
else:
Expand All @@ -137,7 +119,7 @@ async def dependency(*, __session: async_scoped_session, **params: Any):

return result

dependency.__signature__ = Signature(
__dependency.__signature__ = Signature(
[
Parameter(
"__session", Parameter.KEYWORD_ONLY, annotation=async_scoped_session
Expand All @@ -150,7 +132,7 @@ async def dependency(*, __session: async_scoped_session, **params: Any):
]
)

return Depends(dependency)
return Depends(__dependency)


def generic_issubclass(scls: Any, cls: Any) -> Any:
Expand Down Expand Up @@ -239,16 +221,16 @@ def is_editable(plugin: Plugin) -> bool:
with suppress(PackageNotFoundError):
dist = distribution(plugin.name.replace("_", "-"))

if not (dist or plugin.module.__file__ is None):
if not dist and plugin.module.__file__:
path = Path(plugin.module.__file__)
for name in pkgs.get(plugin.module_name.split(".")[0], ()):
dist = distribution(name)
if path in map(methodcaller("locate"), dist.files or ()):
if path in (file.locate() for file in dist.files or ()):
break
else:
dist = None

if dist is None:
if not dist:
return True

# https://github.com/pdm-project/pdm/blob/fee1e6bffd7de30315e2134e19f9a6f58e15867c/src/pdm/utils.py#L361-L374
Expand All @@ -263,12 +245,10 @@ def is_editable(plugin: Plugin) -> bool:
return direct_url_data.get("dir_info", {}).get("editable", False)


def get_subclasses(cls: type[_T]) -> set[type[_T]]:
subclasses = set()
def get_subclasses(cls: type[_T]) -> Generator[type[_T], None, None]:
yield from cls.__subclasses__()
for subclass in cls.__subclasses__():
subclasses.add(subclass)
subclasses.update(get_subclasses(subclass))
return subclasses
yield from get_subclasses(subclass)


if sys.version_info >= (3, 10):
Expand Down

0 comments on commit cc73a0f

Please sign in to comment.