From b783552501857d231552723afd9363a3dcc7f3f1 Mon Sep 17 00:00:00 2001 From: Mindaugas Vinkelis Date: Thu, 8 Aug 2024 13:05:05 +0300 Subject: [PATCH 1/2] selecting caching strategy --- CONTRIBUTING.md | 2 +- diesel/src/connection/mod.rs | 13 + .../mod.rs} | 121 +++--- .../connection/statement_cache/strategy.rs | 392 ++++++++++++++++++ diesel/src/connection/transaction_manager.rs | 4 + diesel/src/mysql/connection/mod.rs | 4 + diesel/src/pg/connection/mod.rs | 180 +------- diesel/src/r2d2.rs | 4 + diesel/src/sqlite/connection/mod.rs | 96 ++--- diesel/src/sqlite/connection/stmt.rs | 4 +- diesel_derives/src/multiconnection.rs | 15 + 11 files changed, 544 insertions(+), 291 deletions(-) rename diesel/src/connection/{statement_cache.rs => statement_cache/mod.rs} (82%) create mode 100644 diesel/src/connection/statement_cache/strategy.rs diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 2b24204d1370..8923128bc707 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -133,7 +133,7 @@ To run rustfmt tests locally: rustup component add clippy ``` -3. Install [typos](https://github.com/crate-ci/typos) via `cargo install typos` +3. Install [typos](https://github.com/crate-ci/typos) via `cargo install typos-cli` 4. Use `cargo xtask tidy` to check if your changes follow the expected code style. This will run `cargo fmt --check`, `typos` and `cargo clippy` internally. See `cargo xtask tidy --help` diff --git a/diesel/src/connection/mod.rs b/diesel/src/connection/mod.rs index 278b8b3ec0f8..5740ca0f4c73 100644 --- a/diesel/src/connection/mod.rs +++ b/diesel/src/connection/mod.rs @@ -45,6 +45,16 @@ pub use self::instrumentation::{DynInstrumentation, StrQueryHelper}; ))] pub(crate) use self::private::MultiConnectionHelper; +/// Set cache size for a connection +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[non_exhaustive] +pub enum CacheSize { + /// Caches all queries if possible + Unbounded, + /// Disable statement cache + Disabled, +} + /// Perform simple operations on a backend. /// /// You should likely use [`Connection`] instead. @@ -401,6 +411,9 @@ where /// Set a specific [`Instrumentation`] implementation for this connection fn set_instrumentation(&mut self, instrumentation: impl Instrumentation); + + /// Set the prepared statement cache size to [`CacheSize`] for this connection + fn set_prepared_statement_cache_size(&mut self, size: CacheSize); } /// The specific part of a [`Connection`] which actually loads data from the database diff --git a/diesel/src/connection/statement_cache.rs b/diesel/src/connection/statement_cache/mod.rs similarity index 82% rename from diesel/src/connection/statement_cache.rs rename to diesel/src/connection/statement_cache/mod.rs index eb44461a4d15..9dd42e0d74d4 100644 --- a/diesel/src/connection/statement_cache.rs +++ b/diesel/src/connection/statement_cache/mod.rs @@ -10,8 +10,9 @@ //! statements is [`SimpleConnection::batch_execute`](super::SimpleConnection::batch_execute). //! //! In order to avoid the cost of re-parsing and planning subsequent queries, -//! Diesel caches the prepared statement whenever possible. Queries will fall -//! into one of three buckets: +//! by default Diesel caches the prepared statement whenever possible, but +//! this an be customized by calling [`Connection::set_cache_size`](super::Connection::set_cache_size). +//! Queries will fall into one of three buckets: //! //! - Unsafe to cache //! - Cached by SQL @@ -94,16 +95,21 @@ use std::any::TypeId; use std::borrow::Cow; -use std::collections::HashMap; use std::hash::Hash; use std::ops::{Deref, DerefMut}; +use strategy::{StatementCacheStrategy, WithCacheStrategy, WithoutCacheStrategy}; + use crate::backend::Backend; use crate::connection::InstrumentationEvent; use crate::query_builder::*; use crate::result::QueryResult; -use super::Instrumentation; +use super::{CacheSize, Instrumentation}; + +/// Various interfaces and implementations to control connection statement caching. +#[allow(unreachable_pub)] +pub mod strategy; /// A prepared statement cache #[allow(missing_debug_implementations, unreachable_pub)] @@ -112,7 +118,10 @@ use super::Instrumentation; doc(cfg(feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes")) )] pub struct StatementCache { - pub(crate) cache: HashMap, Statement>, + cache: Box>, + // increment every time a query is cached + // some backends might use it to create unique prepared statement names + cache_counter: u64, } /// A helper type that indicates if a certain query @@ -128,45 +137,51 @@ pub struct StatementCache { )] #[allow(unreachable_pub)] pub enum PrepareForCache { - /// The statement will be cached - Yes, + /// The statement will be cached + Yes { + /// Counter might be used as unique identifier for prepared statement. + #[allow(dead_code)] + counter: u64, + }, /// The statement won't be cached No, } -#[allow( - clippy::len_without_is_empty, - clippy::new_without_default, - unreachable_pub -)] +#[allow(clippy::new_without_default, unreachable_pub)] impl StatementCache where - DB: Backend, + DB: Backend + 'static, + Statement: 'static, DB::TypeMetadata: Clone, DB::QueryBuilder: Default, StatementCacheKey: Hash + Eq, { - /// Create a new prepared statement cache + /// Create a new prepared statement cache using [`CacheSize::Unbounded`] as caching strategy. #[allow(unreachable_pub)] pub fn new() -> Self { StatementCache { - cache: HashMap::new(), + cache: Box::new(WithCacheStrategy::default()), + cache_counter: 0, } } - /// Get the current length of the statement cache - #[allow(unreachable_pub)] - #[cfg(any( - feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes", - feature = "postgres", - all(feature = "sqlite", test) - ))] - #[cfg_attr( - docsrs, - doc(cfg(feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes")) - )] - pub fn len(&self) -> usize { - self.cache.len() + /// Set caching strategy from predefined implementations + pub fn set_cache_size(&mut self, size: CacheSize) { + if self.cache.cache_size() != size { + self.cache = match size { + CacheSize::Unbounded => Box::new(WithCacheStrategy::default()), + CacheSize::Disabled => Box::new(WithoutCacheStrategy::default()), + } + } + } + + /// Setting custom caching strategy. It is used in tests, to verify caching logic + #[allow(dead_code)] + pub(crate) fn set_strategy(&mut self, s: Strategy) + where + Strategy: StatementCacheStrategy + 'static, + { + self.cache = Box::new(s); } /// Prepare a query as prepared statement @@ -193,50 +208,44 @@ where T: QueryFragment + QueryId, F: FnMut(&str, PrepareForCache) -> QueryResult, { - self.cached_statement_non_generic( + Self::cached_statement_non_generic( + self.cache.as_mut(), T::query_id(), source, backend, bind_types, - &mut prepare_fn, - instrumentation, + &mut |sql, is_cached| { + if is_cached { + instrumentation.on_connection_event(InstrumentationEvent::CacheQuery { sql }); + self.cache_counter += 1; + prepare_fn( + sql, + PrepareForCache::Yes { + counter: self.cache_counter, + }, + ) + } else { + prepare_fn(sql, PrepareForCache::No) + } + }, ) } /// Reduce the amount of monomorphized code by factoring this via dynamic dispatch - fn cached_statement_non_generic( - &mut self, + fn cached_statement_non_generic<'a>( + cache: &'a mut dyn StatementCacheStrategy, maybe_type_id: Option, source: &dyn QueryFragmentForCachedStatement, backend: &DB, bind_types: &[DB::TypeMetadata], - prepare_fn: &mut dyn FnMut(&str, PrepareForCache) -> QueryResult, - instrumentation: &mut dyn Instrumentation, - ) -> QueryResult> { - use std::collections::hash_map::Entry::{Occupied, Vacant}; - + prepare_fn: &mut dyn FnMut(&str, bool) -> QueryResult, + ) -> QueryResult> { let cache_key = StatementCacheKey::for_source(maybe_type_id, source, bind_types, backend)?; - if !source.is_safe_to_cache_prepared(backend)? { let sql = cache_key.sql(source, backend)?; - return prepare_fn(&sql, PrepareForCache::No).map(MaybeCached::CannotCache); + return prepare_fn(&sql, false).map(MaybeCached::CannotCache); } - - let cached_result = match self.cache.entry(cache_key) { - Occupied(entry) => entry.into_mut(), - Vacant(entry) => { - let statement = { - let sql = entry.key().sql(source, backend)?; - instrumentation - .on_connection_event(InstrumentationEvent::CacheQuery { sql: &sql }); - prepare_fn(&sql, PrepareForCache::Yes) - }; - - entry.insert(statement?) - } - }; - - Ok(MaybeCached::Cached(cached_result)) + cache.get(cache_key, backend, source, prepare_fn) } } diff --git a/diesel/src/connection/statement_cache/strategy.rs b/diesel/src/connection/statement_cache/strategy.rs new file mode 100644 index 000000000000..7e3740526368 --- /dev/null +++ b/diesel/src/connection/statement_cache/strategy.rs @@ -0,0 +1,392 @@ +use std::collections::hash_map::Entry; +use std::collections::HashMap; +use std::hash::Hash; + +use crate::{backend::Backend, result::Error}; + +use super::{CacheSize, MaybeCached, QueryFragmentForCachedStatement, StatementCacheKey}; + +/// Implement this trait, in order to control statement caching. +#[allow(unreachable_pub)] +pub trait StatementCacheStrategy +where + DB: Backend, + StatementCacheKey: Hash + Eq, +{ + /// Returns which prepared statement cache size is implemented by this trait + fn cache_size(&self) -> CacheSize; + + /// Every query (which is safe to cache) will go through this function + /// The implementation will decide whether to cache statement or not + /// * `prepare_fn` - will be invoked if prepared statement wasn't cached already + /// * first argument is sql query string + /// * second argument specify whether statement will be cached (true) or not (false). + fn get( + &mut self, + key: StatementCacheKey, + backend: &DB, + source: &dyn QueryFragmentForCachedStatement, + prepare_fn: &mut dyn FnMut(&str, bool) -> Result, + ) -> Result, Error>; +} + +/// Cache all (safe) statements for as long as connection is alive. +#[allow(missing_debug_implementations, unreachable_pub)] +pub struct WithCacheStrategy +where + DB: Backend, +{ + cache: HashMap, Statement>, +} + +impl Default for WithCacheStrategy +where + DB: Backend, +{ + fn default() -> Self { + Self { + cache: Default::default(), + } + } +} + +impl StatementCacheStrategy for WithCacheStrategy +where + DB: Backend, + StatementCacheKey: Hash + Eq, + DB::TypeMetadata: Clone, + DB::QueryBuilder: Default, +{ + fn get( + &mut self, + key: StatementCacheKey, + backend: &DB, + source: &dyn QueryFragmentForCachedStatement, + prepare_fn: &mut dyn FnMut(&str, bool) -> Result, + ) -> Result, Error> { + let entry = self.cache.entry(key); + match entry { + Entry::Occupied(e) => Ok(MaybeCached::Cached(e.into_mut())), + Entry::Vacant(e) => { + let sql = e.key().sql(source, backend)?; + let st = prepare_fn(&sql, true)?; + Ok(MaybeCached::Cached(e.insert(st))) + } + } + } + + fn cache_size(&self) -> CacheSize { + CacheSize::Unbounded + } +} + +/// No statements will be cached, +#[allow(missing_debug_implementations, unreachable_pub)] +#[derive(Clone, Copy, Default)] +pub struct WithoutCacheStrategy {} + +impl StatementCacheStrategy for WithoutCacheStrategy +where + DB: Backend, + StatementCacheKey: Hash + Eq, + DB::TypeMetadata: Clone, + DB::QueryBuilder: Default, +{ + fn get( + &mut self, + key: StatementCacheKey, + backend: &DB, + source: &dyn QueryFragmentForCachedStatement, + prepare_fn: &mut dyn FnMut(&str, bool) -> Result, + ) -> Result, Error> { + let sql = key.sql(source, backend)?; + Ok(MaybeCached::CannotCache(prepare_fn(&sql, false)?)) + } + + fn cache_size(&self) -> CacheSize { + CacheSize::Disabled + } +} + +#[allow(dead_code)] +#[cfg(test)] +mod testing_utils { + + use crate::{ + connection::{Instrumentation, InstrumentationEvent}, + Connection, + }; + + #[derive(Default)] + pub struct RecordCacheEvents { + pub list: Vec, + } + + impl Instrumentation for RecordCacheEvents { + fn on_connection_event(&mut self, event: InstrumentationEvent<'_>) { + if let InstrumentationEvent::CacheQuery { sql } = event { + self.list.push(sql.to_owned()); + } + } + } + + pub fn count_cache_calls(conn: &mut impl Connection) -> usize { + conn.instrumentation() + .as_any() + .downcast_ref::() + .unwrap() + .list + .len() + } +} + +#[cfg(test)] +#[cfg(feature = "postgres")] +mod tests_pg { + use crate::{ + dsl::sql, + expression::functions::define_sql_function, + insertable::Insertable, + macros::table, + pg::Pg, + sql_types::{Integer, VarChar}, + test_helpers::pg_database_url, + Connection, ExpressionMethods, IntoSql, PgConnection, QueryDsl, RunQueryDsl, + }; + + use super::testing_utils::{count_cache_calls, RecordCacheEvents}; + + table! { + users { + id -> Integer, + name -> Text, + } + } + + pub fn connection() -> PgConnection { + let mut conn = PgConnection::establish(&pg_database_url()).unwrap(); + conn.set_instrumentation(RecordCacheEvents::default()); + conn + } + + #[test] + fn prepared_statements_are_cached() { + let connection = &mut connection(); + + let query = crate::select(1.into_sql::()); + + assert_eq!(Ok(1), query.get_result(connection)); + assert_eq!(1, count_cache_calls(connection)); + assert_eq!(Ok(1), query.get_result(connection)); + assert_eq!(1, count_cache_calls(connection)); + } + + #[test] + fn queries_with_identical_sql_but_different_types_are_cached_separately() { + let connection = &mut connection(); + + let query = crate::select(1.into_sql::()); + let query2 = crate::select("hi".into_sql::()); + + assert_eq!(Ok(1), query.get_result(connection)); + assert_eq!(1, count_cache_calls(connection)); + assert_eq!(Ok("hi".to_string()), query2.get_result(connection)); + assert_eq!(2, count_cache_calls(connection)); + } + + #[test] + fn queries_with_identical_types_and_sql_but_different_bind_types_are_cached_separately() { + let connection = &mut connection(); + + let query = crate::select(1.into_sql::()).into_boxed::(); + let query2 = crate::select("hi".into_sql::()).into_boxed::(); + + assert_eq!(Ok(1), query.get_result(connection)); + assert_eq!(1, count_cache_calls(connection)); + assert_eq!(Ok("hi".to_string()), query2.get_result(connection)); + assert_eq!(2, count_cache_calls(connection)); + } + + define_sql_function!(fn lower(x: VarChar) -> VarChar); + + #[test] + fn queries_with_identical_types_and_binds_but_different_sql_are_cached_separately() { + let connection = &mut connection(); + + let hi = "HI".into_sql::(); + let query = crate::select(hi).into_boxed::(); + let query2 = crate::select(lower(hi)).into_boxed::(); + + assert_eq!(Ok("HI".to_string()), query.get_result(connection)); + assert_eq!(1, count_cache_calls(connection)); + assert_eq!(Ok("hi".to_string()), query2.get_result(connection)); + assert_eq!(2, count_cache_calls(connection)); + } + + #[test] + fn queries_with_sql_literal_nodes_are_not_cached() { + let connection = &mut connection(); + let query = crate::select(sql::("1")); + + assert_eq!(Ok(1), query.get_result(connection)); + assert_eq!(0, count_cache_calls(connection)); + } + + #[test] + fn inserts_from_select_are_cached() { + let connection = &mut connection(); + connection.begin_test_transaction().unwrap(); + + crate::sql_query( + "CREATE TABLE IF NOT EXISTS users(id INTEGER PRIMARY KEY, name TEXT NOT NULL);", + ) + .execute(connection) + .unwrap(); + + let query = users::table.filter(users::id.eq(42)); + let insert = query + .insert_into(users::table) + .into_columns((users::id, users::name)); + assert!(insert.execute(connection).is_ok()); + assert_eq!(1, count_cache_calls(connection)); + + let query = users::table.filter(users::id.eq(42)).into_boxed(); + let insert = query + .insert_into(users::table) + .into_columns((users::id, users::name)); + assert!(insert.execute(connection).is_ok()); + assert_eq!(2, count_cache_calls(connection)); + } + + #[test] + fn single_inserts_are_cached() { + let connection = &mut connection(); + connection.begin_test_transaction().unwrap(); + + crate::sql_query( + "CREATE TABLE IF NOT EXISTS users(id INTEGER PRIMARY KEY, name TEXT NOT NULL);", + ) + .execute(connection) + .unwrap(); + + let insert = + crate::insert_into(users::table).values((users::id.eq(42), users::name.eq("Foo"))); + + assert!(insert.execute(connection).is_ok()); + assert_eq!(1, count_cache_calls(connection)); + } + + #[test] + fn dynamic_batch_inserts_are_not_cached() { + let connection = &mut connection(); + connection.begin_test_transaction().unwrap(); + + crate::sql_query( + "CREATE TABLE IF NOT EXISTS users(id INTEGER PRIMARY KEY, name TEXT NOT NULL);", + ) + .execute(connection) + .unwrap(); + + let insert = crate::insert_into(users::table) + .values(vec![(users::id.eq(42), users::name.eq("Foo"))]); + + assert!(insert.execute(connection).is_ok()); + assert_eq!(0, count_cache_calls(connection)); + } + + #[test] + fn static_batch_inserts_are_cached() { + let connection = &mut connection(); + connection.begin_test_transaction().unwrap(); + + crate::sql_query( + "CREATE TABLE IF NOT EXISTS users(id INTEGER PRIMARY KEY, name TEXT NOT NULL);", + ) + .execute(connection) + .unwrap(); + + let insert = + crate::insert_into(users::table).values([(users::id.eq(42), users::name.eq("Foo"))]); + + assert!(insert.execute(connection).is_ok()); + assert_eq!(1, count_cache_calls(connection)); + } + + #[test] + fn queries_containing_in_with_vec_are_cached() { + let connection = &mut connection(); + let one_as_expr = 1.into_sql::(); + let query = crate::select(one_as_expr.eq_any(vec![1, 2, 3])); + + assert_eq!(Ok(true), query.get_result(connection)); + assert_eq!(1, count_cache_calls(connection)); + } +} + +#[cfg(test)] +#[cfg(feature = "sqlite")] +mod tests_sqlite { + + use crate::{ + dsl::sql, query_dsl::RunQueryDsl, sql_types::Integer, Connection, ExpressionMethods, + IntoSql, SqliteConnection, + }; + + use super::testing_utils::{count_cache_calls, RecordCacheEvents}; + + pub fn connection() -> SqliteConnection { + let mut conn = SqliteConnection::establish(":memory:").unwrap(); + conn.set_instrumentation(RecordCacheEvents::default()); + conn + } + + #[test] + fn prepared_statements_are_cached_when_run() { + let connection = &mut connection(); + let query = crate::select(1.into_sql::()); + + assert_eq!(Ok(1), query.get_result(connection)); + assert_eq!(1, count_cache_calls(connection)); + assert_eq!(Ok(1), query.get_result(connection)); + assert_eq!(1, count_cache_calls(connection)); + } + + #[test] + fn sql_literal_nodes_are_not_cached() { + let connection = &mut connection(); + let query = crate::select(sql::("1")); + + assert_eq!(Ok(1), query.get_result(connection)); + assert_eq!(0, count_cache_calls(connection)); + } + + #[test] + fn queries_containing_sql_literal_nodes_are_not_cached() { + let connection = &mut connection(); + let one_as_expr = 1.into_sql::(); + let query = crate::select(one_as_expr.eq(sql::("1"))); + + assert_eq!(Ok(true), query.get_result(connection)); + assert_eq!(0, count_cache_calls(connection)); + } + + #[test] + fn queries_containing_in_with_vec_are_not_cached() { + let connection = &mut connection(); + let one_as_expr = 1.into_sql::(); + let query = crate::select(one_as_expr.eq_any(vec![1, 2, 3])); + + assert_eq!(Ok(true), query.get_result(connection)); + assert_eq!(0, count_cache_calls(connection)); + } + + #[test] + fn queries_containing_in_with_subselect_are_cached() { + let connection = &mut connection(); + let one_as_expr = 1.into_sql::(); + let query = crate::select(one_as_expr.eq_any(crate::select(one_as_expr))); + + assert_eq!(Ok(true), query.get_result(connection)); + assert_eq!(1, count_cache_calls(connection)); + } +} diff --git a/diesel/src/connection/transaction_manager.rs b/diesel/src/connection/transaction_manager.rs index bad7846e83a8..e361346d8612 100644 --- a/diesel/src/connection/transaction_manager.rs +++ b/diesel/src/connection/transaction_manager.rs @@ -594,6 +594,10 @@ mod test { ) { self.instrumentation = Some(Box::new(instrumentation)); } + + fn set_prepared_statement_cache_size(&mut self, _size: crate::connection::CacheSize) { + panic!("implement, if you want to use it") + } } } diff --git a/diesel/src/mysql/connection/mod.rs b/diesel/src/mysql/connection/mod.rs index fc355ddbadb2..d43c21ad2642 100644 --- a/diesel/src/mysql/connection/mod.rs +++ b/diesel/src/mysql/connection/mod.rs @@ -205,6 +205,10 @@ impl Connection for MysqlConnection { fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) { self.instrumentation = instrumentation.into(); } + + fn set_prepared_statement_cache_size(&mut self, size: CacheSize) { + self.statement_cache.set_cache_size(size); + } } #[inline(always)] diff --git a/diesel/src/pg/connection/mod.rs b/diesel/src/pg/connection/mod.rs index 8b0e1925c071..6b8d32b41be4 100644 --- a/diesel/src/pg/connection/mod.rs +++ b/diesel/src/pg/connection/mod.rs @@ -5,6 +5,8 @@ mod result; mod row; mod stmt; +use statement_cache::PrepareForCache; + use self::copy::{CopyFromSink, CopyToBuffer}; use self::cursor::*; use self::private::ConnectionAndTransactionManager; @@ -124,7 +126,8 @@ pub(super) use self::result::PgResult; #[allow(missing_debug_implementations)] #[cfg(feature = "postgres")] pub struct PgConnection { - statement_cache: StatementCache, + /// pub(crate) for tests + pub(crate) statement_cache: StatementCache, metadata_cache: PgMetadataCache, connection_and_transaction_manager: ConnectionAndTransactionManager, } @@ -236,6 +239,10 @@ impl Connection for PgConnection { fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) { self.connection_and_transaction_manager.instrumentation = instrumentation.into(); } + + fn set_prepared_statement_cache_size(&mut self, size: CacheSize) { + self.statement_cache.set_cache_size(size); + } } impl LoadConnection for PgConnection @@ -499,18 +506,16 @@ impl PgConnection { let binds = bind_collector.binds; let metadata = bind_collector.metadata; - let cache_len = self.statement_cache.len(); let cache = &mut self.statement_cache; let conn = &mut self.connection_and_transaction_manager.raw_connection; let query = cache.cached_statement( &source, &Pg, &metadata, - |sql, _| { - let query_name = if source.is_safe_to_cache_prepared(&Pg)? { - Some(format!("__diesel_stmt_{cache_len}")) - } else { - None + |sql, is_cached| { + let query_name = match is_cached { + PrepareForCache::Yes { counter } => Some(format!("__diesel_stmt_{counter}")), + PrepareForCache::No => None, }; Statement::prepare(conn, sql, query_name.as_deref(), &metadata) }, @@ -614,12 +619,14 @@ mod tests { extern crate dotenvy; use super::*; - use crate::dsl::sql; use crate::prelude::*; use crate::result::Error::DatabaseError; - use crate::sql_types::{Integer, VarChar}; use std::num::NonZeroU32; + fn connection() -> PgConnection { + crate::test_helpers::pg_connection_no_transaction() + } + #[test] fn malformed_sql_query() { let connection = &mut connection(); @@ -633,67 +640,6 @@ mod tests { } } - #[test] - fn prepared_statements_are_cached() { - let connection = &mut connection(); - - let query = crate::select(1.into_sql::()); - - assert_eq!(Ok(1), query.get_result(connection)); - assert_eq!(Ok(1), query.get_result(connection)); - assert_eq!(1, connection.statement_cache.len()); - } - - #[test] - fn queries_with_identical_sql_but_different_types_are_cached_separately() { - let connection = &mut connection(); - - let query = crate::select(1.into_sql::()); - let query2 = crate::select("hi".into_sql::()); - - assert_eq!(Ok(1), query.get_result(connection)); - assert_eq!(Ok("hi".to_string()), query2.get_result(connection)); - assert_eq!(2, connection.statement_cache.len()); - } - - #[test] - fn queries_with_identical_types_and_sql_but_different_bind_types_are_cached_separately() { - let connection = &mut connection(); - - let query = crate::select(1.into_sql::()).into_boxed::(); - let query2 = crate::select("hi".into_sql::()).into_boxed::(); - - assert_eq!(0, connection.statement_cache.len()); - assert_eq!(Ok(1), query.get_result(connection)); - assert_eq!(Ok("hi".to_string()), query2.get_result(connection)); - assert_eq!(2, connection.statement_cache.len()); - } - - define_sql_function!(fn lower(x: VarChar) -> VarChar); - - #[test] - fn queries_with_identical_types_and_binds_but_different_sql_are_cached_separately() { - let connection = &mut connection(); - - let hi = "HI".into_sql::(); - let query = crate::select(hi).into_boxed::(); - let query2 = crate::select(lower(hi)).into_boxed::(); - - assert_eq!(0, connection.statement_cache.len()); - assert_eq!(Ok("HI".to_string()), query.get_result(connection)); - assert_eq!(Ok("hi".to_string()), query2.get_result(connection)); - assert_eq!(2, connection.statement_cache.len()); - } - - #[test] - fn queries_with_sql_literal_nodes_are_not_cached() { - let connection = &mut connection(); - let query = crate::select(sql::("1")); - - assert_eq!(Ok(1), query.get_result(connection)); - assert_eq!(0, connection.statement_cache.len()); - } - table! { users { id -> Integer, @@ -701,100 +647,6 @@ mod tests { } } - #[test] - fn inserts_from_select_are_cached() { - let connection = &mut connection(); - connection.begin_test_transaction().unwrap(); - - crate::sql_query( - "CREATE TABLE IF NOT EXISTS users(id INTEGER PRIMARY KEY, name TEXT NOT NULL);", - ) - .execute(connection) - .unwrap(); - - let query = users::table.filter(users::id.eq(42)); - let insert = query - .insert_into(users::table) - .into_columns((users::id, users::name)); - assert!(insert.execute(connection).is_ok()); - assert_eq!(1, connection.statement_cache.len()); - - let query = users::table.filter(users::id.eq(42)).into_boxed(); - let insert = query - .insert_into(users::table) - .into_columns((users::id, users::name)); - assert!(insert.execute(connection).is_ok()); - assert_eq!(2, connection.statement_cache.len()); - } - - #[test] - fn single_inserts_are_cached() { - let connection = &mut connection(); - connection.begin_test_transaction().unwrap(); - - crate::sql_query( - "CREATE TABLE IF NOT EXISTS users(id INTEGER PRIMARY KEY, name TEXT NOT NULL);", - ) - .execute(connection) - .unwrap(); - - let insert = - crate::insert_into(users::table).values((users::id.eq(42), users::name.eq("Foo"))); - - assert!(insert.execute(connection).is_ok()); - assert_eq!(1, connection.statement_cache.len()); - } - - #[test] - fn dynamic_batch_inserts_are_not_cached() { - let connection = &mut connection(); - connection.begin_test_transaction().unwrap(); - - crate::sql_query( - "CREATE TABLE IF NOT EXISTS users(id INTEGER PRIMARY KEY, name TEXT NOT NULL);", - ) - .execute(connection) - .unwrap(); - - let insert = crate::insert_into(users::table) - .values(vec![(users::id.eq(42), users::name.eq("Foo"))]); - - assert!(insert.execute(connection).is_ok()); - assert_eq!(0, connection.statement_cache.len()); - } - - #[test] - fn static_batch_inserts_are_cached() { - let connection = &mut connection(); - connection.begin_test_transaction().unwrap(); - - crate::sql_query( - "CREATE TABLE IF NOT EXISTS users(id INTEGER PRIMARY KEY, name TEXT NOT NULL);", - ) - .execute(connection) - .unwrap(); - - let insert = - crate::insert_into(users::table).values([(users::id.eq(42), users::name.eq("Foo"))]); - - assert!(insert.execute(connection).is_ok()); - assert_eq!(1, connection.statement_cache.len()); - } - - #[test] - fn queries_containing_in_with_vec_are_cached() { - let connection = &mut connection(); - let one_as_expr = 1.into_sql::(); - let query = crate::select(one_as_expr.eq_any(vec![1, 2, 3])); - - assert_eq!(Ok(true), query.get_result(connection)); - assert_eq!(1, connection.statement_cache.len()); - } - - fn connection() -> PgConnection { - crate::test_helpers::pg_connection_no_transaction() - } - #[test] fn transaction_manager_returns_an_error_when_attempting_to_commit_outside_of_a_transaction() { use crate::connection::{AnsiTransactionManager, TransactionManager}; diff --git a/diesel/src/r2d2.rs b/diesel/src/r2d2.rs index 62df9ad7cbca..6d361b405f77 100644 --- a/diesel/src/r2d2.rs +++ b/diesel/src/r2d2.rs @@ -341,6 +341,10 @@ where fn set_instrumentation(&mut self, instrumentation: impl crate::connection::Instrumentation) { (**self).set_instrumentation(instrumentation) } + + fn set_prepared_statement_cache_size(&mut self, size: crate::connection::CacheSize) { + (**self).set_prepared_statement_cache_size(size) + } } impl LoadConnection for PooledConnection diff --git a/diesel/src/sqlite/connection/mod.rs b/diesel/src/sqlite/connection/mod.rs index 86dcd00be063..40021f0079e6 100644 --- a/diesel/src/sqlite/connection/mod.rs +++ b/diesel/src/sqlite/connection/mod.rs @@ -121,7 +121,8 @@ pub struct SqliteConnection { // statement_cache needs to be before raw_connection // otherwise we will get errors about open statements before closing the // connection itself - statement_cache: StatementCache, + // pub(crate) for tests + pub(crate) statement_cache: StatementCache, raw_connection: RawConnection, transaction_state: AnsiTransactionManager, // this exists for the sole purpose of implementing `WithMetadataLookup` trait @@ -206,6 +207,10 @@ impl Connection for SqliteConnection { fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) { self.instrumentation = instrumentation.into(); } + + fn set_prepared_statement_cache_size(&mut self, size: CacheSize) { + self.statement_cache.set_cache_size(size); + } } impl LoadConnection for SqliteConnection { @@ -561,6 +566,10 @@ mod tests { use crate::prelude::*; use crate::sql_types::Integer; + fn connection() -> SqliteConnection { + SqliteConnection::establish(":memory:").unwrap() + } + #[test] fn database_serializes_and_deserializes_successfully() { let expected_users = vec![ @@ -576,81 +585,32 @@ mod tests { ), ]; - let connection = &mut SqliteConnection::establish(":memory:").unwrap(); + let conn1 = &mut connection(); let _ = crate::sql_query("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, email TEXT)") - .execute(connection); + .execute(conn1); let _ = crate::sql_query("INSERT INTO users (name, email) VALUES ('John Doe', 'john.doe@example.com'), ('Jane Doe', 'jane.doe@example.com')") - .execute(connection); + .execute(conn1); - let serialized_database = connection.serialize_database_to_buffer(); + let serialized_database = conn1.serialize_database_to_buffer(); - let connection = &mut SqliteConnection::establish(":memory:").unwrap(); - connection + let conn2 = &mut connection(); + conn2 .deserialize_readonly_database_from_buffer(serialized_database.as_slice()) .unwrap(); let query = sql::<(Integer, Text, Text)>("SELECT id, name, email FROM users ORDER BY id"); - let actual_users = query.load::<(i32, String, String)>(connection).unwrap(); + let actual_users = query.load::<(i32, String, String)>(conn2).unwrap(); assert_eq!(expected_users, actual_users); } - #[test] - fn prepared_statements_are_cached_when_run() { - let connection = &mut SqliteConnection::establish(":memory:").unwrap(); - let query = crate::select(1.into_sql::()); - - assert_eq!(Ok(1), query.get_result(connection)); - assert_eq!(Ok(1), query.get_result(connection)); - assert_eq!(1, connection.statement_cache.len()); - } - - #[test] - fn sql_literal_nodes_are_not_cached() { - let connection = &mut SqliteConnection::establish(":memory:").unwrap(); - let query = crate::select(sql::("1")); - - assert_eq!(Ok(1), query.get_result(connection)); - assert_eq!(0, connection.statement_cache.len()); - } - - #[test] - fn queries_containing_sql_literal_nodes_are_not_cached() { - let connection = &mut SqliteConnection::establish(":memory:").unwrap(); - let one_as_expr = 1.into_sql::(); - let query = crate::select(one_as_expr.eq(sql::("1"))); - - assert_eq!(Ok(true), query.get_result(connection)); - assert_eq!(0, connection.statement_cache.len()); - } - - #[test] - fn queries_containing_in_with_vec_are_not_cached() { - let connection = &mut SqliteConnection::establish(":memory:").unwrap(); - let one_as_expr = 1.into_sql::(); - let query = crate::select(one_as_expr.eq_any(vec![1, 2, 3])); - - assert_eq!(Ok(true), query.get_result(connection)); - assert_eq!(0, connection.statement_cache.len()); - } - - #[test] - fn queries_containing_in_with_subselect_are_cached() { - let connection = &mut SqliteConnection::establish(":memory:").unwrap(); - let one_as_expr = 1.into_sql::(); - let query = crate::select(one_as_expr.eq_any(crate::select(one_as_expr))); - - assert_eq!(Ok(true), query.get_result(connection)); - assert_eq!(1, connection.statement_cache.len()); - } - use crate::sql_types::Text; define_sql_function!(fn fun_case(x: Text) -> Text); #[test] fn register_custom_function() { - let connection = &mut SqliteConnection::establish(":memory:").unwrap(); + let connection = &mut connection(); fun_case_utils::register_impl(connection, |x: String| { x.chars() .enumerate() @@ -675,7 +635,7 @@ mod tests { #[test] fn register_multiarg_function() { - let connection = &mut SqliteConnection::establish(":memory:").unwrap(); + let connection = &mut connection(); my_add_utils::register_impl(connection, |x: i32, y: i32| x + y).unwrap(); let added = crate::select(my_add(1, 2)).get_result::(connection); @@ -686,7 +646,7 @@ mod tests { #[test] fn register_noarg_function() { - let connection = &mut SqliteConnection::establish(":memory:").unwrap(); + let connection = &mut connection(); answer_utils::register_impl(connection, || 42).unwrap(); let answer = crate::select(answer()).get_result::(connection); @@ -695,7 +655,7 @@ mod tests { #[test] fn register_nondeterministic_noarg_function() { - let connection = &mut SqliteConnection::establish(":memory:").unwrap(); + let connection = &mut connection(); answer_utils::register_nondeterministic_impl(connection, || 42).unwrap(); let answer = crate::select(answer()).get_result::(connection); @@ -706,7 +666,7 @@ mod tests { #[test] fn register_nondeterministic_function() { - let connection = &mut SqliteConnection::establish(":memory:").unwrap(); + let connection = &mut connection(); let mut y = 0; add_counter_utils::register_nondeterministic_impl(connection, move |x: i32| { y += 1; @@ -752,7 +712,7 @@ mod tests { fn register_aggregate_function() { use self::my_sum_example::dsl::*; - let connection = &mut SqliteConnection::establish(":memory:").unwrap(); + let connection = &mut connection(); crate::sql_query( "CREATE TABLE my_sum_example (id integer primary key autoincrement, value integer)", ) @@ -774,7 +734,7 @@ mod tests { fn register_aggregate_function_returns_finalize_default_on_empty_set() { use self::my_sum_example::dsl::*; - let connection = &mut SqliteConnection::establish(":memory:").unwrap(); + let connection = &mut connection(); crate::sql_query( "CREATE TABLE my_sum_example (id integer primary key autoincrement, value integer)", ) @@ -836,7 +796,7 @@ mod tests { fn register_aggregate_multiarg_function() { use self::range_max_example::dsl::*; - let connection = &mut SqliteConnection::establish(":memory:").unwrap(); + let connection = &mut connection(); crate::sql_query( r#"CREATE TABLE range_max_example ( id integer primary key autoincrement, @@ -872,7 +832,7 @@ mod tests { fn register_collation_function() { use self::my_collation_example::dsl::*; - let connection = &mut SqliteConnection::establish(":memory:").unwrap(); + let connection = &mut connection(); connection .register_collation("RUSTNOCASE", |rhs, lhs| { @@ -952,7 +912,7 @@ mod tests { } } - let connection = &mut SqliteConnection::establish(":memory:").unwrap(); + let connection = &mut connection(); let res = crate::select( CustomWrapper("".into()) @@ -982,7 +942,7 @@ mod tests { } } - let connection = &mut SqliteConnection::establish(":memory:").unwrap(); + let connection = &mut connection(); let res = crate::select( CustomWrapper(Vec::new()) diff --git a/diesel/src/sqlite/connection/stmt.rs b/diesel/src/sqlite/connection/stmt.rs index 99bb271918a3..0a3c253c091b 100644 --- a/diesel/src/sqlite/connection/stmt.rs +++ b/diesel/src/sqlite/connection/stmt.rs @@ -15,7 +15,7 @@ use std::io::{stderr, Write}; use std::os::raw as libc; use std::ptr::{self, NonNull}; -pub(super) struct Statement { +pub(crate) struct Statement { inner_statement: NonNull, } @@ -38,7 +38,7 @@ impl Statement { raw_connection.internal_connection.as_ptr(), CString::new(sql)?.as_ptr(), n_byte, - if matches!(is_cached, PrepareForCache::Yes) { + if matches!(is_cached, PrepareForCache::Yes { counter: _ }) { ffi::SQLITE_PREPARE_PERSISTENT as u32 } else { 0 diff --git a/diesel_derives/src/multiconnection.rs b/diesel_derives/src/multiconnection.rs index fde7838aad8b..23e430c2a74c 100644 --- a/diesel_derives/src/multiconnection.rs +++ b/diesel_derives/src/multiconnection.rs @@ -118,6 +118,15 @@ fn generate_connection_impl( } }); + let set_cache_impl = connection_types.iter().map(|c| { + let variant_ident = c.name; + quote::quote! { + #ident::#variant_ident(conn) => { + diesel::connection::Connection::set_prepared_statement_cache_size(conn, size); + } + } + }); + let get_instrumentation_impl = connection_types.iter().map(|c| { let variant_ident = c.name; quote::quote! { @@ -367,6 +376,12 @@ fn generate_connection_impl( } } + fn set_prepared_statement_cache_size(&mut self, size: diesel::connection::CacheSize) { + match self { + #(#set_cache_impl,)* + } + } + fn begin_test_transaction(&mut self) -> diesel::QueryResult<()> { match self { #(#impl_begin_test_transaction,)* From bda514b4f88110a64d31f76bada2e81c599477a9 Mon Sep 17 00:00:00 2001 From: Georg Semmler Date: Fri, 30 Aug 2024 09:47:00 +0200 Subject: [PATCH 2/2] Address review comments --- .../connection/statement_cache/strategy.rs | 62 +++++++++++++------ diesel/src/pg/connection/mod.rs | 3 +- diesel/src/sqlite/connection/mod.rs | 3 +- diesel/src/sqlite/connection/stmt.rs | 2 +- 4 files changed, 47 insertions(+), 23 deletions(-) diff --git a/diesel/src/connection/statement_cache/strategy.rs b/diesel/src/connection/statement_cache/strategy.rs index 7e3740526368..2636070a46cc 100644 --- a/diesel/src/connection/statement_cache/strategy.rs +++ b/diesel/src/connection/statement_cache/strategy.rs @@ -143,16 +143,15 @@ mod testing_utils { #[cfg(test)] #[cfg(feature = "postgres")] mod tests_pg { - use crate::{ - dsl::sql, - expression::functions::define_sql_function, - insertable::Insertable, - macros::table, - pg::Pg, - sql_types::{Integer, VarChar}, - test_helpers::pg_database_url, - Connection, ExpressionMethods, IntoSql, PgConnection, QueryDsl, RunQueryDsl, - }; + use crate::connection::CacheSize; + use crate::dsl::sql; + use crate::expression::functions::define_sql_function; + use crate::insertable::Insertable; + use crate::pg::Pg; + use crate::sql_types::{Integer, VarChar}; + use crate::table; + use crate::test_helpers::pg_database_url; + use crate::{Connection, ExpressionMethods, IntoSql, PgConnection, QueryDsl, RunQueryDsl}; use super::testing_utils::{count_cache_calls, RecordCacheEvents}; @@ -238,7 +237,7 @@ mod tests_pg { connection.begin_test_transaction().unwrap(); crate::sql_query( - "CREATE TABLE IF NOT EXISTS users(id INTEGER PRIMARY KEY, name TEXT NOT NULL);", + "CREATE TEMPORARY TABLE users(id INTEGER PRIMARY KEY, name TEXT NOT NULL);", ) .execute(connection) .unwrap(); @@ -264,7 +263,7 @@ mod tests_pg { connection.begin_test_transaction().unwrap(); crate::sql_query( - "CREATE TABLE IF NOT EXISTS users(id INTEGER PRIMARY KEY, name TEXT NOT NULL);", + "CREATE TEMPORARY TABLE users(id INTEGER PRIMARY KEY, name TEXT NOT NULL);", ) .execute(connection) .unwrap(); @@ -282,7 +281,7 @@ mod tests_pg { connection.begin_test_transaction().unwrap(); crate::sql_query( - "CREATE TABLE IF NOT EXISTS users(id INTEGER PRIMARY KEY, name TEXT NOT NULL);", + "CREATE TEMPORARY TABLE users(id INTEGER PRIMARY KEY, name TEXT NOT NULL);", ) .execute(connection) .unwrap(); @@ -300,7 +299,7 @@ mod tests_pg { connection.begin_test_transaction().unwrap(); crate::sql_query( - "CREATE TABLE IF NOT EXISTS users(id INTEGER PRIMARY KEY, name TEXT NOT NULL);", + "CREATE TEMPORARY TABLE users(id INTEGER PRIMARY KEY, name TEXT NOT NULL);", ) .execute(connection) .unwrap(); @@ -321,16 +320,30 @@ mod tests_pg { assert_eq!(Ok(true), query.get_result(connection)); assert_eq!(1, count_cache_calls(connection)); } + + #[test] + fn disabling_the_cache_works() { + let connection = &mut connection(); + connection.set_prepared_statement_cache_size(CacheSize::Disabled); + + let query = crate::select(1.into_sql::()); + + assert_eq!(Ok(1), query.get_result(connection)); + assert_eq!(0, count_cache_calls(connection)); + assert_eq!(Ok(1), query.get_result(connection)); + assert_eq!(0, count_cache_calls(connection)); + } } #[cfg(test)] #[cfg(feature = "sqlite")] mod tests_sqlite { - use crate::{ - dsl::sql, query_dsl::RunQueryDsl, sql_types::Integer, Connection, ExpressionMethods, - IntoSql, SqliteConnection, - }; + use crate::connection::CacheSize; + use crate::dsl::sql; + use crate::query_dsl::RunQueryDsl; + use crate::sql_types::Integer; + use crate::{Connection, ExpressionMethods, IntoSql, SqliteConnection}; use super::testing_utils::{count_cache_calls, RecordCacheEvents}; @@ -389,4 +402,17 @@ mod tests_sqlite { assert_eq!(Ok(true), query.get_result(connection)); assert_eq!(1, count_cache_calls(connection)); } + + #[test] + fn disabling_the_cache_works() { + let connection = &mut connection(); + connection.set_prepared_statement_cache_size(CacheSize::Disabled); + + let query = crate::select(1.into_sql::()); + + assert_eq!(Ok(1), query.get_result(connection)); + assert_eq!(0, count_cache_calls(connection)); + assert_eq!(Ok(1), query.get_result(connection)); + assert_eq!(0, count_cache_calls(connection)); + } } diff --git a/diesel/src/pg/connection/mod.rs b/diesel/src/pg/connection/mod.rs index 6b8d32b41be4..90ad676090b4 100644 --- a/diesel/src/pg/connection/mod.rs +++ b/diesel/src/pg/connection/mod.rs @@ -126,8 +126,7 @@ pub(super) use self::result::PgResult; #[allow(missing_debug_implementations)] #[cfg(feature = "postgres")] pub struct PgConnection { - /// pub(crate) for tests - pub(crate) statement_cache: StatementCache, + statement_cache: StatementCache, metadata_cache: PgMetadataCache, connection_and_transaction_manager: ConnectionAndTransactionManager, } diff --git a/diesel/src/sqlite/connection/mod.rs b/diesel/src/sqlite/connection/mod.rs index 40021f0079e6..c6a99d61ec73 100644 --- a/diesel/src/sqlite/connection/mod.rs +++ b/diesel/src/sqlite/connection/mod.rs @@ -121,8 +121,7 @@ pub struct SqliteConnection { // statement_cache needs to be before raw_connection // otherwise we will get errors about open statements before closing the // connection itself - // pub(crate) for tests - pub(crate) statement_cache: StatementCache, + statement_cache: StatementCache, raw_connection: RawConnection, transaction_state: AnsiTransactionManager, // this exists for the sole purpose of implementing `WithMetadataLookup` trait diff --git a/diesel/src/sqlite/connection/stmt.rs b/diesel/src/sqlite/connection/stmt.rs index 0a3c253c091b..f3ea3befaa7d 100644 --- a/diesel/src/sqlite/connection/stmt.rs +++ b/diesel/src/sqlite/connection/stmt.rs @@ -15,7 +15,7 @@ use std::io::{stderr, Write}; use std::os::raw as libc; use std::ptr::{self, NonNull}; -pub(crate) struct Statement { +pub(super) struct Statement { inner_statement: NonNull, }