Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add --task-impl option #468

Merged
merged 2 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ Options:
--loop [auto|asyncio|rloop|uvloop]
Event loop implementation [env var:
GRANIAN_LOOP; default: (auto)]
--task-impl [auto|rust|asyncio]
Async task implementation to use [env var:
GRANIAN_TASK_IMPL; default: (auto)]
--backlog INTEGER RANGE Maximum number of connections to hold in
backlog (globally) [env var:
GRANIAN_BACKLOG; default: 1024; x>=128]
Expand Down
15 changes: 11 additions & 4 deletions granian/_futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,17 @@ def __init__(self, loop, ctx, cb, aio_tenter, aio_texit):
self._schedule_fn = _cbsched_schedule(loop, ctx, self._run, cb)


def _new_cbscheduler(loop, cb):
return _CBScheduler(
loop, contextvars.copy_context(), cb, partial(_aio_taskenter, loop), partial(_aio_taskleave, loop)
)
class _CBSchedulerAIO(_BaseCBScheduler):
__slots__ = []

def __init__(self, loop, ctx, cb, aio_tenter, aio_texit):
super().__init__()
self._schedule_fn = _cbsched_schedule(loop, ctx, loop.create_task, cb)


def _new_cbscheduler(loop, cb, impl_asyncio=False):
_cls = _CBSchedulerAIO if impl_asyncio else _CBScheduler
return _cls(loop, contextvars.copy_context(), cb, partial(_aio_taskenter, loop), partial(_aio_taskleave, loop))


def _cbsched_schedule(loop, ctx, run, cb):
Expand Down
5 changes: 5 additions & 0 deletions granian/_imports.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
try:
import anyio
except ImportError:
anyio = None

try:
import setproctitle
except ImportError:
Expand Down
10 changes: 9 additions & 1 deletion granian/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import click

