Skip to content

Commit

Permalink
Fix rebase failures
Browse files Browse the repository at this point in the history
Now that we're wrapping the connection in our own smart pointer that
doesn't implement `Connection` we need some explicit derefs (adding the
manual `Connection` impl isn't worth avoiding these `*`s). We need to be
able to clone the connection pool (only in tests, but this also only
requires modifying the test variant), so we need the `Arc`.

Similarly, since in tests the connection is a re-entrant mutex, we can't
grab the connection before spawning the worker thread. The lock isn't
`Send` that's for a very good reason. So we instead need to clone a
handle to the pool and grab the connection on the thread we intend to
use it.
  • Loading branch information
sgrif committed Feb 12, 2019
1 parent 65539fd commit 766b29d
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 11 deletions.
18 changes: 10 additions & 8 deletions src/background/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,10 @@ impl<Env: RefUnwindSafe + Send + Sync + 'static> Runner<Env> {
where
F: FnOnce(storage::BackgroundJob) -> CargoResult<()> + Send + UnwindSafe + 'static,
{
let conn = self.connection().expect("Could not acquire connection");
// The connection may not be `Send` so we need to clone the pool instead
let pool = self.connection_pool.clone();
self.thread_pool.execute(move || {
let conn = pool.get().expect("Could not acquire connection");
conn.transaction::<_, Box<dyn CargoError>, _>(|| {
let job = storage::find_next_unlocked_job(&conn).optional()?;
let job = match job {
Expand Down Expand Up @@ -192,7 +194,7 @@ mod tests {

let remaining_jobs = background_jobs
.count()
.get_result(&runner.connection().unwrap());
.get_result(&*runner.connection().unwrap());
assert_eq!(Ok(0), remaining_jobs);
}

Expand Down Expand Up @@ -223,15 +225,15 @@ mod tests {
.select(id)
.filter(retries.eq(0))
.for_update()
.load::<i64>(&conn)
.load::<i64>(&*conn)
.unwrap();
assert_eq!(0, available_jobs.len());

// Sanity check to make sure the job actually is there
let total_jobs_including_failed = background_jobs
.select(id)
.for_update()
.load::<i64>(&conn)
.load::<i64>(&*conn)
.unwrap();
assert_eq!(1, total_jobs_including_failed.len());

Expand All @@ -251,7 +253,7 @@ mod tests {
.find(job_id)
.select(retries)
.for_update()
.first::<i32>(&runner.connection().unwrap())
.first::<i32>(&*runner.connection().unwrap())
.unwrap();
assert_eq!(1, tries);
}
Expand All @@ -277,7 +279,7 @@ mod tests {
impl<'a> Drop for TestGuard<'a> {
fn drop(&mut self) {
::diesel::sql_query("TRUNCATE TABLE background_jobs")
.execute(&runner().connection().unwrap())
.execute(&*runner().connection().unwrap())
.unwrap();
}
}
Expand All @@ -290,14 +292,14 @@ mod tests {
let manager = r2d2::ConnectionManager::new(database_url);
let pool = r2d2::Pool::builder().max_size(2).build(manager).unwrap();

Runner::builder(pool, ()).thread_count(2).build()
Runner::builder(DieselPool::Pool(pool), ()).thread_count(2).build()
}

fn create_dummy_job(runner: &Runner<()>) -> storage::BackgroundJob {
::diesel::insert_into(background_jobs)
.values((job_type.eq("Foo"), data.eq(json!(null))))
.returning((id, job_type, data))
.get_result(&runner.connection().unwrap())
.get_result(&*runner.connection().unwrap())
.unwrap()
}
}
2 changes: 1 addition & 1 deletion src/bin/background-worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ fn main() {

// We're only using 1 thread, so we only need 1 connection
let db_config = r2d2::Pool::builder().max_size(1);
let db_pool = db::diesel_pool(&config.db_url, db_config);
let db_pool = db::diesel_pool(&config.db_url, config.env, db_config);

let builder = background::Runner::builder(db_pool, environment).thread_count(1);
let runner = job_runner(builder);
Expand Down
8 changes: 6 additions & 2 deletions src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use conduit::Request;
use diesel::prelude::*;
use diesel::r2d2::{self, ConnectionManager, CustomizeConnection};
use parking_lot::{ReentrantMutex, ReentrantMutexGuard};
use std::sync::Arc;
use std::ops::Deref;
use url::Url;

Expand All @@ -12,9 +13,10 @@ use crate::util::CargoResult;
use crate::Env;

#[allow(missing_debug_implementations)]
#[derive(Clone)]
pub enum DieselPool {
Pool(r2d2::Pool<ConnectionManager<PgConnection>>),
Test(ReentrantMutex<PgConnection>),
Test(Arc<ReentrantMutex<PgConnection>>),
}

impl DieselPool {
Expand All @@ -33,7 +35,7 @@ impl DieselPool {
}

fn test_conn(conn: PgConnection) -> Self {
DieselPool::Test(ReentrantMutex::new(conn))
DieselPool::Test(Arc::new(ReentrantMutex::new(conn)))
}
}

Expand All @@ -43,6 +45,8 @@ pub enum DieselPooledConn<'a> {
Test(ReentrantMutexGuard<'a, PgConnection>),
}

unsafe impl<'a> Send for DieselPooledConn<'a> {}

impl Deref for DieselPooledConn<'_> {
type Target = PgConnection;

Expand Down

0 comments on commit 766b29d

Please sign in to comment.