From 9589a868037e669740d9202a756d397e208fc502 Mon Sep 17 00:00:00 2001 From: "Nathan (Blaise) Bruer" Date: Fri, 26 Apr 2024 16:30:19 -0500 Subject: [PATCH] All tokio::spawn and related functions must use nativelink's version We now enforce all locations in our code base to use one of the `nativelink-util::task` when trying to do a task that might require an operation that changes threads. --- .bazelrc | 3 +- BUILD.bazel | 1 + clippy.toml | 16 +++ nativelink-macro/src/lib.rs | 6 +- .../src/cache_lookup_scheduler.rs | 3 +- .../src/default_scheduler_factory.rs | 3 +- nativelink-scheduler/src/grpc_scheduler.rs | 9 +- nativelink-scheduler/src/simple_scheduler.rs | 59 +++++---- nativelink-service/src/bytestream_server.rs | 11 +- nativelink-service/src/health_server.rs | 57 +++++---- nativelink-service/src/worker_api_server.rs | 13 +- .../tests/bytestream_server_test.rs | 54 ++++---- nativelink-store/src/compression_store.rs | 19 +-- nativelink-store/src/filesystem_store.rs | 116 ++++++++--------- nativelink-store/src/redis_store.rs | 13 +- .../tests/compression_store_test.rs | 7 +- .../tests/filesystem_store_test.rs | 65 ++++++---- nativelink-store/tests/memory_store_test.rs | 7 +- nativelink-store/tests/s3_store_test.rs | 9 +- nativelink-store/tests/verify_store_test.rs | 14 ++- nativelink-util/BUILD.bazel | 1 + nativelink-util/src/common.rs | 30 ----- nativelink-util/src/connection_manager.rs | 3 +- nativelink-util/src/digest_hasher.rs | 8 +- nativelink-util/src/fs.rs | 4 +- nativelink-util/src/lib.rs | 1 + nativelink-util/src/task.rs | 119 ++++++++++++++++++ nativelink-worker/src/local_worker.rs | 6 +- .../src/running_actions_manager.rs | 51 ++++---- .../tests/utils/local_worker_test_utils.rs | 5 +- src/bin/nativelink.rs | 76 +++++------ 31 files changed, 465 insertions(+), 324 deletions(-) create mode 100644 clippy.toml create mode 100644 nativelink-util/src/task.rs diff --git a/.bazelrc b/.bazelrc index f93b59e43..1ca20a393 100644 --- a/.bazelrc +++ b/.bazelrc @@ -44,7 +44,8 @@ build --aspects=@rules_rust//rust:defs.bzl%rustfmt_aspect build --aspects=@rules_rust//rust:defs.bzl%rust_clippy_aspect # TODO(aaronmondal): Extend these flags until we can run with clippy::pedantic. -build --@rules_rust//:clippy_flags=-D,clippy::uninlined_format_args +build --@rules_rust//:clippy_flags=-Dwarnings,-Dclippy::uninlined_format_args +build --@rules_rust//:clippy.toml=//:clippy.toml test --@rules_rust//:rustfmt.toml=//:.rustfmt.toml diff --git a/BUILD.bazel b/BUILD.bazel index 619f65ae5..b3af22cc3 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -3,6 +3,7 @@ load("@rules_rust//rust:defs.bzl", "rust_binary") exports_files( [ ".rustfmt.toml", + "clippy.toml", ], visibility = ["//visibility:public"], ) diff --git a/clippy.toml b/clippy.toml new file mode 100644 index 000000000..263d33777 --- /dev/null +++ b/clippy.toml @@ -0,0 +1,16 @@ +disallowed-methods = [ + { path = "tokio::spawn", reason = "use `nativelink-util::task::spawn` or `nativelink-util::task::background_spawn` instead" }, + { path = "tokio::task::spawn", reason = "use `nativelink-util::task::spawn` or `nativelink-util::task::background_spawn` instead" }, + { path = "tokio::task::spawn_blocking", reason = "use `nativelink-util::task::spawn_blocking` instead" }, + { path = "tokio::task::block_in_place", reason = "use one of the `nativelink-util::task` functions instead" }, + { path = "tokio::task::spawn_local", reason = "use one of the `nativelink-util::task` functions instead" }, + { path = "tokio::runtime::Builder::new_current_thread", reason = "use one of the `nativelink-util::task` functions instead" }, + { path = "tokio::runtime::Builder::new_multi_thread", reason = "use one of the `nativelink-util::task` functions instead" }, + { path = "tokio::runtime::Builder::new_multi_thread_alt", reason = "use one of the `nativelink-util::task` functions instead" }, + { path = "tokio::runtime::Runtime::new", reason = "use one of the `nativelink-util::task` functions instead" }, + { path = "tokio::runtime::Runtime::spawn", reason = "use one of the `nativelink-util::task` functions instead" }, + { path = "tokio::runtime::Runtime::spawn_blocking", reason = "use one of the `nativelink-util::task` functions instead" }, + { path = "tokio::runtime::Runtime::block_on", reason = "use one of the `nativelink-util::task` functions instead" }, + { path = "std::thread::spawn", reason = "use one of the `nativelink-util::task` functions instead" }, + { path = "std::thread::Builder::new", reason = "use one of the `nativelink-util::task` functions instead" }, +] diff --git a/nativelink-macro/src/lib.rs b/nativelink-macro/src/lib.rs index 840d8e35b..a9527d868 100644 --- a/nativelink-macro/src/lib.rs +++ b/nativelink-macro/src/lib.rs @@ -31,9 +31,13 @@ pub fn nativelink_test(attr: TokenStream, item: TokenStream) -> TokenStream { let expanded = quote! { #(#fn_attr)* + #[allow(clippy::disallowed_methods)] #[tokio::test(#attr)] async fn #fn_name(#fn_inputs) #fn_output { - #fn_block + #[warn(clippy::disallowed_methods)] + { + #fn_block + } } }; diff --git a/nativelink-scheduler/src/cache_lookup_scheduler.rs b/nativelink-scheduler/src/cache_lookup_scheduler.rs index fc243616a..a88927396 100644 --- a/nativelink-scheduler/src/cache_lookup_scheduler.rs +++ b/nativelink-scheduler/src/cache_lookup_scheduler.rs @@ -27,6 +27,7 @@ use nativelink_store::grpc_store::GrpcStore; use nativelink_util::action_messages::{ ActionInfo, ActionInfoHashKey, ActionResult, ActionStage, ActionState, }; +use nativelink_util::background_spawn; use nativelink_util::common::DigestInfo; use nativelink_util::store_trait::Store; use parking_lot::{Mutex, MutexGuard}; @@ -158,7 +159,7 @@ impl ActionScheduler for CacheLookupScheduler { let ac_store = self.ac_store.clone(); let action_scheduler = self.action_scheduler.clone(); // We need this spawn because we are returning a stream and this spawn will populate the stream's data. - tokio::spawn(async move { + background_spawn!("cache_lookup_scheduler_add_action", async move { // If our spawn ever dies, we will remove the action from the cache_check_actions map. let _scope_guard = scope_guard; diff --git a/nativelink-scheduler/src/default_scheduler_factory.rs b/nativelink-scheduler/src/default_scheduler_factory.rs index 8ebb7233b..2dc94ddeb 100644 --- a/nativelink-scheduler/src/default_scheduler_factory.rs +++ b/nativelink-scheduler/src/default_scheduler_factory.rs @@ -19,6 +19,7 @@ use std::time::Duration; use nativelink_config::schedulers::SchedulerConfig; use nativelink_error::{Error, ResultExt}; use nativelink_store::store_manager::StoreManager; +use nativelink_util::background_spawn; use nativelink_util::metrics_utils::Registry; use tokio::time::interval; @@ -117,7 +118,7 @@ fn inner_scheduler_factory( fn start_cleanup_timer(action_scheduler: &Arc) { let weak_scheduler = Arc::downgrade(action_scheduler); - tokio::spawn(async move { + background_spawn!("default_scheduler_factory_cleanup_timer", async move { let mut ticker = interval(Duration::from_secs(1)); loop { ticker.tick().await; diff --git a/nativelink-scheduler/src/grpc_scheduler.rs b/nativelink-scheduler/src/grpc_scheduler.rs index 8f012dba3..fec4ac217 100644 --- a/nativelink-scheduler/src/grpc_scheduler.rs +++ b/nativelink-scheduler/src/grpc_scheduler.rs @@ -32,7 +32,7 @@ use nativelink_util::action_messages::{ }; use nativelink_util::connection_manager::ConnectionManager; use nativelink_util::retry::{Retrier, RetryResult}; -use nativelink_util::tls_utils; +use nativelink_util::{background_spawn, tls_utils}; use parking_lot::Mutex; use rand::rngs::OsRng; use rand::Rng; @@ -40,7 +40,7 @@ use tokio::select; use tokio::sync::watch; use tokio::time::sleep; use tonic::{Request, Streaming}; -use tracing::{error_span, event, Instrument, Level}; +use tracing::{event, Level}; use crate::action_scheduler::ActionScheduler; use crate::platform_property_manager::PlatformPropertyManager; @@ -119,7 +119,7 @@ impl GrpcScheduler { .err_tip(|| "Recieving response from upstream scheduler")? { let (tx, rx) = watch::channel(Arc::new(initial_response.try_into()?)); - tokio::spawn(async move { + background_spawn!("grpc_scheduler_stream_state", async move { loop { select!( _ = tx.closed() => { @@ -157,8 +157,7 @@ impl GrpcScheduler { } ) } - } - .instrument(error_span!("stream_state"))); + }); return Ok(rx); } Err(make_err!( diff --git a/nativelink-scheduler/src/simple_scheduler.rs b/nativelink-scheduler/src/simple_scheduler.rs index 1c08ec041..caa773190 100644 --- a/nativelink-scheduler/src/simple_scheduler.rs +++ b/nativelink-scheduler/src/simple_scheduler.rs @@ -34,9 +34,10 @@ use nativelink_util::metrics_utils::{ MetricsComponent, Registry, }; use nativelink_util::platform_properties::PlatformPropertyValue; +use nativelink_util::spawn; +use nativelink_util::task::JoinHandleDropGuard; use parking_lot::{Mutex, MutexGuard}; use tokio::sync::{watch, Notify}; -use tokio::task::JoinHandle; use tokio::time::Duration; use tracing::{event, Level}; @@ -689,8 +690,9 @@ impl SimpleSchedulerImpl { pub struct SimpleScheduler { inner: Arc>, platform_property_manager: Arc, - task_worker_matching_future: JoinHandle<()>, metrics: Arc, + // Triggers `drop()`` call if scheduler is dropped. + _task_worker_matching_future: JoinHandleDropGuard<()>, } impl SimpleScheduler { @@ -758,29 +760,32 @@ impl SimpleScheduler { Self { inner, platform_property_manager, - task_worker_matching_future: tokio::spawn(async move { - // Break out of the loop only when the inner is dropped. - loop { - tasks_or_workers_change_notify.notified().await; - match weak_inner.upgrade() { - // Note: According to `parking_lot` documentation, the default - // `Mutex` implementation is eventual fairness, so we don't - // really need to worry about this thread taking the lock - // starving other threads too much. - Some(inner_mux) => { - let mut inner = inner_mux.lock(); - let timer = metrics_for_do_try_match.do_try_match.begin_timer(); - inner.do_try_match(); - timer.measure(); - } - // If the inner went away it means the scheduler is shutting - // down, so we need to resolve our future. - None => return, - }; - on_matching_engine_run().await; + _task_worker_matching_future: spawn!( + "simple_scheduler_task_worker_matching", + async move { + // Break out of the loop only when the inner is dropped. + loop { + tasks_or_workers_change_notify.notified().await; + match weak_inner.upgrade() { + // Note: According to `parking_lot` documentation, the default + // `Mutex` implementation is eventual fairness, so we don't + // really need to worry about this thread taking the lock + // starving other threads too much. + Some(inner_mux) => { + let mut inner = inner_mux.lock(); + let timer = metrics_for_do_try_match.do_try_match.begin_timer(); + inner.do_try_match(); + timer.measure(); + } + // If the inner went away it means the scheduler is shutting + // down, so we need to resolve our future. + None => return, + }; + on_matching_engine_run().await; + } + // Unreachable. } - // Unreachable. - }), + ), metrics, } } @@ -982,12 +987,6 @@ impl WorkerScheduler for SimpleScheduler { } } -impl Drop for SimpleScheduler { - fn drop(&mut self) { - self.task_worker_matching_future.abort(); - } -} - impl MetricsComponent for SimpleScheduler { fn gather_metrics(&self, c: &mut CollectorState) { self.metrics.gather_metrics(c); diff --git a/nativelink-service/src/bytestream_server.rs b/nativelink-service/src/bytestream_server.rs index b94f080cb..c165ebcbe 100644 --- a/nativelink-service/src/bytestream_server.rs +++ b/nativelink-service/src/bytestream_server.rs @@ -40,9 +40,10 @@ use nativelink_util::buf_channel::{ use nativelink_util::common::DigestInfo; use nativelink_util::proto_stream_utils::WriteRequestStreamWrapper; use nativelink_util::resource_info::ResourceInfo; +use nativelink_util::spawn; use nativelink_util::store_trait::{Store, UploadSizeInfo}; +use nativelink_util::task::JoinHandleDropGuard; use parking_lot::Mutex; -use tokio::task::AbortHandle; use tokio::time::sleep; use tonic::{Request, Response, Status, Streaming}; use tracing::{enabled, error_span, event, instrument, Instrument, Level}; @@ -110,15 +111,14 @@ impl<'a> Drop for ActiveStreamGuard<'a> { let sleep_fn = self.bytestream_server.sleep_fn.clone(); active_uploads_slot.1 = Some(IdleStream { stream_state, - abort_timeout_handle: tokio::spawn(async move { + _timeout_streaam_drop_guard: spawn!("bytestream_idle_stream_timeout", async move { (*sleep_fn)().await; if let Some(active_uploads) = weak_active_uploads.upgrade() { let mut active_uploads = active_uploads.lock(); event!(Level::INFO, msg = "Removing idle stream", uuid = ?uuid); active_uploads.remove(&uuid); } - }) - .abort_handle(), + }), }); } } @@ -129,7 +129,7 @@ impl<'a> Drop for ActiveStreamGuard<'a> { #[derive(Debug)] struct IdleStream { stream_state: StreamState, - abort_timeout_handle: AbortHandle, + _timeout_streaam_drop_guard: JoinHandleDropGuard<()>, } impl IdleStream { @@ -138,7 +138,6 @@ impl IdleStream { bytes_received: Arc, bytestream_server: &ByteStreamServer, ) -> ActiveStreamGuard<'_> { - self.abort_timeout_handle.abort(); ActiveStreamGuard { stream_state: Some(self.stream_state), bytes_received, diff --git a/nativelink-service/src/health_server.rs b/nativelink-service/src/health_server.rs index 1bd9af94a..dfce9d18e 100644 --- a/nativelink-service/src/health_server.rs +++ b/nativelink-service/src/health_server.rs @@ -22,7 +22,9 @@ use hyper::{Body, Request, Response, StatusCode}; use nativelink_util::health_utils::{ HealthRegistry, HealthStatus, HealthStatusDescription, HealthStatusReporter, }; +use nativelink_util::task::instrument_future; use tower::Service; +use tracing::error_span; /// Content type header value for JSON. const JSON_CONTENT_TYPE: &str = "application/json; charset=utf-8"; @@ -49,34 +51,37 @@ impl Service> for HealthServer { fn call(&mut self, _req: Request) -> Self::Future { let health_registry = self.health_registry.clone(); - Box::pin(async move { - let health_status_descriptions: Vec = - health_registry.health_status_report().collect().await; - match serde_json5::to_string(&health_status_descriptions) { - Ok(body) => { - let contains_failed_report = - health_status_descriptions.iter().any(|description| { - matches!(description.status, HealthStatus::Failed { .. }) - }); - let status_code = if contains_failed_report { - StatusCode::SERVICE_UNAVAILABLE - } else { - StatusCode::OK - }; + Box::pin(instrument_future( + async move { + let health_status_descriptions: Vec = + health_registry.health_status_report().collect().await; + match serde_json5::to_string(&health_status_descriptions) { + Ok(body) => { + let contains_failed_report = + health_status_descriptions.iter().any(|description| { + matches!(description.status, HealthStatus::Failed { .. }) + }); + let status_code = if contains_failed_report { + StatusCode::SERVICE_UNAVAILABLE + } else { + StatusCode::OK + }; - Ok(Response::builder() - .status(status_code) + Ok(Response::builder() + .status(status_code) + .header(CONTENT_TYPE, HeaderValue::from_static(JSON_CONTENT_TYPE)) + .body(Body::from(body)) + .unwrap()) + } + + Err(e) => Ok(Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) .header(CONTENT_TYPE, HeaderValue::from_static(JSON_CONTENT_TYPE)) - .body(Body::from(body)) - .unwrap()) + .body(Body::from(format!("Internal Failure: {e:?}"))) + .unwrap()), } - - Err(e) => Ok(Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) - .header(CONTENT_TYPE, HeaderValue::from_static(JSON_CONTENT_TYPE)) - .body(Body::from(format!("Internal Failure: {e:?}"))) - .unwrap()), - } - }) + }, + error_span!("health_server_call"), + )) } } diff --git a/nativelink-service/src/worker_api_server.rs b/nativelink-service/src/worker_api_server.rs index 47bcca454..a4a731c73 100644 --- a/nativelink-service/src/worker_api_server.rs +++ b/nativelink-service/src/worker_api_server.rs @@ -29,13 +29,14 @@ use nativelink_proto::com::github::trace_machina::nativelink::remote_execution:: }; use nativelink_scheduler::worker::{Worker, WorkerId}; use nativelink_scheduler::worker_scheduler::WorkerScheduler; +use nativelink_util::background_spawn; use nativelink_util::action_messages::ActionInfoHashKey; use nativelink_util::common::DigestInfo; use nativelink_util::platform_properties::PlatformProperties; use tokio::sync::mpsc; use tokio::time::interval; use tonic::{Request, Response, Status}; -use tracing::{error_span, event, instrument, Instrument, Level}; +use tracing::{event, instrument, Level}; use uuid::Uuid; pub type ConnectWorkerStream = @@ -58,7 +59,7 @@ impl WorkerApiServer { // event our ExecutionServer dies. Our scheduler is a weak ref, so the spawn will // eventually see the Arc went away and return. let weak_scheduler = Arc::downgrade(scheduler); - tokio::spawn(async move { + background_spawn!("worker_api_server", async move { let mut ticker = interval(Duration::from_secs(1)); loop { ticker.tick().await; @@ -70,18 +71,14 @@ impl WorkerApiServer { if let Err(err) = scheduler.remove_timedout_workers(timestamp.as_secs()).await { - event!( - Level::ERROR, - ?err, - "Failed to remove_timedout_workers", - ); + event!(Level::ERROR, ?err, "Failed to remove_timedout_workers",); } } // If we fail to upgrade, our service is probably destroyed, so return. None => return, } } - }.instrument(error_span!("worker_api_server"))); + }); } Self::new_with_now_fn( diff --git a/nativelink-service/tests/bytestream_server_test.rs b/nativelink-service/tests/bytestream_server_test.rs index cf4fc431c..f2a249181 100644 --- a/nativelink-service/tests/bytestream_server_test.rs +++ b/nativelink-service/tests/bytestream_server_test.rs @@ -26,8 +26,10 @@ use nativelink_service::bytestream_server::ByteStreamServer; use nativelink_store::default_store_factory::store_factory; use nativelink_store::store_manager::StoreManager; use nativelink_util::common::{encode_stream_proto, DigestInfo}; +use nativelink_util::spawn; +use nativelink_util::task::JoinHandleDropGuard; use prometheus_client::registry::Registry; -use tokio::task::{yield_now, JoinHandle}; +use tokio::task::yield_now; use tonic::{Request, Response}; const INSTANCE_NAME: &str = "foo_instance_name"; @@ -100,10 +102,13 @@ pub mod write_tests { None, ); - let join_handle = tokio::spawn(async move { - let response_future = bs_server.write(Request::new(stream)); - response_future.await - }); + let join_handle = spawn!( + "chunked_stream_receives_all_data_write_stream", + async move { + let response_future = bs_server.write(Request::new(stream)); + response_future.await + }, + ); (tx, join_handle) }; // Send data. @@ -185,7 +190,7 @@ pub mod write_tests { ) -> Result< ( Sender, - JoinHandle<( + JoinHandleDropGuard<( Result, tonic::Status>, ByteStreamServer, )>, @@ -202,7 +207,7 @@ pub mod write_tests { None, ); - let join_handle = tokio::spawn(async move { + let join_handle = spawn!("resume_write_success_write_stream", async move { let response_future = bs_server.write(Request::new(stream)); (response_future.await, bs_server) }); @@ -281,7 +286,7 @@ pub mod write_tests { ) -> Result< ( Sender, - JoinHandle<( + JoinHandleDropGuard<( Result, tonic::Status>, ByteStreamServer, )>, @@ -298,7 +303,7 @@ pub mod write_tests { None, ); - let join_handle = tokio::spawn(async move { + let join_handle = spawn!("restart_write_success_write_stream", async move { let response_future = bs_server.write(Request::new(stream)); (response_future.await, bs_server) }); @@ -383,7 +388,7 @@ pub mod write_tests { ) -> Result< ( Sender, - JoinHandle<( + JoinHandleDropGuard<( Result, tonic::Status>, ByteStreamServer, )>, @@ -400,10 +405,13 @@ pub mod write_tests { None, ); - let join_handle = tokio::spawn(async move { - let response_future = bs_server.write(Request::new(stream)); - (response_future.await, bs_server) - }); + let join_handle = spawn!( + "restart_mid_stream_write_success_write_stream", + async move { + let response_future = bs_server.write(Request::new(stream)); + (response_future.await, bs_server) + }, + ); Ok((tx, join_handle)) } let (mut tx, join_handle) = setup_stream(bs_server).await?; @@ -573,7 +581,7 @@ pub mod write_tests { ) -> Result< ( Sender, - JoinHandle<( + JoinHandleDropGuard<( Result, tonic::Status>, ByteStreamServer, )>, @@ -590,7 +598,7 @@ pub mod write_tests { None, ); - let join_handle = tokio::spawn(async move { + let join_handle = spawn!("out_of_order_data_fails_write_stream", async move { let response_future = bs_server.write(Request::new(stream)); (response_future.await, bs_server) }); @@ -658,7 +666,7 @@ pub mod write_tests { ) -> Result< ( Sender, - JoinHandle<( + JoinHandleDropGuard<( Result, tonic::Status>, ByteStreamServer, )>, @@ -675,7 +683,7 @@ pub mod write_tests { None, ); - let join_handle = tokio::spawn(async move { + let join_handle = spawn!("upload_zero_byte_chunk_write_stream", async move { let response_future = bs_server.write(Request::new(stream)); (response_future.await, bs_server) }); @@ -723,7 +731,7 @@ pub mod write_tests { ) -> Result< ( Sender, - JoinHandle<( + JoinHandleDropGuard<( Result, tonic::Status>, ByteStreamServer, )>, @@ -740,7 +748,7 @@ pub mod write_tests { None, ); - let join_handle = tokio::spawn(async move { + let join_handle = spawn!("disallow_negative_write_offset_write_stream", async move { let response_future = bs_server.write(Request::new(stream)); (response_future.await, bs_server) }); @@ -781,7 +789,7 @@ pub mod write_tests { ) -> Result< ( Sender, - JoinHandle<( + JoinHandleDropGuard<( Result, tonic::Status>, ByteStreamServer, )>, @@ -798,7 +806,7 @@ pub mod write_tests { None, ); - let join_handle = tokio::spawn(async move { + let join_handle = spawn!("out_of_sequence_write_write_stream", async move { let response_future = bs_server.write(Request::new(stream)); (response_future.await, bs_server) }); @@ -1030,7 +1038,7 @@ pub mod query_tests { ); let bs_server_clone = bs_server.clone(); - let join_handle = tokio::spawn(async move { + let join_handle = spawn!("query_write_status_smoke_test_write_stream", async move { let response_future = bs_server_clone.write(Request::new(stream)); response_future.await }); diff --git a/nativelink-store/src/compression_store.rs b/nativelink-store/src/compression_store.rs index 442ab2004..fcbae2a8f 100644 --- a/nativelink-store/src/compression_store.rs +++ b/nativelink-store/src/compression_store.rs @@ -27,9 +27,10 @@ use nativelink_error::{error_if, make_err, Code, Error, ResultExt}; use nativelink_util::buf_channel::{ make_buf_channel_pair, DropCloserReadHalf, DropCloserWriteHalf, }; -use nativelink_util::common::{DigestInfo, JoinHandleDropGuard}; +use nativelink_util::common::DigestInfo; use nativelink_util::health_utils::{default_health_status_indicator, HealthStatusIndicator}; use nativelink_util::metrics_utils::Registry; +use nativelink_util::spawn; use nativelink_util::store_trait::{Store, UploadSizeInfo}; use serde::{Deserialize, Serialize}; @@ -262,7 +263,7 @@ impl Store for CompressionStore { let (mut tx, rx) = make_buf_channel_pair(); let inner_store = self.inner_store.clone(); - let update_fut = JoinHandleDropGuard::new(tokio::spawn(async move { + let update_fut = spawn!("compression_store_update_spawn", async move { Pin::new(inner_store.as_ref()) .update( digest, @@ -271,15 +272,15 @@ impl Store for CompressionStore { ) .await .err_tip(|| "Inner store update in compression store failed") - })) - .map(|result| { - match result.err_tip(|| "Failed to run compression update spawn") { + }) + .map( + |result| match result.err_tip(|| "Failed to run compression update spawn") { Ok(inner_result) => { inner_result.err_tip(|| "Compression underlying store update failed") } Err(e) => Err(e), - } - }); + }, + ); let write_fut = async move { { @@ -405,12 +406,12 @@ impl Store for CompressionStore { let (tx, mut rx) = make_buf_channel_pair(); let inner_store = self.inner_store.clone(); - let get_part_fut = JoinHandleDropGuard::new(tokio::spawn(async move { + let get_part_fut = spawn!("compression_store_get_part_spawn", async move { Pin::new(inner_store.as_ref()) .get_part(digest, tx, 0, None) .await .err_tip(|| "Inner store get in compression store failed") - })) + }) .map( |result| match result.err_tip(|| "Failed to run compression get spawn") { Ok(inner_result) => { diff --git a/nativelink-store/src/filesystem_store.rs b/nativelink-store/src/filesystem_store.rs index dad11e1a3..7d806a690 100644 --- a/nativelink-store/src/filesystem_store.rs +++ b/nativelink-store/src/filesystem_store.rs @@ -35,11 +35,11 @@ use nativelink_util::evicting_map::{EvictingMap, LenEntry}; use nativelink_util::health_utils::{HealthRegistryBuilder, HealthStatus, HealthStatusIndicator}; use nativelink_util::metrics_utils::{Collector, CollectorState, MetricsComponent, Registry}; use nativelink_util::store_trait::{Store, StoreOptimizations, UploadSizeInfo}; +use nativelink_util::{background_spawn, spawn_blocking}; use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt, SeekFrom}; -use tokio::task::spawn_blocking; use tokio::time::{sleep, timeout, Sleep}; use tokio_stream::wrappers::ReadDirStream; -use tracing::{event, trace_span, Instrument, Level}; +use tracing::{event, Level}; use crate::cas_utils::is_zero_digest; @@ -112,21 +112,18 @@ impl Drop for EncodedFilePath { shared_context .active_drop_spawns .fetch_add(1, Ordering::Relaxed); - tokio::spawn( - async move { - event!(Level::INFO, ?file_path, "File deleted",); - let result = fs::remove_file(&file_path) - .await - .err_tip(|| format!("Failed to remove file {file_path:?}")); - if let Err(err) = result { - event!(Level::ERROR, ?file_path, ?err, "Failed to delete file",); - } - shared_context - .active_drop_spawns - .fetch_sub(1, Ordering::Relaxed); + background_spawn!("filesystem_delete_file", async move { + event!(Level::INFO, ?file_path, "File deleted",); + let result = fs::remove_file(&file_path) + .await + .err_tip(|| format!("Failed to remove file {file_path:?}")); + if let Err(err) = result { + event!(Level::ERROR, ?file_path, ?err, "Failed to delete file",); } - .instrument(trace_span!("delete_file")), - ); + shared_context + .active_drop_spawns + .fetch_sub(1, Ordering::Relaxed); + }); } } @@ -326,7 +323,7 @@ impl LenEntry for FileEntryImpl { let result = self .get_file_path_locked(move |full_content_path| async move { let full_content_path = full_content_path.to_os_string(); - spawn_blocking(move || { + spawn_blocking!("filesystem_touch_set_mtime", move || { set_file_atime(&full_content_path, FileTime::now()).err_tip(|| { format!("Failed to touch file in filesystem store {full_content_path:?}") }) @@ -668,52 +665,49 @@ impl FilesystemStore { // We need to guarantee that this will get to the end even if the parent future is dropped. // See: https://github.com/TraceMachina/nativelink/issues/495 - tokio::spawn( - async move { - let mut encoded_file_path = entry.get_encoded_file_path().write().await; - let final_path = get_file_path_raw( - &PathType::Content, - encoded_file_path.shared_context.as_ref(), - &digest, - ); + background_spawn!("filesystem_store_emplace_file", async move { + let mut encoded_file_path = entry.get_encoded_file_path().write().await; + let final_path = get_file_path_raw( + &PathType::Content, + encoded_file_path.shared_context.as_ref(), + &digest, + ); - evicting_map.insert(digest, entry.clone()).await; - - let from_path = encoded_file_path.get_file_path(); - // Internally tokio spawns fs commands onto a blocking thread anyways. - // Since we are already on a blocking thread, we just need the `fs` wrapper to manage - // an open-file permit (ensure we don't open too many files at once). - let result = (rename_fn)(&from_path, &final_path) - .err_tip(|| format!("Failed to rename temp file to final path {final_path:?}")); - - // In the event our move from temp file to final file fails we need to ensure we remove - // the entry from our map. - // Remember: At this point it is possible for another thread to have a reference to - // `entry`, so we can't delete the file, only drop() should ever delete files. - if let Err(err) = result { - event!( - Level::ERROR, - ?err, - ?from_path, - ?final_path, - "Failed to rename file", - ); - // Warning: To prevent deadlock we need to release our lock or during `remove_if()` - // it will call `unref()`, which triggers a write-lock on `encoded_file_path`. - drop(encoded_file_path); - // It is possible that the item in our map is no longer the item we inserted, - // So, we need to conditionally remove it only if the pointers are the same. - evicting_map - .remove_if(&digest, |map_entry| Arc::::ptr_eq(map_entry, &entry)) - .await; - return Err(err); - } - encoded_file_path.path_type = PathType::Content; - encoded_file_path.digest = digest; - Ok(()) + evicting_map.insert(digest, entry.clone()).await; + + let from_path = encoded_file_path.get_file_path(); + // Internally tokio spawns fs commands onto a blocking thread anyways. + // Since we are already on a blocking thread, we just need the `fs` wrapper to manage + // an open-file permit (ensure we don't open too many files at once). + let result = (rename_fn)(&from_path, &final_path) + .err_tip(|| format!("Failed to rename temp file to final path {final_path:?}")); + + // In the event our move from temp file to final file fails we need to ensure we remove + // the entry from our map. + // Remember: At this point it is possible for another thread to have a reference to + // `entry`, so we can't delete the file, only drop() should ever delete files. + if let Err(err) = result { + event!( + Level::ERROR, + ?err, + ?from_path, + ?final_path, + "Failed to rename file", + ); + // Warning: To prevent deadlock we need to release our lock or during `remove_if()` + // it will call `unref()`, which triggers a write-lock on `encoded_file_path`. + drop(encoded_file_path); + // It is possible that the item in our map is no longer the item we inserted, + // So, we need to conditionally remove it only if the pointers are the same. + evicting_map + .remove_if(&digest, |map_entry| Arc::::ptr_eq(map_entry, &entry)) + .await; + return Err(err); } - .instrument(trace_span!("emplace_file")), - ) + encoded_file_path.path_type = PathType::Content; + encoded_file_path.digest = digest; + Ok(()) + }) .await .err_tip(|| "Failed to create spawn in filesystem store update_file")? } diff --git a/nativelink-store/src/redis_store.rs b/nativelink-store/src/redis_store.rs index 21fc52938..75370d7f6 100644 --- a/nativelink-store/src/redis_store.rs +++ b/nativelink-store/src/redis_store.rs @@ -22,6 +22,7 @@ use async_trait::async_trait; use bytes::Bytes; use futures::future::{BoxFuture, FutureExt, Shared}; use nativelink_error::{error_if, make_err, Code, Error, ResultExt}; +use nativelink_util::background_spawn; use nativelink_util::buf_channel::{DropCloserReadHalf, DropCloserWriteHalf}; use nativelink_util::common::DigestInfo; use nativelink_util::health_utils::{HealthRegistryBuilder, HealthStatus, HealthStatusIndicator}; @@ -29,7 +30,6 @@ use nativelink_util::metrics_utils::{Collector, CollectorState, MetricsComponent use nativelink_util::store_trait::{Store, UploadSizeInfo}; use redis::aio::{ConnectionLike, ConnectionManager}; use redis::AsyncCommands; -use tracing::{error_span, Instrument}; use crate::cas_utils::is_zero_digest; @@ -74,14 +74,11 @@ impl RedisStore { let conn_fut_clone = conn_fut.clone(); // Start connecting to redis, but don't block our construction on it. - tokio::spawn( - async move { - if let Err(e) = conn_fut_clone.await { - make_err!(Code::Unavailable, "Failed to connect to Redis: {:?}", e); - } + background_spawn!("redis_initial_connection", async move { + if let Err(e) = conn_fut_clone.await { + make_err!(Code::Unavailable, "Failed to connect to Redis: {:?}", e); } - .instrument(error_span!("redis_initial_connection")), - ); + }); let lazy_conn = LazyConnection::Future(conn_fut); diff --git a/nativelink-store/tests/compression_store_test.rs b/nativelink-store/tests/compression_store_test.rs index 8d908e04a..e3f127541 100644 --- a/nativelink-store/tests/compression_store_test.rs +++ b/nativelink-store/tests/compression_store_test.rs @@ -28,7 +28,8 @@ use nativelink_store::compression_store::{ }; use nativelink_store::memory_store::MemoryStore; use nativelink_util::buf_channel::make_buf_channel_pair; -use nativelink_util::common::{DigestInfo, JoinHandleDropGuard}; +use nativelink_util::common::DigestInfo; +use nativelink_util::spawn; use nativelink_util::store_trait::{Store, UploadSizeInfo}; use rand::rngs::SmallRng; use rand::{Rng, SeedableRng}; @@ -526,13 +527,13 @@ mod compression_store_tests { let (mut writer, mut reader) = make_buf_channel_pair(); - let _drop_guard = JoinHandleDropGuard::new(tokio::spawn(async move { + let _drop_guard = spawn!("get_part_is_zero_digest", async move { let _ = store .as_ref() .get_part_ref(digest, &mut writer, 0, None) .await .err_tip(|| "Failed to get_part_ref"); - })); + }); let file_data = reader .consume(Some(1024)) diff --git a/nativelink-store/tests/filesystem_store_test.rs b/nativelink-store/tests/filesystem_store_test.rs index 17773fc94..95ab9b753 100644 --- a/nativelink-store/tests/filesystem_store_test.rs +++ b/nativelink-store/tests/filesystem_store_test.rs @@ -38,9 +38,11 @@ use nativelink_store::filesystem_store::{ digest_from_filename, EncodedFilePath, FileEntry, FileEntryImpl, FilesystemStore, }; use nativelink_util::buf_channel::make_buf_channel_pair; -use nativelink_util::common::{fs, DigestInfo, JoinHandleDropGuard}; +use nativelink_util::common::{fs, DigestInfo}; use nativelink_util::evicting_map::LenEntry; use nativelink_util::store_trait::{Store, UploadSizeInfo}; +use nativelink_util::task::instrument_future; +use nativelink_util::{background_spawn, spawn}; use once_cell::sync::Lazy; use rand::{thread_rng, Rng}; use sha2::{Digest, Sha256}; @@ -165,19 +167,26 @@ impl Drop for TestFileEntry 0 { + while shared_context.active_drop_spawns.load(Ordering::Acquire) > 0 { tokio::task::yield_now().await; } - }); - }); + }, + tracing::error_span!("test_file_entry_drop"), + ); + #[allow(clippy::disallowed_methods)] + let thread_handle = { + std::thread::spawn(move || { + let rt = tokio::runtime::Builder::new_current_thread() + .build() + .unwrap(); + rt.block_on(fut) + }) + }; thread_handle.join().unwrap(); // At this point we can guarantee our file drop spawn has completed. Hooks::on_drop(self); @@ -412,11 +421,14 @@ mod filesystem_store_tests { let (writer, mut reader) = make_buf_channel_pair(); let store_clone = store.clone(); let digest1_clone = digest1; - tokio::spawn(async move { - Pin::new(store_clone.as_ref()) - .get(digest1_clone, writer) - .await - }); + background_spawn!( + "file_continues_to_stream_on_content_replace_test_store_get", + async move { + Pin::new(store_clone.as_ref()) + .get(digest1_clone, writer) + .await + }, + ); { // Check to ensure our first byte has been received. The future should be stalled here. @@ -542,7 +554,10 @@ mod filesystem_store_tests { let mut reader = { let (writer, reader) = make_buf_channel_pair(); let store_clone = store.clone(); - tokio::spawn(async move { Pin::new(store_clone.as_ref()).get(digest1, writer).await }); + background_spawn!( + "file_gets_cleans_up_on_cache_eviction_store_get", + async move { Pin::new(store_clone.as_ref()).get(digest1, writer).await }, + ); reader }; // Ensure we have received 1 byte in our buffer. This will ensure we have a reference to @@ -888,7 +903,7 @@ mod filesystem_store_tests { struct LocalHooks {} impl FileEntryHooks for LocalHooks { fn on_drop(_file_entry: &Fe) { - tokio::spawn(FILE_DELETED_BARRIER.wait()); + background_spawn!("rename_on_insert_fails_due_to_filesystem_error_proper_cleanup_happens_local_hooks_on_drop", FILE_DELETED_BARRIER.wait()); } } @@ -976,9 +991,12 @@ mod filesystem_store_tests { fs::remove_dir_all(&content_path).await?; // Because send_eof() waits for shutdown of the rx side, we cannot just await in this thread. - tokio::spawn(async move { - tx.send_eof().unwrap(); - }); + background_spawn!( + "rename_on_insert_fails_due_to_filesystem_error_proper_cleanup_happens_send_eof", + async move { + tx.send_eof().unwrap(); + }, + ); // Now finish waiting on update(). This should reuslt in an error because we deleted our dest // folder. @@ -1008,7 +1026,6 @@ mod filesystem_store_tests { None, "Entry should not be in store" ); - Ok(()) } @@ -1044,11 +1061,11 @@ mod filesystem_store_tests { let store_clone = store.clone(); let digest_clone = digest; - let _drop_guard = JoinHandleDropGuard::new(tokio::spawn(async move { + let _drop_guard = spawn!("get_part_timeout_test_get", async move { Pin::new(store_clone.as_ref()) .get(digest_clone, writer) .await - })); + }); let file_data = reader .consume(Some(1024)) @@ -1090,12 +1107,12 @@ mod filesystem_store_tests { let store_clone = store.clone(); let (mut writer, mut reader) = make_buf_channel_pair(); - let _drop_guard = JoinHandleDropGuard::new(tokio::spawn(async move { + let _drop_guard = spawn!("get_part_is_zero_digest_get_part_ref", async move { let _ = Pin::new(store_clone.as_ref()) .get_part_ref(digest, &mut writer, 0, None) .await .err_tip(|| "Failed to get_part_ref"); - })); + }); let file_data = reader .consume(Some(1024)) diff --git a/nativelink-store/tests/memory_store_test.rs b/nativelink-store/tests/memory_store_test.rs index f9cd0ad25..4dcfe74c6 100644 --- a/nativelink-store/tests/memory_store_test.rs +++ b/nativelink-store/tests/memory_store_test.rs @@ -21,7 +21,8 @@ use nativelink_error::{Error, ResultExt}; use nativelink_macro::nativelink_test; use nativelink_store::memory_store::MemoryStore; use nativelink_util::buf_channel::make_buf_channel_pair; -use nativelink_util::common::{DigestInfo, JoinHandleDropGuard}; +use nativelink_util::common::DigestInfo; +use nativelink_util::spawn; use nativelink_util::store_trait::Store; use sha2::{Digest, Sha256}; @@ -255,12 +256,12 @@ mod memory_store_tests { let store_clone = store.clone(); let (mut writer, mut reader) = make_buf_channel_pair(); - let _drop_guard = JoinHandleDropGuard::new(tokio::spawn(async move { + let _drop_guard = spawn!("get_part_is_zero_digest", async move { let _ = Pin::new(store_clone.as_ref()) .get_part_ref(digest, &mut writer, 0, None) .await .err_tip(|| "Failed to get_part_ref"); - })); + }); let file_data = reader .consume(Some(1024)) diff --git a/nativelink-store/tests/s3_store_test.rs b/nativelink-store/tests/s3_store_test.rs index 6e737d8c8..f6c5f43f5 100644 --- a/nativelink-store/tests/s3_store_test.rs +++ b/nativelink-store/tests/s3_store_test.rs @@ -30,7 +30,8 @@ use nativelink_error::{make_input_err, Error, ResultExt}; use nativelink_macro::nativelink_test; use nativelink_store::s3_store::S3Store; use nativelink_util::buf_channel::make_buf_channel_pair; -use nativelink_util::common::{DigestInfo, JoinHandleDropGuard}; +use nativelink_util::common::DigestInfo; +use nativelink_util::spawn; use nativelink_util::store_trait::{Store, UploadSizeInfo}; use sha2::{Digest, Sha256}; @@ -242,7 +243,7 @@ mod s3_store_tests { let send_data_copy = send_data.clone(); // Create spawn that is responsible for sending the stream of data // to the S3Store and processing/forwarding to the S3 backend. - let spawn_fut = tokio::spawn(async move { + let spawn_fut = spawn!("simple_update_ac", async move { tokio::try_join!(update_fut, async move { for i in 0..CONTENT_LENGTH { tx.send(send_data_copy.slice(i..(i + 1))).await?; @@ -622,12 +623,12 @@ mod s3_store_tests { let store_clone = store.clone(); let (mut writer, mut reader) = make_buf_channel_pair(); - let _drop_guard = JoinHandleDropGuard::new(tokio::spawn(async move { + let _drop_guard = spawn!("get_part_is_zero_digest", async move { let _ = Pin::new(store_clone.as_ref()) .get_part_ref(digest, &mut writer, 0, None) .await .err_tip(|| "Failed to get_part_ref"); - })); + }); let file_data = reader .consume(Some(1024)) diff --git a/nativelink-store/tests/verify_store_test.rs b/nativelink-store/tests/verify_store_test.rs index bcef1316d..537a104d9 100644 --- a/nativelink-store/tests/verify_store_test.rs +++ b/nativelink-store/tests/verify_store_test.rs @@ -25,6 +25,7 @@ mod verify_store_tests { use nativelink_store::verify_store::VerifyStore; use nativelink_util::buf_channel::make_buf_channel_pair; use nativelink_util::common::DigestInfo; + use nativelink_util::spawn; use nativelink_util::store_trait::{Store, UploadSizeInfo}; use pretty_assertions::assert_eq; // Must be declared in every module. @@ -158,11 +159,14 @@ mod verify_store_tests { let digest = DigestInfo::try_new(VALID_HASH1, 6).unwrap(); let digest_clone = digest; - let future = tokio::spawn(async move { - Pin::new(&store_owned) - .update(digest_clone, rx, UploadSizeInfo::ExactSize(6)) - .await - }); + let future = spawn!( + "verify_size_true_suceeds_on_multi_chunk_stream_update", + async move { + Pin::new(&store_owned) + .update(digest_clone, rx, UploadSizeInfo::ExactSize(6)) + .await + }, + ); tx.send("foo".into()).await?; tx.send("bar".into()).await?; tx.send_eof()?; diff --git a/nativelink-util/BUILD.bazel b/nativelink-util/BUILD.bazel index d7a3b92dc..d6c01d866 100644 --- a/nativelink-util/BUILD.bazel +++ b/nativelink-util/BUILD.bazel @@ -25,6 +25,7 @@ rust_library( "src/resource_info.rs", "src/retry.rs", "src/store_trait.rs", + "src/task.rs", "src/tls_utils.rs", "src/write_counter.rs", ], diff --git a/nativelink-util/src/common.rs b/nativelink-util/src/common.rs index 7add99b37..f10612988 100644 --- a/nativelink-util/src/common.rs +++ b/nativelink-util/src/common.rs @@ -15,10 +15,7 @@ use std::cmp::Ordering; use std::collections::HashMap; use std::fmt; -use std::future::Future; use std::hash::Hash; -use std::pin::Pin; -use std::task::{Context, Poll}; use bytes::{BufMut, Bytes, BytesMut}; use hex::FromHex; @@ -26,7 +23,6 @@ use nativelink_error::{make_input_err, Error, ResultExt}; use nativelink_proto::build::bazel::remote::execution::v2::Digest; use prost::Message; use serde::{Deserialize, Serialize}; -use tokio::task::{JoinError, JoinHandle}; pub use crate::fs; @@ -143,32 +139,6 @@ impl From<&DigestInfo> for Digest { } } -/// Simple wrapper that will abort a future that is running in another spawn in the -/// event that this handle gets dropped. -pub struct JoinHandleDropGuard { - inner: JoinHandle, -} - -impl JoinHandleDropGuard { - pub fn new(inner: JoinHandle) -> Self { - Self { inner } - } -} - -impl Future for JoinHandleDropGuard { - type Output = Result; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - Pin::new(&mut self.inner).poll(cx) - } -} - -impl Drop for JoinHandleDropGuard { - fn drop(&mut self) { - self.inner.abort(); - } -} - // Simple utility trait that makes it easier to apply `.try_map` to Vec. // This will convert one vector into another vector with a different type. pub trait VecExt { diff --git a/nativelink-util/src/connection_manager.rs b/nativelink-util/src/connection_manager.rs index 7002e83dd..1d80f63c3 100644 --- a/nativelink-util/src/connection_manager.rs +++ b/nativelink-util/src/connection_manager.rs @@ -26,6 +26,7 @@ use tokio::sync::{mpsc, oneshot}; use tonic::transport::{channel, Channel, Endpoint}; use tracing::{event, Level}; +use crate::background_spawn; use crate::retry::{self, Retrier, RetryResult}; /// A helper utility that enables management of a suite of connections to an @@ -148,7 +149,7 @@ impl ConnectionManager { retry, ), }; - tokio::spawn(async move { + background_spawn!("connection_manager_worker_spawn", async move { worker .service_requests(connections_per_endpoint, worker_rx, connection_rx) .await; diff --git a/nativelink-util/src/digest_hasher.rs b/nativelink-util/src/digest_hasher.rs index b0ddb5f35..87b5852e6 100644 --- a/nativelink-util/src/digest_hasher.rs +++ b/nativelink-util/src/digest_hasher.rs @@ -23,8 +23,8 @@ use nativelink_proto::build::bazel::remote::execution::v2::digest_function::Valu use sha2::{Digest, Sha256}; use tokio::io::{AsyncRead, AsyncReadExt}; -use crate::common::{DigestInfo, JoinHandleDropGuard}; -use crate::fs; +use crate::common::DigestInfo; +use crate::{fs, spawn_blocking}; static DEFAULT_DIGEST_HASHER_FUNC: OnceLock = OnceLock::new(); @@ -224,7 +224,7 @@ impl DigestHasher for DigestHasherImpl { match self.hash_func_impl { DigestHasherFuncImpl::Sha256(_) => self.hash_file(file).await, DigestHasherFuncImpl::Blake3(mut hasher) => { - JoinHandleDropGuard::new(tokio::task::spawn_blocking(move || { + spawn_blocking!("digest_for_file", move || { hasher.update_mmap(file.get_path()).map_err(|e| { make_err!(Code::Internal, "Error in blake3's update_mmap: {e:?}") })?; @@ -232,7 +232,7 @@ impl DigestHasher for DigestHasherImpl { DigestInfo::new(hasher.finalize().into(), hasher.count() as i64), file, )) - })) + }) .await .err_tip(|| "Could not spawn blocking task in digest_for_file")? } diff --git a/nativelink-util/src/fs.rs b/nativelink-util/src/fs.rs index ab36bf488..3b878400d 100644 --- a/nativelink-util/src/fs.rs +++ b/nativelink-util/src/fs.rs @@ -35,6 +35,8 @@ use tokio::sync::{Semaphore, SemaphorePermit}; use tokio::time::timeout; use tracing::{event, Level}; +use crate::spawn_blocking; + /// Default read buffer size when reading to/from disk. pub const DEFAULT_READ_BUFF_SIZE: usize = 16384; @@ -302,7 +304,7 @@ where T: Send + 'static, { let permit = get_permit().await?; - tokio::task::spawn_blocking(move || f(permit)) + spawn_blocking!("fs_call_with_permit", move || f(permit)) .await .unwrap_or_else(|e| Err(make_err!(Code::Internal, "background task failed: {e:?}"))) } diff --git a/nativelink-util/src/lib.rs b/nativelink-util/src/lib.rs index 3ac0caa2b..2d88aac06 100644 --- a/nativelink-util/src/lib.rs +++ b/nativelink-util/src/lib.rs @@ -27,5 +27,6 @@ pub mod proto_stream_utils; pub mod resource_info; pub mod retry; pub mod store_trait; +pub mod task; pub mod tls_utils; pub mod write_counter; diff --git a/nativelink-util/src/task.rs b/nativelink-util/src/task.rs new file mode 100644 index 000000000..91deb5f03 --- /dev/null +++ b/nativelink-util/src/task.rs @@ -0,0 +1,119 @@ +// Copyright 2024 The NativeLink Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::pin::Pin; +use std::task::{Context, Poll}; + +use futures::Future; +use tokio::task::{spawn_blocking, JoinError, JoinHandle}; +pub use tracing::error_span as __error_span; +use tracing::{Instrument, Span}; + +#[inline(always)] +pub fn instrument_future(f: F, span: Span) -> impl Future +where + F: Future + Send + 'static, +{ + f.instrument(span) +} + +#[inline(always)] +pub fn __spawn_with_span(f: F, span: Span) -> JoinHandle +where + T: Send + 'static, + F: Future + Send + 'static, +{ + #[allow(clippy::disallowed_methods)] + tokio::spawn(instrument_future(f, span)) +} + +#[inline(always)] +pub fn __spawn_blocking(f: F, span: Span) -> JoinHandle +where + F: FnOnce() -> T + Send + 'static, + T: Send + 'static, +{ + #[allow(clippy::disallowed_methods)] + spawn_blocking(move || span.in_scope(f)) +} + +#[macro_export] +macro_rules! background_spawn { + ($name:expr, $fut:expr) => {{ + $crate::task::__spawn_with_span($fut, $crate::task::__error_span!($name)) + }}; + ($name:expr, $fut:expr, $($fields:tt)*) => {{ + $crate::task::__spawn_with_span($fut, $crate::task::__error_span!($name, $($fields)*)) + }}; + (name: $name:expr, fut: $fut:expr, target: $target:expr, $($fields:tt)*) => {{ + $crate::task::__spawn_with_span($fut, $crate::task::__error_span!(target: $target, $name, $($fields)*)) + }}; +} + +#[macro_export] +macro_rules! spawn { + ($name:expr, $fut:expr) => {{ + $crate::task::JoinHandleDropGuard::new($crate::background_spawn!($name, $fut)) + }}; + ($name:expr, $fut:expr, $($fields:tt)*) => {{ + $crate::task::JoinHandleDropGuard::new($crate::background_spawn!($name, $fut, $($fields)*)) + }}; + (name: $name:expr, fut: $fut:expr, target: $target:expr, $($fields:tt)*) => {{ + $crate::task::JoinHandleDropGuard::new($crate::background_spawn!($name, $fut, target: $target, $($fields)*)) + }}; +} + +#[macro_export] +macro_rules! spawn_blocking { + ($name:expr, $fut:expr) => {{ + $crate::task::JoinHandleDropGuard::new($crate::task::__spawn_blocking($fut, $crate::task::__error_span!($name))) + }}; + ($name:expr, $fut:expr, $($fields:tt)*) => {{ + $crate::task::JoinHandleDropGuard::new($crate::task::__spawn_blocking($fut, $crate::task::__error_span!($name, $($fields)*))) + }}; + ($name:expr, $fut:expr, target: $target:expr) => {{ + $crate::task::JoinHandleDropGuard::new($crate::task::__spawn_blocking($fut, $crate::task::__error_span!(target: $target, $name))) + }}; + ($name:expr, $fut:expr, target: $target:expr, $($fields:tt)*) => {{ + $crate::task::JoinHandleDropGuard::new($crate::task::__spawn_blocking($fut, $crate::task::__error_span!(target: $target, $name, $($fields)*))) + }}; +} + +/// Simple wrapper that will abort a future that is running in another spawn in the +/// event that this handle gets dropped. +#[derive(Debug)] +#[must_use] +pub struct JoinHandleDropGuard { + inner: JoinHandle, +} + +impl JoinHandleDropGuard { + pub fn new(inner: JoinHandle) -> Self { + Self { inner } + } +} + +impl Future for JoinHandleDropGuard { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.inner).poll(cx) + } +} + +impl Drop for JoinHandleDropGuard { + fn drop(&mut self) { + self.inner.abort(); + } +} diff --git a/nativelink-worker/src/local_worker.rs b/nativelink-worker/src/local_worker.rs index bf54bc40a..939222a43 100644 --- a/nativelink-worker/src/local_worker.rs +++ b/nativelink-worker/src/local_worker.rs @@ -37,13 +37,13 @@ use nativelink_util::metrics_utils::{ AsyncCounterWrapper, Collector, CollectorState, CounterWithTime, MetricsComponent, Registry, }; use nativelink_util::store_trait::Store; -use nativelink_util::tls_utils; +use nativelink_util::{spawn, tls_utils}; use tokio::process; use tokio::sync::mpsc; use tokio::time::sleep; use tokio_stream::wrappers::UnboundedReceiverStream; use tonic::Streaming; -use tracing::{error_span, event, instrument, Instrument, Level}; +use tracing::{event, instrument, Level}; use crate::running_actions_manager::{ ExecutionConfiguration, Metrics as RunningActionManagerMetrics, RunningAction, @@ -319,7 +319,7 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a, self.actions_in_transit.fetch_add(1, Ordering::Release); futures.push( - tokio::spawn(start_action_fut.instrument(error_span!("worker_start_action"))).map(move |res| { + spawn!("worker_start_action", start_action_fut).map(move |res| { let res = res.err_tip(|| "Failed to launch spawn")?; if let Err(err) = &res { event!( diff --git a/nativelink-worker/src/running_actions_manager.rs b/nativelink-worker/src/running_actions_manager.rs index 324851333..860ab614a 100644 --- a/nativelink-worker/src/running_actions_manager.rs +++ b/nativelink-worker/src/running_actions_manager.rs @@ -59,12 +59,13 @@ use nativelink_util::action_messages::{ to_execute_response, ActionInfo, ActionResult, DirectoryInfo, ExecutionMetadata, FileInfo, NameOrPath, SymlinkInfo, }; -use nativelink_util::common::{fs, DigestInfo, JoinHandleDropGuard}; +use nativelink_util::common::{fs, DigestInfo}; use nativelink_util::digest_hasher::{DigestHasher, DigestHasherFunc}; use nativelink_util::metrics_utils::{ AsyncCounterWrapper, CollectorState, CounterWithTime, MetricsComponent, }; use nativelink_util::store_trait::{Store, UploadSizeInfo}; +use nativelink_util::{background_spawn, spawn, spawn_blocking}; use parking_lot::Mutex; use prost::Message; use relative_path::RelativePath; @@ -73,10 +74,9 @@ use serde::Deserialize; use tokio::io::{AsyncReadExt, AsyncSeekExt}; use tokio::process; use tokio::sync::{oneshot, watch}; -use tokio::task::spawn_blocking; use tokio_stream::wrappers::ReadDirStream; use tonic::Request; -use tracing::{enabled, error_span, event, Instrument, Level}; +use tracing::{enabled, event, Level}; use uuid::Uuid; pub type ActionId = [u8; 32]; @@ -172,7 +172,7 @@ pub fn download_to_directory<'a>( })?; } if let Some(mtime) = mtime { - spawn_blocking(move || { + spawn_blocking!("download_to_directory_set_mtime", move || { set_file_mtime( &dest, FileTime::from_unix_time(mtime.seconds, mtime.nanos as u32), @@ -870,10 +870,12 @@ impl RunningActionImpl { Level::ERROR, "Child process was not cleaned up before dropping the call to execute(), killing in background spawn." ); - tokio::spawn(async move { child_process.kill().await }); + background_spawn!("running_actions_manager_kill_child_process", async move { + child_process.kill().await + }); }); - let all_stdout_fut = JoinHandleDropGuard::new(tokio::spawn(async move { + let all_stdout_fut = spawn!("stdout_reader", async move { let mut all_stdout = BytesMut::new(); loop { let sz = stdout_reader @@ -885,8 +887,8 @@ impl RunningActionImpl { } } Result::::Ok(all_stdout.freeze()) - })); - let all_stderr_fut = JoinHandleDropGuard::new(tokio::spawn(async move { + }); + let all_stderr_fut = spawn!("stderr_reader", async move { let mut all_stderr = BytesMut::new(); loop { let sz = stderr_reader @@ -898,7 +900,7 @@ impl RunningActionImpl { } } Result::::Ok(all_stderr.freeze()) - })); + }); let mut killed_action = false; let timer = self.metrics().child_process.begin_timer(); @@ -1252,23 +1254,20 @@ impl Drop for RunningActionImpl { let running_actions_manager = self.running_actions_manager.clone(); let action_id = self.action_id; let action_directory = self.action_directory.clone(); - tokio::spawn( - async move { - let Err(err) = - do_cleanup(&running_actions_manager, &action_id, &action_directory).await - else { - return; - }; - event!( - Level::ERROR, - action_id = hex::encode(action_id), - ?action_directory, - ?err, - "Error cleaning up action" - ); - } - .instrument(error_span!("RunningActionImpl::drop")), - ); + background_spawn!("running_action_impl_drop", async move { + let Err(err) = + do_cleanup(&running_actions_manager, &action_id, &action_directory).await + else { + return; + }; + event!( + Level::ERROR, + action_id = hex::encode(action_id), + ?action_directory, + ?err, + "Error cleaning up action" + ); + }); } } diff --git a/nativelink-worker/tests/utils/local_worker_test_utils.rs b/nativelink-worker/tests/utils/local_worker_test_utils.rs index 07d1fda29..1c157bbe2 100644 --- a/nativelink-worker/tests/utils/local_worker_test_utils.rs +++ b/nativelink-worker/tests/utils/local_worker_test_utils.rs @@ -23,7 +23,8 @@ use nativelink_error::Error; use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::{ ExecuteResult, GoingAwayRequest, KeepAliveRequest, SupportedProperties, UpdateForWorker, }; -use nativelink_util::common::JoinHandleDropGuard; +use nativelink_util::spawn; +use nativelink_util::task::JoinHandleDropGuard; use nativelink_worker::local_worker::LocalWorker; use nativelink_worker::worker_api_client_wrapper::WorkerApiClientTrait; use tokio::sync::mpsc; @@ -192,7 +193,7 @@ pub async fn setup_local_worker_with_config(local_worker_config: LocalWorkerConf }), Box::new(move |_| Box::pin(async move { /* No sleep */ })), ); - let drop_guard = JoinHandleDropGuard::new(tokio::spawn(async move { worker.run().await })); + let drop_guard = spawn!("local_worker_spawn", async move { worker.run().await }); let (tx_stream, streaming_response) = setup_grpc_stream(); TestContext { diff --git a/src/bin/nativelink.rs b/src/bin/nativelink.rs index 47f5a8d09..a3f745bcd 100644 --- a/src/bin/nativelink.rs +++ b/src/bin/nativelink.rs @@ -51,6 +51,7 @@ use nativelink_util::metrics_utils::{ use nativelink_util::store_trait::{ set_default_digest_size_health_check, DEFAULT_DIGEST_SIZE_HEALTH_CHECK_CFG, }; +use nativelink_util::{background_spawn, spawn, spawn_blocking}; use nativelink_worker::local_worker::new_local_worker; use parking_lot::Mutex; use rustls_pemfile::{certs as extract_certs, crls as extract_crls}; @@ -58,7 +59,6 @@ use scopeguard::guard; use tokio::net::TcpListener; #[cfg(target_family = "unix")] use tokio::signal::unix::{signal, SignalKind}; -use tokio::task::spawn_blocking; use tokio_rustls::rustls::pki_types::{CertificateDer, CertificateRevocationListDer}; use tokio_rustls::rustls::server::WebPkiClientVerifier; use tokio_rustls::rustls::{RootCertStore, ServerConfig as TlsServerConfig}; @@ -66,7 +66,7 @@ use tokio_rustls::TlsAcceptor; use tonic::codec::CompressionEncoding; use tonic::transport::Server as TonicServer; use tower::util::ServiceExt; -use tracing::{error_span, event, Instrument, Level}; +use tracing::{event, Level}; use tracing_subscriber::filter::{EnvFilter, LevelFilter}; #[global_allocator] @@ -428,7 +428,7 @@ async fn inner_main( // We spawn on a thread that can block to give more freedom to our metrics // collection. This allows it to call functions like `tokio::block_in_place` // if it needs to wait on a future. - spawn_blocking(move || { + spawn_blocking!("prometheus_metrics", move || { let mut buf = String::new(); let root_metrics_registry_guard = futures::executor::block_on(root_metrics_registry.lock()); @@ -705,8 +705,9 @@ async fn inner_main( }, Ok, ); - tokio::spawn( - async move { + background_spawn!( + name: "http_connection", + fut: async move { // Move it into our spawn, so if our spawn dies the cleanup happens. let _guard = scope_guard; if let Err(err) = fut.await { @@ -717,13 +718,10 @@ async fn inner_main( "Failed running service" ); } - } - .instrument(error_span!( - target: "nativelink::services", - "http_connection", - ?remote_addr, - ?socket_addr - )), + }, + target: "nativelink::services", + ?remote_addr, + ?socket_addr, ); } })); @@ -794,7 +792,7 @@ async fn inner_main( let worker_metrics = root_worker_metrics.sub_registry_with_prefix(&name); local_worker.register_metrics(worker_metrics); worker_names.insert(name.clone()); - tokio::spawn(local_worker.run().instrument(error_span!("worker", ?name))) + spawn!("worker", local_worker.run(), ?name) } }; root_futures.push(Box::pin(spawn_fut.map_ok_or_else(|e| Err(e.into()), |v| v))); @@ -900,29 +898,31 @@ fn main() -> Result<(), Box> { .duration_since(UNIX_EPOCH) .unwrap() .as_secs(); - let runtime = tokio::runtime::Builder::new_multi_thread() - .max_blocking_threads(max_blocking_threads) - .enable_all() - .on_thread_start(move || set_metrics_enabled_for_this_thread(metrics_enabled)) - .build()?; - - runtime.spawn(async move { - tokio::signal::ctrl_c() - .await - .expect("Failed to listen to SIGINT"); - eprintln!("User terminated process via SIGINT"); - std::process::exit(130); - }); - - #[cfg(target_family = "unix")] - runtime.spawn(async move { - signal(SignalKind::terminate()) - .expect("Failed to listen to SIGTERM") - .recv() - .await; - eprintln!("Process terminated via SIGTERM"); - std::process::exit(143); - }); - - runtime.block_on(inner_main(cfg, server_start_time)) + #[allow(clippy::disallowed_methods)] + { + let runtime = tokio::runtime::Builder::new_multi_thread() + .max_blocking_threads(max_blocking_threads) + .enable_all() + .on_thread_start(move || set_metrics_enabled_for_this_thread(metrics_enabled)) + .build()?; + runtime.spawn(async move { + tokio::signal::ctrl_c() + .await + .expect("Failed to listen to SIGINT"); + eprintln!("User terminated process via SIGINT"); + std::process::exit(130); + }); + + #[cfg(target_family = "unix")] + runtime.spawn(async move { + signal(SignalKind::terminate()) + .expect("Failed to listen to SIGTERM") + .recv() + .await; + eprintln!("Process terminated via SIGTERM"); + std::process::exit(143); + }); + + runtime.block_on(inner_main(cfg, server_start_time)) + } }