from .constants import HTTPModes, Interfaces, Loops, ThreadModes
from .constants import HTTPModes, Interfaces, Loops, TaskImpl, ThreadModes
from .errors import FatalError
from .http import HTTP1Settings, HTTP2Settings
from .log import LogLevels
Expand Down Expand Up @@ -77,6 +77,12 @@ def option(*param_decls: str, cls: Optional[Type[click.Option]] = None, **attrs:
help='Threading mode to use',
)
@option('--loop', type=EnumType(Loops), default=Loops.auto, help='Event loop implementation')
@option(
'--task-impl',
type=EnumType(TaskImpl),
default=TaskImpl.auto,
help='Async task implementation to use',
)
@option(
'--backlog',
type=click.IntRange(128),
Expand Down Expand Up @@ -261,6 +267,7 @@ def cli(
blocking_threads: Optional[int],
threading_mode: ThreadModes,
loop: Loops,
task_impl: TaskImpl,
backlog: int,
backpressure: Optional[int],
http1_buffer_size: int,
Expand Down Expand Up @@ -316,6 +323,7 @@ def cli(
blocking_threads=blocking_threads,
threading_mode=threading_mode,
loop=loop,
task_impl=task_impl,
http=http,
websockets=websockets,
backlog=backlog,
Expand Down
6 changes: 6 additions & 0 deletions granian/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,9 @@ class Loops(StrEnum):
asyncio = 'asyncio'
rloop = 'rloop'
uvloop = 'uvloop'


class TaskImpl(StrEnum):
auto = 'auto'
rust = 'rust'
asyncio = 'asyncio'
30 changes: 24 additions & 6 deletions granian/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

from ._futures import _future_watcher_wrapper, _new_cbscheduler
from ._granian import ASGIWorker, RSGIWorker, WSGIWorker
from ._imports import setproctitle, watchfiles
from ._imports import anyio, setproctitle, watchfiles
from ._internal import load_target
from ._signals import set_main_signals
from .asgi import LifespanProtocol, _callback_wrapper as _asgi_call_wrap
from .constants import HTTPModes, Interfaces, Loops, ThreadModes
from .constants import HTTPModes, Interfaces, Loops, TaskImpl, ThreadModes
from .errors import ConfigurationError, PidFileError
from .http import HTTP1Settings, HTTP2Settings
from .log import DEFAULT_ACCESSLOG_FMT, LogLevels, configure_logging, logger
Expand Down Expand Up @@ -81,6 +81,7 @@ def __init__(
blocking_threads: Optional[int] = None,
threading_mode: ThreadModes = ThreadModes.workers,
loop: Loops = Loops.auto,
task_impl: TaskImpl = TaskImpl.auto,
http: HTTPModes = HTTPModes.auto,
websockets: bool = True,
backlog: int = 1024,
Expand Down Expand Up @@ -118,6 +119,7 @@ def __init__(
self.threads = max(1, threads)
self.threading_mode = threading_mode
self.loop = loop
self.task_impl = task_impl
self.http = http
self.websockets = websockets
self.backlog = max(128, backlog)
Expand Down Expand Up @@ -188,6 +190,7 @@ def _spawn_asgi_worker(
blocking_threads: int,
backpressure: int,
threading_mode: ThreadModes,
task_impl: TaskImpl,
http_mode: HTTPModes,
http1_settings: Optional[HTTP1Settings],
http2_settings: Optional[HTTP2Settings],
Expand Down Expand Up @@ -225,7 +228,9 @@ def _spawn_asgi_worker(
*ssl_ctx,
)
serve = getattr(worker, {ThreadModes.runtime: 'serve_rth', ThreadModes.workers: 'serve_wth'}[threading_mode])
scheduler = _new_cbscheduler(loop, _future_watcher_wrapper(wcallback))
scheduler = _new_cbscheduler(
loop, _future_watcher_wrapper(wcallback), impl_asyncio=task_impl == TaskImpl.asyncio
)
serve(scheduler, loop, shutdown_event)

@staticmethod
Expand All @@ -239,6 +244,7 @@ def _spawn_asgi_lifespan_worker(
blocking_threads: int,
backpressure: int,
threading_mode: ThreadModes,
task_impl: TaskImpl,
http_mode: HTTPModes,
http1_settings: Optional[HTTP1Settings],
http2_settings: Optional[HTTP2Settings],
Expand Down Expand Up @@ -283,7 +289,9 @@ def _spawn_asgi_lifespan_worker(
*ssl_ctx,
)
serve = getattr(worker, {ThreadModes.runtime: 'serve_rth', ThreadModes.workers: 'serve_wth'}[threading_mode])
scheduler = _new_cbscheduler(loop, _future_watcher_wrapper(wcallback))
scheduler = _new_cbscheduler(
loop, _future_watcher_wrapper(wcallback), impl_asyncio=task_impl == TaskImpl.asyncio
)
serve(scheduler, loop, shutdown_event)
loop.run_until_complete(lifespan_handler.shutdown())

Expand All @@ -298,6 +306,7 @@ def _spawn_rsgi_worker(
blocking_threads: int,
backpressure: int,
threading_mode: ThreadModes,
task_impl: TaskImpl,
http_mode: HTTPModes,
http1_settings: Optional[HTTP1Settings],
http2_settings: Optional[HTTP2Settings],
Expand Down Expand Up @@ -343,7 +352,9 @@ def _spawn_rsgi_worker(
*ssl_ctx,
)
serve = getattr(worker, {ThreadModes.runtime: 'serve_rth', ThreadModes.workers: 'serve_wth'}[threading_mode])
scheduler = _new_cbscheduler(loop, _future_watcher_wrapper(callback))
scheduler = _new_cbscheduler(
loop, _future_watcher_wrapper(callback), impl_asyncio=task_impl == TaskImpl.asyncio
)
serve(scheduler, loop, shutdown_event)
callback_del(loop)

Expand All @@ -358,6 +369,7 @@ def _spawn_wsgi_worker(
blocking_threads: int,
backpressure: int,
threading_mode: ThreadModes,
task_impl: TaskImpl,
http_mode: HTTPModes,
http1_settings: Optional[HTTP1Settings],
http2_settings: Optional[HTTP2Settings],
Expand Down Expand Up @@ -385,7 +397,9 @@ def _spawn_wsgi_worker(
worker_id, sfd, threads, blocking_threads, backpressure, http_mode, http1_settings, http2_settings, *ssl_ctx
)
serve = getattr(worker, {ThreadModes.runtime: 'serve_rth', ThreadModes.workers: 'serve_wth'}[threading_mode])
scheduler = _new_cbscheduler(loop, _wsgi_call_wrap(callback, scope_opts, log_access_fmt))
scheduler = _new_cbscheduler(
loop, _wsgi_call_wrap(callback, scope_opts, log_access_fmt), impl_asyncio=task_impl == TaskImpl.asyncio
)
serve(scheduler, loop, shutdown_event)
shutdown_event.qs.wait()

Expand Down Expand Up @@ -416,6 +430,7 @@ def _spawn_proc(self, idx, target, callback_loader, socket_loader) -> Worker:
self.blocking_threads,
self.backpressure,
self.threading_mode,
self.task_impl,
self.http,
self.http1_settings,
self.http2_settings,
Expand Down Expand Up @@ -713,5 +728,8 @@ def serve(
logger.error('Workers lifetime cannot be less than 60 seconds')
raise ConfigurationError('workers_lifetime')

if self.task_impl == TaskImpl.auto:
self.task_impl = TaskImpl.asyncio if anyio is not None else TaskImpl.rust

serve_method = self._serve_with_reloader if self.reload_on_changes else self._serve
serve_method(spawn_target, target_loader)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ extend-ignore = [
'S110', # except pass is fine
]
flake8-quotes = { inline-quotes = 'single', multiline-quotes = 'double' }
mccabe = { max-complexity = 13 }
mccabe = { max-complexity = 14 }

[tool.ruff.format]
quote-style = 'single'
Expand Down
Loading