Skip to content

Commit

Permalink
Accept either bytes or strings for queries
Browse files Browse the repository at this point in the history
As discussed in #194, MySQL actually accepts arbitrary sequences of
bytes, not just utf-8 strings, as queries, but this crate is limited to
only working with types that impl AsRef<str>. To allow sending arbitrary
byte slices as queries this commit, which goes along with and requires
blackbeam/rust_mysql_common#64, introduces a new
`AsQuery` trait, which is impl'd for all of the standard library types
that either impl AsRef<str> or AsRef<[u8]>, and uses that trait in place
of `AsRef<str>` for all query methods, going on down the chain to error
types, internal cache structures, etc. as well.

Fixes: #194
  • Loading branch information
glittershark committed May 25, 2022
1 parent efe51e0 commit f4f2ab6
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 38 deletions.
4 changes: 2 additions & 2 deletions src/conn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1382,8 +1382,8 @@ mod test {
.stmt_cache_ref()
.iter()
.map(|item| item.1.query.0.as_ref())
.collect::<Vec<&str>>();
assert_eq!(order, &["DO 6", "DO 5", "DO 3"]);
.collect::<Vec<&[u8]>>();
assert_eq!(order, &[b"DO 6", b"DO 5", b"DO 3"]);
conn.disconnect().await?;
Ok(())
}
Expand Down
8 changes: 4 additions & 4 deletions src/conn/routines/prepare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,21 @@ use super::Routine;
/// A routine that performs `COM_STMT_PREPARE`.
#[derive(Debug, Clone)]
pub struct PrepareRoutine {
query: Arc<str>,
query: Arc<[u8]>,
}

