From 1cb4c455d778dce89bed441fb3b01639ee870d79 Mon Sep 17 00:00:00 2001 From: Aman Karmani Date: Wed, 30 Oct 2024 03:19:06 -0700 Subject: [PATCH] postgres: implement query/query_raw using rust-postgres query_typed --- quaint/src/connector/postgres/native/mod.rs | 145 ++++++++++++-------- 1 file changed, 84 insertions(+), 61 deletions(-) diff --git a/quaint/src/connector/postgres/native/mod.rs b/quaint/src/connector/postgres/native/mod.rs index c294733fc51..93f81e503e3 100644 --- a/quaint/src/connector/postgres/native/mod.rs +++ b/quaint/src/connector/postgres/native/mod.rs @@ -27,7 +27,7 @@ use futures::{future::FutureExt, lock::Mutex}; use lru_cache::LruCache; use native_tls::{Certificate, Identity, TlsConnector}; use postgres_native_tls::MakeTlsConnector; -use postgres_types::{Kind as PostgresKind, Type as PostgresType}; +use postgres_types::{Kind as PostgresKind, Type as PostgresType, ToSql}; use std::hash::{DefaultHasher, Hash, Hasher}; use std::{ fmt::{Debug, Display}, @@ -540,29 +540,37 @@ impl Queryable for PostgreSql { sql, params, move || async move { - let stmt = self.fetch_cached(sql, &[]).await?; - - if stmt.params().len() != params.len() { - let kind = ErrorKind::IncorrectNumberOfParameters { - expected: stmt.params().len(), - actual: params.len(), - }; - - return Err(Error::builder(kind).build()); - } + let converted_params = conversion::conv_params(params); + let param_types = conversion::params_to_types(params); + let params_with_types: Vec<(&(dyn ToSql + Sync), PostgresType)> = converted_params + .iter() + .zip(param_types) + .map(|(value, ty)| (*value as &(dyn ToSql + Sync), ty)) + .collect(); + // Execute the query using `query_typed` let rows = self - .perform_io(self.client.0.query(&stmt, conversion::conv_params(params).as_slice())) + .perform_io(self.client.0.query_typed(sql, params_with_types.as_slice())) .await?; - let col_types = stmt - .columns() - .iter() - .map(|c| PGColumnType::from_pg_type(c.type_())) - .map(ColumnType::from) - .collect::>(); - let mut result = ResultSet::new(stmt.to_column_names(), col_types, Vec::new()); + // Extract column information from the first row, if available + let (col_types, column_names) = if let Some(row) = rows.first() { + let columns = row.columns(); + let col_types = columns + .iter() + .map(|c| PGColumnType::from_pg_type(c.type_())) + .map(ColumnType::from) + .collect::>(); + let column_names = columns.iter().map(|c| c.name().to_string()).collect(); + + (col_types, column_names) + } else { + (Vec::new(), Vec::new()) + }; + let mut result = ResultSet::new(column_names, col_types, Vec::new()); + + // Process each row in the result set for row in rows { result.rows.push(row.get_result_row()?); } @@ -582,28 +590,35 @@ impl Queryable for PostgreSql { sql, params, move || async move { - let stmt = self.fetch_cached(sql, params).await?; - - if stmt.params().len() != params.len() { - let kind = ErrorKind::IncorrectNumberOfParameters { - expected: stmt.params().len(), - actual: params.len(), - }; - - return Err(Error::builder(kind).build()); - } - - let col_types = stmt - .columns() + let converted_params = conversion::conv_params(params); + let param_types = conversion::params_to_types(params); + let params_with_types: Vec<(&(dyn ToSql + Sync), PostgresType)> = converted_params .iter() - .map(|c| PGColumnType::from_pg_type(c.type_())) - .map(ColumnType::from) - .collect::>(); + .zip(param_types) + .map(|(value, ty)| (*value as &(dyn ToSql + Sync), ty)) + .collect(); + + // Execute the query using `query_typed` let rows = self - .perform_io(self.client.0.query(&stmt, conversion::conv_params(params).as_slice())) + .perform_io(self.client.0.query_typed(sql, params_with_types.as_slice())) .await?; - let mut result = ResultSet::new(stmt.to_column_names(), col_types, Vec::new()); + // Extract column information from the first row, if available + let (col_types, column_names) = if let Some(row) = rows.first() { + let columns = row.columns(); + let col_types = columns + .iter() + .map(|c| PGColumnType::from_pg_type(c.type_())) + .map(ColumnType::from) + .collect::>(); + let column_names = columns.iter().map(|c| c.name().to_string()).collect(); + + (col_types, column_names) + } else { + (Vec::new(), Vec::new()) + }; + + let mut result = ResultSet::new(column_names, col_types, Vec::new()); for row in rows { result.rows.push(row.get_result_row()?); @@ -705,20 +720,24 @@ impl Queryable for PostgreSql { sql, params, move || async move { - let stmt = self.fetch_cached(sql, &[]).await?; - - if stmt.params().len() != params.len() { - let kind = ErrorKind::IncorrectNumberOfParameters { - expected: stmt.params().len(), - actual: params.len(), - }; - - return Err(Error::builder(kind).build()); - } + let converted_params = conversion::conv_params(params); + let param_types = conversion::params_to_types(params); + let params_with_types: Vec<(&(dyn ToSql + Sync), PostgresType)> = converted_params + .iter() + .zip(param_types) + .map(|(value, ty)| (*value as &(dyn ToSql + Sync), ty)) + .collect(); let changes = self - .perform_io(self.client.0.execute(&stmt, conversion::conv_params(params).as_slice())) - .await?; + .perform_io(self.client.0.query_typed_raw::<&(dyn ToSql + Sync), _>( + sql, + params_with_types.as_slice().iter() + .map(|(v, t)| (*v, t.clone())) + .collect::>() + )) + .await? + .rows_affected() + .unwrap_or(0); Ok(changes) }, @@ -735,20 +754,24 @@ impl Queryable for PostgreSql { sql, params, move || async move { - let stmt = self.fetch_cached(sql, params).await?; - - if stmt.params().len() != params.len() { - let kind = ErrorKind::IncorrectNumberOfParameters { - expected: stmt.params().len(), - actual: params.len(), - }; - - return Err(Error::builder(kind).build()); - } + let converted_params = conversion::conv_params(params); + let param_types = conversion::params_to_types(params); + let params_with_types: Vec<(&(dyn ToSql + Sync), PostgresType)> = converted_params + .iter() + .zip(param_types) + .map(|(value, ty)| (*value as &(dyn ToSql + Sync), ty)) + .collect(); let changes = self - .perform_io(self.client.0.execute(&stmt, conversion::conv_params(params).as_slice())) - .await?; + .perform_io(self.client.0.query_typed_raw::<&(dyn ToSql + Sync), _>( + sql, + params_with_types.as_slice().iter() + .map(|(v, t)| (*v, t.clone())) + .collect::>() + )) + .await? + .rows_affected() + .unwrap_or(0); Ok(changes) },