Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

editoast: refactor main.rs db_pool #5695

Merged
merged 1 commit into from
Nov 16, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 58 additions & 68 deletions editoast/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ use clap::Parser;
use client::{
ClearArgs, Client, Color, Commands, DeleteProfileSetArgs, ElectricalProfilesCommands,
GenerateArgs, ImportProfileSetArgs, ImportRailjsonArgs, ImportRollingStockArgs, InfraCloneArgs,
InfraCommands, ListProfileSetArgs, MakeMigrationArgs, PostgresConfig, RedisConfig, RefreshArgs,
RunserverArgs, SearchCommands,
InfraCommands, ListProfileSetArgs, MakeMigrationArgs, RedisConfig, RefreshArgs, RunserverArgs,
SearchCommands,
};
use colored::*;
use diesel::{sql_query, ConnectionError, ConnectionResult};
Expand Down Expand Up @@ -82,6 +82,13 @@ async fn main() {
async fn run() -> Result<(), Box<dyn Error + Send + Sync>> {
let client = Client::parse();
let pg_config = client.postgres_config;
let create_db_pool = || {
Ok::<_, Box<dyn Error + Send + Sync>>(Data::new(get_pool(
pg_config.url()?,
pg_config.pool_size,
)))
};

let redis_config = client.redis_config;

match client.color {
Expand All @@ -91,10 +98,10 @@ async fn run() -> Result<(), Box<dyn Error + Send + Sync>> {
}

match client.command {
Commands::Runserver(args) => runserver(args, pg_config, redis_config).await,
Commands::Generate(args) => generate(args, pg_config, redis_config).await,
Commands::ImportRailjson(args) => import_railjson(args, pg_config).await,
Commands::ImportRollingStock(args) => import_rolling_stock(args, pg_config).await,
Commands::Runserver(args) => runserver(args, create_db_pool()?, redis_config).await,
Commands::Generate(args) => generate(args, create_db_pool()?, redis_config).await,
Commands::ImportRailjson(args) => import_railjson(args, create_db_pool()?).await,
Commands::ImportRollingStock(args) => import_rolling_stock(args, create_db_pool()?).await,
Commands::OsmToRailjson(args) => {
converters::osm_to_railjson(args.osm_pbf_in, args.railjson_out)
}
Expand All @@ -104,13 +111,13 @@ async fn run() -> Result<(), Box<dyn Error + Send + Sync>> {
}
Commands::ElectricalProfiles(subcommand) => match subcommand {
ElectricalProfilesCommands::Import(args) => {
electrical_profile_set_import(args, pg_config).await
electrical_profile_set_import(args, create_db_pool()?).await
}
ElectricalProfilesCommands::List(args) => {
electrical_profile_set_list(args, pg_config).await
electrical_profile_set_list(args, create_db_pool()?).await
}
ElectricalProfilesCommands::Delete(args) => {
electrical_profile_set_delete(args, pg_config).await
electrical_profile_set_delete(args, create_db_pool()?).await
}
},
Commands::Search(SearchCommands::List) => {
Expand All @@ -122,11 +129,11 @@ async fn run() -> Result<(), Box<dyn Error + Send + Sync>> {
Ok(())
}
Commands::Search(SearchCommands::Refresh(args)) => {
refresh_search_tables(args, pg_config).await
refresh_search_tables(args, create_db_pool()?).await
}
Commands::Infra(subcommand) => match subcommand {
InfraCommands::Clone(args) => clone_infra(args, pg_config).await,
InfraCommands::Clear(args) => clear_infra(args, pg_config, redis_config).await,
InfraCommands::Clone(args) => clone_infra(args, create_db_pool()?).await,
InfraCommands::Clear(args) => clear_infra(args, create_db_pool()?, redis_config).await,
},
}
}
Expand Down Expand Up @@ -204,13 +211,11 @@ fn get_pool(url: Url, max_size: usize) -> DbPool {
/// Create and run the server
async fn runserver(
args: RunserverArgs,
pg_config: PostgresConfig,
db_pool: Data<DbPool>,
redis_config: RedisConfig,
) -> Result<(), Box<dyn Error + Send + Sync>> {
info!("Building server...");
// Config databases
let pool = get_pool(pg_config.url()?, pg_config.pool_size);

// Config database
let redis = RedisClient::new(redis_config)?;

// Custom Json extractor configuration
Expand Down Expand Up @@ -266,7 +271,7 @@ async fn runserver(
.wrap(Logger::new(actix_logger_format).log_target("actix_logger"))
.app_data(json_cfg.clone())
.app_data(payload_config.clone())
.app_data(Data::new(pool.clone()))
.app_data(db_pool.clone())
.app_data(Data::new(redis.clone()))
.app_data(infra_caches.clone())
.app_data(Data::new(MapLayers::parse()))
Expand Down Expand Up @@ -299,11 +304,10 @@ async fn build_redis_pool_and_invalidate_all_cache(redis_config: RedisConfig, in
/// This command refresh all infra given as input (if no infra given then refresh all of them)
async fn generate(
args: GenerateArgs,
pg_config: PostgresConfig,
db_pool: Data<DbPool>,
redis_config: RedisConfig,
) -> Result<(), Box<dyn Error + Send + Sync>> {
let pool = Data::new(get_pool(pg_config.url()?, pg_config.pool_size));
let mut conn = pool.get().await?;
let mut conn = db_pool.get().await?;
let mut infras = vec![];
if args.infra_ids.is_empty() {
// Retrieve all available infra
Expand All @@ -313,7 +317,7 @@ async fn generate(
} else {
// Retrieve given infras
for id in args.infra_ids {
let infra = match Infra::retrieve(pool.clone(), id as i64).await? {
let infra = match Infra::retrieve(db_pool.clone(), id as i64).await? {
Some(infra) => infra,
None => {
return Err(InfraApiError::NotFound {
Expand All @@ -335,7 +339,7 @@ async fn generate(
);
let infra_cache = InfraCache::load(&mut conn, &infra).await?;
if infra
.refresh(pool.clone(), args.force, &infra_cache)
.refresh(db_pool.clone(), args.force, &infra_cache)
.await?
{
build_redis_pool_and_invalidate_all_cache(redis_config.clone(), infra.id.unwrap())
Expand All @@ -362,9 +366,8 @@ async fn generate(

async fn import_rolling_stock(
args: ImportRollingStockArgs,
pg_config: PostgresConfig,
db_pool: Data<DbPool>,
) -> Result<(), Box<dyn Error + Send + Sync>> {
let pool = Data::new(get_pool(pg_config.url()?, pg_config.pool_size));
for rolling_stock_path in args.rolling_stock_path {
let rolling_stock_file = File::open(rolling_stock_path)?;
let mut rolling_stock: RollingStockModel =
Expand All @@ -375,7 +378,7 @@ async fn import_rolling_stock(
);
rolling_stock.locked = Some(false);
rolling_stock.version = Some(0);
let rolling_stock = rolling_stock.create(pool.clone()).await?;
let rolling_stock = rolling_stock.create(db_pool.clone()).await?;
info!(
"✅ Rolling stock {}[{}] saved!",
rolling_stock.name.clone().unwrap().bold(),
Expand All @@ -387,9 +390,8 @@ async fn import_rolling_stock(

async fn clone_infra(
infra_args: InfraCloneArgs,
pg_config: PostgresConfig,
db_pool: Data<DbPool>,
) -> Result<(), Box<dyn Error + Send + Sync>> {
let db_pool = Data::new(get_pool(pg_config.url()?, pg_config.pool_size));
match Infra::clone(infra_args.id, db_pool, infra_args.new_name).await {
Ok(cloned_infra) => println!(
"✅ Infra {} (ID: {}) was successfully cloned",
Expand All @@ -409,23 +411,21 @@ async fn clone_infra(

async fn import_railjson(
args: ImportRailjsonArgs,
pg_config: PostgresConfig,
db_pool: Data<DbPool>,
) -> Result<(), Box<dyn Error + Send + Sync>> {
let railjson_file = File::open(args.railjson_path)?;

let pool = Data::new(get_pool(pg_config.url()?, pg_config.pool_size));

let infra: Infra = InfraForm {
name: args.infra_name,
}
.into();
let railjson: RailJson = serde_json::from_reader(BufReader::new(railjson_file))?;

info!("🍞 Importing infra {}", infra.name.clone().unwrap().bold());
let infra = infra.persist(railjson, pool.clone()).await?;
let infra = infra.persist(railjson, db_pool.clone()).await?;
let infra_id = infra.id.unwrap();

let mut conn = pool.get().await?;
let mut conn = db_pool.get().await?;
let infra = infra
.bump_version(&mut conn)
.await
Expand All @@ -439,7 +439,7 @@ async fn import_railjson(
// Generate only if the was set
if args.generate {
let infra_cache = InfraCache::load(&mut conn, &infra).await?;
infra.refresh(pool, true, &infra_cache).await?;
infra.refresh(db_pool, true, &infra_cache).await?;
info!(
"✅ Infra {}[{}] generated data refreshed!",
infra.name.unwrap().bold(),
Expand All @@ -451,7 +451,7 @@ async fn import_railjson(

async fn electrical_profile_set_import(
args: ImportProfileSetArgs,
pg_config: PostgresConfig,
db_pool: Data<DbPool>,
) -> Result<(), Box<dyn Error + Send + Sync>> {
let electrical_profile_set_file = File::open(args.electrical_profile_set_path)?;

Expand All @@ -463,7 +463,7 @@ async fn electrical_profile_set_import(
data: Some(DieselJson(electrical_profile_set_data)),
};

let mut conn = establish_connection(pg_config.url()?.as_str()).await?;
let mut conn = db_pool.get().await?;
let created_ep_set = ep_set.create_conn(&mut conn).await.unwrap();
let ep_set_id = created_ep_set.id.unwrap();
info!("✅ Electrical profile set {ep_set_id} created");
Expand All @@ -472,9 +472,9 @@ async fn electrical_profile_set_import(

async fn electrical_profile_set_list(
args: ListProfileSetArgs,
pg_config: PostgresConfig,
db_pool: Data<DbPool>,
) -> Result<(), Box<dyn Error + Send + Sync>> {
let mut conn = establish_connection(pg_config.url()?.as_str()).await?;
let mut conn = db_pool.get().await?;
let electrical_profile_sets = ElectricalProfileSet::list_light(&mut conn).await.unwrap();
if !args.quiet {
println!("Electrical profile sets:\nID - Name");
Expand All @@ -490,11 +490,10 @@ async fn electrical_profile_set_list(

async fn electrical_profile_set_delete(
args: DeleteProfileSetArgs,
pg_config: PostgresConfig,
db_pool: Data<DbPool>,
) -> Result<(), Box<dyn Error + Send + Sync>> {
let pool = Data::new(get_pool(pg_config.url()?, pg_config.pool_size));
for profile_set_id in args.profile_set_ids {
let deleted = ElectricalProfileSet::delete(pool.clone(), profile_set_id)
let deleted = ElectricalProfileSet::delete(db_pool.clone(), profile_set_id)
.await
.unwrap();
if !deleted {
Expand All @@ -510,11 +509,10 @@ async fn electrical_profile_set_delete(
/// This command clear all generated data for the given infra
async fn clear_infra(
args: ClearArgs,
pg_config: PostgresConfig,
db_pool: Data<DbPool>,
redis_config: RedisConfig,
) -> Result<(), Box<dyn Error + Send + Sync>> {
let pool = Data::new(get_pool(pg_config.url()?, pg_config.pool_size));
let mut conn = pool.get().await?;
let mut conn = db_pool.get().await?;
let mut infras = vec![];
if args.infra_ids.is_empty() {
// Retrieve all available infra
Expand All @@ -524,7 +522,7 @@ async fn clear_infra(
} else {
// Retrieve given infras
for id in args.infra_ids {
match Infra::retrieve(pool.clone(), id as i64).await? {
match Infra::retrieve(db_pool.clone(), id as i64).await? {
Some(infra) => infras.push(infra),
None => {
eprintln!("❌ Infrastructure not found, ID: {}", id);
Expand Down Expand Up @@ -629,7 +627,7 @@ fn make_search_migration(args: MakeMigrationArgs) {

async fn refresh_search_tables(
args: RefreshArgs,
pg_config: PostgresConfig,
db_pool: Data<DbPool>,
) -> Result<(), Box<dyn Error + Send + Sync>> {
let objects = if args.objects.is_empty() {
SearchConfigFinder::all()
Expand All @@ -641,7 +639,7 @@ async fn refresh_search_tables(
args.objects
};

let mut conn = establish_connection(pg_config.url()?.as_str()).await?;
let mut conn = db_pool.get().await?;
for object in objects {
let Some(search_config) = SearchConfigFinder::find(&object) else {
eprintln!("❌ No search object found for {object}");
Expand Down Expand Up @@ -669,40 +667,37 @@ async fn refresh_search_tables(
mod tests {
use super::*;

use crate::fixtures::tests::{electrical_profile_set, TestFixture};
use actix_web::test as actix_test;
use crate::fixtures::tests::{db_pool, electrical_profile_set, TestFixture};
use diesel::sql_query;
use diesel::sql_types::Text;
use diesel_async::{AsyncConnection, RunQueryDsl};
use diesel_async::RunQueryDsl;
use rand::distributions::Alphanumeric;
use rand::{thread_rng, Rng};
use rstest::rstest;
use std::io::Write;
use tempfile::NamedTempFile;

#[actix_test]
async fn import_railjson_ko_file_not_found() {
#[rstest]
async fn import_railjson_ko_file_not_found(db_pool: Data<DbPool>) {
// GIVEN
let pg_config = Default::default();
let args: ImportRailjsonArgs = ImportRailjsonArgs {
infra_name: "test".into(),
railjson_path: "non/existing/railjson/file/location".into(),
generate: false,
};

// WHEN
let result = import_railjson(args, pg_config).await;
let result = import_railjson(args, db_pool).await;

// THEN
assert!(result.is_err())
}

#[actix_test]
async fn import_railjson_ok() {
#[rstest]
async fn import_railjson_ok(db_pool: Data<DbPool>) {
// GIVEN
let railjson = Default::default();
let file = generate_railjson_temp_file(&railjson);
let pg_config = PostgresConfig::default();
let infra_name = format!(
"{}_{}",
"infra",
Expand All @@ -717,16 +712,13 @@ mod tests {
};

// WHEN
let result = import_railjson(args, pg_config.clone()).await;
let result = import_railjson(args, db_pool.clone()).await;

// THEN
assert!(result.is_ok());

// CLEANUP
let pg_config_url = pg_config.url().expect("cannot get postgres config url");
let mut conn = PgConnection::establish(pg_config_url.as_str())
.await
.expect("Error while connecting DB");
let mut conn = db_pool.get().await.unwrap();
sql_query("DELETE FROM infra WHERE name = $1")
.bind::<Text, _>(infra_name)
.execute(&mut conn)
Expand All @@ -743,25 +735,22 @@ mod tests {
#[rstest]
async fn test_electrical_profile_set_delete(
#[future] electrical_profile_set: TestFixture<ElectricalProfileSet>,
db_pool: Data<DbPool>,
) {
// GIVEN
let pg_config = PostgresConfig::default();
let electrical_profile_set = electrical_profile_set.await;

let args = DeleteProfileSetArgs {
profile_set_ids: vec![electrical_profile_set.id()],
};

// WHEN
electrical_profile_set_delete(args, pg_config.clone())
electrical_profile_set_delete(args, db_pool.clone())
.await
.unwrap();

// THEN
let pg_config_url = pg_config.url().expect("cannot get postgres config url");
let mut conn = PgConnection::establish(pg_config_url.as_str())
.await
.unwrap();
let mut conn = db_pool.get().await.unwrap();
let empty = !ElectricalProfileSet::list_light(&mut conn)
.await
.unwrap()
Expand All @@ -773,11 +762,12 @@ mod tests {
#[rstest]
async fn test_electrical_profile_set_list_doesnt_fail(
#[future] electrical_profile_set: TestFixture<ElectricalProfileSet>,
db_pool: Data<DbPool>,
) {
let _electrical_profile_set = electrical_profile_set.await;
for quiet in [true, false] {
let args = ListProfileSetArgs { quiet };
electrical_profile_set_list(args, PostgresConfig::default())
electrical_profile_set_list(args, db_pool.clone())
.await
.unwrap();
}
Expand Down
Loading