impl PrepareRoutine {
pub fn new(raw_query: Cow<'_, str>) -> Self {
pub fn new(raw_query: Cow<'_, [u8]>) -> Self {
Self {
query: raw_query.into_owned().into_boxed_str().into(),
query: raw_query.into_owned().into_boxed_slice().into(),
}
}
}

impl Routine<Arc<StmtInner>> for PrepareRoutine {
fn call<'a>(&'a mut self, conn: &'a mut Conn) -> BoxFuture<'a, crate::Result<Arc<StmtInner>>> {
async move {
conn.write_command_data(Command::COM_STMT_PREPARE, self.query.as_bytes())
conn.write_command_data(Command::COM_STMT_PREPARE, &self.query)
.await?;

let packet = conn.read_packet().await?;
Expand Down
16 changes: 8 additions & 8 deletions src/conn/stmt_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@ use std::{
use crate::queryable::stmt::StmtInner;

#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct QueryString(pub Arc<str>);
pub struct QueryString(pub Arc<[u8]>);

impl Borrow<str> for QueryString {
fn borrow(&self) -> &str {
impl Borrow<[u8]> for QueryString {
fn borrow(&self) -> &[u8] {
&*self.0.as_ref()
}
}

impl PartialEq<str> for QueryString {
fn eq(&self, other: &str) -> bool {
impl PartialEq<[u8]> for QueryString {
fn eq(&self, other: &[u8]) -> bool {
&*self.0.as_ref() == other
}
}
Expand Down Expand Up @@ -68,7 +68,7 @@ impl StmtCache {
}
}

pub fn put(&mut self, query: Arc<str>, stmt: Arc<StmtInner>) -> Option<Arc<StmtInner>> {
pub fn put(&mut self, query: Arc<[u8]>, stmt: Arc<StmtInner>) -> Option<Arc<StmtInner>> {
if self.cap == 0 {
return None;
}
Expand All @@ -95,7 +95,7 @@ impl StmtCache {

pub fn remove(&mut self, id: u32) {
if let Some(entry) = self.cache.pop(&id) {
self.query_map.remove::<str>(entry.query.borrow());
self.query_map.remove::<[u8]>(entry.query.borrow());
}
}

Expand Down Expand Up @@ -135,7 +135,7 @@ impl super::Conn {
/// Returns statement, if cached.
///
/// `raw_query` is the query with `?` placeholders (not with `:<name>` placeholders).
pub(crate) fn get_cached_stmt(&mut self, raw_query: &str) -> Option<Arc<StmtInner>> {
pub(crate) fn get_cached_stmt(&mut self, raw_query: &[u8]) -> Option<Arc<StmtInner>> {
self.stmt_cache_mut()
.by_query(raw_query)
.map(|entry| entry.stmt.clone())
Expand Down
4 changes: 2 additions & 2 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ pub enum DriverError {
#[error("Error converting from mysql row.")]
FromRow { row: Row },

#[error("Missing named parameter `{}'.", name)]
MissingNamedParam { name: String },
#[error("Missing named parameter `{}'.", String::from_utf8_lossy(&name))]
MissingNamedParam { name: Vec<u8> },

#[error("Named and positional parameters mixed in one statement.")]
MixedParams,
Expand Down
50 changes: 49 additions & 1 deletion src/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,54 @@ use crate::{
BinaryProtocol, BoxFuture, Params, QueryResult, ResultSetStream, TextProtocol,
};

/// Types that can be treated as a MySQL query.
///
/// This trait is implemented by all "string-ish" standard library types, like `String`, `&str`,
/// `Cow<str>`, but also all types that can be treated as a slice of bytes (such as `Vec<u8>` and
/// `&[u8]`), since MySQL does not require queries to be valid UTF-8.
pub trait AsQuery: Send + Sync {
fn as_query(&self) -> &[u8];
}

impl AsQuery for &'_ [u8] {
fn as_query(&self) -> &[u8] {
self
}
}

macro_rules! impl_as_query_as_ref {
($type: ty) => {
impl AsQuery for $type {
fn as_query(&self) -> &[u8] {
self.as_ref()
}
}
};
}

impl_as_query_as_ref!(Vec<u8>);
impl_as_query_as_ref!(&Vec<u8>);
impl_as_query_as_ref!(Box<[u8]>);
impl_as_query_as_ref!(std::borrow::Cow<'_, [u8]>);
impl_as_query_as_ref!(std::sync::Arc<[u8]>);

macro_rules! impl_as_query_as_bytes {
($type: ty) => {
impl AsQuery for $type {
fn as_query(&self) -> &[u8] {
self.as_bytes()
}
}
};
}

impl_as_query_as_bytes!(String);
impl_as_query_as_bytes!(&String);
impl_as_query_as_bytes!(&str);
impl_as_query_as_bytes!(Box<str>);
impl_as_query_as_bytes!(std::borrow::Cow<'_, str>);
impl_as_query_as_bytes!(std::sync::Arc<str>);

/// MySql text query.
///
/// This trait covers the set of `query*` methods on the `Queryable` trait.
Expand Down Expand Up @@ -157,7 +205,7 @@ pub trait Query: Send + Sized {
}
}

impl<Q: AsRef<str> + Send + Sync> Query for Q {
impl<Q: AsQuery> Query for Q {
type Protocol = TextProtocol;

fn run<'a, 't: 'a, C>(self, conn: C) -> BoxFuture<'a, QueryResult<'a, 't, TextProtocol>>
Expand Down
27 changes: 13 additions & 14 deletions src/queryable/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use crate::{
consts::CapabilityFlags,
error::*,
prelude::{FromRow, StatementLike},
query::AsQuery,
queryable::query_result::ResultSetMeta,
BoxFuture, Column, Conn, Params, ResultSetStream, Row,
};
Expand Down Expand Up @@ -102,10 +103,9 @@ impl Conn {
/// Low level function that performs a text query.
pub(crate) async fn raw_query<'a, Q>(&'a mut self, query: Q) -> Result<()>
where
Q: AsRef<str> + Send + Sync + 'a,
Q: AsQuery + 'a,
{
self.routine(QueryRoutine::new(query.as_ref().as_bytes()))
.await
self.routine(QueryRoutine::new(query.as_query())).await
}
}

Expand All @@ -122,7 +122,7 @@ pub trait Queryable: Send {
query: Q,
) -> BoxFuture<'a, QueryResult<'a, 'static, TextProtocol>>
where
Q: AsRef<str> + Send + Sync + 'a;
Q: AsQuery + 'a;

/// Prepares the given statement.
///
Expand Down Expand Up @@ -165,7 +165,7 @@ pub trait Queryable: Send {
/// to make this conversion infallible.
fn query<'a, T, Q>(&'a mut self, query: Q) -> BoxFuture<'a, Vec<T>>
where
Q: AsRef<str> + Send + Sync + 'a,
Q: AsQuery + 'a,
T: FromRow + Send + 'static,
{
async move { self.query_iter(query).await?.collect_and_drop::<T>().await }.boxed()
Expand All @@ -180,7 +180,7 @@ pub trait Queryable: Send {
/// to make this conversion infallible.
fn query_first<'a, T, Q>(&'a mut self, query: Q) -> BoxFuture<'a, Option<T>>
where
Q: AsRef<str> + Send + Sync + 'a,
Q: AsQuery + 'a,
T: FromRow + Send + 'static,
{
async move {
Expand All @@ -205,7 +205,7 @@ pub trait Queryable: Send {
/// to make this conversion infallible.
fn query_map<'a, T, F, Q, U>(&'a mut self, query: Q, mut f: F) -> BoxFuture<'a, Vec<U>>
where
Q: AsRef<str> + Send + Sync + 'a,
Q: AsQuery + 'a,
T: FromRow + Send + 'static,
F: FnMut(T) -> U + Send + 'a,
U: Send,
Expand All @@ -229,7 +229,7 @@ pub trait Queryable: Send {
/// to make this conversion infallible.
fn query_fold<'a, T, F, Q, U>(&'a mut self, query: Q, init: U, mut f: F) -> BoxFuture<'a, U>
where
Q: AsRef<str> + Send + Sync + 'a,
Q: AsQuery + 'a,
T: FromRow + Send + 'static,
F: FnMut(U, T) -> U + Send + 'a,
U: Send + 'a,
Expand All @@ -246,7 +246,7 @@ pub trait Queryable: Send {
/// Performs the given query and drops the query result.
fn query_drop<'a, Q>(&'a mut self, query: Q) -> BoxFuture<'a, ()>
where
Q: AsRef<str> + Send + Sync + 'a,
Q: AsQuery + 'a,
{
async move { self.query_iter(query).await?.drop_result().await }.boxed()
}
Expand Down Expand Up @@ -397,7 +397,7 @@ pub trait Queryable: Send {
) -> BoxFuture<'a, ResultSetStream<'a, 'a, 'static, T, TextProtocol>>
where
T: Unpin + FromRow + Send + 'static,
Q: AsRef<str> + Send + Sync + 'a,
Q: AsQuery + 'a,
{
async move {
self.query_iter(query)
Expand Down Expand Up @@ -451,11 +451,10 @@ impl Queryable for Conn {
query: Q,
) -> BoxFuture<'a, QueryResult<'a, 'static, TextProtocol>>
where
Q: AsRef<str> + Send + Sync + 'a,
Q: AsQuery + 'a,
{
async move {
self.routine(QueryRoutine::new(query.as_ref().as_bytes()))
.await?;
self.routine(QueryRoutine::new(query.as_query())).await?;
Ok(QueryResult::new(self))
}
.boxed()
Expand Down Expand Up @@ -525,7 +524,7 @@ impl Queryable for Transaction<'_> {
query: Q,
) -> BoxFuture<'a, QueryResult<'a, 'static, TextProtocol>>
where
Q: AsRef<str> + Send + Sync + 'a,
Q: AsQuery + 'a,
{
self.0.query_iter(query)
}
Expand Down
16 changes: 9 additions & 7 deletions src/queryable/stmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ use crate::{
Column, Params,
};

use super::AsQuery;

/// Result of a `StatementLike::to_statement` call.
pub enum ToStatementResult<'a> {
/// Statement is immediately available.
Expand All @@ -37,12 +39,12 @@ pub trait StatementLike: Send + Sync {
Self: 'a;
}

fn to_statement_move<'a, T: AsRef<str> + Send + Sync + 'a>(
fn to_statement_move<'a, T: AsQuery + 'a>(
stmt: T,
conn: &'a mut crate::Conn,
) -> ToStatementResult<'a> {
let fut = async move {
let (named_params, raw_query) = parse_named_params(stmt.as_ref())?;
let (named_params, raw_query) = parse_named_params(stmt.as_query())?;
let inner_stmt = match conn.get_cached_stmt(&*raw_query) {
Some(inner_stmt) => inner_stmt,
None => conn.prepare_statement(raw_query).await?,
Expand Down Expand Up @@ -119,7 +121,7 @@ impl<T: StatementLike + Clone> StatementLike for &'_ T {
/// Statement data.
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct StmtInner {
pub(crate) raw_query: Arc<str>,
pub(crate) raw_query: Arc<[u8]>,
columns: Option<Box<[Column]>>,
params: Option<Box<[Column]>>,
stmt_packet: StmtPacket,
Expand All @@ -130,7 +132,7 @@ impl StmtInner {
pub(crate) fn from_payload(
pld: &[u8],
connection_id: u32,
raw_query: Arc<str>,
raw_query: Arc<[u8]>,
) -> std::io::Result<Self> {
let stmt_packet = ParseBuf(pld).parse(())?;

Expand Down Expand Up @@ -192,11 +194,11 @@ impl StmtInner {
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct Statement {
pub(crate) inner: Arc<StmtInner>,
pub(crate) named_params: Option<Vec<String>>,
pub(crate) named_params: Option<Vec<Vec<u8>>>,
}

impl Statement {
pub(crate) fn new(inner: Arc<StmtInner>, named_params: Option<Vec<String>>) -> Self {
pub(crate) fn new(inner: Arc<StmtInner>, named_params: Option<Vec<Vec<u8>>>) -> Self {
Self {
inner,
named_params,
Expand Down Expand Up @@ -275,7 +277,7 @@ impl crate::Conn {
/// Low-level helper, that prepares the given statement.
///
/// `raw_query` is a query with `?` placeholders (if any).
async fn prepare_statement(&mut self, raw_query: Cow<'_, str>) -> Result<Arc<StmtInner>> {
async fn prepare_statement(&mut self, raw_query: Cow<'_, [u8]>) -> Result<Arc<StmtInner>> {
let inner_stmt = self.routine(PrepareRoutine::new(raw_query)).await?;

if let Some(old_stmt) = self.cache_stmt(&inner_stmt) {
Expand Down

0 comments on commit f4f2ab6

Please sign in to comment.