diff --git a/granian/server/mt.py b/granian/server/mt.py index c63b163f..2a02e198 100644 --- a/granian/server/mt.py +++ b/granian/server/mt.py @@ -148,6 +148,7 @@ def _spawn_asgi_lifespan_worker( loop, _future_watcher_wrapper(wcallback), impl_asyncio=task_impl == TaskImpl.asyncio ) serve(scheduler, loop, shutdown_event) + loop.run_until_complete(lifespan_handler.shutdown()) @staticmethod def _spawn_rsgi_worker( diff --git a/src/asgi/callbacks.rs b/src/asgi/callbacks.rs index 438a16fc..a321628a 100644 --- a/src/asgi/callbacks.rs +++ b/src/asgi/callbacks.rs @@ -13,7 +13,7 @@ use super::{ use crate::{ callbacks::ArcCBScheduler, http::{response_500, HTTPResponse}, - runtime::RuntimeRef, + runtime::{Runtime, RuntimeRef}, utils::log_application_callable_exception, ws::{HyperWebsocket, UpgradeData}, }; @@ -35,9 +35,9 @@ macro_rules! callback_impl_done_ws { } macro_rules! callback_impl_done_err { - ($self:expr, $err:expr) => { + ($self:expr, $py:expr, $err:expr) => { $self.done(); - log_application_callable_exception($err); + log_application_callable_exception($py, $err); }; } @@ -72,8 +72,8 @@ impl CallbackWatcherHTTP { callback_impl_done_http!(self); } - fn err(&self, err: Bound) { - callback_impl_done_err!(self, &PyErr::from_value(err)); + fn err(&self, py: Python, err: Bound) { + callback_impl_done_err!(self, py, &PyErr::from_value(err)); } fn taskref(&self, py: Python, task: PyObject) { @@ -106,8 +106,8 @@ impl CallbackWatcherWebsocket { callback_impl_done_ws!(self); } - fn err(&self, err: Bound) { - callback_impl_done_err!(self, &PyErr::from_value(err)); + fn err(&self, py: Python, err: Bound) { + callback_impl_done_err!(self, py, &PyErr::from_value(err)); } fn taskref(&self, py: Python, task: PyObject) { @@ -138,7 +138,6 @@ impl CallbackWatcherWebsocket { // } // } -#[cfg(not(Py_GIL_DISABLED))] #[inline] pub(crate) fn call_http( cb: ArcCBScheduler, @@ -149,12 +148,11 @@ pub(crate) fn call_http( req: hyper::http::request::Parts, body: hyper::body::Incoming, ) -> oneshot::Receiver { - let brt = rt.innerb.clone(); let (tx, rx) = oneshot::channel(); - let protocol = HTTPProtocol::new(rt, body, tx); + let protocol = HTTPProtocol::new(rt.clone(), body, tx); let scheme: Arc = scheme.into(); - let _ = brt.run(move || { + rt.spawn_blocking(move |py| { scope_native_parts!( req, server_addr, @@ -165,45 +163,18 @@ pub(crate) fn call_http( server, client ); - Python::with_gil(|py| { - let scope = build_scope_http(py, &req, version, server, client, &scheme, &path, query_string).unwrap(); - let watcher = Py::new(py, CallbackWatcherHTTP::new(py, protocol, scope)).unwrap(); - cb.get().schedule(py, watcher.as_any()); - }); - }); - - rx -} - -#[cfg(Py_GIL_DISABLED)] -#[inline] -pub(crate) fn call_http( - cb: ArcCBScheduler, - rt: RuntimeRef, - server_addr: SocketAddr, - client_addr: SocketAddr, - scheme: &str, - req: hyper::http::request::Parts, - body: hyper::body::Incoming, -) -> oneshot::Receiver { - let (tx, rx) = oneshot::channel(); - let protocol = HTTPProtocol::new(rt, body, tx); - let scheme: Arc = scheme.into(); - - scope_native_parts!( - req, - server_addr, - client_addr, - path, - query_string, - version, - server, - client - ); - Python::with_gil(|py| { - let scope = build_scope_http(py, &req, version, server, client, &scheme, &path, query_string).unwrap(); - let watcher = Py::new(py, CallbackWatcherHTTP::new(py, protocol, scope)).unwrap(); - cb.get().schedule(py, watcher.as_any()); + cb.get().schedule( + py, + Py::new( + py, + CallbackWatcherHTTP::new( + py, + protocol, + build_scope_http(py, &req, version, server, client, &scheme, &path, query_string).unwrap(), + ), + ) + .unwrap(), + ); }); rx @@ -221,12 +192,11 @@ pub(crate) fn call_ws( req: hyper::http::request::Parts, upgrade: UpgradeData, ) -> oneshot::Receiver { - let brt = rt.innerb.clone(); let (tx, rx) = oneshot::channel(); - let protocol = WebsocketProtocol::new(rt, tx, ws, upgrade); + let protocol = WebsocketProtocol::new(rt.clone(), tx, ws, upgrade); let scheme: Arc = scheme.into(); - let _ = brt.run(move || { + rt.spawn_blocking(move |py| { scope_native_parts!( req, server_addr, @@ -237,46 +207,18 @@ pub(crate) fn call_ws( server, client ); - Python::with_gil(|py| { - let scope = build_scope_ws(py, &req, version, server, client, &scheme, &path, query_string).unwrap(); - let watcher = Py::new(py, CallbackWatcherWebsocket::new(py, protocol, scope)).unwrap(); - cb.get().schedule(py, watcher.as_any()); - }); - }); - - rx -} - -#[cfg(Py_GIL_DISABLED)] -#[inline] -pub(crate) fn call_ws( - cb: ArcCBScheduler, - rt: RuntimeRef, - server_addr: SocketAddr, - client_addr: SocketAddr, - scheme: &str, - ws: HyperWebsocket, - req: hyper::http::request::Parts, - upgrade: UpgradeData, -) -> oneshot::Receiver { - let (tx, rx) = oneshot::channel(); - let protocol = WebsocketProtocol::new(rt, tx, ws, upgrade); - let scheme: Arc = scheme.into(); - - scope_native_parts!( - req, - server_addr, - client_addr, - path, - query_string, - version, - server, - client - ); - Python::with_gil(|py| { - let scope = build_scope_ws(py, &req, version, server, client, &scheme, &path, query_string).unwrap(); - let watcher = Py::new(py, CallbackWatcherWebsocket::new(py, protocol, scope)).unwrap(); - cb.get().schedule(py, watcher.as_any()); + cb.get().schedule( + py, + Py::new( + py, + CallbackWatcherWebsocket::new( + py, + protocol, + build_scope_ws(py, &req, version, server, client, &scheme, &path, query_string).unwrap(), + ), + ) + .unwrap(), + ); }); rx diff --git a/src/blocking.rs b/src/blocking.rs index 0fc36e94..6926b686 100644 --- a/src/blocking.rs +++ b/src/blocking.rs @@ -1,50 +1,103 @@ use crossbeam_channel as channel; +use pyo3::prelude::*; use std::thread; pub(crate) struct BlockingTask { - inner: Box, + inner: Box, } impl BlockingTask { pub fn new(inner: T) -> BlockingTask where - T: FnOnce() + Send + 'static, + T: FnOnce(Python) + Send + 'static, { Self { inner: Box::new(inner) } } - pub fn run(self) { - (self.inner)(); + pub fn run(self, py: Python) { + (self.inner)(py); } } -#[derive(Clone)] pub(crate) struct BlockingRunner { queue: channel::Sender, + #[cfg(Py_GIL_DISABLED)] + sig: channel::Sender<()>, } impl BlockingRunner { + #[cfg(not(Py_GIL_DISABLED))] pub fn new() -> Self { let queue = blocking_thread(); Self { queue } } + #[cfg(Py_GIL_DISABLED)] + pub fn new() -> Self { + let (sigtx, sigrx) = channel::bounded(1); + let queue = blocking_thread(sigrx); + Self { queue, sig: sigtx } + } + pub fn run(&self, task: T) -> Result<(), channel::SendError> where - T: FnOnce() + Send + 'static, + T: FnOnce(Python) + Send + 'static, { self.queue.send(BlockingTask::new(task)) } + + #[cfg(Py_GIL_DISABLED)] + pub fn shutdown(&self) { + _ = self.sig.send(()); + } } -fn bloking_loop(queue: channel::Receiver) { +#[cfg(not(Py_GIL_DISABLED))] +fn blocking_loop(queue: channel::Receiver) { while let Ok(task) = queue.recv() { - task.run(); + Python::with_gil(|py| task.run(py)); } } +// NOTE: for some reason, on no-gil callback watchers are not GCd until following req. +// It's not clear atm wether this is an issue with pyo3, CPython itself, or smth +// different in terms of pointers due to multi-threaded environment. +// Thus, we need a signal to manually stop the loop and let the server shutdown. +// The following function would be the intended one if we hadn't the issue just described. +// +// #[cfg(Py_GIL_DISABLED)] +// fn blocking_loop(queue: channel::Receiver) { +// Python::with_gil(|py| { +// while let Ok(task) = queue.recv() { +// task.run(py); +// } +// }); +// } +#[cfg(Py_GIL_DISABLED)] +fn blocking_loop(queue: channel::Receiver, sig: channel::Receiver<()>) { + Python::with_gil(|py| loop { + crossbeam_channel::select! { + recv(queue) -> task => match task { + Ok(task) => task.run(py), + _ => break, + }, + recv(sig) -> _ => break + } + }); +} + +#[cfg(not(Py_GIL_DISABLED))] fn blocking_thread() -> channel::Sender { let (qtx, qrx) = channel::unbounded(); - thread::spawn(|| bloking_loop(qrx)); + thread::spawn(|| blocking_loop(qrx)); + + qtx +} + +#[cfg(Py_GIL_DISABLED)] +fn blocking_thread(sig: channel::Receiver<()>) -> channel::Sender { + let (qtx, qrx) = channel::unbounded(); + thread::spawn(|| blocking_loop(qrx, sig)); + qtx } diff --git a/src/callbacks.rs b/src/callbacks.rs index 59c0524b..a3366007 100644 --- a/src/callbacks.rs +++ b/src/callbacks.rs @@ -31,13 +31,15 @@ pub(crate) struct CallbackScheduler { #[cfg(not(PyPy))] impl CallbackScheduler { #[inline] - pub(crate) fn schedule(&self, _py: Python, watcher: &PyObject) { + pub(crate) fn schedule(&self, py: Python, watcher: Py) { let cbarg = watcher.as_ptr(); let sched = self.schedule_fn.get().unwrap().as_ptr(); unsafe { pyo3::ffi::PyObject_CallOneArg(sched, cbarg); } + + watcher.drop_ref(py); } #[inline] @@ -130,13 +132,15 @@ impl CallbackScheduler { #[cfg(PyPy)] impl CallbackScheduler { #[inline] - pub(crate) fn schedule(&self, py: Python, watcher: &PyObject) { + pub(crate) fn schedule(&self, py: Python, watcher: Py) { let cbarg = (watcher,).into_pyobject(py).unwrap().into_ptr(); let sched = self.schedule_fn.get().unwrap().as_ptr(); unsafe { pyo3::ffi::PyObject_CallObject(sched, cbarg); } + + watcher.drop_ref(py); } #[inline] @@ -508,8 +512,9 @@ impl PyIterAwaitable { } #[inline] - pub(crate) fn set_result(&self, py: Python, result: FutureResultToPy) { - let _ = self.result.set(result.into_pyobject(py).map(Bound::unbind)); + pub(crate) fn set_result(pyself: Py, py: Python, result: FutureResultToPy) { + _ = pyself.get().result.set(result.into_pyobject(py).map(Bound::unbind)); + pyself.drop_ref(py); } } @@ -583,18 +588,22 @@ impl PyFutureAwaitable { ) .is_err() { + pyself.drop_ref(py); return; } - let ack = rself.ack.read().unwrap(); - if let Some((cb, ctx)) = &*ack { - let _ = rself.event_loop.clone_ref(py).call_method( - py, - pyo3::intern!(py, "call_soon_threadsafe"), - (cb, pyself.clone_ref(py)), - Some(ctx.bind(py)), - ); + { + let ack = rself.ack.read().unwrap(); + if let Some((cb, ctx)) = &*ack { + _ = rself.event_loop.clone_ref(py).call_method( + py, + pyo3::intern!(py, "call_soon_threadsafe"), + (cb, pyself.clone_ref(py)), + Some(ctx.bind(py)), + ); + } } + pyself.drop_ref(py); } } diff --git a/src/conversion.rs b/src/conversion.rs index 1f5c32ff..aceed557 100644 --- a/src/conversion.rs +++ b/src/conversion.rs @@ -31,8 +31,8 @@ pub(crate) enum FutureResultToPy { Bytes(hyper::body::Bytes), ASGIMessage(crate::asgi::types::ASGIMessageType), ASGIWSMessage(tokio_tungstenite::tungstenite::Message), + RSGIWSAccept(crate::rsgi::io::RSGIWebsocketTransport), RSGIWSMessage(tokio_tungstenite::tungstenite::Message), - Py(PyObject), } impl<'p> IntoPyObject<'p> for FutureResultToPy { @@ -47,8 +47,8 @@ impl<'p> IntoPyObject<'p> for FutureResultToPy { Self::Bytes(inner) => inner.into_pyobject(py), Self::ASGIMessage(message) => crate::asgi::conversion::message_into_py(py, message), Self::ASGIWSMessage(message) => crate::asgi::conversion::ws_message_into_py(py, message), + Self::RSGIWSAccept(obj) => obj.into_bound_py_any(py), Self::RSGIWSMessage(message) => crate::rsgi::conversion::ws_message_into_py(py, message), - Self::Py(obj) => Ok(obj.into_bound(py)), } } } diff --git a/src/rsgi/callbacks.rs b/src/rsgi/callbacks.rs index e5998005..53e2be58 100644 --- a/src/rsgi/callbacks.rs +++ b/src/rsgi/callbacks.rs @@ -8,7 +8,7 @@ use super::{ }; use crate::{ callbacks::ArcCBScheduler, - runtime::RuntimeRef, + runtime::{Runtime, RuntimeRef}, utils::log_application_callable_exception, ws::{HyperWebsocket, UpgradeData}, }; @@ -28,9 +28,9 @@ macro_rules! callback_impl_done_ws { } macro_rules! callback_impl_done_err { - ($self:expr, $err:expr) => { + ($self:expr, $py:expr, $err:expr) => { $self.done(); - log_application_callable_exception($err); + log_application_callable_exception($py, $err); }; } @@ -65,8 +65,8 @@ impl CallbackWatcherHTTP { callback_impl_done_http!(self); } - fn err(&self, err: Bound) { - callback_impl_done_err!(self, &PyErr::from_value(err)); + fn err(&self, py: Python, err: Bound) { + callback_impl_done_err!(self, py, &PyErr::from_value(err)); } fn taskref(&self, py: Python, task: PyObject) { @@ -99,8 +99,8 @@ impl CallbackWatcherWebsocket { callback_impl_done_ws!(self); } - fn err(&self, err: Bound) { - callback_impl_done_err!(self, &PyErr::from_value(err)); + fn err(&self, py: Python, err: Bound) { + callback_impl_done_err!(self, py, &PyErr::from_value(err)); } fn taskref(&self, py: Python, task: PyObject) { @@ -108,7 +108,6 @@ impl CallbackWatcherWebsocket { } } -#[cfg(not(Py_GIL_DISABLED))] #[inline] pub(crate) fn call_http( cb: ArcCBScheduler, @@ -116,63 +115,17 @@ pub(crate) fn call_http( body: hyper::body::Incoming, scope: HTTPScope, ) -> oneshot::Receiver { - let brt = rt.innerb.clone(); let (tx, rx) = oneshot::channel(); - let protocol = HTTPProtocol::new(rt, tx, body); + let protocol = HTTPProtocol::new(rt.clone(), tx, body); - let _ = brt.run(move || { - Python::with_gil(|py| { - let watcher = Py::new(py, CallbackWatcherHTTP::new(py, protocol, scope)).unwrap(); - cb.get().schedule(py, watcher.as_any()); - }); + rt.spawn_blocking(move |py| { + cb.get() + .schedule(py, Py::new(py, CallbackWatcherHTTP::new(py, protocol, scope)).unwrap()); }); rx } -#[cfg(Py_GIL_DISABLED)] -#[inline] -pub(crate) fn call_http( - cb: ArcCBScheduler, - rt: RuntimeRef, - body: hyper::body::Incoming, - scope: HTTPScope, -) -> oneshot::Receiver { - let (tx, rx) = oneshot::channel(); - let protocol = HTTPProtocol::new(rt, tx, body); - - Python::with_gil(|py| { - let watcher = Py::new(py, CallbackWatcherHTTP::new(py, protocol, scope)).unwrap(); - cb.get().schedule(py, watcher.as_any()); - }); - - rx -} - -#[cfg(not(Py_GIL_DISABLED))] -#[inline] -pub(crate) fn call_ws( - cb: ArcCBScheduler, - rt: RuntimeRef, - ws: HyperWebsocket, - upgrade: UpgradeData, - scope: WebsocketScope, -) -> oneshot::Receiver { - let brt = rt.innerb.clone(); - let (tx, rx) = oneshot::channel(); - let protocol = WebsocketProtocol::new(rt, tx, ws, upgrade); - - let _ = brt.run(move || { - Python::with_gil(|py| { - let watcher = Py::new(py, CallbackWatcherWebsocket::new(py, protocol, scope)).unwrap(); - cb.get().schedule(py, watcher.as_any()); - }); - }); - - rx -} - -#[cfg(Py_GIL_DISABLED)] #[inline] pub(crate) fn call_ws( cb: ArcCBScheduler, @@ -182,11 +135,13 @@ pub(crate) fn call_ws( scope: WebsocketScope, ) -> oneshot::Receiver { let (tx, rx) = oneshot::channel(); - let protocol = WebsocketProtocol::new(rt, tx, ws, upgrade); + let protocol = WebsocketProtocol::new(rt.clone(), tx, ws, upgrade); - Python::with_gil(|py| { - let watcher = Py::new(py, CallbackWatcherWebsocket::new(py, protocol, scope)).unwrap(); - cb.get().schedule(py, watcher.as_any()); + rt.spawn_blocking(move |py| { + cb.get().schedule( + py, + Py::new(py, CallbackWatcherWebsocket::new(py, protocol, scope)).unwrap(), + ); }); rx diff --git a/src/rsgi/http.rs b/src/rsgi/http.rs index ee3278ba..400cb755 100644 --- a/src/rsgi/http.rs +++ b/src/rsgi/http.rs @@ -1,3 +1,4 @@ +use futures::sink::SinkExt; use http_body_util::BodyExt; use hyper::{header::SERVER as HK_SERVER, http::response::Builder as ResponseBuilder, StatusCode}; use std::net::SocketAddr; @@ -99,7 +100,7 @@ macro_rules! handle_request_with_ws { let tx_ref = restx.clone(); match $handler_ws(callback, rt, ws, UpgradeData::new(res, restx), scope).await { - Ok((status, consumed, handle)) => match (consumed, handle) { + Ok((status, consumed, stream)) => match (consumed, stream) { (false, _) => { let _ = tx_ref .send( @@ -111,8 +112,8 @@ macro_rules! handle_request_with_ws { ) .await; } - (true, Some(handle)) => { - let _ = handle.await; + (true, Some(mut stream)) => { + let _ = stream.close().await; } _ => {} }, diff --git a/src/rsgi/io.rs b/src/rsgi/io.rs index 3728d029..65321d7c 100644 --- a/src/rsgi/io.rs +++ b/src/rsgi/io.rs @@ -4,7 +4,7 @@ use hyper::body; use pyo3::{prelude::*, pybacked::PyBackedStr}; use std::{ borrow::Cow, - sync::{atomic, Arc, Mutex, RwLock}, + sync::{Arc, Mutex, RwLock}, }; use tokio::sync::{mpsc, oneshot, Mutex as AsyncMutex}; use tokio_tungstenite::tungstenite::Message; @@ -15,11 +15,11 @@ use super::{ }; use crate::{ conversion::FutureResultToPy, - runtime::{future_into_py_futlike, Runtime, RuntimeRef}, - ws::{HyperWebsocket, UpgradeData, WSRxStream, WSStream, WSTxStream}, + runtime::{future_into_py_futlike, RuntimeRef}, + ws::{HyperWebsocket, UpgradeData, WSRxStream, WSTxStream}, }; -pub(crate) type WebsocketDetachedTransport = (i32, bool, Option>); +pub(crate) type WebsocketDetachedTransport = (i32, bool, Option); #[pyclass(frozen, module = "granian._granian")] pub(crate) struct RSGIHTTPStreamTransport { @@ -183,38 +183,18 @@ impl RSGIHTTPProtocol { #[pyclass(frozen, module = "granian._granian")] pub(crate) struct RSGIWebsocketTransport { rt: RuntimeRef, - tx: Arc>, + tx: Arc>>, rx: Arc>, - closed: atomic::AtomicBool, } impl RSGIWebsocketTransport { - pub fn new(rt: RuntimeRef, transport: WSStream) -> Self { - let (tx, rx) = transport.split(); + pub fn new(rt: RuntimeRef, tx: Arc>>, rx: WSRxStream) -> Self { Self { rt, - tx: Arc::new(AsyncMutex::new(tx)), + tx, rx: Arc::new(AsyncMutex::new(rx)), - closed: false.into(), } } - - pub fn close(&self) -> Option> { - if self.closed.load(atomic::Ordering::Relaxed) { - return None; - } - self.closed.store(true, atomic::Ordering::Relaxed); - - let tx = self.tx.clone(); - let handle = self.rt.spawn(async move { - if let Ok(mut tx) = tx.try_lock() { - if let Err(err) = tx.close().await { - log::info!("Failed to close websocket with error {:?}", err); - } - } - }); - Some(handle) - } } #[pymethods] @@ -241,11 +221,13 @@ impl RSGIWebsocketTransport { let bdata: Box<[u8]> = data.into(); future_into_py_futlike(self.rt.clone(), py, async move { - if let Ok(mut stream) = transport.try_lock() { - return match stream.send(bdata[..].into()).await { - Ok(()) => FutureResultToPy::None, - _ => FutureResultToPy::Err(error_stream!()), - }; + if let Ok(mut guard) = transport.try_lock() { + if let Some(stream) = &mut *guard { + return match stream.send(bdata[..].into()).await { + Ok(()) => FutureResultToPy::None, + _ => FutureResultToPy::Err(error_stream!()), + }; + } } FutureResultToPy::Err(error_proto!()) }) @@ -255,11 +237,13 @@ impl RSGIWebsocketTransport { let transport = self.tx.clone(); future_into_py_futlike(self.rt.clone(), py, async move { - if let Ok(mut stream) = transport.try_lock() { - return match stream.send(data.into()).await { - Ok(()) => FutureResultToPy::None, - _ => FutureResultToPy::Err(error_stream!()), - }; + if let Ok(mut guard) = transport.try_lock() { + if let Some(stream) = &mut *guard { + return match stream.send(data.into()).await { + Ok(()) => FutureResultToPy::None, + _ => FutureResultToPy::Err(error_stream!()), + }; + } } FutureResultToPy::Err(error_proto!()) }) @@ -272,7 +256,7 @@ pub(crate) struct RSGIWebsocketProtocol { tx: Mutex>>, websocket: Arc>, upgrade: RwLock>, - transport: Arc>>>, + transport: Arc>>, } impl RSGIWebsocketProtocol { @@ -287,7 +271,7 @@ impl RSGIWebsocketProtocol { tx: Mutex::new(Some(tx)), websocket: Arc::new(AsyncMutex::new(websocket)), upgrade: RwLock::new(Some(upgrade)), - transport: Arc::new(Mutex::new(None)), + transport: Arc::new(AsyncMutex::new(None)), } } @@ -304,7 +288,7 @@ impl RSGIWebsocketProtocol { let mut handle = None; if let Ok(mut transport) = self.transport.try_lock() { if let Some(transport) = transport.take() { - handle = transport.get().close(); + handle = Some(transport); } } @@ -322,12 +306,16 @@ impl RSGIWebsocketProtocol { match upgrade.send(None).await { Ok(()) => match (&mut *ws).await { Ok(stream) => { - let mut trx = itransport.lock().unwrap(); - Python::with_gil(|py| { - let pytransport = Py::new(py, RSGIWebsocketTransport::new(rth, stream)).unwrap(); - *trx = Some(pytransport.clone_ref(py)); - FutureResultToPy::Py(pytransport.into_any()) - }) + let (stx, srx) = stream.split(); + { + let mut guard = itransport.lock().await; + *guard = Some(stx); + } + FutureResultToPy::RSGIWSAccept(RSGIWebsocketTransport::new( + rth.clone(), + itransport.clone(), + srx, + )) } _ => FutureResultToPy::Err(error_proto!()), }, diff --git a/src/rsgi/mod.rs b/src/rsgi/mod.rs index 4e8e24af..94a16a8f 100644 --- a/src/rsgi/mod.rs +++ b/src/rsgi/mod.rs @@ -4,7 +4,7 @@ mod callbacks; pub(crate) mod conversion; mod errors; mod http; -mod io; +pub(crate) mod io; pub(crate) mod serve; mod types; diff --git a/src/runtime.rs b/src/runtime.rs index e822b267..d7e76e55 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -34,8 +34,9 @@ pub trait Runtime: Send + 'static { where F: Future + Send + 'static; - #[cfg(not(Py_GIL_DISABLED))] - fn blocking(&self) -> BlockingRunner; + fn spawn_blocking(&self, task: F) + where + F: FnOnce(Python) + Send + 'static; } pub trait ContextExt: Runtime { @@ -43,42 +44,49 @@ pub trait ContextExt: Runtime { } pub(crate) struct RuntimeWrapper { - rt: tokio::runtime::Runtime, - br: BlockingRunner, + pub inner: tokio::runtime::Runtime, + br: Arc, pr: Arc, } impl RuntimeWrapper { pub fn new(blocking_threads: usize, py_loop: Arc) -> Self { Self { - rt: default_runtime(blocking_threads), - br: BlockingRunner::new(), + inner: default_runtime(blocking_threads), + br: BlockingRunner::new().into(), pr: py_loop, } } pub fn with_runtime(rt: tokio::runtime::Runtime, py_loop: Arc) -> Self { Self { - rt, - br: BlockingRunner::new(), + inner: rt, + br: BlockingRunner::new().into(), pr: py_loop, } } pub fn handler(&self) -> RuntimeRef { - RuntimeRef::new(self.rt.handle().clone(), self.br.clone(), self.pr.clone()) + RuntimeRef::new(self.inner.handle().clone(), self.br.clone(), self.pr.clone()) + } +} + +#[cfg(Py_GIL_DISABLED)] +impl Drop for RuntimeWrapper { + fn drop(&mut self) { + self.br.shutdown(); } } #[derive(Clone)] pub struct RuntimeRef { pub inner: tokio::runtime::Handle, - pub innerb: BlockingRunner, + innerb: Arc, innerp: Arc, } impl RuntimeRef { - pub fn new(rt: tokio::runtime::Handle, br: BlockingRunner, pyloop: Arc) -> Self { + pub fn new(rt: tokio::runtime::Handle, br: Arc, pyloop: Arc) -> Self { Self { inner: rt, innerb: br, @@ -104,9 +112,11 @@ impl Runtime for RuntimeRef { self.inner.spawn(fut) } - #[cfg(not(Py_GIL_DISABLED))] - fn blocking(&self) -> BlockingRunner { - self.innerb.clone() + fn spawn_blocking(&self, task: F) + where + F: FnOnce(Python) + Send + 'static, + { + _ = self.innerb.run(task); } } @@ -146,31 +156,6 @@ pub(crate) fn init_runtime_st(blocking_threads: usize, py_loop: Arc) - // It consumes more cpu-cycles than `future_into_py_futlike`, // but for "quick" operations it's something like 12% faster. #[allow(unused_must_use)] -#[cfg(not(Py_GIL_DISABLED))] -pub(crate) fn future_into_py_iter(rt: R, py: Python, fut: F) -> PyResult> -where - R: Runtime + ContextExt + Clone, - F: Future + Send + 'static, -{ - let aw = Py::new(py, PyIterAwaitable::new())?; - let py_fut = aw.clone_ref(py); - let rb = rt.blocking(); - - rt.spawn(async move { - let result = fut.await; - let _ = rb.run(move || { - Python::with_gil(|py| { - aw.get().set_result(py, result); - drop(aw); - }); - }); - }); - - Ok(py_fut.into_any().into_bound(py)) -} - -#[allow(unused_must_use)] -#[cfg(Py_GIL_DISABLED)] pub(crate) fn future_into_py_iter(rt: R, py: Python, fut: F) -> PyResult> where R: Runtime + ContextExt + Clone, @@ -178,13 +163,11 @@ where { let aw = Py::new(py, PyIterAwaitable::new())?; let py_fut = aw.clone_ref(py); + let rth = rt.clone(); rt.spawn(async move { let result = fut.await; - Python::with_gil(|py| { - aw.get().set_result(py, result); - drop(aw); - }); + rth.spawn_blocking(move |py| PyIterAwaitable::set_result(aw, py, result)); }); Ok(py_fut.into_any().into_bound(py)) @@ -196,33 +179,7 @@ where // It won't consume more cpu-cycles than standard asyncio implementation, // and for "long" operations it's something like 6% faster than `future_into_py_iter`. #[allow(unused_must_use)] -#[cfg(all(unix, not(Py_GIL_DISABLED)))] -pub(crate) fn future_into_py_futlike(rt: R, py: Python, fut: F) -> PyResult> -where - R: Runtime + ContextExt + Clone, - F: Future + Send + 'static, -{ - let event_loop = rt.py_event_loop(py); - let (aw, cancel_tx) = PyFutureAwaitable::new(event_loop).to_spawn(py)?; - let py_fut = aw.clone_ref(py); - let rb = rt.blocking(); - - rt.spawn(async move { - tokio::select! { - result = fut => { - let _ = rb.run(move || Python::with_gil(|py| PyFutureAwaitable::set_result(aw, py, result))); - }, - () = cancel_tx.notified() => { - let _ = rb.run(move || Python::with_gil(|_| drop(aw))); - } - } - }); - - Ok(py_fut.into_any().into_bound(py)) -} - -#[allow(unused_must_use)] -#[cfg(all(unix, Py_GIL_DISABLED))] +#[cfg(unix)] pub(crate) fn future_into_py_futlike(rt: R, py: Python, fut: F) -> PyResult> where R: Runtime + ContextExt + Clone, @@ -231,15 +188,12 @@ where let event_loop = rt.py_event_loop(py); let (aw, cancel_tx) = PyFutureAwaitable::new(event_loop).to_spawn(py)?; let py_fut = aw.clone_ref(py); + let rth = rt.clone(); rt.spawn(async move { tokio::select! { - result = fut => { - Python::with_gil(|py| PyFutureAwaitable::set_result(aw, py, result)); - }, - () = cancel_tx.notified() => { - Python::with_gil(|_| drop(aw)); - } + result = fut => rth.spawn_blocking(move |py| PyFutureAwaitable::set_result(aw, py, result)), + () = cancel_tx.notified() => rth.spawn_blocking(move |py| aw.drop_ref(py)), } }); @@ -247,59 +201,7 @@ where } #[allow(unused_must_use)] -#[cfg(all(windows, not(Py_GIL_DISABLED)))] -pub(crate) fn future_into_py_futlike(rt: R, py: Python, fut: F) -> PyResult> -where - R: Runtime + ContextExt + Clone, - F: Future + Send + 'static, -{ - let event_loop = rt.py_event_loop(py); - let event_loop_ref = event_loop.clone_ref(py); - let cancel_tx = Arc::new(tokio::sync::Notify::new()); - let rb = rt.blocking(); - - let py_fut = event_loop.call_method0(py, pyo3::intern!(py, "create_future"))?; - py_fut.call_method1( - py, - pyo3::intern!(py, "add_done_callback"), - (PyFutureDoneCallback { - cancel_tx: cancel_tx.clone(), - },), - )?; - let fut_ref = py_fut.clone_ref(py); - - rt.spawn(async move { - tokio::select! { - result = fut => { - let _ = rb.run(move || { - Python::with_gil(|py| { - let pyres = result.into_pyobject(py).map(Bound::unbind); - let (cb, value) = match pyres { - Ok(val) => (fut_ref.getattr(py, pyo3::intern!(py, "set_result")).unwrap(), val), - Err(err) => (fut_ref.getattr(py, pyo3::intern!(py, "set_exception")).unwrap(), err.into_py_any(py).unwrap()) - }; - let _ = event_loop_ref.call_method1(py, pyo3::intern!(py, "call_soon_threadsafe"), (PyFutureResultSetter, cb, value)); - drop(fut_ref); - drop(event_loop_ref); - }); - }); - }, - () = cancel_tx.notified() => { - let _ = rb.run(move || { - Python::with_gil(|_| { - drop(fut_ref); - drop(event_loop_ref); - }); - }); - } - } - }); - - Ok(py_fut.into_bound(py)) -} - -#[allow(unused_must_use)] -#[cfg(all(windows, Py_GIL_DISABLED))] +#[cfg(windows)] pub(crate) fn future_into_py_futlike(rt: R, py: Python, fut: F) -> PyResult> where R: Runtime + ContextExt + Clone, @@ -308,6 +210,7 @@ where let event_loop = rt.py_event_loop(py); let event_loop_ref = event_loop.clone_ref(py); let cancel_tx = Arc::new(tokio::sync::Notify::new()); + let rth = rt.clone(); let py_fut = event_loop.call_method0(py, pyo3::intern!(py, "create_future"))?; py_fut.call_method1( @@ -322,21 +225,21 @@ where rt.spawn(async move { tokio::select! { result = fut => { - Python::with_gil(|py| { + rth.spawn_blocking(move |py| { let pyres = result.into_pyobject(py).map(Bound::unbind); let (cb, value) = match pyres { Ok(val) => (fut_ref.getattr(py, pyo3::intern!(py, "set_result")).unwrap(), val), Err(err) => (fut_ref.getattr(py, pyo3::intern!(py, "set_exception")).unwrap(), err.into_py_any(py).unwrap()) }; let _ = event_loop_ref.call_method1(py, pyo3::intern!(py, "call_soon_threadsafe"), (PyFutureResultSetter, cb, value)); - drop(fut_ref); - drop(event_loop_ref); + fut_ref.drop_ref(py); + event_loop_ref.drop_ref(py); }); }, () = cancel_tx.notified() => { - Python::with_gil(|_| { - drop(fut_ref); - drop(event_loop_ref); + rth.spawn_blocking(move |py| { + fut_ref.drop_ref(py); + event_loop_ref.drop_ref(py); }); } } @@ -352,9 +255,8 @@ pub(crate) fn empty_future_into_py(py: Python) -> PyResult> { } #[allow(unused_must_use)] -pub(crate) fn run_until_complete(rt: R, event_loop: Bound, fut: F) -> PyResult<()> +pub(crate) fn run_until_complete(rt: RuntimeWrapper, event_loop: Bound, fut: F) -> PyResult<()> where - R: Runtime + ContextExt + Clone, F: Future> + Send + 'static, { let result_tx = Arc::new(Mutex::new(None)); @@ -364,7 +266,7 @@ where let loop_tx = event_loop.clone().unbind(); let future_tx = py_fut.clone().unbind(); - rt.spawn(async move { + rt.inner.spawn(async move { let _ = fut.await; if let Ok(mut result) = result_tx.lock() { *result = Some(()); @@ -375,8 +277,8 @@ where Python::with_gil(move |py| { let res_method = future_tx.getattr(py, "set_result").unwrap(); let _ = loop_tx.call_method(py, "call_soon_threadsafe", (res_method, py.None()), None); - drop(future_tx); - drop(loop_tx); + future_tx.drop_ref(py); + loop_tx.drop_ref(py); }); }); @@ -390,5 +292,5 @@ pub(crate) fn block_on_local(rt: &RuntimeWrapper, local: LocalSet, fut: F) where F: Future + 'static, { - local.block_on(&rt.rt, fut); + local.block_on(&rt.inner, fut); } diff --git a/src/utils.rs b/src/utils.rs index 4ad10e8a..0fee64f8 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,4 +1,4 @@ -use pyo3::types::PyTracebackMethods; +use pyo3::{prelude::*, types::PyTracebackMethods}; pub(crate) fn header_contains_value( headers: &hyper::HeaderMap, @@ -41,13 +41,11 @@ fn trim_end(data: &[u8]) -> &[u8] { } #[inline] -pub(crate) fn log_application_callable_exception(err: &pyo3::PyErr) { - let tb = pyo3::Python::with_gil(|py| { - let tb = match err.traceback(py).map(|t| t.format()) { - Some(Ok(tb)) => tb, - _ => String::new(), - }; - format!("{tb}{err}") - }); - log::error!("Application callable raised an exception\n{tb}"); +pub(crate) fn log_application_callable_exception(py: Python, err: &pyo3::PyErr) { + let tb = match err.traceback(py).map(|t| t.format()) { + Some(Ok(tb)) => tb, + _ => String::new(), + }; + let errs = format!("{tb}{err}"); + log::error!("Application callable raised an exception\n{errs}"); } diff --git a/src/workers.rs b/src/workers.rs index 09e71dbe..42a265e3 100644 --- a/src/workers.rs +++ b/src/workers.rs @@ -610,7 +610,7 @@ macro_rules! serve_rth { let rth = rt.handler(); let mut srx = signal.get().rx.lock().unwrap().take().unwrap(); - let main_loop = crate::runtime::run_until_complete(rt.handler(), event_loop.clone(), async move { + let main_loop = crate::runtime::run_until_complete(rt, event_loop.clone(), async move { crate::workers::loop_match!( http_mode, http_upgrades, @@ -633,13 +633,10 @@ macro_rules! serve_rth { Ok(()) }); - match main_loop { - Ok(()) => {} - Err(err) => { - log::error!("{}", err); - std::process::exit(1); - } - }; + if let Err(err) = main_loop { + log::error!("{}", err); + std::process::exit(1); + } } }; } @@ -674,7 +671,7 @@ macro_rules! serve_rth_ssl { let rth = rt.handler(); let mut srx = signal.get().rx.lock().unwrap().take().unwrap(); - let main_loop = crate::runtime::run_until_complete(rt.handler(), event_loop.clone(), async move { + let main_loop = crate::runtime::run_until_complete(rt, event_loop.clone(), async move { crate::workers::loop_match_tls!( http_mode, http_upgrades, @@ -698,13 +695,10 @@ macro_rules! serve_rth_ssl { Ok(()) }); - match main_loop { - Ok(()) => {} - Err(err) => { - log::error!("{}", err); - std::process::exit(1); - } - }; + if let Err(err) = main_loop { + log::error!("{}", err); + std::process::exit(1); + } } }; } @@ -751,8 +745,6 @@ macro_rules! serve_wth_inner { ); log::info!("Stopping worker-{} runtime-{}", $wid, thread_id + 1); - - Python::with_gil(|_| drop(callback_wrapper)); }); Python::with_gil(|_| drop(rt)); @@ -780,7 +772,7 @@ macro_rules! serve_wth { let rtm = crate::runtime::init_runtime_mt(1, 1, std::sync::Arc::new(event_loop.clone().unbind())); let mut pyrx = signal.get().rx.lock().unwrap().take().unwrap(); - let main_loop = crate::runtime::run_until_complete(rtm.handler(), event_loop.clone(), async move { + let main_loop = crate::runtime::run_until_complete(rtm, event_loop.clone(), async move { let _ = pyrx.changed().await; stx.send(true).unwrap(); log::info!("Stopping worker-{}", worker_id); @@ -790,13 +782,10 @@ macro_rules! serve_wth { Ok(()) }); - match main_loop { - Ok(()) => {} - Err(err) => { - log::error!("{}", err); - std::process::exit(1); - } - }; + if let Err(err) = main_loop { + log::error!("{}", err); + std::process::exit(1); + } } }; } @@ -870,7 +859,7 @@ macro_rules! serve_wth_ssl { let rtm = crate::runtime::init_runtime_mt(1, 1, std::sync::Arc::new(event_loop.clone().unbind())); let mut pyrx = signal.get().rx.lock().unwrap().take().unwrap(); - let main_loop = crate::runtime::run_until_complete(rtm.handler(), event_loop.clone(), async move { + let main_loop = crate::runtime::run_until_complete(rtm, event_loop.clone(), async move { let _ = pyrx.changed().await; stx.send(true).unwrap(); log::info!("Stopping worker-{}", worker_id); @@ -880,13 +869,10 @@ macro_rules! serve_wth_ssl { Ok(()) }); - match main_loop { - Ok(()) => {} - Err(err) => { - log::error!("{}", err); - std::process::exit(1); - } - }; + if let Err(err) = main_loop { + log::error!("{}", err); + std::process::exit(1); + } } }; } diff --git a/src/wsgi/callbacks.rs b/src/wsgi/callbacks.rs index 5946292a..254f24b8 100644 --- a/src/wsgi/callbacks.rs +++ b/src/wsgi/callbacks.rs @@ -95,7 +95,7 @@ fn run_callback( environ.update(headers.into_py_dict(py).unwrap().as_mapping())?; if let Err(err) = callback.call1(py, (proto.clone_ref(py), environ)) { - log_application_callable_exception(&err); + log_application_callable_exception(py, &err); if let Some(tx) = proto.get().tx() { let _ = tx.send((500, HeaderMap::new(), empty_body())); } diff --git a/src/wsgi/io.rs b/src/wsgi/io.rs index bccc9fa8..0affe63a 100644 --- a/src/wsgi/io.rs +++ b/src/wsgi/io.rs @@ -82,7 +82,7 @@ impl WSGIProtocol { }, Err(err) => { if !err.is_instance_of::(py) { - log_application_callable_exception(&err); + log_application_callable_exception(py, &err); } let _ = body.call_method0(pyo3::intern!(py, "close")); closed = true; diff --git a/src/wsgi/serve.rs b/src/wsgi/serve.rs index 14ae1721..4c3efc21 100644 --- a/src/wsgi/serve.rs +++ b/src/wsgi/serve.rs @@ -40,7 +40,7 @@ impl WSGIWorker { let rth = rt.handler(); let (stx, mut srx) = tokio::sync::watch::channel(false); - let main_loop = rt.handler().inner.spawn(async move { + let main_loop = rt.inner.spawn(async move { crate::workers::loop_match!( http_mode, http_upgrades, @@ -145,7 +145,7 @@ impl WSGIWorker { let rth = rt.handler(); let (stx, mut srx) = tokio::sync::watch::channel(false); - rt.handler().inner.spawn(async move { + rt.inner.spawn(async move { crate::workers::loop_match_tls!( http_mode, http_upgrades,