diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 2f83bdc..87f7a1d 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -238,3 +238,28 @@ jobs: # Update Cargo.lock to minimal version dependencies. cargo update -Z minimal-versions cargo hack check --all-features --ignore-private + + build-feature-power-set: + if: github.event_name == 'push' || github.event_name == 'schedule' || + github.event.pull_request.head.repo.full_name != github.repository + + name: Build with each feature combination + runs-on: ubuntu-latest + needs: ["build"] + steps: + - name: Checkout source + uses: actions/checkout@v4 + + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@master + with: + toolchain: nightly + + - name: Install cargo-hack + uses: taiki-e/install-action@cargo-hack + + - name: Cache Cargo registry + uses: Swatinem/rust-cache@v2 + + - name: Run cargo check with every combination of features + run: cargo hack check --feature-powerset --exclude-features db --no-dev-deps diff --git a/Cargo.lock b/Cargo.lock index 88bee32..3897e80 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -717,8 +717,7 @@ checksum = "1435fa1053d8b2fbbe9be7e97eca7f33d37b28409959813daefc1446a14247f1" [[package]] name = "dummy" version = "0.9.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3ee4e39146145f7dd28e6c85ffdce489d93c0d9c88121063b8aacabbd9858d2" +source = "git+https://github.com/m4tx/fake-rs.git#7414661e15da393b0c4a45dcdd81f8e57c70b459" dependencies = [ "darling", "proc-macro2", @@ -829,8 +828,7 @@ dependencies = [ [[package]] name = "fake" version = "3.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "661cb0601b5f4050d1e65452c5b0ea555c0b3e88fb5ed7855906adc6c42523ef" +source = "git+https://github.com/m4tx/fake-rs.git#7414661e15da393b0c4a45dcdd81f8e57c70b459" dependencies = [ "chrono", "deunicode", diff --git a/Cargo.toml b/Cargo.toml index 459af7c..dd54b5a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,7 +39,8 @@ darling = "0.20" derive_builder = "0.20" derive_more = "1" env_logger = "0.11" -fake = "3" +# TODO: replace with upstream when https://github.com/cksac/fake-rs/pull/204 is merged and released +fake = { git = "https://github.com/m4tx/fake-rs.git" } flareon = { path = "flareon" } flareon_codegen = { path = "flareon-codegen" } flareon_macros = { path = "flareon-macros" } diff --git a/clippy.toml b/clippy.toml new file mode 100644 index 0000000..91c8bb3 --- /dev/null +++ b/clippy.toml @@ -0,0 +1 @@ +doc-valid-idents = ["PostgreSQL", "MySQL", "SQLite"] diff --git a/compose.yml b/compose.yml index f03a41a..5fc5ccc 100644 --- a/compose.yml +++ b/compose.yml @@ -1,8 +1,18 @@ services: + mariadb: + image: mariadb:11 + container_name: flareon-mariadb + environment: + MARIADB_DATABASE: mysql + MARIADB_USER: flareon + MARIADB_PASSWORD: flareon + MARIADB_ALLOW_EMPTY_ROOT_PASSWORD: 1 + ports: + - "3306:3306" + postgres: - image: postgres:16-alpine + image: postgres:17-alpine container_name: flareon-postgres - restart: always environment: POSTGRES_USER: flareon POSTGRES_PASSWORD: flareon diff --git a/flareon-macros/src/dbtest.rs b/flareon-macros/src/dbtest.rs index b450452..d5452fa 100644 --- a/flareon-macros/src/dbtest.rs +++ b/flareon-macros/src/dbtest.rs @@ -6,6 +6,7 @@ pub(super) fn fn_to_dbtest(test_function_decl: ItemFn) -> syn::Result syn::Result flareon::Result { } async fn authenticate(request: &mut Request, login_form: LoginForm) -> flareon::Result { + #[cfg(feature = "db")] let user = request - .authenticate(&DatabaseUserCredentials::new( + .authenticate(&crate::auth::db::DatabaseUserCredentials::new( login_form.username, // TODO unify auth::Password and forms::fields::Password flareon::auth::Password::new(login_form.password.into_string()), )) .await?; + #[cfg(not(any(feature = "sqlite", feature = "postgres", feature = "mysql")))] + let mut user: Option> = None; + if let Some(user) = user { request.login(user).await?; Ok(true) diff --git a/flareon/src/auth.rs b/flareon/src/auth.rs index 8593b48..e31b92a 100644 --- a/flareon/src/auth.rs +++ b/flareon/src/auth.rs @@ -6,6 +6,7 @@ //! //! For the default way to store users in the database, see the [`db`] module. +#[cfg(feature = "db")] pub mod db; use std::any::Any; @@ -14,8 +15,6 @@ use std::sync::Arc; use async_trait::async_trait; use chrono::{DateTime, FixedOffset}; -use flareon::config::SecretKey; -use flareon::db::impl_postgres::PostgresValueRef; #[cfg(test)] use mockall::automock; use password_auth::VerifyError; @@ -23,7 +22,8 @@ use serde::{Deserialize, Serialize}; use subtle::ConstantTimeEq; use thiserror::Error; -use crate::db::impl_sqlite::SqliteValueRef; +use crate::config::SecretKey; +#[cfg(feature = "db")] use crate::db::{ColumnType, DatabaseField, FromDbValue, SqlxValueRef, ToDbValue}; use crate::request::{Request, RequestExt}; @@ -403,20 +403,35 @@ impl Debug for PasswordHash { const MAX_PASSWORD_HASH_LENGTH: u32 = 128; +#[cfg(feature = "db")] impl DatabaseField for PasswordHash { const TYPE: ColumnType = ColumnType::String(MAX_PASSWORD_HASH_LENGTH); } +#[cfg(feature = "db")] impl FromDbValue for PasswordHash { - fn from_sqlite(value: SqliteValueRef) -> flareon::db::Result { + #[cfg(feature = "sqlite")] + fn from_sqlite(value: crate::db::impl_sqlite::SqliteValueRef) -> flareon::db::Result { PasswordHash::new(value.get::()?).map_err(flareon::db::DatabaseError::value_decode) } - fn from_postgres(value: PostgresValueRef) -> flareon::db::Result { + #[cfg(feature = "postgres")] + fn from_postgres( + value: crate::db::impl_postgres::PostgresValueRef, + ) -> flareon::db::Result { + PasswordHash::new(value.get::()?).map_err(flareon::db::DatabaseError::value_decode) + } + + #[cfg(feature = "mysql")] + fn from_mysql(value: crate::db::impl_mysql::MySqlValueRef) -> crate::db::Result + where + Self: Sized, + { PasswordHash::new(value.get::()?).map_err(flareon::db::DatabaseError::value_decode) } } +#[cfg(feature = "db")] impl ToDbValue for PasswordHash { fn to_sea_query_value(&self) -> sea_query::Value { self.0.clone().into() @@ -710,6 +725,28 @@ pub trait AuthBackend: Send + Sync { ) -> Result>>; } +#[derive(Debug, Copy, Clone)] +pub struct NoAuthBackend; + +#[async_trait] +impl AuthBackend for NoAuthBackend { + async fn authenticate( + &self, + _request: &Request, + _credentials: &(dyn Any + Send + Sync), + ) -> Result>> { + Ok(None) + } + + async fn get_by_id( + &self, + _request: &Request, + _id: UserId, + ) -> Result>> { + Ok(None) + } +} + #[cfg(test)] mod tests { use std::sync::Mutex; @@ -720,27 +757,6 @@ mod tests { use crate::config::ProjectConfig; use crate::test::TestRequestBuilder; - struct NoUserAuthBackend; - - #[async_trait] - impl AuthBackend for NoUserAuthBackend { - async fn authenticate( - &self, - _request: &Request, - _credentials: &(dyn Any + Send + Sync), - ) -> Result>> { - Ok(None) - } - - async fn get_by_id( - &self, - _request: &Request, - _id: UserId, - ) -> Result>> { - Ok(None) - } - } - struct MockAuthBackend { return_user: F, } @@ -894,7 +910,7 @@ mod tests { #[tokio::test] async fn user_anonymous() { - let mut request = test_request_with_auth_backend(NoUserAuthBackend {}); + let mut request = test_request_with_auth_backend(NoAuthBackend {}); let user = request.user().await.unwrap(); assert!(!user.is_authenticated()); @@ -955,7 +971,7 @@ mod tests { /// session (can happen if the user is deleted from the database) #[tokio::test] async fn logout_on_invalid_user_id_in_session() { - let mut request = test_request_with_auth_backend(NoUserAuthBackend {}); + let mut request = test_request_with_auth_backend(NoAuthBackend {}); request .session_mut() diff --git a/flareon/src/config.rs b/flareon/src/config.rs index 92396c1..2f4e872 100644 --- a/flareon/src/config.rs +++ b/flareon/src/config.rs @@ -16,6 +16,7 @@ use derive_builder::Builder; use derive_more::Debug; use subtle::ConstantTimeEq; +#[cfg(feature = "db")] use crate::auth::db::DatabaseUserBackend; use crate::auth::AuthBackend; @@ -63,6 +64,7 @@ pub struct ProjectConfig { #[debug("..")] #[builder(setter(custom))] auth_backend: Arc, + #[cfg(feature = "db")] database_config: DatabaseConfig, } @@ -81,17 +83,20 @@ impl ProjectConfigBuilder { .auth_backend .clone() .unwrap_or_else(default_auth_backend), + #[cfg(feature = "db")] database_config: self.database_config.clone().unwrap_or_default(), } } } +#[cfg(feature = "db")] #[derive(Debug, Clone, Builder)] pub struct DatabaseConfig { #[builder(setter(into))] url: String, } +#[cfg(feature = "db")] impl DatabaseConfig { #[must_use] pub fn builder() -> DatabaseConfigBuilder { @@ -104,6 +109,7 @@ impl DatabaseConfig { } } +#[cfg(feature = "db")] impl Default for DatabaseConfig { fn default() -> Self { Self { @@ -119,7 +125,15 @@ impl Default for ProjectConfig { } fn default_auth_backend() -> Arc { - Arc::new(DatabaseUserBackend::new()) + #[cfg(feature = "db")] + { + Arc::new(DatabaseUserBackend::new()) + } + + #[cfg(not(any(feature = "sqlite", feature = "postgres", feature = "mysql")))] + { + Arc::new(flareon::auth::NoAuthBackend) + } } impl ProjectConfig { @@ -144,6 +158,7 @@ impl ProjectConfig { } #[must_use] + #[cfg(feature = "db")] pub fn database_config(&self) -> &DatabaseConfig { &self.database_config } diff --git a/flareon/src/db.rs b/flareon/src/db.rs index 4aa18a3..738c275 100644 --- a/flareon/src/db.rs +++ b/flareon/src/db.rs @@ -4,10 +4,15 @@ //! the error types that can occur when interacting with the database. mod fields; +#[cfg(feature = "mysql")] +pub mod impl_mysql; +#[cfg(feature = "postgres")] pub mod impl_postgres; +#[cfg(feature = "sqlite")] pub mod impl_sqlite; pub mod migrations; pub mod query; +mod sea_query_db; use std::fmt::Write; use std::hash::Hash; @@ -24,8 +29,13 @@ use sea_query_binder::{SqlxBinder, SqlxValues}; use sqlx::{Type, TypeInfo}; use thiserror::Error; +#[cfg(feature = "mysql")] +use crate::db::impl_mysql::{DatabaseMySql, MySqlRow, MySqlValueRef}; +#[cfg(feature = "postgres")] use crate::db::impl_postgres::{DatabasePostgres, PostgresRow, PostgresValueRef}; +#[cfg(feature = "sqlite")] use crate::db::impl_sqlite::{DatabaseSqlite, SqliteRow, SqliteValueRef}; +use crate::db::migrations::ColumnTypeMapper; /// An error that can occur when interacting with the database. #[derive(Debug, Error)] @@ -206,8 +216,12 @@ impl Column { #[non_exhaustive] #[derive(Debug)] pub enum Row { + #[cfg(feature = "sqlite")] Sqlite(SqliteRow), + #[cfg(feature = "postgres")] Postgres(PostgresRow), + #[cfg(feature = "mysql")] + MySql(MySqlRow), } impl Row { @@ -223,12 +237,18 @@ impl Row { /// returned by the database. pub fn get(&self, index: usize) -> Result { let result = match self { + #[cfg(feature = "sqlite")] Row::Sqlite(sqlite_row) => sqlite_row .get_raw(index) .and_then(|value| T::from_sqlite(value))?, - Row::Postgres(postgres) => postgres + #[cfg(feature = "postgres")] + Row::Postgres(postgres_row) => postgres_row .get_raw(index) .and_then(|value| T::from_postgres(value))?, + #[cfg(feature = "mysql")] + Row::MySql(mysql_row) => mysql_row + .get_raw(index) + .and_then(|value| T::from_mysql(value))?, }; Ok(result) @@ -261,25 +281,38 @@ pub trait DatabaseField: FromDbValue + ToDbValue { /// A trait for converting a database value to a Rust value. pub trait FromDbValue { - /// Converts the given `SQLite` database value to a Rust value. + /// Converts the given SQLite database value to a Rust value. /// /// # Errors /// /// This method can return an error if the value is not compatible with the /// Rust type. + #[cfg(feature = "sqlite")] fn from_sqlite(value: SqliteValueRef) -> Result where Self: Sized; - /// Converts the given `Postgresql` database value to a Rust value. + /// Converts the given PostgreSQL database value to a Rust value. /// /// # Errors /// /// This method can return an error if the value is not compatible with the /// Rust type. + #[cfg(feature = "postgres")] fn from_postgres(value: PostgresValueRef) -> Result where Self: Sized; + + /// Converts the given MySQL database value to a Rust value. + /// + /// # Errors + /// + /// This method can return an error if the value is not compatible with the + /// Rust type. + #[cfg(feature = "mysql")] + fn from_mysql(value: MySqlValueRef) -> Result + where + Self: Sized; } /// A trait for converting a Rust value to a database value. @@ -342,8 +375,12 @@ pub struct Database { #[derive(Debug)] enum DatabaseImpl { + #[cfg(feature = "sqlite")] Sqlite(DatabaseSqlite), + #[cfg(feature = "postgres")] Postgres(DatabasePostgres), + #[cfg(feature = "mysql")] + MySql(DatabaseMySql), } impl Database { @@ -371,23 +408,35 @@ impl Database { /// ``` pub async fn new>(url: T) -> Result { let url = url.into(); - let db = if url.starts_with("sqlite:") { + + #[cfg(feature = "sqlite")] + if url.starts_with("sqlite:") { let inner = DatabaseSqlite::new(&url).await?; - Self { + return Ok(Self { _url: url, inner: DatabaseImpl::Sqlite(inner), - } - } else if url.starts_with("postgresql:") { + }); + } + + #[cfg(feature = "postgres")] + if url.starts_with("postgresql:") { let inner = DatabasePostgres::new(&url).await?; - Self { + return Ok(Self { _url: url, inner: DatabaseImpl::Postgres(inner), - } - } else { - todo!("Other databases are not supported yet"); - }; + }); + } + + #[cfg(feature = "mysql")] + if url.starts_with("mysql:") { + let inner = DatabaseMySql::new(&url).await?; + return Ok(Self { + _url: url, + inner: DatabaseImpl::MySql(inner), + }); + } - Ok(db) + panic!("Unsupported database URL: {url}"); } /// Closes the database connection. @@ -414,8 +463,12 @@ impl Database { /// ``` pub async fn close(&self) -> Result<()> { match &self.inner { + #[cfg(feature = "sqlite")] DatabaseImpl::Sqlite(inner) => inner.close().await, + #[cfg(feature = "postgres")] DatabaseImpl::Postgres(inner) => inner.close().await, + #[cfg(feature = "mysql")] + DatabaseImpl::MySql(inner) => inner.close().await, } } @@ -574,8 +627,12 @@ impl Database { let values = SqlxValues(sea_query::Values(values)); let result = match &self.inner { + #[cfg(feature = "sqlite")] DatabaseImpl::Sqlite(inner) => inner.raw_with(query, values).await?, + #[cfg(feature = "postgres")] DatabaseImpl::Postgres(inner) => inner.raw_with(query, values).await?, + #[cfg(feature = "mysql")] + DatabaseImpl::MySql(inner) => inner.raw_with(query, values).await?, }; Ok(result) @@ -586,10 +643,14 @@ impl Database { T: SqlxBinder, { let result = match &self.inner { + #[cfg(feature = "sqlite")] DatabaseImpl::Sqlite(inner) => inner.fetch_option(statement).await?.map(Row::Sqlite), + #[cfg(feature = "postgres")] DatabaseImpl::Postgres(inner) => { inner.fetch_option(statement).await?.map(Row::Postgres) } + #[cfg(feature = "mysql")] + DatabaseImpl::MySql(inner) => inner.fetch_option(statement).await?.map(Row::MySql), }; Ok(result) @@ -600,18 +661,27 @@ impl Database { T: SqlxBinder, { let result = match &self.inner { + #[cfg(feature = "sqlite")] DatabaseImpl::Sqlite(inner) => inner .fetch_all(statement) .await? .into_iter() .map(Row::Sqlite) .collect(), + #[cfg(feature = "postgres")] DatabaseImpl::Postgres(inner) => inner .fetch_all(statement) .await? .into_iter() .map(Row::Postgres) .collect(), + #[cfg(feature = "mysql")] + DatabaseImpl::MySql(inner) => inner + .fetch_all(statement) + .await? + .into_iter() + .map(Row::MySql) + .collect(), }; Ok(result) @@ -619,11 +689,15 @@ impl Database { async fn execute_statement(&self, statement: &T) -> Result where - T: SqlxBinder, + T: SqlxBinder + Sync, { let result = match &self.inner { + #[cfg(feature = "sqlite")] DatabaseImpl::Sqlite(inner) => inner.execute_statement(statement).await?, + #[cfg(feature = "postgres")] DatabaseImpl::Postgres(inner) => inner.execute_statement(statement).await?, + #[cfg(feature = "mysql")] + DatabaseImpl::MySql(inner) => inner.execute_statement(statement).await?, }; Ok(result) @@ -634,14 +708,31 @@ impl Database { statement: T, ) -> Result { let result = match &self.inner { + #[cfg(feature = "sqlite")] DatabaseImpl::Sqlite(inner) => inner.execute_schema(statement).await?, + #[cfg(feature = "postgres")] DatabaseImpl::Postgres(inner) => inner.execute_schema(statement).await?, + #[cfg(feature = "mysql")] + DatabaseImpl::MySql(inner) => inner.execute_schema(statement).await?, }; Ok(result) } } +impl ColumnTypeMapper for Database { + fn sea_query_column_type_for(&self, column_type: ColumnType) -> sea_query::ColumnType { + match &self.inner { + #[cfg(feature = "sqlite")] + DatabaseImpl::Sqlite(inner) => inner.sea_query_column_type_for(column_type), + #[cfg(feature = "postgres")] + DatabaseImpl::Postgres(inner) => inner.sea_query_column_type_for(column_type), + #[cfg(feature = "mysql")] + DatabaseImpl::MySql(inner) => inner.sea_query_column_type_for(column_type), + } + } +} + #[cfg_attr(test, automock)] #[async_trait] pub trait DatabaseBackend: Send + Sync { @@ -821,8 +912,7 @@ pub enum ColumnType { Time, Date, DateTime, - Timestamp, - TimestampWithTimeZone, + DateTimeWithTimeZone, Text, Blob, String(u32), diff --git a/flareon/src/db/fields.rs b/flareon/src/db/fields.rs index 1fa1b8e..47abc69 100644 --- a/flareon/src/db/fields.rs +++ b/flareon/src/db/fields.rs @@ -1,37 +1,45 @@ use flareon::db::DatabaseField; use sea_query::Value; +#[cfg(feature = "mysql")] +use crate::db::impl_mysql::MySqlValueRef; +#[cfg(feature = "postgres")] +use crate::db::impl_postgres::PostgresValueRef; +#[cfg(feature = "sqlite")] +use crate::db::impl_sqlite::SqliteValueRef; use crate::db::{ - ColumnType, DatabaseError, FromDbValue, LimitedString, PostgresValueRef, Result, - SqliteValueRef, SqlxValueRef, ToDbValue, + ColumnType, DatabaseError, FromDbValue, LimitedString, Result, SqlxValueRef, ToDbValue, }; -macro_rules! impl_db_field { - ($ty:ty, $column_type:ident) => { - impl DatabaseField for $ty { - const TYPE: ColumnType = ColumnType::$column_type; +macro_rules! impl_from_sqlite_default { + () => { + #[cfg(feature = "sqlite")] + fn from_sqlite(value: SqliteValueRef) -> Result { + value.get::() } + }; +} - impl FromDbValue for $ty { - fn from_sqlite(value: SqliteValueRef) -> Result { - value.get::<$ty>() - } - - fn from_postgres(value: PostgresValueRef) -> Result { - value.get::<$ty>() - } +macro_rules! impl_from_postgres_default { + () => { + #[cfg(feature = "postgres")] + fn from_postgres(value: PostgresValueRef) -> Result { + value.get::() } + }; +} - impl FromDbValue for Option<$ty> { - fn from_sqlite(value: SqliteValueRef) -> Result { - value.get::>() - } - - fn from_postgres(value: PostgresValueRef) -> Result { - value.get::>() - } +macro_rules! impl_from_mysql_default { + () => { + #[cfg(feature = "mysql")] + fn from_mysql(value: MySqlValueRef) -> Result { + value.get::() } + }; +} +macro_rules! impl_to_db_value_default { + ($ty:ty) => { impl ToDbValue for $ty { fn to_sea_query_value(&self) -> Value { self.clone().into() @@ -46,133 +54,84 @@ macro_rules! impl_db_field { }; } -macro_rules! impl_db_field_unsigned { - ($ty:ty, $signed_ty:ty, $column_type:ident) => { +macro_rules! impl_db_field { + ($ty:ty, $column_type:ident) => { impl DatabaseField for $ty { const TYPE: ColumnType = ColumnType::$column_type; } impl FromDbValue for $ty { - fn from_sqlite(value: SqliteValueRef) -> Result { - value.get::<$ty>() - } + impl_from_sqlite_default!(); - fn from_postgres(value: PostgresValueRef) -> Result { - value.get::<$signed_ty>().map(|v| v as $ty) - } + impl_from_postgres_default!(); + + impl_from_mysql_default!(); } impl FromDbValue for Option<$ty> { - fn from_sqlite(value: SqliteValueRef) -> Result { - value.get::>() - } + impl_from_sqlite_default!(); - fn from_postgres(value: PostgresValueRef) -> Result { - value - .get::>() - .map(|v| v.map(|v| v as $ty)) - } - } + impl_from_postgres_default!(); - impl ToDbValue for $ty { - fn to_sea_query_value(&self) -> Value { - self.clone().into() - } + impl_from_mysql_default!(); } - impl ToDbValue for Option<$ty> { - fn to_sea_query_value(&self) -> Value { - self.clone().into() - } - } + impl_to_db_value_default!($ty); }; } -impl DatabaseField for i8 { - const TYPE: ColumnType = ColumnType::TinyInteger; -} - -impl FromDbValue for i8 { - fn from_sqlite(value: SqliteValueRef) -> Result { - value.get::() - } - - fn from_postgres(value: PostgresValueRef) -> Result { - value.get::().map(|v| v as i8) - } -} - -impl FromDbValue for Option { - fn from_sqlite(value: SqliteValueRef) -> Result { - value.get::>() - } - - fn from_postgres(value: PostgresValueRef) -> Result { - value.get::>().map(|v| v.map(|v| v as i8)) - } -} - -impl ToDbValue for i8 { - fn to_sea_query_value(&self) -> Value { - (*self).into() - } -} - -impl ToDbValue for Option { - fn to_sea_query_value(&self) -> Value { - (*self).into() - } -} +macro_rules! impl_db_field_with_postgres_int_cast { + ($dest_ty:ty, $src_ty:ty, $column_type:ident) => { + impl DatabaseField for $dest_ty { + const TYPE: ColumnType = ColumnType::$column_type; + } -impl DatabaseField for u8 { - const TYPE: ColumnType = ColumnType::TinyUnsignedInteger; -} + impl FromDbValue for $dest_ty { + impl_from_sqlite_default!(); -impl FromDbValue for u8 { - fn from_sqlite(value: SqliteValueRef) -> Result { - value.get::() - } + impl_from_mysql_default!(); - fn from_postgres(value: PostgresValueRef) -> Result { - value.get::().map(|v| v as u8) - } -} + #[cfg(feature = "postgres")] + fn from_postgres(value: PostgresValueRef) -> Result { + #[allow(clippy::cast_possible_truncation)] + #[allow(clippy::cast_sign_loss)] + value.get::<$src_ty>().map(|v| v as $dest_ty) + } + } -impl FromDbValue for Option { - fn from_sqlite(value: SqliteValueRef) -> Result { - value.get::>() - } + impl FromDbValue for Option<$dest_ty> { + impl_from_sqlite_default!(); - fn from_postgres(value: PostgresValueRef) -> Result { - value.get::>().map(|v| v.map(|v| v as u8)) - } -} + impl_from_mysql_default!(); -impl ToDbValue for u8 { - fn to_sea_query_value(&self) -> Value { - (*self).into() - } -} + #[cfg(feature = "postgres")] + fn from_postgres(value: PostgresValueRef) -> Result { + #[allow(clippy::cast_possible_truncation)] + #[allow(clippy::cast_sign_loss)] + value + .get::>() + .map(|v| v.map(|v| v as $dest_ty)) + } + } -impl ToDbValue for Option { - fn to_sea_query_value(&self) -> Value { - (*self).into() - } + impl_to_db_value_default!($dest_ty); + }; } impl_db_field!(bool, Boolean); impl_db_field!(i16, SmallInteger); impl_db_field!(i32, Integer); impl_db_field!(i64, BigInteger); -impl_db_field_unsigned!(u16, i16, SmallUnsignedInteger); -impl_db_field_unsigned!(u32, i32, UnsignedInteger); -impl_db_field_unsigned!(u64, i64, BigUnsignedInteger); +impl_db_field_with_postgres_int_cast!(i8, i16, TinyInteger); +impl_db_field_with_postgres_int_cast!(u8, i16, TinyUnsignedInteger); +impl_db_field_with_postgres_int_cast!(u16, i16, SmallUnsignedInteger); +impl_db_field_with_postgres_int_cast!(u32, i32, UnsignedInteger); +impl_db_field_with_postgres_int_cast!(u64, i64, BigUnsignedInteger); impl_db_field!(f32, Float); impl_db_field!(f64, Double); impl_db_field!(chrono::NaiveDate, Date); impl_db_field!(chrono::NaiveTime, Time); impl_db_field!(chrono::NaiveDateTime, DateTime); -impl_db_field!(chrono::DateTime, TimestampWithTimeZone); impl_db_field!(String, Text); impl_db_field!(Vec, Blob); @@ -182,6 +141,35 @@ impl ToDbValue for &str { } } +impl DatabaseField for chrono::DateTime { + const TYPE: ColumnType = ColumnType::DateTimeWithTimeZone; +} + +impl FromDbValue for chrono::DateTime { + impl_from_sqlite_default!(); + + impl_from_postgres_default!(); + + #[cfg(feature = "mysql")] + fn from_mysql(value: MySqlValueRef) -> Result { + Ok(value.get::>()?.fixed_offset()) + } +} +impl FromDbValue for Option> { + impl_from_sqlite_default!(); + + impl_from_postgres_default!(); + + #[cfg(feature = "mysql")] + fn from_mysql(value: MySqlValueRef) -> Result { + Ok(value + .get::>>()? + .map(|dt| dt.fixed_offset())) + } +} + +impl_to_db_value_default!(chrono::DateTime); + impl ToDbValue for Option<&str> { fn to_sea_query_value(&self) -> Value { self.map(ToString::to_string).into() @@ -201,15 +189,23 @@ impl DatabaseField for LimitedString { } impl FromDbValue for LimitedString { + #[cfg(feature = "sqlite")] fn from_sqlite(value: SqliteValueRef) -> Result { let str = value.get::()?; Self::new(str).map_err(DatabaseError::value_decode) } + #[cfg(feature = "postgres")] fn from_postgres(value: PostgresValueRef) -> Result { let str = value.get::()?; Self::new(str).map_err(DatabaseError::value_decode) } + + #[cfg(feature = "mysql")] + fn from_mysql(value: MySqlValueRef) -> Result { + let str = value.get::()?; + Self::new(str).map_err(DatabaseError::value_decode) + } } impl ToDbValue for LimitedString { @@ -217,3 +213,9 @@ impl ToDbValue for LimitedString { self.0.clone().into() } } + +impl ToDbValue for Option> { + fn to_sea_query_value(&self) -> Value { + self.clone().map(|s| s.0).into() + } +} diff --git a/flareon/src/db/impl_mysql.rs b/flareon/src/db/impl_mysql.rs new file mode 100644 index 0000000..2314104 --- /dev/null +++ b/flareon/src/db/impl_mysql.rs @@ -0,0 +1,24 @@ +use crate::db::sea_query_db::impl_sea_query_db_backend; +use crate::db::ColumnType; + +impl_sea_query_db_backend!(DatabaseMySql: sqlx::mysql::MySql, sqlx::mysql::MySqlPool, MySqlRow, MySqlValueRef, sea_query::MysqlQueryBuilder); + +impl DatabaseMySql { + fn prepare_values(_values: &mut sea_query_binder::SqlxValues) { + // No changes are needed for MySQL + } + + pub(super) fn sea_query_column_type_for( + &self, + column_type: ColumnType, + ) -> sea_query::ColumnType { + match column_type { + ColumnType::DateTime | ColumnType::DateTimeWithTimeZone => { + return sea_query::ColumnType::custom("DATETIME(6)"); + } + _ => {} + } + + sea_query::ColumnType::from(column_type) + } +} diff --git a/flareon/src/db/impl_postgres.rs b/flareon/src/db/impl_postgres.rs index 6f16172..5ade464 100644 --- a/flareon/src/db/impl_postgres.rs +++ b/flareon/src/db/impl_postgres.rs @@ -1,114 +1,9 @@ -use derive_more::Debug; -use flareon::db::{SqlxRowRef, SqlxValueRef}; -use log::debug; -use sea_query::{PostgresQueryBuilder, SchemaStatementBuilder}; -use sea_query_binder::{SqlxBinder, SqlxValues}; -use sqlx::{Database, PgPool, Row}; +use crate::db::sea_query_db::impl_sea_query_db_backend; -use super::{Result, RowsNum, StatementResult}; - -#[derive(Debug)] -pub(super) struct DatabasePostgres { - db_connection: PgPool, -} +impl_sea_query_db_backend!(DatabasePostgres: sqlx::postgres::Postgres, sqlx::postgres::PgPool, PostgresRow, PostgresValueRef, sea_query::PostgresQueryBuilder); impl DatabasePostgres { - pub(super) async fn new(url: &str) -> Result { - let db_connection = PgPool::connect(url).await?; - - Ok(Self { db_connection }) - } - - pub(super) async fn close(&self) -> Result<()> { - self.db_connection.close().await; - Ok(()) - } - - pub(super) async fn fetch_option( - &self, - statement: &T, - ) -> Result> { - let (sql, values) = Self::build_sql(statement); - - let row = Self::sqlx_query_with(&sql, values) - .fetch_optional(&self.db_connection) - .await?; - Ok(row.map(PostgresRow::new)) - } - - pub(super) async fn fetch_all(&self, statement: &T) -> Result> { - let (sql, values) = Self::build_sql(statement); - - let result = Self::sqlx_query_with(&sql, values) - .fetch_all(&self.db_connection) - .await? - .into_iter() - .map(PostgresRow::new) - .collect(); - Ok(result) - } - - pub(super) async fn execute_statement( - &self, - statement: &T, - ) -> Result { - let (sql, mut values) = Self::build_sql(statement); - Self::prepare_values(&mut values); - - debug!("Postgres Query: `{}` (values: {:?})", sql, values); - - self.execute_sqlx(Self::sqlx_query_with(&sql, values)).await - } - - pub(super) async fn execute_schema( - &self, - statement: T, - ) -> Result { - let sql = statement.build(PostgresQueryBuilder); - debug!("Schema modification: {}", sql); - - self.execute_sqlx(sqlx::query(&sql)).await - } - - pub(super) async fn raw_with(&self, sql: &str, values: SqlxValues) -> Result { - self.execute_sqlx(Self::sqlx_query_with(sql, values)).await - } - - async fn execute_sqlx<'a, A>( - &self, - sqlx_statement: sqlx::query::Query<'a, sqlx::postgres::Postgres, A>, - ) -> Result - where - A: 'a + sqlx::IntoArguments<'a, sqlx::postgres::Postgres>, - { - let result = sqlx_statement.execute(&self.db_connection).await?; - let result = StatementResult { - rows_affected: RowsNum(result.rows_affected()), - }; - - debug!("Rows affected: {}", result.rows_affected.0); - Ok(result) - } - - fn build_sql(statement: &T) -> (String, SqlxValues) - where - T: SqlxBinder, - { - let (sql, values) = statement.build_sqlx(PostgresQueryBuilder); - debug!("Postgres Query: `{}` (values: {:?})", sql, values); - - (sql, values) - } - - fn sqlx_query_with( - sql: &str, - mut values: SqlxValues, - ) -> sqlx::query::Query<'_, sqlx::postgres::Postgres, SqlxValues> { - Self::prepare_values(&mut values); - sqlx::query_with(sql, values) - } - - fn prepare_values(values: &mut SqlxValues) { + fn prepare_values(values: &mut sea_query_binder::SqlxValues) { for value in &mut values.0 .0 { Self::tinyint_to_smallint(value); Self::unsigned_to_signed(value); @@ -120,15 +15,16 @@ impl DatabasePostgres { /// and we'll get an error. fn tinyint_to_smallint(value: &mut sea_query::Value) { if let sea_query::Value::TinyInt(num) = value { - *value = sea_query::Value::SmallInt(num.map(|v| v as i16)); + *value = sea_query::Value::SmallInt(num.map(i16::from)); } else if let sea_query::Value::TinyUnsigned(num) = value { - *value = sea_query::Value::SmallInt(num.map(|v| v as i16)); + *value = sea_query::Value::SmallInt(num.map(i16::from)); } } - /// PostgreSQL doesn't support unsigned integers, so we need to convert them - /// to signed integers. + /// PostgreSQL doesn't support unsigned integers, so we need to convert + /// them to signed integers. fn unsigned_to_signed(value: &mut sea_query::Value) { + #[allow(clippy::cast_possible_wrap)] if let sea_query::Value::SmallUnsigned(num) = value { *value = sea_query::Value::SmallInt(num.map(|v| v as i16)); } else if let sea_query::Value::Unsigned(num) = value { @@ -137,46 +33,11 @@ impl DatabasePostgres { *value = sea_query::Value::BigInt(num.map(|v| v as i64)); } } -} - -#[derive(Debug)] -pub struct PostgresRow { - #[debug("...")] - inner: sqlx::postgres::PgRow, -} - -impl PostgresRow { - #[must_use] - fn new(inner: sqlx::postgres::PgRow) -> Self { - Self { inner } - } -} -impl SqlxRowRef for PostgresRow { - type ValueRef<'r> = PostgresValueRef<'r>; - - fn get_raw(&self, index: usize) -> Result> { - Ok(PostgresValueRef::new(self.inner.try_get_raw(index)?)) - } -} - -#[derive(Debug)] -pub struct PostgresValueRef<'r> { - #[debug("...")] - inner: sqlx::postgres::PgValueRef<'r>, -} - -impl<'r> PostgresValueRef<'r> { - #[must_use] - fn new(inner: sqlx::postgres::PgValueRef<'r>) -> Self { - Self { inner } - } -} - -impl<'r> SqlxValueRef<'r> for PostgresValueRef<'r> { - type DB = sqlx::Postgres; - - fn get_raw(self) -> ::ValueRef<'r> { - self.inner + pub(super) fn sea_query_column_type_for( + &self, + column_type: crate::db::ColumnType, + ) -> sea_query::ColumnType { + sea_query::ColumnType::from(column_type) } } diff --git a/flareon/src/db/impl_sqlite.rs b/flareon/src/db/impl_sqlite.rs index fcd3d7d..5f228b1 100644 --- a/flareon/src/db/impl_sqlite.rs +++ b/flareon/src/db/impl_sqlite.rs @@ -1,141 +1,16 @@ -use derive_more::Debug; -use flareon::db::{SqlxRowRef, SqlxValueRef}; -use log::debug; -use sea_query::{SchemaStatementBuilder, SqliteQueryBuilder}; -use sea_query_binder::{SqlxBinder, SqlxValues}; -use sqlx::{Database, Row, SqlitePool}; +use crate::db::sea_query_db::impl_sea_query_db_backend; -use super::{Result, RowsNum, StatementResult}; - -#[derive(Debug)] -pub(super) struct DatabaseSqlite { - db_connection: SqlitePool, -} +impl_sea_query_db_backend!(DatabaseSqlite: sqlx::sqlite::Sqlite, sqlx::sqlite::SqlitePool, SqliteRow, SqliteValueRef, sea_query::SqliteQueryBuilder); impl DatabaseSqlite { - pub(super) async fn new(url: &str) -> Result { - let db_connection = SqlitePool::connect(url).await?; - - Ok(Self { db_connection }) - } - - pub(super) async fn close(&self) -> Result<()> { - self.db_connection.close().await; - Ok(()) - } - - pub(super) async fn fetch_option( - &self, - statement: &T, - ) -> Result> { - let (sql, values) = Self::build_sql(statement); - - let row = sqlx::query_with(&sql, values) - .fetch_optional(&self.db_connection) - .await?; - Ok(row.map(SqliteRow::new)) - } - - pub(super) async fn fetch_all(&self, statement: &T) -> Result> { - let (sql, values) = Self::build_sql(statement); - - let result = sqlx::query_with(&sql, values) - .fetch_all(&self.db_connection) - .await? - .into_iter() - .map(SqliteRow::new) - .collect(); - Ok(result) - } - - pub(super) async fn execute_statement( - &self, - statement: &T, - ) -> Result { - let (sql, values) = Self::build_sql(statement); - - self.execute_sqlx(sqlx::query_with(&sql, values)).await - } - - pub(super) async fn execute_schema( - &self, - statement: T, - ) -> Result { - let sql = statement.build(SqliteQueryBuilder); - debug!("Schema modification: {}", sql); - - self.execute_sqlx(sqlx::query(&sql)).await - } - - pub(super) async fn raw_with(&self, sql: &str, values: SqlxValues) -> Result { - self.execute_sqlx(sqlx::query_with(sql, values)).await + fn prepare_values(_values: &mut sea_query_binder::SqlxValues) { + // No changes are needed for SQLite } - async fn execute_sqlx<'a, A>( + pub(super) fn sea_query_column_type_for( &self, - sqlx_statement: sqlx::query::Query<'a, sqlx::sqlite::Sqlite, A>, - ) -> Result - where - A: 'a + sqlx::IntoArguments<'a, sqlx::sqlite::Sqlite>, - { - let result = sqlx_statement.execute(&self.db_connection).await?; - let result = StatementResult { - rows_affected: RowsNum(result.rows_affected()), - }; - - debug!("Rows affected: {}", result.rows_affected.0); - Ok(result) - } - - fn build_sql(statement: &T) -> (String, SqlxValues) - where - T: SqlxBinder, - { - let (sql, values) = statement.build_sqlx(SqliteQueryBuilder); - debug!("SQLite Query: `{}` (values: {:?})", sql, values); - - (sql, values) - } -} - -#[derive(Debug)] -pub struct SqliteRow { - #[debug("...")] - inner: sqlx::sqlite::SqliteRow, -} - -impl SqliteRow { - #[must_use] - fn new(inner: sqlx::sqlite::SqliteRow) -> Self { - Self { inner } - } -} - -impl SqlxRowRef for SqliteRow { - type ValueRef<'r> = SqliteValueRef<'r>; - - fn get_raw(&self, index: usize) -> Result> { - Ok(SqliteValueRef::new(self.inner.try_get_raw(index)?)) - } -} - -#[derive(Debug)] -pub struct SqliteValueRef<'r> { - #[debug("...")] - inner: sqlx::sqlite::SqliteValueRef<'r>, -} - -impl<'r> SqliteValueRef<'r> { - #[must_use] - fn new(inner: sqlx::sqlite::SqliteValueRef<'r>) -> Self { - Self { inner } - } -} - -impl<'r> SqlxValueRef<'r> for SqliteValueRef<'r> { - type DB = sqlx::Sqlite; - - fn get_raw(self) -> ::ValueRef<'r> { - self.inner + column_type: crate::db::ColumnType, + ) -> sea_query::ColumnType { + sea_query::ColumnType::from(column_type) } } diff --git a/flareon/src/db/migrations.rs b/flareon/src/db/migrations.rs index d484bcd..9cf36bb 100644 --- a/flareon/src/db/migrations.rs +++ b/flareon/src/db/migrations.rs @@ -226,7 +226,7 @@ impl Operation { } => { let mut query = sea_query::Table::create().table(*table_name).to_owned(); for field in *fields { - query.col(ColumnDef::from(field)); + query.col(field.as_column_def(database)); } if *if_not_exists { query.if_not_exists(); @@ -236,7 +236,7 @@ impl Operation { OperationInner::AddField { table_name, field } => { let query = sea_query::Table::alter() .table(*table_name) - .add_column(ColumnDef::from(field)) + .add_column(field.as_column_def(database)) .to_owned(); database.execute_schema(query).await?; } @@ -372,27 +372,33 @@ impl Field { self.unique = true; self } -} -impl From<&Field> for ColumnDef { - fn from(column: &Field) -> Self { - let mut def = ColumnDef::new_with_type(column.name, column.ty.into()); - if column.primary_key { + fn as_column_def(&self, mapper: &T) -> ColumnDef { + let mut def = + ColumnDef::new_with_type(self.name, mapper.sea_query_column_type_for(self.ty)); + if self.primary_key { def.primary_key(); } - if column.auto_value { + if self.auto_value { def.auto_increment(); } - if column.null { + if self.null { def.null(); + } else { + def.not_null(); } - if column.unique { + if self.unique { def.unique_key(); } def } } +#[cfg_attr(test, mockall::automock)] +pub(super) trait ColumnTypeMapper { + fn sea_query_column_type_for(&self, column_type: ColumnType) -> sea_query::ColumnType; +} + macro_rules! unwrap_builder_option { ($self:ident, $field:ident) => { match $self.$field { @@ -599,8 +605,7 @@ impl From for sea_query::ColumnType { ColumnType::Time => Self::Time, ColumnType::Date => Self::Date, ColumnType::DateTime => Self::DateTime, - ColumnType::Timestamp => Self::Timestamp, - ColumnType::TimestampWithTimeZone => Self::TimestampWithTimeZone, + ColumnType::DateTimeWithTimeZone => Self::TimestampWithTimeZone, ColumnType::Text => Self::Text, ColumnType::Blob => Self::Blob, ColumnType::String(len) => Self::String(StringLen::N(len)), @@ -752,7 +757,11 @@ mod tests { .null() .unique(); - let column_def = ColumnDef::from(&field); + let mut mapper = MockColumnTypeMapper::new(); + mapper + .expect_sea_query_column_type_for() + .return_const(sea_query::ColumnType::Integer); + let column_def = field.as_column_def(&mapper); assert_eq!(column_def.get_column_name(), "id"); assert_eq!( @@ -769,7 +778,11 @@ mod tests { fn test_field_to_column_def_without_options() { let field = Field::new(Identifier::new("name"), ColumnType::Text); - let column_def = ColumnDef::from(&field); + let mut mapper = MockColumnTypeMapper::new(); + mapper + .expect_sea_query_column_type_for() + .return_const(sea_query::ColumnType::Text); + let column_def = field.as_column_def(&mapper); assert_eq!(column_def.get_column_name(), "name"); assert_eq!( diff --git a/flareon/src/db/sea_query_db.rs b/flareon/src/db/sea_query_db.rs new file mode 100644 index 0000000..fc08016 --- /dev/null +++ b/flareon/src/db/sea_query_db.rs @@ -0,0 +1,162 @@ +/// Implements the database backend for a specific engine using `SeaQuery`. +/// +/// Note that this macro doesn't implement certain engine-specific methods, and +/// they need to be implemented in a separate `impl` block. These methods are: +/// * `prepare_values` +/// * `sea_query_column_type_for` +macro_rules! impl_sea_query_db_backend { + ($db_name:ident : $sqlx_db_ty:ty, $pool_ty:ty, $row_name:ident, $value_ref_name:ident, $query_builder:expr) => { + #[derive(Debug)] + pub(super) struct $db_name { + db_connection: $pool_ty, + } + + impl $db_name { + pub(super) async fn new(url: &str) -> crate::db::Result { + let db_connection = <$pool_ty>::connect(url).await?; + + Ok(Self { db_connection }) + } + + pub(super) async fn close(&self) -> crate::db::Result<()> { + self.db_connection.close().await; + Ok(()) + } + + pub(super) async fn fetch_option( + &self, + statement: &T, + ) -> crate::db::Result> { + let (sql, values) = Self::build_sql(statement); + + let row = Self::sqlx_query_with(&sql, values) + .fetch_optional(&self.db_connection) + .await?; + Ok(row.map($row_name::new)) + } + + pub(super) async fn fetch_all( + &self, + statement: &T, + ) -> crate::db::Result> { + let (sql, values) = Self::build_sql(statement); + + let result = Self::sqlx_query_with(&sql, values) + .fetch_all(&self.db_connection) + .await? + .into_iter() + .map($row_name::new) + .collect(); + Ok(result) + } + + pub(super) async fn execute_statement( + &self, + statement: &T, + ) -> crate::db::Result { + let (sql, mut values) = Self::build_sql(statement); + Self::prepare_values(&mut values); + + self.execute_sqlx(Self::sqlx_query_with(&sql, values)).await + } + + pub(super) async fn execute_schema( + &self, + statement: T, + ) -> crate::db::Result { + let sql = statement.build($query_builder); + log::debug!("Schema modification: {}", sql); + + self.execute_sqlx(sqlx::query(&sql)).await + } + + pub(super) async fn raw_with( + &self, + sql: &str, + values: sea_query_binder::SqlxValues, + ) -> crate::db::Result { + self.execute_sqlx(Self::sqlx_query_with(sql, values)).await + } + + async fn execute_sqlx<'a, A>( + &self, + sqlx_statement: sqlx::query::Query<'a, $sqlx_db_ty, A>, + ) -> crate::db::Result + where + A: 'a + sqlx::IntoArguments<'a, $sqlx_db_ty>, + { + let result = sqlx_statement.execute(&self.db_connection).await?; + let result = crate::db::StatementResult { + rows_affected: crate::db::RowsNum(result.rows_affected()), + }; + + log::debug!("Rows affected: {}", result.rows_affected.0); + Ok(result) + } + + fn build_sql(statement: &T) -> (String, sea_query_binder::SqlxValues) + where + T: sea_query_binder::SqlxBinder, + { + let (sql, values) = statement.build_sqlx($query_builder); + + (sql, values) + } + + fn sqlx_query_with( + sql: &str, + mut values: sea_query_binder::SqlxValues, + ) -> sqlx::query::Query<'_, $sqlx_db_ty, sea_query_binder::SqlxValues> { + Self::prepare_values(&mut values); + log::debug!("Query: `{}` (values: {:?})", sql, values); + + sqlx::query_with(sql, values) + } + } + + #[derive(derive_more::Debug)] + pub struct $row_name { + #[debug("...")] + inner: <$sqlx_db_ty as sqlx::Database>::Row, + } + + impl $row_name { + #[must_use] + fn new(inner: <$sqlx_db_ty as sqlx::Database>::Row) -> Self { + Self { inner } + } + } + + impl crate::db::SqlxRowRef for $row_name { + type ValueRef<'r> = $value_ref_name<'r>; + + fn get_raw(&self, index: usize) -> crate::db::Result> { + use sqlx::Row; + Ok($value_ref_name::new(self.inner.try_get_raw(index)?)) + } + } + + #[derive(derive_more::Debug)] + pub struct $value_ref_name<'r> { + #[debug("...")] + inner: <$sqlx_db_ty as sqlx::Database>::ValueRef<'r>, + } + + impl<'r> $value_ref_name<'r> { + #[must_use] + fn new(inner: <$sqlx_db_ty as sqlx::Database>::ValueRef<'r>) -> Self { + Self { inner } + } + } + + impl<'r> crate::db::SqlxValueRef<'r> for $value_ref_name<'r> { + type DB = $sqlx_db_ty; + + fn get_raw(self) -> ::ValueRef<'r> { + self.inner + } + } + }; +} + +pub(super) use impl_sea_query_db_backend; diff --git a/flareon/src/error.rs b/flareon/src/error.rs index c356fca..0cd7a12 100644 --- a/flareon/src/error.rs +++ b/flareon/src/error.rs @@ -68,6 +68,7 @@ impl From for askama::Error { impl_error_from_repr!(askama::Error); impl_error_from_repr!(crate::router::path::ReverseError); +#[cfg(feature = "db")] impl_error_from_repr!(crate::db::DatabaseError); impl_error_from_repr!(crate::forms::FormError); impl_error_from_repr!(crate::auth::AuthError); @@ -105,6 +106,7 @@ pub(crate) enum ErrorRepr { TemplateRender(#[from] askama::Error), /// An error occurred while communicating with the database. #[error("Database error: {0}")] + #[cfg(feature = "db")] DatabaseError(#[from] crate::db::DatabaseError), /// An error occurred while parsing a form. #[error("Failed to process a form: {0}")] diff --git a/flareon/src/lib.rs b/flareon/src/lib.rs index 4270a8a..8813223 100644 --- a/flareon/src/lib.rs +++ b/flareon/src/lib.rs @@ -45,6 +45,7 @@ extern crate self as flareon; +#[cfg(feature = "db")] pub mod db; mod error; pub mod forms; @@ -76,8 +77,6 @@ use axum::handler::HandlerWithoutStateExt; use bytes::Bytes; use derive_more::{Debug, Deref, Display, From}; pub use error::Error; -use flareon::config::DatabaseConfig; -use flareon::router::RouterService; pub use flareon_macros::main; use futures_core::Stream; use futures_util::FutureExt; @@ -91,12 +90,17 @@ use tower::util::BoxCloneService; use tower::Service; use crate::admin::AdminModelManager; +#[cfg(feature = "db")] +use crate::config::DatabaseConfig; use crate::config::ProjectConfig; +#[cfg(feature = "db")] use crate::db::migrations::{DynMigration, MigrationEngine}; +#[cfg(feature = "db")] use crate::db::Database; use crate::error::ErrorRepr; use crate::error_page::{ErrorPageTrigger, FlareonDiagnostics}; use crate::response::Response; +use crate::router::RouterService; /// A type alias for a result that can return a `flareon::Error`. pub type Result = std::result::Result; @@ -156,6 +160,7 @@ pub trait FlareonApp: Send + Sync { Router::empty() } + #[cfg(feature = "db")] fn migrations(&self) -> Vec> { vec![] } @@ -352,6 +357,7 @@ pub struct AppContext { #[debug("...")] apps: Vec>, router: Arc, + #[cfg(feature = "db")] database: Option>, } @@ -361,12 +367,13 @@ impl AppContext { config: Arc, apps: Vec>, router: Arc, - database: Option>, + #[cfg(feature = "db")] database: Option>, ) -> Self { Self { config, apps, router, + #[cfg(feature = "db")] database, } } @@ -387,11 +394,13 @@ impl AppContext { } #[must_use] + #[cfg(feature = "db")] pub fn try_database(&self) -> Option<&Arc> { self.database.as_ref() } #[must_use] + #[cfg(feature = "db")] pub fn database(&self) -> &Database { self.try_database().expect( "Database missing. Did you forget to add the database when configuring FlareonProject?", @@ -419,6 +428,7 @@ impl FlareonProjectBuilder { config: Arc::new(ProjectConfig::default()), apps: vec![], router: Arc::new(Router::default()), + #[cfg(feature = "db")] database: None, }, urls: Vec::new(), @@ -518,8 +528,11 @@ where /// Builds the Flareon project instance. pub async fn build(mut self) -> Result { - let database = Self::init_database(self.context.config.database_config()).await?; - self.context.database = Some(database); + #[cfg(feature = "db")] + { + let database = Self::init_database(self.context.config.database_config()).await?; + self.context.database = Some(database); + } Ok(FlareonProject { context: self.context, @@ -527,6 +540,7 @@ where }) } + #[cfg(feature = "db")] async fn init_database(config: &DatabaseConfig) -> Result> { let database = Database::new(config.url()).await?; Ok(Arc::new(database)) @@ -583,6 +597,7 @@ pub async fn run(project: FlareonProject, address_str: &str) -> Result<()> { pub async fn run_at(project: FlareonProject, listener: tokio::net::TcpListener) -> Result<()> { let (mut context, mut project_handler) = project.into_context(); + #[cfg(feature = "db")] if let Some(database) = &context.database { let mut migrations: Vec> = Vec::new(); for app in &context.apps { @@ -601,6 +616,7 @@ pub async fn run_at(project: FlareonProject, listener: tokio::net::TcpListener) context.apps = apps; let context = Arc::new(context); + #[cfg(feature = "db")] let context_cleanup = context.clone(); let handler = |axum_request: axum::extract::Request| async move { @@ -660,6 +676,7 @@ pub async fn run_at(project: FlareonProject, listener: tokio::net::TcpListener) if config::REGISTER_PANIC_HOOK { let _ = std::panic::take_hook(); } + #[cfg(feature = "db")] if let Some(database) = &context_cleanup.database { database.close().await?; } diff --git a/flareon/src/request.rs b/flareon/src/request.rs index 3708cc5..6126087 100644 --- a/flareon/src/request.rs +++ b/flareon/src/request.rs @@ -20,6 +20,7 @@ use bytes::Bytes; use indexmap::IndexMap; use tower_sessions::Session; +#[cfg(feature = "db")] use crate::db::Database; use crate::error::ErrorRepr; use crate::headers::FORM_CONTENT_TYPE; @@ -57,6 +58,7 @@ pub trait RequestExt: private::Sealed { #[must_use] fn path_params_mut(&mut self) -> &mut PathParams; + #[cfg(feature = "db")] #[must_use] fn db(&self) -> &Database; @@ -116,6 +118,7 @@ impl RequestExt for Request { self.extensions_mut().get_or_insert_default::() } + #[cfg(feature = "db")] fn db(&self) -> &Database { self.context().database() } diff --git a/flareon/src/test.rs b/flareon/src/test.rs index 8639dfc..ee086d9 100644 --- a/flareon/src/test.rs +++ b/flareon/src/test.rs @@ -1,23 +1,23 @@ //! Test utilities for Flareon projects. use std::future::poll_fn; -use std::mem; -use std::ops::Deref; use std::sync::Arc; use derive_more::Debug; -use flareon::{prepare_request, FlareonProject}; use tower::Service; use tower_sessions::{MemoryStore, Session}; +#[cfg(feature = "db")] use crate::auth::db::DatabaseUserBackend; use crate::config::ProjectConfig; +#[cfg(feature = "db")] use crate::db::migrations::{DynMigration, MigrationEngine, MigrationWrapper}; +#[cfg(feature = "db")] use crate::db::Database; use crate::request::{Request, RequestExt}; use crate::response::Response; use crate::router::Router; -use crate::{AppContext, Body, BoxedHandler, Result}; +use crate::{prepare_request, AppContext, Body, BoxedHandler, FlareonProject, Result}; /// A test client for making requests to a Flareon project. /// @@ -61,6 +61,7 @@ pub struct TestRequestBuilder { url: String, session: Option, config: Option>, + #[cfg(feature = "db")] database: Option>, form_data: Option>, } @@ -110,11 +111,13 @@ impl TestRequestBuilder { self } + #[cfg(feature = "db")] pub fn database(&mut self, database: Arc) -> &mut Self { self.database = Some(database); self } + #[cfg(feature = "db")] pub fn with_db_auth(&mut self, db: Arc) -> &mut Self { let auth_backend = DatabaseUserBackend; let config = ProjectConfig::builder().auth_backend(auth_backend).build(); @@ -147,6 +150,7 @@ impl TestRequestBuilder { self.config.clone().unwrap_or_default(), Vec::new(), Arc::new(Router::empty()), + #[cfg(feature = "db")] self.database.clone(), ); prepare_request(&mut request, Arc::new(app_context)); @@ -176,6 +180,7 @@ impl TestRequestBuilder { } } +#[cfg(feature = "db")] #[derive(Debug)] pub struct TestDatabase { database: Arc, @@ -183,6 +188,7 @@ pub struct TestDatabase { migrations: Vec, } +#[cfg(feature = "db")] impl TestDatabase { fn new(database: Database, kind: TestDatabaseKind) -> TestDatabase { Self { @@ -198,12 +204,13 @@ impl TestDatabase { Ok(Self::new(database, TestDatabaseKind::Sqlite)) } - /// Create a new Postgres database for testing and connects to it. + /// Create a new PostgreSQL database for testing and connects to it. /// /// The database URL is read from the `POSTGRES_URL` environment variable. /// Note that it shouldn't include the database name — the function will /// create a new database for the test by connecting to the `postgres` - /// database. + /// database. If no URL is provided, it defaults to + /// `postgresql://flareon:flareon@localhost`. /// /// The database is created with the name `test_flareon__{test_name}`. /// Make sure that `test_name` is unique for each test so that the databases @@ -213,15 +220,15 @@ impl TestDatabase { /// means that the database will not be dropped if the test panics. pub async fn new_postgres(test_name: &str) -> Result { let db_url = std::env::var("POSTGRES_URL") - .unwrap_or_else(|_| "postgresql://flareon:flareon@localhost:5432".to_string()); + .unwrap_or_else(|_| "postgresql://flareon:flareon@localhost".to_string()); let database = Database::new(format!("{db_url}/postgres")).await?; - let test_database_name = format!("test_flareon__{}", test_name); + let test_database_name = format!("test_flareon__{test_name}"); database - .raw(&format!("DROP DATABASE IF EXISTS {}", test_database_name)) + .raw(&format!("DROP DATABASE IF EXISTS {test_database_name}")) .await?; database - .raw(&format!("CREATE DATABASE {}", test_database_name)) + .raw(&format!("CREATE DATABASE {test_database_name}")) .await?; database.close().await?; @@ -236,6 +243,45 @@ impl TestDatabase { )) } + /// Create a new MySQL database for testing and connects to it. + /// + /// The database URL is read from the `MYSQL_URL` environment variable. + /// Note that it shouldn't include the database name — the function will + /// create a new database for the test by connecting to the `mysql` + /// database. If no URL is provided, it defaults to + /// `mysql://root:@localhost`. + /// + /// The database is created with the name `test_flareon__{test_name}`. + /// Make sure that `test_name` is unique for each test so that the databases + /// don't conflict with each other. + /// + /// The database is dropped when `self.cleanup()` is called. Note that this + /// means that the database will not be dropped if the test panics. + pub async fn new_mysql(test_name: &str) -> Result { + let db_url = + std::env::var("MYSQL_URL").unwrap_or_else(|_| "mysql://root:@localhost".to_string()); + let database = Database::new(format!("{db_url}/mysql")).await?; + + let test_database_name = format!("test_flareon__{test_name}"); + database + .raw(&format!("DROP DATABASE IF EXISTS {test_database_name}")) + .await?; + database + .raw(&format!("CREATE DATABASE {test_database_name}")) + .await?; + database.close().await?; + + let database = Database::new(format!("{db_url}/{test_database_name}")).await?; + + Ok(Self::new( + database, + TestDatabaseKind::MySql { + db_url, + db_name: test_database_name, + }, + )) + } + pub fn add_migrations>( &mut self, migrations: V, @@ -245,6 +291,7 @@ impl TestDatabase { self } + #[cfg(feature = "db")] pub fn with_auth(&mut self) -> &mut Self { self.add_migrations(flareon::auth::db::migrations::MIGRATIONS.to_vec()); self @@ -252,7 +299,7 @@ impl TestDatabase { pub async fn run_migrations(&mut self) -> &mut Self { if !self.migrations.is_empty() { - let engine = MigrationEngine::new(mem::take(&mut self.migrations)); + let engine = MigrationEngine::new(std::mem::take(&mut self.migrations)); engine.run(&self.database()).await.unwrap(); } self @@ -270,7 +317,13 @@ impl TestDatabase { TestDatabaseKind::Postgres { db_url, db_name } => { let database = Database::new(format!("{db_url}/postgres")).await?; - database.raw(&format!("DROP DATABASE {}", db_name)).await?; + database.raw(&format!("DROP DATABASE {db_name}")).await?; + database.close().await?; + } + TestDatabaseKind::MySql { db_url, db_name } => { + let database = Database::new(format!("{db_url}/mysql")).await?; + + database.raw(&format!("DROP DATABASE {db_name}")).await?; database.close().await?; } } @@ -279,7 +332,8 @@ impl TestDatabase { } } -impl Deref for TestDatabase { +#[cfg(feature = "db")] +impl std::ops::Deref for TestDatabase { type Target = Database; fn deref(&self) -> &Self::Target { @@ -287,8 +341,10 @@ impl Deref for TestDatabase { } } +#[cfg(feature = "db")] #[derive(Debug, Clone)] enum TestDatabaseKind { Sqlite, Postgres { db_url: String, db_name: String }, + MySql { db_url: String, db_name: String }, } diff --git a/flareon/tests/db.rs b/flareon/tests/db.rs index 68b1c69..466e9fa 100644 --- a/flareon/tests/db.rs +++ b/flareon/tests/db.rs @@ -116,6 +116,7 @@ struct AllFieldsModel { field_f64: f64, field_date: chrono::NaiveDate, field_time: chrono::NaiveTime, + #[dummy(faker = "fake::chrono::Precision::<6>")] field_datetime: chrono::NaiveDateTime, #[dummy(faker = "fake::chrono::Precision::<6>")] field_datetime_timezone: chrono::DateTime, @@ -177,7 +178,14 @@ async fn all_fields_model(db: &mut TestDatabase) { normalize_datetimes(&mut models_from_db); assert_eq!(models.len(), models_from_db.len()); - assert!(models.iter().all(|model| models_from_db.contains(model))); + for model in &models { + assert!( + models_from_db.contains(model), + "Could not find model {:?} in models_from_db: {:?}", + model, + models_from_db + ); + } } /// Normalize the datetimes to UTC.