From 016ffe8d288eba6589a602afd4b5cdad8b1fc602 Mon Sep 17 00:00:00 2001 From: Valentin Chanas Date: Wed, 15 Nov 2023 15:07:22 +0100 Subject: [PATCH] editoast: refactor main.rs db_pool --- editoast/src/main.rs | 126 ++++++++++++++++++++----------------------- 1 file changed, 58 insertions(+), 68 deletions(-) diff --git a/editoast/src/main.rs b/editoast/src/main.rs index 1823fbfb9aa..a5f0726b36c 100644 --- a/editoast/src/main.rs +++ b/editoast/src/main.rs @@ -32,8 +32,8 @@ use clap::Parser; use client::{ ClearArgs, Client, Color, Commands, DeleteProfileSetArgs, ElectricalProfilesCommands, GenerateArgs, ImportProfileSetArgs, ImportRailjsonArgs, ImportRollingStockArgs, InfraCloneArgs, - InfraCommands, ListProfileSetArgs, MakeMigrationArgs, PostgresConfig, RedisConfig, RefreshArgs, - RunserverArgs, SearchCommands, + InfraCommands, ListProfileSetArgs, MakeMigrationArgs, RedisConfig, RefreshArgs, RunserverArgs, + SearchCommands, }; use colored::*; use diesel::{sql_query, ConnectionError, ConnectionResult}; @@ -82,6 +82,13 @@ async fn main() { async fn run() -> Result<(), Box> { let client = Client::parse(); let pg_config = client.postgres_config; + let create_db_pool = || { + Ok::<_, Box>(Data::new(get_pool( + pg_config.url()?, + pg_config.pool_size, + ))) + }; + let redis_config = client.redis_config; match client.color { @@ -91,10 +98,10 @@ async fn run() -> Result<(), Box> { } match client.command { - Commands::Runserver(args) => runserver(args, pg_config, redis_config).await, - Commands::Generate(args) => generate(args, pg_config, redis_config).await, - Commands::ImportRailjson(args) => import_railjson(args, pg_config).await, - Commands::ImportRollingStock(args) => import_rolling_stock(args, pg_config).await, + Commands::Runserver(args) => runserver(args, create_db_pool()?, redis_config).await, + Commands::Generate(args) => generate(args, create_db_pool()?, redis_config).await, + Commands::ImportRailjson(args) => import_railjson(args, create_db_pool()?).await, + Commands::ImportRollingStock(args) => import_rolling_stock(args, create_db_pool()?).await, Commands::OsmToRailjson(args) => { converters::osm_to_railjson(args.osm_pbf_in, args.railjson_out) } @@ -104,13 +111,13 @@ async fn run() -> Result<(), Box> { } Commands::ElectricalProfiles(subcommand) => match subcommand { ElectricalProfilesCommands::Import(args) => { - electrical_profile_set_import(args, pg_config).await + electrical_profile_set_import(args, create_db_pool()?).await } ElectricalProfilesCommands::List(args) => { - electrical_profile_set_list(args, pg_config).await + electrical_profile_set_list(args, create_db_pool()?).await } ElectricalProfilesCommands::Delete(args) => { - electrical_profile_set_delete(args, pg_config).await + electrical_profile_set_delete(args, create_db_pool()?).await } }, Commands::Search(SearchCommands::List) => { @@ -122,11 +129,11 @@ async fn run() -> Result<(), Box> { Ok(()) } Commands::Search(SearchCommands::Refresh(args)) => { - refresh_search_tables(args, pg_config).await + refresh_search_tables(args, create_db_pool()?).await } Commands::Infra(subcommand) => match subcommand { - InfraCommands::Clone(args) => clone_infra(args, pg_config).await, - InfraCommands::Clear(args) => clear_infra(args, pg_config, redis_config).await, + InfraCommands::Clone(args) => clone_infra(args, create_db_pool()?).await, + InfraCommands::Clear(args) => clear_infra(args, create_db_pool()?, redis_config).await, }, } } @@ -204,13 +211,11 @@ fn get_pool(url: Url, max_size: usize) -> DbPool { /// Create and run the server async fn runserver( args: RunserverArgs, - pg_config: PostgresConfig, + db_pool: Data, redis_config: RedisConfig, ) -> Result<(), Box> { info!("Building server..."); - // Config databases - let pool = get_pool(pg_config.url()?, pg_config.pool_size); - + // Config database let redis = RedisClient::new(redis_config)?; // Custom Json extractor configuration @@ -266,7 +271,7 @@ async fn runserver( .wrap(Logger::new(actix_logger_format).log_target("actix_logger")) .app_data(json_cfg.clone()) .app_data(payload_config.clone()) - .app_data(Data::new(pool.clone())) + .app_data(db_pool.clone()) .app_data(Data::new(redis.clone())) .app_data(infra_caches.clone()) .app_data(Data::new(MapLayers::parse())) @@ -299,11 +304,10 @@ async fn build_redis_pool_and_invalidate_all_cache(redis_config: RedisConfig, in /// This command refresh all infra given as input (if no infra given then refresh all of them) async fn generate( args: GenerateArgs, - pg_config: PostgresConfig, + db_pool: Data, redis_config: RedisConfig, ) -> Result<(), Box> { - let pool = Data::new(get_pool(pg_config.url()?, pg_config.pool_size)); - let mut conn = pool.get().await?; + let mut conn = db_pool.get().await?; let mut infras = vec![]; if args.infra_ids.is_empty() { // Retrieve all available infra @@ -313,7 +317,7 @@ async fn generate( } else { // Retrieve given infras for id in args.infra_ids { - let infra = match Infra::retrieve(pool.clone(), id as i64).await? { + let infra = match Infra::retrieve(db_pool.clone(), id as i64).await? { Some(infra) => infra, None => { return Err(InfraApiError::NotFound { @@ -335,7 +339,7 @@ async fn generate( ); let infra_cache = InfraCache::load(&mut conn, &infra).await?; if infra - .refresh(pool.clone(), args.force, &infra_cache) + .refresh(db_pool.clone(), args.force, &infra_cache) .await? { build_redis_pool_and_invalidate_all_cache(redis_config.clone(), infra.id.unwrap()) @@ -362,9 +366,8 @@ async fn generate( async fn import_rolling_stock( args: ImportRollingStockArgs, - pg_config: PostgresConfig, + db_pool: Data, ) -> Result<(), Box> { - let pool = Data::new(get_pool(pg_config.url()?, pg_config.pool_size)); for rolling_stock_path in args.rolling_stock_path { let rolling_stock_file = File::open(rolling_stock_path)?; let mut rolling_stock: RollingStockModel = @@ -375,7 +378,7 @@ async fn import_rolling_stock( ); rolling_stock.locked = Some(false); rolling_stock.version = Some(0); - let rolling_stock = rolling_stock.create(pool.clone()).await?; + let rolling_stock = rolling_stock.create(db_pool.clone()).await?; info!( "✅ Rolling stock {}[{}] saved!", rolling_stock.name.clone().unwrap().bold(), @@ -387,9 +390,8 @@ async fn import_rolling_stock( async fn clone_infra( infra_args: InfraCloneArgs, - pg_config: PostgresConfig, + db_pool: Data, ) -> Result<(), Box> { - let db_pool = Data::new(get_pool(pg_config.url()?, pg_config.pool_size)); match Infra::clone(infra_args.id, db_pool, infra_args.new_name).await { Ok(cloned_infra) => println!( "✅ Infra {} (ID: {}) was successfully cloned", @@ -409,12 +411,10 @@ async fn clone_infra( async fn import_railjson( args: ImportRailjsonArgs, - pg_config: PostgresConfig, + db_pool: Data, ) -> Result<(), Box> { let railjson_file = File::open(args.railjson_path)?; - let pool = Data::new(get_pool(pg_config.url()?, pg_config.pool_size)); - let infra: Infra = InfraForm { name: args.infra_name, } @@ -422,10 +422,10 @@ async fn import_railjson( let railjson: RailJson = serde_json::from_reader(BufReader::new(railjson_file))?; info!("🍞 Importing infra {}", infra.name.clone().unwrap().bold()); - let infra = infra.persist(railjson, pool.clone()).await?; + let infra = infra.persist(railjson, db_pool.clone()).await?; let infra_id = infra.id.unwrap(); - let mut conn = pool.get().await?; + let mut conn = db_pool.get().await?; let infra = infra .bump_version(&mut conn) .await @@ -439,7 +439,7 @@ async fn import_railjson( // Generate only if the was set if args.generate { let infra_cache = InfraCache::load(&mut conn, &infra).await?; - infra.refresh(pool, true, &infra_cache).await?; + infra.refresh(db_pool, true, &infra_cache).await?; info!( "✅ Infra {}[{}] generated data refreshed!", infra.name.unwrap().bold(), @@ -451,7 +451,7 @@ async fn import_railjson( async fn electrical_profile_set_import( args: ImportProfileSetArgs, - pg_config: PostgresConfig, + db_pool: Data, ) -> Result<(), Box> { let electrical_profile_set_file = File::open(args.electrical_profile_set_path)?; @@ -463,7 +463,7 @@ async fn electrical_profile_set_import( data: Some(DieselJson(electrical_profile_set_data)), }; - let mut conn = establish_connection(pg_config.url()?.as_str()).await?; + let mut conn = db_pool.get().await?; let created_ep_set = ep_set.create_conn(&mut conn).await.unwrap(); let ep_set_id = created_ep_set.id.unwrap(); info!("✅ Electrical profile set {ep_set_id} created"); @@ -472,9 +472,9 @@ async fn electrical_profile_set_import( async fn electrical_profile_set_list( args: ListProfileSetArgs, - pg_config: PostgresConfig, + db_pool: Data, ) -> Result<(), Box> { - let mut conn = establish_connection(pg_config.url()?.as_str()).await?; + let mut conn = db_pool.get().await?; let electrical_profile_sets = ElectricalProfileSet::list_light(&mut conn).await.unwrap(); if !args.quiet { println!("Electrical profile sets:\nID - Name"); @@ -490,11 +490,10 @@ async fn electrical_profile_set_list( async fn electrical_profile_set_delete( args: DeleteProfileSetArgs, - pg_config: PostgresConfig, + db_pool: Data, ) -> Result<(), Box> { - let pool = Data::new(get_pool(pg_config.url()?, pg_config.pool_size)); for profile_set_id in args.profile_set_ids { - let deleted = ElectricalProfileSet::delete(pool.clone(), profile_set_id) + let deleted = ElectricalProfileSet::delete(db_pool.clone(), profile_set_id) .await .unwrap(); if !deleted { @@ -510,11 +509,10 @@ async fn electrical_profile_set_delete( /// This command clear all generated data for the given infra async fn clear_infra( args: ClearArgs, - pg_config: PostgresConfig, + db_pool: Data, redis_config: RedisConfig, ) -> Result<(), Box> { - let pool = Data::new(get_pool(pg_config.url()?, pg_config.pool_size)); - let mut conn = pool.get().await?; + let mut conn = db_pool.get().await?; let mut infras = vec![]; if args.infra_ids.is_empty() { // Retrieve all available infra @@ -524,7 +522,7 @@ async fn clear_infra( } else { // Retrieve given infras for id in args.infra_ids { - match Infra::retrieve(pool.clone(), id as i64).await? { + match Infra::retrieve(db_pool.clone(), id as i64).await? { Some(infra) => infras.push(infra), None => { eprintln!("❌ Infrastructure not found, ID: {}", id); @@ -629,7 +627,7 @@ fn make_search_migration(args: MakeMigrationArgs) { async fn refresh_search_tables( args: RefreshArgs, - pg_config: PostgresConfig, + db_pool: Data, ) -> Result<(), Box> { let objects = if args.objects.is_empty() { SearchConfigFinder::all() @@ -641,7 +639,7 @@ async fn refresh_search_tables( args.objects }; - let mut conn = establish_connection(pg_config.url()?.as_str()).await?; + let mut conn = db_pool.get().await?; for object in objects { let Some(search_config) = SearchConfigFinder::find(&object) else { eprintln!("❌ No search object found for {object}"); @@ -669,21 +667,19 @@ async fn refresh_search_tables( mod tests { use super::*; - use crate::fixtures::tests::{electrical_profile_set, TestFixture}; - use actix_web::test as actix_test; + use crate::fixtures::tests::{db_pool, electrical_profile_set, TestFixture}; use diesel::sql_query; use diesel::sql_types::Text; - use diesel_async::{AsyncConnection, RunQueryDsl}; + use diesel_async::RunQueryDsl; use rand::distributions::Alphanumeric; use rand::{thread_rng, Rng}; use rstest::rstest; use std::io::Write; use tempfile::NamedTempFile; - #[actix_test] - async fn import_railjson_ko_file_not_found() { + #[rstest] + async fn import_railjson_ko_file_not_found(db_pool: Data) { // GIVEN - let pg_config = Default::default(); let args: ImportRailjsonArgs = ImportRailjsonArgs { infra_name: "test".into(), railjson_path: "non/existing/railjson/file/location".into(), @@ -691,18 +687,17 @@ mod tests { }; // WHEN - let result = import_railjson(args, pg_config).await; + let result = import_railjson(args, db_pool).await; // THEN assert!(result.is_err()) } - #[actix_test] - async fn import_railjson_ok() { + #[rstest] + async fn import_railjson_ok(db_pool: Data) { // GIVEN let railjson = Default::default(); let file = generate_railjson_temp_file(&railjson); - let pg_config = PostgresConfig::default(); let infra_name = format!( "{}_{}", "infra", @@ -717,16 +712,13 @@ mod tests { }; // WHEN - let result = import_railjson(args, pg_config.clone()).await; + let result = import_railjson(args, db_pool.clone()).await; // THEN assert!(result.is_ok()); // CLEANUP - let pg_config_url = pg_config.url().expect("cannot get postgres config url"); - let mut conn = PgConnection::establish(pg_config_url.as_str()) - .await - .expect("Error while connecting DB"); + let mut conn = db_pool.get().await.unwrap(); sql_query("DELETE FROM infra WHERE name = $1") .bind::(infra_name) .execute(&mut conn) @@ -743,9 +735,9 @@ mod tests { #[rstest] async fn test_electrical_profile_set_delete( #[future] electrical_profile_set: TestFixture, + db_pool: Data, ) { // GIVEN - let pg_config = PostgresConfig::default(); let electrical_profile_set = electrical_profile_set.await; let args = DeleteProfileSetArgs { @@ -753,15 +745,12 @@ mod tests { }; // WHEN - electrical_profile_set_delete(args, pg_config.clone()) + electrical_profile_set_delete(args, db_pool.clone()) .await .unwrap(); // THEN - let pg_config_url = pg_config.url().expect("cannot get postgres config url"); - let mut conn = PgConnection::establish(pg_config_url.as_str()) - .await - .unwrap(); + let mut conn = db_pool.get().await.unwrap(); let empty = !ElectricalProfileSet::list_light(&mut conn) .await .unwrap() @@ -773,11 +762,12 @@ mod tests { #[rstest] async fn test_electrical_profile_set_list_doesnt_fail( #[future] electrical_profile_set: TestFixture, + db_pool: Data, ) { let _electrical_profile_set = electrical_profile_set.await; for quiet in [true, false] { let args = ListProfileSetArgs { quiet }; - electrical_profile_set_list(args, PostgresConfig::default()) + electrical_profile_set_list(args, db_pool.clone()) .await .unwrap(); }