From 148195e3b05d297ce17a883351ea81d175b451f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Uzarski?= Date: Thu, 21 Nov 2024 15:26:30 +0100 Subject: [PATCH 1/5] cargo: add thiserror dependency Once we bump to 0.15, the error can occur during the conversion to QueryRowsResult. The returned error is not a QueryError, but a standalone type. This is why, we need to extend our `CassErrorResult` so it's an enum over possible error types returned during query execution and conversions. thiserror crate will come in handy, when defining such error type. It allows to quickly generate `From` implementations for corresponding variants and most importantly, it helps generating `std::fmt::Display` implementation. --- scylla-rust-wrapper/Cargo.lock | 1 + scylla-rust-wrapper/Cargo.toml | 1 + 2 files changed, 2 insertions(+) diff --git a/scylla-rust-wrapper/Cargo.lock b/scylla-rust-wrapper/Cargo.lock index 970e99b2..06de9c51 100644 --- a/scylla-rust-wrapper/Cargo.lock +++ b/scylla-rust-wrapper/Cargo.lock @@ -1077,6 +1077,7 @@ dependencies = [ "rusty-fork", "scylla", "scylla-proxy", + "thiserror", "tokio", "tracing", "tracing-subscriber", diff --git a/scylla-rust-wrapper/Cargo.toml b/scylla-rust-wrapper/Cargo.toml index 249fbac4..c60da4bf 100644 --- a/scylla-rust-wrapper/Cargo.toml +++ b/scylla-rust-wrapper/Cargo.toml @@ -26,6 +26,7 @@ openssl = "0.10.32" tracing-subscriber = { version = "0.3.15", features = ["env-filter"] } tracing = "0.1.37" futures = "0.3" +thiserror = "1.0" [build-dependencies] bindgen = "0.65" From 34715e5b6523e6ce982d8b22f1375c3d36a5497e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Uzarski?= Date: Thu, 21 Nov 2024 15:41:54 +0100 Subject: [PATCH 2/5] errors: change CassResultError to an enum Changing it to enum right now, to reduce the noise during the commit that actually bumps the version. --- scylla-rust-wrapper/src/cass_error.rs | 16 +++ scylla-rust-wrapper/src/query_error.rs | 142 +++++++++++++++++-------- scylla-rust-wrapper/src/session.rs | 6 +- 3 files changed, 117 insertions(+), 47 deletions(-) diff --git a/scylla-rust-wrapper/src/cass_error.rs b/scylla-rust-wrapper/src/cass_error.rs index de5a0e3c..3b4eb803 100644 --- a/scylla-rust-wrapper/src/cass_error.rs +++ b/scylla-rust-wrapper/src/cass_error.rs @@ -2,6 +2,16 @@ use scylla::transport::errors::*; // Re-export error types. pub(crate) use crate::cass_error_types::{CassError, CassErrorSource}; +use crate::query_error::CassErrorResult; + +// TODO: From is bad practice. Will be replaced in the next commit. +impl From<&CassErrorResult> for CassError { + fn from(error: &CassErrorResult) -> Self { + match error { + CassErrorResult::Query(query_error) => CassError::from(query_error), + } + } +} impl From<&QueryError> for CassError { fn from(error: &QueryError) -> Self { @@ -133,6 +143,12 @@ pub trait CassErrorMessage { fn msg(&self) -> String; } +impl CassErrorMessage for CassErrorResult { + fn msg(&self) -> String { + self.to_string() + } +} + impl CassErrorMessage for QueryError { fn msg(&self) -> String { self.to_string() diff --git a/scylla-rust-wrapper/src/query_error.rs b/scylla-rust-wrapper/src/query_error.rs index ef5bfe56..5a4473a3 100644 --- a/scylla-rust-wrapper/src/query_error.rs +++ b/scylla-rust-wrapper/src/query_error.rs @@ -5,8 +5,13 @@ use crate::cass_types::CassConsistency; use crate::types::*; use scylla::statement::Consistency; use scylla::transport::errors::*; +use thiserror::Error; -pub type CassErrorResult = QueryError; +#[derive(Error, Debug)] +pub enum CassErrorResult { + #[error(transparent)] + Query(#[from] QueryError), +} impl From for CassConsistency { fn from(c: Consistency) -> CassConsistency { @@ -59,21 +64,26 @@ pub unsafe extern "C" fn cass_error_result_consistency( ) -> CassConsistency { let error_result: &CassErrorResult = ptr_to_ref(error_result); match error_result { - QueryError::DbError(DbError::Unavailable { consistency, .. }, _) => { - CassConsistency::from(*consistency) - } - QueryError::DbError(DbError::ReadTimeout { consistency, .. }, _) => { - CassConsistency::from(*consistency) - } - QueryError::DbError(DbError::WriteTimeout { consistency, .. }, _) => { - CassConsistency::from(*consistency) - } - QueryError::DbError(DbError::ReadFailure { consistency, .. }, _) => { - CassConsistency::from(*consistency) - } - QueryError::DbError(DbError::WriteFailure { consistency, .. }, _) => { - CassConsistency::from(*consistency) - } + CassErrorResult::Query(QueryError::DbError( + DbError::Unavailable { consistency, .. }, + _, + )) => CassConsistency::from(*consistency), + CassErrorResult::Query(QueryError::DbError( + DbError::ReadTimeout { consistency, .. }, + _, + )) => CassConsistency::from(*consistency), + CassErrorResult::Query(QueryError::DbError( + DbError::WriteTimeout { consistency, .. }, + _, + )) => CassConsistency::from(*consistency), + CassErrorResult::Query(QueryError::DbError( + DbError::ReadFailure { consistency, .. }, + _, + )) => CassConsistency::from(*consistency), + CassErrorResult::Query(QueryError::DbError( + DbError::WriteFailure { consistency, .. }, + _, + )) => CassConsistency::from(*consistency), _ => CassConsistency::CASS_CONSISTENCY_UNKNOWN, } } @@ -84,11 +94,21 @@ pub unsafe extern "C" fn cass_error_result_responses_received( ) -> cass_int32_t { let error_result: &CassErrorResult = ptr_to_ref(error_result); match error_result { - QueryError::DbError(DbError::Unavailable { alive, .. }, _) => *alive, - QueryError::DbError(DbError::ReadTimeout { received, .. }, _) => *received, - QueryError::DbError(DbError::WriteTimeout { received, .. }, _) => *received, - QueryError::DbError(DbError::ReadFailure { received, .. }, _) => *received, - QueryError::DbError(DbError::WriteFailure { received, .. }, _) => *received, + CassErrorResult::Query(QueryError::DbError(DbError::Unavailable { alive, .. }, _)) => { + *alive + } + CassErrorResult::Query(QueryError::DbError(DbError::ReadTimeout { received, .. }, _)) => { + *received + } + CassErrorResult::Query(QueryError::DbError(DbError::WriteTimeout { received, .. }, _)) => { + *received + } + CassErrorResult::Query(QueryError::DbError(DbError::ReadFailure { received, .. }, _)) => { + *received + } + CassErrorResult::Query(QueryError::DbError(DbError::WriteFailure { received, .. }, _)) => { + *received + } _ => -1, } } @@ -99,11 +119,21 @@ pub unsafe extern "C" fn cass_error_result_responses_required( ) -> cass_int32_t { let error_result: &CassErrorResult = ptr_to_ref(error_result); match error_result { - QueryError::DbError(DbError::Unavailable { required, .. }, _) => *required, - QueryError::DbError(DbError::ReadTimeout { required, .. }, _) => *required, - QueryError::DbError(DbError::WriteTimeout { required, .. }, _) => *required, - QueryError::DbError(DbError::ReadFailure { required, .. }, _) => *required, - QueryError::DbError(DbError::WriteFailure { required, .. }, _) => *required, + CassErrorResult::Query(QueryError::DbError(DbError::Unavailable { required, .. }, _)) => { + *required + } + CassErrorResult::Query(QueryError::DbError(DbError::ReadTimeout { required, .. }, _)) => { + *required + } + CassErrorResult::Query(QueryError::DbError(DbError::WriteTimeout { required, .. }, _)) => { + *required + } + CassErrorResult::Query(QueryError::DbError(DbError::ReadFailure { required, .. }, _)) => { + *required + } + CassErrorResult::Query(QueryError::DbError(DbError::WriteFailure { required, .. }, _)) => { + *required + } _ => -1, } } @@ -114,8 +144,14 @@ pub unsafe extern "C" fn cass_error_result_num_failures( ) -> cass_int32_t { let error_result: &CassErrorResult = ptr_to_ref(error_result); match error_result { - QueryError::DbError(DbError::ReadFailure { numfailures, .. }, _) => *numfailures, - QueryError::DbError(DbError::WriteFailure { numfailures, .. }, _) => *numfailures, + CassErrorResult::Query(QueryError::DbError( + DbError::ReadFailure { numfailures, .. }, + _, + )) => *numfailures, + CassErrorResult::Query(QueryError::DbError( + DbError::WriteFailure { numfailures, .. }, + _, + )) => *numfailures, _ => -1, } } @@ -126,14 +162,20 @@ pub unsafe extern "C" fn cass_error_result_data_present( ) -> cass_bool_t { let error_result: &CassErrorResult = ptr_to_ref(error_result); match error_result { - QueryError::DbError(DbError::ReadTimeout { data_present, .. }, _) => { + CassErrorResult::Query(QueryError::DbError( + DbError::ReadTimeout { data_present, .. }, + _, + )) => { if *data_present { cass_true } else { cass_false } } - QueryError::DbError(DbError::ReadFailure { data_present, .. }, _) => { + CassErrorResult::Query(QueryError::DbError( + DbError::ReadFailure { data_present, .. }, + _, + )) => { if *data_present { cass_true } else { @@ -150,12 +192,14 @@ pub unsafe extern "C" fn cass_error_result_write_type( ) -> CassWriteType { let error_result: &CassErrorResult = ptr_to_ref(error_result); match error_result { - QueryError::DbError(DbError::WriteTimeout { write_type, .. }, _) => { - CassWriteType::from(write_type) - } - QueryError::DbError(DbError::WriteFailure { write_type, .. }, _) => { - CassWriteType::from(write_type) - } + CassErrorResult::Query(QueryError::DbError( + DbError::WriteTimeout { write_type, .. }, + _, + )) => CassWriteType::from(write_type), + CassErrorResult::Query(QueryError::DbError( + DbError::WriteFailure { write_type, .. }, + _, + )) => CassWriteType::from(write_type), _ => CassWriteType::CASS_WRITE_TYPE_UNKNOWN, } } @@ -168,11 +212,14 @@ pub unsafe extern "C" fn cass_error_result_keyspace( ) -> CassError { let error_result: &CassErrorResult = ptr_to_ref(error_result); match error_result { - QueryError::DbError(DbError::AlreadyExists { keyspace, .. }, _) => { + CassErrorResult::Query(QueryError::DbError(DbError::AlreadyExists { keyspace, .. }, _)) => { write_str_to_c(keyspace.as_str(), c_keyspace, c_keyspace_len); CassError::CASS_OK } - QueryError::DbError(DbError::FunctionFailure { keyspace, .. }, _) => { + CassErrorResult::Query(QueryError::DbError( + DbError::FunctionFailure { keyspace, .. }, + _, + )) => { write_str_to_c(keyspace.as_str(), c_keyspace, c_keyspace_len); CassError::CASS_OK } @@ -188,7 +235,7 @@ pub unsafe extern "C" fn cass_error_result_table( ) -> CassError { let error_result: &CassErrorResult = ptr_to_ref(error_result); match error_result { - QueryError::DbError(DbError::AlreadyExists { table, .. }, _) => { + CassErrorResult::Query(QueryError::DbError(DbError::AlreadyExists { table, .. }, _)) => { write_str_to_c(table.as_str(), c_table, c_table_len); CassError::CASS_OK } @@ -204,7 +251,10 @@ pub unsafe extern "C" fn cass_error_result_function( ) -> CassError { let error_result: &CassErrorResult = ptr_to_ref(error_result); match error_result { - QueryError::DbError(DbError::FunctionFailure { function, .. }, _) => { + CassErrorResult::Query(QueryError::DbError( + DbError::FunctionFailure { function, .. }, + _, + )) => { write_str_to_c(function.as_str(), c_function, c_function_len); CassError::CASS_OK } @@ -216,9 +266,10 @@ pub unsafe extern "C" fn cass_error_result_function( pub unsafe extern "C" fn cass_error_num_arg_types(error_result: *const CassErrorResult) -> size_t { let error_result: &CassErrorResult = ptr_to_ref(error_result); match error_result { - QueryError::DbError(DbError::FunctionFailure { arg_types, .. }, _) => { - arg_types.len() as size_t - } + CassErrorResult::Query(QueryError::DbError( + DbError::FunctionFailure { arg_types, .. }, + _, + )) => arg_types.len() as size_t, _ => 0, } } @@ -232,7 +283,10 @@ pub unsafe extern "C" fn cass_error_result_arg_type( ) -> CassError { let error_result: &CassErrorResult = ptr_to_ref(error_result); match error_result { - QueryError::DbError(DbError::FunctionFailure { arg_types, .. }, _) => { + CassErrorResult::Query(QueryError::DbError( + DbError::FunctionFailure { arg_types, .. }, + _, + )) => { if index >= arg_types.len() as size_t { return CassError::CASS_ERROR_LIB_INDEX_OUT_OF_BOUNDS; } diff --git a/scylla-rust-wrapper/src/session.rs b/scylla-rust-wrapper/src/session.rs index 6de63835..77eef0be 100644 --- a/scylla-rust-wrapper/src/session.rs +++ b/scylla-rust-wrapper/src/session.rs @@ -224,7 +224,7 @@ pub unsafe extern "C" fn cass_session_execute_batch( tracing_id: None, paging_state_response: PagingStateResponse::NoMorePages, }))), - Err(err) => Ok(CassResultValue::QueryError(Arc::new(err))), + Err(err) => Ok(CassResultValue::QueryError(Arc::new(err.into()))), } }; @@ -243,7 +243,7 @@ async fn request_with_timeout( match tokio::time::timeout(Duration::from_millis(request_timeout_ms), future).await { Ok(result) => result, Err(_timeout_err) => Ok(CassResultValue::QueryError(Arc::new( - QueryError::TimeoutError, + QueryError::TimeoutError.into(), ))), } } @@ -372,7 +372,7 @@ pub unsafe extern "C" fn cass_session_execute( Ok(CassResultValue::QueryResult(cass_result)) } - Err(err) => Ok(CassResultValue::QueryError(Arc::new(err))), + Err(err) => Ok(CassResultValue::QueryError(Arc::new(err.into()))), } }; From 97adf6648127fb400e9c86a4bfcebd34fc0381f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Uzarski?= Date: Thu, 21 Nov 2024 15:48:54 +0100 Subject: [PATCH 3/5] errors: replace From<&> conversions with a trait Partially fixes https://github.com/scylladb/cpp-rust-driver/issues/177. --- scylla-rust-wrapper/src/cass_error.rs | 47 ++++++++++++++------------ scylla-rust-wrapper/src/future.rs | 3 +- scylla-rust-wrapper/src/query_error.rs | 2 +- scylla-rust-wrapper/src/session.rs | 6 ++-- 4 files changed, 31 insertions(+), 27 deletions(-) diff --git a/scylla-rust-wrapper/src/cass_error.rs b/scylla-rust-wrapper/src/cass_error.rs index 3b4eb803..55fc4fdd 100644 --- a/scylla-rust-wrapper/src/cass_error.rs +++ b/scylla-rust-wrapper/src/cass_error.rs @@ -4,20 +4,23 @@ use scylla::transport::errors::*; pub(crate) use crate::cass_error_types::{CassError, CassErrorSource}; use crate::query_error::CassErrorResult; -// TODO: From is bad practice. Will be replaced in the next commit. -impl From<&CassErrorResult> for CassError { - fn from(error: &CassErrorResult) -> Self { - match error { - CassErrorResult::Query(query_error) => CassError::from(query_error), +pub trait ToCassError { + fn to_cass_error(&self) -> CassError; +} + +impl ToCassError for CassErrorResult { + fn to_cass_error(&self) -> CassError { + match self { + CassErrorResult::Query(query_error) => query_error.to_cass_error(), } } } -impl From<&QueryError> for CassError { - fn from(error: &QueryError) -> Self { - match error { - QueryError::DbError(db_error, _string) => CassError::from(db_error), - QueryError::BadQuery(bad_query) => CassError::from(bad_query), +impl ToCassError for QueryError { + fn to_cass_error(&self) -> CassError { + match self { + QueryError::DbError(db_error, _string) => db_error.to_cass_error(), + QueryError::BadQuery(bad_query) => bad_query.to_cass_error(), QueryError::ProtocolError(_str) => CassError::CASS_ERROR_SERVER_PROTOCOL_ERROR, QueryError::TimeoutError => CassError::CASS_ERROR_LIB_REQUEST_TIMED_OUT, // This may be either read or write timeout error QueryError::UnableToAllocStreamId => CassError::CASS_ERROR_LIB_NO_STREAMS, @@ -42,9 +45,9 @@ impl From<&QueryError> for CassError { } } -impl From<&DbError> for CassError { - fn from(error: &DbError) -> Self { - match error { +impl ToCassError for DbError { + fn to_cass_error(&self) -> CassError { + match self { DbError::ServerError => CassError::CASS_ERROR_SERVER_SERVER_ERROR, DbError::ProtocolError => CassError::CASS_ERROR_SERVER_PROTOCOL_ERROR, DbError::AuthenticationError => CassError::CASS_ERROR_SERVER_BAD_CREDENTIALS, @@ -72,9 +75,9 @@ impl From<&DbError> for CassError { } } -impl From<&BadQuery> for CassError { - fn from(error: &BadQuery) -> Self { - match error { +impl ToCassError for BadQuery { + fn to_cass_error(&self) -> CassError { + match self { BadQuery::SerializeValuesError(_serialize_values_error) => { CassError::CASS_ERROR_LAST_ENTRY } @@ -91,9 +94,9 @@ impl From<&BadQuery> for CassError { } } -impl From<&NewSessionError> for CassError { - fn from(error: &NewSessionError) -> Self { - match error { +impl ToCassError for NewSessionError { + fn to_cass_error(&self) -> CassError { + match self { NewSessionError::FailedToResolveAnyHostname(_hostnames) => { CassError::CASS_ERROR_LIB_NO_HOSTS_AVAILABLE } @@ -127,9 +130,9 @@ impl From<&NewSessionError> for CassError { } } -impl From<&BadKeyspaceName> for CassError { - fn from(error: &BadKeyspaceName) -> Self { - match error { +impl ToCassError for BadKeyspaceName { + fn to_cass_error(&self) -> CassError { + match self { BadKeyspaceName::Empty => CassError::CASS_ERROR_LAST_ENTRY, BadKeyspaceName::TooLong(_string, _usize) => CassError::CASS_ERROR_LAST_ENTRY, BadKeyspaceName::IllegalCharacter(_string, _char) => CassError::CASS_ERROR_LAST_ENTRY, diff --git a/scylla-rust-wrapper/src/future.rs b/scylla-rust-wrapper/src/future.rs index 34df2b37..874aacdd 100644 --- a/scylla-rust-wrapper/src/future.rs +++ b/scylla-rust-wrapper/src/future.rs @@ -1,6 +1,7 @@ use crate::argconv::*; use crate::cass_error::CassError; use crate::cass_error::CassErrorMessage; +use crate::cass_error::ToCassError; use crate::prepared::CassPrepared; use crate::query_error::CassErrorResult; use crate::query_result::CassResult; @@ -320,7 +321,7 @@ pub unsafe extern "C" fn cass_future_ready(future_raw: *const CassFuture) -> cas #[no_mangle] pub unsafe extern "C" fn cass_future_error_code(future_raw: *const CassFuture) -> CassError { ptr_to_ref(future_raw).with_waited_result(|r: &mut CassFutureResult| match r { - Ok(CassResultValue::QueryError(err)) => CassError::from(err.as_ref()), + Ok(CassResultValue::QueryError(err)) => err.to_cass_error(), Err((err, _)) => *err, _ => CassError::CASS_OK, }) diff --git a/scylla-rust-wrapper/src/query_error.rs b/scylla-rust-wrapper/src/query_error.rs index 5a4473a3..b47262fa 100644 --- a/scylla-rust-wrapper/src/query_error.rs +++ b/scylla-rust-wrapper/src/query_error.rs @@ -55,7 +55,7 @@ pub unsafe extern "C" fn cass_error_result_free(error_result: *const CassErrorRe #[no_mangle] pub unsafe extern "C" fn cass_error_result_code(error_result: *const CassErrorResult) -> CassError { let error_result: &CassErrorResult = ptr_to_ref(error_result); - CassError::from(error_result) + error_result.to_cass_error() } #[no_mangle] diff --git a/scylla-rust-wrapper/src/session.rs b/scylla-rust-wrapper/src/session.rs index 77eef0be..4f92438d 100644 --- a/scylla-rust-wrapper/src/session.rs +++ b/scylla-rust-wrapper/src/session.rs @@ -128,7 +128,7 @@ impl CassSessionInner { let session = session_builder .build() .await - .map_err(|err| (CassError::from(&err), err.msg()))?; + .map_err(|err| (err.to_cass_error(), err.msg()))?; *session_guard = Some(CassSessionInner { session, @@ -538,7 +538,7 @@ pub unsafe extern "C" fn cass_session_prepare_from_existing( let prepared = session .prepare(query.query.clone()) .await - .map_err(|err| (CassError::from(&err), err.msg()))?; + .map_err(|err| (err.to_cass_error(), err.msg()))?; Ok(CassResultValue::Prepared(Arc::new( CassPrepared::new_from_prepared_statement(prepared), @@ -582,7 +582,7 @@ pub unsafe extern "C" fn cass_session_prepare_n( let mut prepared = session .prepare(query) .await - .map_err(|err| (CassError::from(&err), err.msg()))?; + .map_err(|err| (err.to_cass_error(), err.msg()))?; // Set Cpp Driver default configuration for queries: prepared.set_consistency(Consistency::One); From 49e5673450c7aa27f0e295c019bb54a31d6e4f7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Uzarski?= Date: Thu, 21 Nov 2024 16:00:06 +0100 Subject: [PATCH 4/5] result: extract constructor logic to query_result.rs The logic will become much more complex once we bump to 0.15. It introduces additional nesting, thus I prefer to extract it to separate functions (since cass_session_execute is already complex enough). I also moved the necessary utility functions to query_result module. --- scylla-rust-wrapper/src/query_result.rs | 162 +++++++++++++++++++++++- scylla-rust-wrapper/src/session.rs | 150 +--------------------- 2 files changed, 165 insertions(+), 147 deletions(-) diff --git a/scylla-rust-wrapper/src/query_result.rs b/scylla-rust-wrapper/src/query_result.rs index cdf3bd93..e8d52461 100644 --- a/scylla-rust-wrapper/src/query_result.rs +++ b/scylla-rust-wrapper/src/query_result.rs @@ -7,10 +7,12 @@ use crate::inet::CassInet; use crate::metadata::{ CassColumnMeta, CassKeyspaceMeta, CassMaterializedViewMeta, CassSchemaMeta, CassTableMeta, }; +use crate::query_result::Value::{CollectionValue, RegularValue}; use crate::types::*; use crate::uuid::CassUuid; -use scylla::frame::response::result::{ColumnSpec, CqlValue}; +use scylla::frame::response::result::{ColumnSpec, CqlValue, Row}; use scylla::transport::PagingStateResponse; +use scylla::QueryResult; use std::convert::TryInto; use std::os::raw::c_char; use std::sync::Arc; @@ -23,6 +25,34 @@ pub struct CassResult { pub paging_state_response: PagingStateResponse, } +impl CassResult { + /// It creates CassResult object based on the: + /// - query result + /// - paging state response + /// - optional cached result metadata - it's provided for prepared statements + pub fn from_result_payload( + result: QueryResult, + paging_state_response: PagingStateResponse, + maybe_result_metadata: Option>, + ) -> Self { + // maybe_result_metadata is: + // - Some(_) for prepared statements + // - None for unprepared statements + let metadata = maybe_result_metadata + .unwrap_or_else(|| Arc::new(CassResultMetadata::from_column_specs(result.col_specs()))); + let cass_rows = result + .rows + .map(|rows| create_cass_rows_from_rows(rows, &metadata)); + + CassResult { + rows: cass_rows, + metadata, + tracing_id: result.tracing_id, + paging_state_response, + } + } +} + #[derive(Debug)] pub struct CassResultMetadata { pub col_specs: Vec, @@ -51,6 +81,18 @@ pub struct CassRow { pub result_metadata: Arc, } +pub fn create_cass_rows_from_rows( + rows: Vec, + metadata: &Arc, +) -> Vec { + rows.into_iter() + .map(|r| CassRow { + columns: create_cass_row_columns(r, metadata), + result_metadata: metadata.clone(), + }) + .collect() +} + pub enum Value { RegularValue(CqlValue), CollectionValue(Collection), @@ -73,6 +115,120 @@ pub struct CassValue { pub value_type: Arc, } +fn create_cass_row_columns(row: Row, metadata: &Arc) -> Vec { + row.columns + .into_iter() + .zip(metadata.col_specs.iter()) + .map(|(val, col_spec)| { + let column_type = Arc::clone(&col_spec.data_type); + CassValue { + value: val.map(|col_val| get_column_value(col_val, &column_type)), + value_type: column_type, + } + }) + .collect() +} + +fn get_column_value(column: CqlValue, column_type: &Arc) -> Value { + match (column, column_type.as_ref()) { + ( + CqlValue::List(list), + CassDataType::List { + typ: Some(list_type), + .. + }, + ) => CollectionValue(Collection::List( + list.into_iter() + .map(|val| CassValue { + value_type: list_type.clone(), + value: Some(get_column_value(val, list_type)), + }) + .collect(), + )), + ( + CqlValue::Map(map), + CassDataType::Map { + typ: MapDataType::KeyAndValue(key_type, value_type), + .. + }, + ) => CollectionValue(Collection::Map( + map.into_iter() + .map(|(key, val)| { + ( + CassValue { + value_type: key_type.clone(), + value: Some(get_column_value(key, key_type)), + }, + CassValue { + value_type: value_type.clone(), + value: Some(get_column_value(val, value_type)), + }, + ) + }) + .collect(), + )), + ( + CqlValue::Set(set), + CassDataType::Set { + typ: Some(set_type), + .. + }, + ) => CollectionValue(Collection::Set( + set.into_iter() + .map(|val| CassValue { + value_type: set_type.clone(), + value: Some(get_column_value(val, set_type)), + }) + .collect(), + )), + ( + CqlValue::UserDefinedType { + keyspace, + type_name, + fields, + }, + CassDataType::UDT(udt_type), + ) => CollectionValue(Collection::UserDefinedType { + keyspace, + type_name, + fields: fields + .into_iter() + .enumerate() + .map(|(index, (name, val_opt))| { + let udt_field_type_opt = udt_type.get_field_by_index(index); + if let (Some(val), Some(udt_field_type)) = (val_opt, udt_field_type_opt) { + return ( + name, + Some(CassValue { + value_type: udt_field_type.clone(), + value: Some(get_column_value(val, udt_field_type)), + }), + ); + } + (name, None) + }) + .collect(), + }), + (CqlValue::Tuple(tuple), CassDataType::Tuple(tuple_types)) => { + CollectionValue(Collection::Tuple( + tuple + .into_iter() + .enumerate() + .map(|(index, val_opt)| { + val_opt + .zip(tuple_types.get(index)) + .map(|(val, tuple_field_type)| CassValue { + value_type: tuple_field_type.clone(), + value: Some(get_column_value(val, tuple_field_type)), + }) + }) + .collect(), + )) + } + (regular_value, _) => RegularValue(regular_value), + } +} + pub struct CassResultIterator { result: Arc, position: Option, @@ -1392,11 +1548,11 @@ mod tests { cass_result_column_data_type, cass_result_column_name, cass_result_first_row, ptr_to_cstr_n, ptr_to_ref, size_t, }, - session::create_cass_rows_from_rows, }; use super::{ - cass_result_column_count, cass_result_column_type, CassResult, CassResultMetadata, + cass_result_column_count, cass_result_column_type, create_cass_rows_from_rows, CassResult, + CassResultMetadata, }; fn col_spec(name: &'static str, typ: ColumnType<'static>) -> ColumnSpec<'static> { diff --git a/scylla-rust-wrapper/src/session.rs b/scylla-rust-wrapper/src/session.rs index 4f92438d..189f5e9b 100644 --- a/scylla-rust-wrapper/src/session.rs +++ b/scylla-rust-wrapper/src/session.rs @@ -1,7 +1,7 @@ use crate::argconv::*; use crate::batch::CassBatch; use crate::cass_error::*; -use crate::cass_types::{CassDataType, MapDataType, UDTDataType}; +use crate::cass_types::{CassDataType, UDTDataType}; use crate::cluster::build_session_builder; use crate::cluster::CassCluster; use crate::exec_profile::{CassExecProfile, ExecProfileName, PerStatementExecProfile}; @@ -9,13 +9,11 @@ use crate::future::{CassFuture, CassFutureResult, CassResultValue}; use crate::metadata::create_table_metadata; use crate::metadata::{CassKeyspaceMeta, CassMaterializedViewMeta, CassSchemaMeta}; use crate::prepared::CassPrepared; -use crate::query_result::Value::{CollectionValue, RegularValue}; -use crate::query_result::{CassResult, CassResultMetadata, CassRow, CassValue, Collection, Value}; +use crate::query_result::{CassResult, CassResultMetadata}; use crate::statement::CassStatement; use crate::statement::Statement; use crate::types::{cass_uint64_t, size_t}; use crate::uuid::CassUuid; -use scylla::frame::response::result::{CqlValue, Row}; use scylla::frame::types::Consistency; use scylla::query::Query; use scylla::transport::errors::QueryError; @@ -354,21 +352,11 @@ pub unsafe extern "C" fn cass_session_execute( match query_res { Ok((result, paging_state_response, maybe_result_metadata)) => { - // maybe_result_metadata is: - // - Some(_) for prepared statements - // - None for unprepared statements - let metadata = maybe_result_metadata.unwrap_or_else(|| { - Arc::new(CassResultMetadata::from_column_specs(result.col_specs())) - }); - let cass_rows = result - .rows - .map(|rows| create_cass_rows_from_rows(rows, &metadata)); - let cass_result = Arc::new(CassResult { - rows: cass_rows, - metadata, - tracing_id: result.tracing_id, + let cass_result = Arc::new(CassResult::from_result_payload( + result, paging_state_response, - }); + maybe_result_metadata, + )); Ok(CassResultValue::QueryResult(cass_result)) } @@ -384,132 +372,6 @@ pub unsafe extern "C" fn cass_session_execute( } } -pub(crate) fn create_cass_rows_from_rows( - rows: Vec, - metadata: &Arc, -) -> Vec { - rows.into_iter() - .map(|r| CassRow { - columns: create_cass_row_columns(r, metadata), - result_metadata: metadata.clone(), - }) - .collect() -} - -fn create_cass_row_columns(row: Row, metadata: &Arc) -> Vec { - row.columns - .into_iter() - .zip(metadata.col_specs.iter()) - .map(|(val, col_spec)| { - let column_type = Arc::clone(&col_spec.data_type); - CassValue { - value: val.map(|col_val| get_column_value(col_val, &column_type)), - value_type: column_type, - } - }) - .collect() -} - -fn get_column_value(column: CqlValue, column_type: &Arc) -> Value { - match (column, column_type.as_ref()) { - ( - CqlValue::List(list), - CassDataType::List { - typ: Some(list_type), - .. - }, - ) => CollectionValue(Collection::List( - list.into_iter() - .map(|val| CassValue { - value_type: list_type.clone(), - value: Some(get_column_value(val, list_type)), - }) - .collect(), - )), - ( - CqlValue::Map(map), - CassDataType::Map { - typ: MapDataType::KeyAndValue(key_type, value_type), - .. - }, - ) => CollectionValue(Collection::Map( - map.into_iter() - .map(|(key, val)| { - ( - CassValue { - value_type: key_type.clone(), - value: Some(get_column_value(key, key_type)), - }, - CassValue { - value_type: value_type.clone(), - value: Some(get_column_value(val, value_type)), - }, - ) - }) - .collect(), - )), - ( - CqlValue::Set(set), - CassDataType::Set { - typ: Some(set_type), - .. - }, - ) => CollectionValue(Collection::Set( - set.into_iter() - .map(|val| CassValue { - value_type: set_type.clone(), - value: Some(get_column_value(val, set_type)), - }) - .collect(), - )), - ( - CqlValue::UserDefinedType { - keyspace, - type_name, - fields, - }, - CassDataType::UDT(udt_type), - ) => CollectionValue(Collection::UserDefinedType { - keyspace, - type_name, - fields: fields - .into_iter() - .enumerate() - .map(|(index, (name, val_opt))| { - let udt_field_type_opt = udt_type.get_field_by_index(index); - if let (Some(val), Some(udt_field_type)) = (val_opt, udt_field_type_opt) { - return ( - name, - Some(CassValue { - value_type: udt_field_type.clone(), - value: Some(get_column_value(val, udt_field_type)), - }), - ); - } - (name, None) - }) - .collect(), - }), - (CqlValue::Tuple(tuple), CassDataType::Tuple(tuple_types)) => { - CollectionValue(Collection::Tuple( - tuple - .into_iter() - .enumerate() - .map(|(index, val_opt)| { - val_opt - .zip(tuple_types.get(index)) - .map(|(val, tuple_field_type)| CassValue { - value_type: tuple_field_type.clone(), - value: Some(get_column_value(val, tuple_field_type)), - }) - }) - .collect(), - )) - } - (regular_value, _) => RegularValue(regular_value), - } -} - #[no_mangle] pub unsafe extern "C" fn cass_session_prepare_from_existing( cass_session: *mut CassSession, From a8b0637a4d57bc35eb589ae6f59199dbf2c01180 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Uzarski?= Date: Thu, 21 Nov 2024 15:06:00 +0100 Subject: [PATCH 5/5] cargo: bump rust-driver to 0.15 1. CassResult was refactored yet again. Now, we clearly distinguish betwen Rows and non-Rows result. `CassResultKind` and `CassRowsResult` are introduced. `CassRowsResult` holds information about the (for now, eagerly deserialized) rows and metadata. The usages of `CassResult` are adjusted throughout the codebase. 2. CassErrorResult was extended by two variants: - metadata deserialization error - happens during conversion from QueryResult to QueryRowsResult - deserialization error - happens during eager rows deserialization. 3. CassResult construction logic was adjusted to new QueryResult. As mentioned above, some errors can occur during conversion, thus `CassResult::from_result_payload` returns a StdResult now. For now, we eagerly deserialize the rows. This is equivalent to the version before this PR. --- scylla-rust-wrapper/Cargo.lock | 78 ++++++++++- scylla-rust-wrapper/Cargo.toml | 4 +- scylla-rust-wrapper/src/cass_error.rs | 10 ++ scylla-rust-wrapper/src/query_error.rs | 6 + scylla-rust-wrapper/src/query_result.rs | 168 +++++++++++++++++------- scylla-rust-wrapper/src/session.rs | 14 +- 6 files changed, 215 insertions(+), 65 deletions(-) diff --git a/scylla-rust-wrapper/Cargo.lock b/scylla-rust-wrapper/Cargo.lock index 06de9c51..9b75378e 100644 --- a/scylla-rust-wrapper/Cargo.lock +++ b/scylla-rust-wrapper/Cargo.lock @@ -1027,8 +1027,8 @@ checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" [[package]] name = "scylla" -version = "0.14.0" -source = "git+https://github.com/scylladb/scylla-rust-driver.git?rev=64b4afcd#64b4afcdb4286b21f6cc1acb55266d6607f250e0" +version = "0.15.0" +source = "git+https://github.com/scylladb/scylla-rust-driver.git?rev=v0.15.0#f59908c54e6b6407112311a3745e32d4bd218d0c" dependencies = [ "arc-swap", "async-trait", @@ -1086,8 +1086,8 @@ dependencies = [ [[package]] name = "scylla-cql" -version = "0.3.0" -source = "git+https://github.com/scylladb/scylla-rust-driver.git?rev=64b4afcd#64b4afcdb4286b21f6cc1acb55266d6607f250e0" +version = "0.4.0" +source = "git+https://github.com/scylladb/scylla-rust-driver.git?rev=v0.15.0#f59908c54e6b6407112311a3745e32d4bd218d0c" dependencies = [ "async-trait", "byteorder", @@ -1095,15 +1095,17 @@ dependencies = [ "lz4_flex", "scylla-macros", "snap", + "stable_deref_trait", "thiserror", "tokio", "uuid", + "yoke", ] [[package]] name = "scylla-macros" -version = "0.6.0" -source = "git+https://github.com/scylladb/scylla-rust-driver.git?rev=64b4afcd#64b4afcdb4286b21f6cc1acb55266d6607f250e0" +version = "0.7.0" +source = "git+https://github.com/scylladb/scylla-rust-driver.git?rev=v0.15.0#f59908c54e6b6407112311a3745e32d4bd218d0c" dependencies = [ "darling", "proc-macro2", @@ -1114,7 +1116,7 @@ dependencies = [ [[package]] name = "scylla-proxy" version = "0.0.3" -source = "git+https://github.com/scylladb/scylla-rust-driver.git?rev=64b4afcd#64b4afcdb4286b21f6cc1acb55266d6607f250e0" +source = "git+https://github.com/scylladb/scylla-rust-driver.git?rev=v0.15.0#f59908c54e6b6407112311a3745e32d4bd218d0c" dependencies = [ "bigdecimal", "byteorder", @@ -1191,6 +1193,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "stable_deref_trait" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" + [[package]] name = "static_assertions" version = "1.1.0" @@ -1225,6 +1233,17 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "synstructure" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.61", +] + [[package]] name = "tempfile" version = "3.5.0" @@ -1782,6 +1801,30 @@ dependencies = [ "winapi", ] +[[package]] +name = "yoke" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c5b1314b079b0930c31e3af543d8ee1757b1951ae1e1565ec704403a7240ca5" +dependencies = [ + "serde", + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28cc31741b18cb6f1d5ff12f5b7523e3d6eb0852bbbad19d73905511d9849b95" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.61", + "synstructure", +] + [[package]] name = "zerocopy" version = "0.7.34" @@ -1801,3 +1844,24 @@ dependencies = [ "quote", "syn 2.0.61", ] + +[[package]] +name = "zerofrom" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91ec111ce797d0e0784a1116d0ddcdbea84322cd79e5d5ad173daeba4f93ab55" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ea7b4a3637ea8669cedf0f1fd5c286a17f3de97b8dd5a70a6c167a1730e63a5" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.61", + "synstructure", +] diff --git a/scylla-rust-wrapper/Cargo.toml b/scylla-rust-wrapper/Cargo.toml index c60da4bf..69a74236 100644 --- a/scylla-rust-wrapper/Cargo.toml +++ b/scylla-rust-wrapper/Cargo.toml @@ -10,7 +10,7 @@ categories = ["database"] license = "MIT OR Apache-2.0" [dependencies] -scylla = { git = "https://github.com/scylladb/scylla-rust-driver.git", rev = "64b4afcd", features = [ +scylla = { git = "https://github.com/scylladb/scylla-rust-driver.git", rev = "v0.15.0", features = [ "ssl", ] } tokio = { version = "1.27.0", features = ["full"] } @@ -33,7 +33,7 @@ bindgen = "0.65" chrono = "0.4.20" [dev-dependencies] -scylla-proxy = { git = "https://github.com/scylladb/scylla-rust-driver.git", rev = "64b4afcd" } +scylla-proxy = { git = "https://github.com/scylladb/scylla-rust-driver.git", rev = "v0.15.0" } assert_matches = "1.5.0" ntest = "0.9.3" diff --git a/scylla-rust-wrapper/src/cass_error.rs b/scylla-rust-wrapper/src/cass_error.rs index 55fc4fdd..06ab321c 100644 --- a/scylla-rust-wrapper/src/cass_error.rs +++ b/scylla-rust-wrapper/src/cass_error.rs @@ -12,6 +12,16 @@ impl ToCassError for CassErrorResult { fn to_cass_error(&self) -> CassError { match self { CassErrorResult::Query(query_error) => query_error.to_cass_error(), + + // TODO: + // For now let's leave these as LIB_INVALID_DATA. + // I don't see any variants that would make more sense. + // TBH, I'm almost sure that we should introduce additional enum variants + // of CassError in the future ~ muzarski. + CassErrorResult::ResultMetadataLazyDeserialization(_) => { + CassError::CASS_ERROR_LIB_INVALID_DATA + } + CassErrorResult::Deserialization(_) => CassError::CASS_ERROR_LIB_INVALID_DATA, } } } diff --git a/scylla-rust-wrapper/src/query_error.rs b/scylla-rust-wrapper/src/query_error.rs index b47262fa..f6a91e1c 100644 --- a/scylla-rust-wrapper/src/query_error.rs +++ b/scylla-rust-wrapper/src/query_error.rs @@ -3,6 +3,8 @@ use crate::cass_error::*; use crate::cass_error_types::CassWriteType; use crate::cass_types::CassConsistency; use crate::types::*; +use scylla::deserialize::DeserializationError; +use scylla::frame::frame_errors::ResultMetadataAndRowsCountParseError; use scylla::statement::Consistency; use scylla::transport::errors::*; use thiserror::Error; @@ -11,6 +13,10 @@ use thiserror::Error; pub enum CassErrorResult { #[error(transparent)] Query(#[from] QueryError), + #[error(transparent)] + ResultMetadataLazyDeserialization(#[from] ResultMetadataAndRowsCountParseError), + #[error("Failed to deserialize rows: {0}")] + Deserialization(#[from] DeserializationError), } impl From for CassConsistency { diff --git a/scylla-rust-wrapper/src/query_result.rs b/scylla-rust-wrapper/src/query_result.rs index e8d52461..84761eb1 100644 --- a/scylla-rust-wrapper/src/query_result.rs +++ b/scylla-rust-wrapper/src/query_result.rs @@ -7,10 +7,12 @@ use crate::inet::CassInet; use crate::metadata::{ CassColumnMeta, CassKeyspaceMeta, CassMaterializedViewMeta, CassSchemaMeta, CassTableMeta, }; +use crate::query_error::CassErrorResult; use crate::query_result::Value::{CollectionValue, RegularValue}; use crate::types::*; use crate::uuid::CassUuid; use scylla::frame::response::result::{ColumnSpec, CqlValue, Row}; +use scylla::transport::query_result::{ColumnSpecs, IntoRowsResultError}; use scylla::transport::PagingStateResponse; use scylla::QueryResult; use std::convert::TryInto; @@ -18,11 +20,20 @@ use std::os::raw::c_char; use std::sync::Arc; use uuid::Uuid; -pub struct CassResult { - pub rows: Option>, +pub enum CassResultKind { + NonRows, + Rows(CassRowsResult), +} + +pub struct CassRowsResult { + pub rows: Vec, pub metadata: Arc, +} + +pub struct CassResult { pub tracing_id: Option, pub paging_state_response: PagingStateResponse, + pub kind: CassResultKind, } impl CassResult { @@ -34,21 +45,51 @@ impl CassResult { result: QueryResult, paging_state_response: PagingStateResponse, maybe_result_metadata: Option>, - ) -> Self { - // maybe_result_metadata is: - // - Some(_) for prepared statements - // - None for unprepared statements - let metadata = maybe_result_metadata - .unwrap_or_else(|| Arc::new(CassResultMetadata::from_column_specs(result.col_specs()))); - let cass_rows = result - .rows - .map(|rows| create_cass_rows_from_rows(rows, &metadata)); - - CassResult { - rows: cass_rows, - metadata, - tracing_id: result.tracing_id, - paging_state_response, + ) -> Result { + match result.into_rows_result() { + Ok(rows_result) => { + // maybe_result_metadata is: + // - Some(_) for prepared statements + // - None for unprepared statements + let metadata = maybe_result_metadata.unwrap_or_else(|| { + Arc::new(CassResultMetadata::from_column_spec_views( + rows_result.column_specs(), + )) + }); + + // For now, let's eagerly deserialize rows into type-erased CqlValues. + // Lazy deserialization requires a non-trivial refactor that needs to be discussed. + let rows: Vec = rows_result + .rows::() + // SAFETY: this unwrap is safe, because `Row` always + // passes the typecheck, no matter the type of the columns. + .unwrap() + .collect::>()?; + let cass_rows = create_cass_rows_from_rows(rows, &metadata); + + let cass_result = CassResult { + tracing_id: rows_result.tracing_id(), + paging_state_response, + kind: CassResultKind::Rows(CassRowsResult { + rows: cass_rows, + metadata, + }), + }; + + Ok(cass_result) + } + Err(IntoRowsResultError::ResultNotRows(result)) => { + let cass_result = CassResult { + tracing_id: result.tracing_id(), + paging_state_response, + kind: CassResultKind::NonRows, + }; + + Ok(cass_result) + } + Err(IntoRowsResultError::ResultMetadataLazyDeserializationError(err)) => { + Err(err.into()) + } } } } @@ -72,6 +113,30 @@ impl CassResultMetadata { CassResultMetadata { col_specs } } + + // I don't like introducing this method, but there is a discrepancy + // between the types representing column specs returned from + // `QueryRowsResult::column_specs()` (returns ColumnSpecs<'_>) and + // `PreparedStatement::get_result_set_col_specs()` (returns &[ColumnSpec<'_>). + // + // I tried to workaround it with accepting a generic type, such as iterator, + // but then again, types of items we are iterating over differ as well - + // ColumnSpecView<'_> vs ColumnSpec<'_>. + // + // This should probably be adjusted on rust-driver side. + pub fn from_column_spec_views(col_specs: ColumnSpecs<'_>) -> CassResultMetadata { + let col_specs = col_specs + .iter() + .map(|col_spec| { + let name = col_spec.name().to_owned(); + let data_type = Arc::new(get_column_type(col_spec.typ())); + + CassColumnSpec { name, data_type } + }) + .collect(); + + CassResultMetadata { col_specs } + } } /// The lifetime of CassRow is bound to CassResult. @@ -311,9 +376,11 @@ pub unsafe extern "C" fn cass_iterator_next(iterator: *mut CassIterator) -> cass result_iterator.position = Some(new_pos); - match &result_iterator.result.rows { - Some(rs) => (new_pos < rs.len()) as cass_bool_t, - None => false as cass_bool_t, + match &result_iterator.result.kind { + CassResultKind::Rows(rows_result) => { + (new_pos < rows_result.rows.len()) as cass_bool_t + } + CassResultKind::NonRows => false as cass_bool_t, } } CassIterator::CassRowIterator(row_iterator) => { @@ -410,12 +477,11 @@ pub unsafe extern "C" fn cass_iterator_get_row(iterator: *const CassIterator) -> None => return std::ptr::null(), }; - let row: &CassRow = match result_iterator - .result - .rows - .as_ref() - .and_then(|rs| rs.get(iter_position)) - { + let CassResultKind::Rows(CassRowsResult { rows, .. }) = &result_iterator.result.kind else { + return std::ptr::null(); + }; + + let row: &CassRow = match rows.get(iter_position) { Some(row) => row, None => return std::ptr::null(), }; @@ -1068,16 +1134,15 @@ pub unsafe extern "C" fn cass_result_column_name( let result_from_raw = ptr_to_ref(result); let index_usize: usize = index.try_into().unwrap(); - if index_usize >= result_from_raw.metadata.col_specs.len() { + let CassResultKind::Rows(CassRowsResult { metadata, .. }) = &result_from_raw.kind else { + return CassError::CASS_ERROR_LIB_INDEX_OUT_OF_BOUNDS; + }; + + if index_usize >= metadata.col_specs.len() { return CassError::CASS_ERROR_LIB_INDEX_OUT_OF_BOUNDS; } - let column_name = &result_from_raw - .metadata - .col_specs - .get(index_usize) - .unwrap() - .name; + let column_name = &metadata.col_specs.get(index_usize).unwrap().name; write_str_to_c(column_name, name, name_length); @@ -1106,8 +1171,11 @@ pub unsafe extern "C" fn cass_result_column_data_type( .try_into() .expect("Provided index is out of bounds. Max possible value is usize::MAX"); - result_from_raw - .metadata + let CassResultKind::Rows(CassRowsResult { metadata, .. }) = &result_from_raw.kind else { + return std::ptr::null(); + }; + + metadata .col_specs .get(index_usize) .map(|col_spec| Arc::as_ptr(&col_spec.data_type)) @@ -1474,28 +1542,33 @@ pub unsafe extern "C" fn cass_value_secondary_sub_type( pub unsafe extern "C" fn cass_result_row_count(result_raw: *const CassResult) -> size_t { let result = ptr_to_ref(result_raw); - if result.rows.as_ref().is_none() { + let CassResultKind::Rows(CassRowsResult { rows, .. }) = &result.kind else { return 0; - } + }; - result.rows.as_ref().unwrap().len() as size_t + rows.len() as size_t } #[no_mangle] pub unsafe extern "C" fn cass_result_column_count(result_raw: *const CassResult) -> size_t { let result = ptr_to_ref(result_raw); - result.metadata.col_specs.len() as size_t + let CassResultKind::Rows(CassRowsResult { metadata, .. }) = &result.kind else { + return 0; + }; + + metadata.col_specs.len() as size_t } #[no_mangle] pub unsafe extern "C" fn cass_result_first_row(result_raw: *const CassResult) -> *const CassRow { let result = ptr_to_ref(result_raw); - result - .rows - .as_ref() - .and_then(|rows| rows.first()) + let CassResultKind::Rows(CassRowsResult { rows, .. }) = &result.kind else { + return std::ptr::null(); + }; + + rows.first() .map(|row| row as *const CassRow) .unwrap_or(std::ptr::null()) } @@ -1552,7 +1625,7 @@ mod tests { use super::{ cass_result_column_count, cass_result_column_type, create_cass_rows_from_rows, CassResult, - CassResultMetadata, + CassResultKind, CassResultMetadata, CassRowsResult, }; fn col_spec(name: &'static str, typ: ColumnType<'static>) -> ColumnSpec<'static> { @@ -1588,10 +1661,9 @@ mod tests { ); CassResult { - rows: Some(rows), - metadata, tracing_id: None, paging_state_response: PagingStateResponse::NoMorePages, + kind: CassResultKind::Rows(CassRowsResult { rows, metadata }), } } @@ -1678,12 +1750,10 @@ mod tests { } fn create_non_rows_cass_result() -> CassResult { - let metadata = Arc::new(CassResultMetadata::from_column_specs(&[])); CassResult { - rows: None, - metadata, tracing_id: None, paging_state_response: PagingStateResponse::NoMorePages, + kind: CassResultKind::NonRows, } } diff --git a/scylla-rust-wrapper/src/session.rs b/scylla-rust-wrapper/src/session.rs index 189f5e9b..5ec3cee4 100644 --- a/scylla-rust-wrapper/src/session.rs +++ b/scylla-rust-wrapper/src/session.rs @@ -9,7 +9,7 @@ use crate::future::{CassFuture, CassFutureResult, CassResultValue}; use crate::metadata::create_table_metadata; use crate::metadata::{CassKeyspaceMeta, CassMaterializedViewMeta, CassSchemaMeta}; use crate::prepared::CassPrepared; -use crate::query_result::{CassResult, CassResultMetadata}; +use crate::query_result::{CassResult, CassResultKind, CassResultMetadata}; use crate::statement::CassStatement; use crate::statement::Statement; use crate::types::{cass_uint64_t, size_t}; @@ -217,10 +217,9 @@ pub unsafe extern "C" fn cass_session_execute_batch( let query_res = session.batch(&state.batch, &state.bound_values).await; match query_res { Ok(_result) => Ok(CassResultValue::QueryResult(Arc::new(CassResult { - rows: None, - metadata: Arc::new(CassResultMetadata::from_column_specs(&[])), tracing_id: None, paging_state_response: PagingStateResponse::NoMorePages, + kind: CassResultKind::NonRows, }))), Err(err) => Ok(CassResultValue::QueryError(Arc::new(err.into()))), } @@ -352,13 +351,14 @@ pub unsafe extern "C" fn cass_session_execute( match query_res { Ok((result, paging_state_response, maybe_result_metadata)) => { - let cass_result = Arc::new(CassResult::from_result_payload( + match CassResult::from_result_payload( result, paging_state_response, maybe_result_metadata, - )); - - Ok(CassResultValue::QueryResult(cass_result)) + ) { + Ok(result) => Ok(CassResultValue::QueryResult(Arc::new(result))), + Err(e) => Ok(CassResultValue::QueryError(Arc::new(e))), + } } Err(err) => Ok(CassResultValue::QueryError(Arc::new(err.into()))), }