diff --git a/src/coll/mod.rs b/src/coll/mod.rs index ed9bfee0f..e20350aab 100644 --- a/src/coll/mod.rs +++ b/src/coll/mod.rs @@ -1,6 +1,12 @@ pub mod options; -use std::{borrow::Borrow, collections::HashSet, fmt, fmt::Debug, sync::Arc}; +use std::{ + borrow::Borrow, + collections::{HashMap, HashSet}, + fmt, + fmt::Debug, + sync::Arc, +}; use futures_util::{ future, @@ -1196,7 +1202,7 @@ where let mut cumulative_failure: Option = None; let mut error_labels: HashSet = Default::default(); - let mut cumulative_result: Option = None; + let mut cumulative_inserted_ids = HashMap::new(); let mut n_attempted = 0; @@ -1211,13 +1217,8 @@ where { Ok(result) => { let current_batch_size = result.inserted_ids.len(); - - let cumulative_result = - cumulative_result.get_or_insert_with(InsertManyResult::new); for (index, id) in result.inserted_ids { - cumulative_result - .inserted_ids - .insert(index + n_attempted, id); + cumulative_inserted_ids.insert(index + n_attempted, id); } n_attempted += current_batch_size; @@ -1235,6 +1236,10 @@ where let failure_ref = cumulative_failure.get_or_insert_with(BulkWriteFailure::new); + for (index, id) in bw.inserted_ids { + cumulative_inserted_ids.insert(index + n_attempted, id); + } + if let Some(write_errors) = bw.write_errors { for err in write_errors { let index = n_attempted + err.index; @@ -1255,7 +1260,8 @@ where if ordered { // this will always be true since we invoked get_or_insert_with // above. - if let Some(failure) = cumulative_failure { + if let Some(mut failure) = cumulative_failure { + failure.inserted_ids = cumulative_inserted_ids; return Err(Error::new( ErrorKind::BulkWrite(failure), Some(error_labels), @@ -1271,11 +1277,14 @@ where } match cumulative_failure { - Some(failure) => Err(Error::new( - ErrorKind::BulkWrite(failure), - Some(error_labels), - )), - None => Ok(cumulative_result.unwrap_or_else(InsertManyResult::new)), + Some(mut failure) => { + failure.inserted_ids = cumulative_inserted_ids; + Err(Error::new( + ErrorKind::BulkWrite(failure), + Some(error_labels), + )) + } + None => Ok(InsertManyResult::new(cumulative_inserted_ids)), } } diff --git a/src/results.rs b/src/results.rs index 5245fbca5..3367ae213 100644 --- a/src/results.rs +++ b/src/results.rs @@ -41,10 +41,8 @@ pub struct InsertManyResult { } impl InsertManyResult { - pub(crate) fn new() -> Self { - InsertManyResult { - inserted_ids: HashMap::new(), - } + pub(crate) fn new(inserted_ids: HashMap) -> Self { + InsertManyResult { inserted_ids } } } diff --git a/src/test/coll.rs b/src/test/coll.rs index 49daa77bc..c4b46bece 100644 --- a/src/test/coll.rs +++ b/src/test/coll.rs @@ -27,7 +27,7 @@ use crate::{ UpdateOptions, WriteConcern, }, - results::DeleteResult, + results::{DeleteResult, InsertManyResult}, runtime, test::{ log_uncaptured, @@ -1201,3 +1201,213 @@ fn assert_duplicate_key_error_with_utf8_replacement(error: &ErrorKind) { ), } } + +async fn run_inserted_ids_test( + client: &TestClient, + docs: &Vec, + ordered: bool, +) -> Result { + let coll = client.init_db_and_coll("bulk_write_test", "test").await; + coll.insert_one(doc! { "_id": 1}, None).await.unwrap(); + + let insert_opts = InsertManyOptions::builder().ordered(ordered).build(); + coll.insert_many(docs, insert_opts).await +} + +/// Verify that when an insert_many fails, the returned BulkWriteFailure has +/// its inserted_ids correctly populated. +#[cfg_attr(feature = "tokio-runtime", tokio::test)] +#[cfg_attr(feature = "async-std-runtime", async_std::test)] +async fn bulk_write_failure_has_inserted_ids() { + let _guard: RwLockReadGuard<()> = LOCK.run_concurrently().await; + + let client = TestClient::new().await; + + // an ordered, single batch bulk write where the last doc generates a write error: + // everything before the last doc should be inserted + let docs = vec![doc! { "_id": 2}, doc! { "_id": 1}]; + let res = run_inserted_ids_test(&client, &docs, true); + let err = res.await.expect_err("insert_many should fail"); + match *err.kind { + ErrorKind::BulkWrite(failure) => { + assert_eq!( + failure.inserted_ids.len(), + 1, + "one document should have been inserted" + ); + assert!( + failure.inserted_ids.contains_key(&0), + "document at index 0 should have been inserted" + ); + assert_eq!( + failure.inserted_ids.get(&0).unwrap(), + &Bson::Int32(2), + "inserted document should have _id 2" + ); + } + _ => panic!("Expected BulkWrite error, but got: {:?}", err), + } + + // an ordered, single batch bulk write where the first doc generates a write error: + // nothing should be inserted + let docs = vec![doc! { "_id": 1}, doc! { "_id": 2}]; + let res = run_inserted_ids_test(&client, &docs, true); + let err = res.await.expect_err("insert_many should fail"); + match *err.kind { + ErrorKind::BulkWrite(failure) => { + assert_eq!( + failure.inserted_ids.len(), + 0, + "inserted_ids should be empty" + ); + } + _ => panic!("Expected BulkWrite error, but got: {:?}", err), + } + + // an unordered, single batch bulk write where the first doc generates a write error: + // everything after the first doc should be inserted + let res = run_inserted_ids_test(&client, &docs, false); + let err = res.await.expect_err("insert_many should fail"); + match *err.kind { + ErrorKind::BulkWrite(failure) => { + assert_eq!( + failure.inserted_ids.len(), + 1, + "one document should have been inserted" + ); + assert!( + failure.inserted_ids.contains_key(&1), + "document at index 1 should have been inserted" + ); + assert_eq!( + failure.inserted_ids.get(&1).unwrap(), + &Bson::Int32(2), + "inserted document should have _id 2" + ); + } + _ => panic!("Expected BulkWrite error, but got: {:?}", err), + } + + // an ordered, 2-batch bulk write where a document in the first batch generates write error: + // nothing should be inserted + // note: these numbers were chosen because maxWriteBatchSize is 100,000 + let mut docs = Vec::with_capacity(100001); + for i in 1..100002 { + docs.push(doc! { "_id": Bson::Int32(i) }); + } + + let res = run_inserted_ids_test(&client, &docs, true); + let err = res.await.expect_err("insert_many should fail"); + match *err.kind { + ErrorKind::BulkWrite(failure) => { + assert_eq!( + failure.inserted_ids.len(), + 0, + "0 documents should have been inserted" + ); + } + _ => panic!("Expected BulkWrite error, but got: {:?}", err), + } + + // an unordered, 2-batch bulk write where a document in the first batch generates a write error: + // everything besides that document should be inserted + let res = run_inserted_ids_test(&client, &docs, false); + let err = res.await.expect_err("insert_many should fail"); + match *err.kind { + ErrorKind::BulkWrite(failure) => { + assert_eq!( + failure.inserted_ids.len(), + 100000, + "100,000 documents should have been inserted" + ); + // docs at index 1 up to and including 100,000 should have been inserted + for (i, doc) in docs.iter().enumerate().take(100001).skip(1) { + match failure.inserted_ids.get(&i) { + Some(doc_id) => { + let expected_id = doc.get("_id").unwrap(); + assert_eq!( + doc_id, expected_id, + "Doc at index {} did not have expected _id", + i + ); + } + None => panic!("Document at index {} should have been inserted", i), + } + } + } + _ => panic!("Expected BulkWrite error, but got: {:?}", err), + } + + // an ordered, 2-batch bulk write where the second-to-last document in the second batch + // generates a write error: everything before that document should be inserted + let mut docs = Vec::with_capacity(100002); + for i in 2..100003 { + docs.push(doc! { "_id": Bson::Int32(i) }); + } + docs.push(doc! { "_id": 1 }); + docs.push(doc! { "_id": 100003 }); + + let res = run_inserted_ids_test(&client, &docs, true); + let err = res.await.expect_err("insert_many should fail"); + match *err.kind { + ErrorKind::BulkWrite(failure) => { + assert_eq!( + failure.inserted_ids.len(), + 100001, + "100001 documents should have been inserted" + ); + // docs at index 0 up to and including 100,000 should have been inserted; + // doc at index 100,001 generates a duplicate key error + for (i, doc) in docs.iter().enumerate().take(100001) { + match failure.inserted_ids.get(&i) { + Some(doc_id) => { + let expected_id = doc.get("_id").unwrap(); + assert_eq!( + doc_id, expected_id, + "Doc at index {} did not have expected _id", + i + ); + } + None => panic!("Document at index {} should have been inserted", i), + } + } + } + _ => panic!("Expected BulkWrite error, but got: {:?}", err), + } + + // an unordered, 2-batch bulk write where the second-to-last document in the second batch + // generates a write error: everything besides that document should be inserted + let res = run_inserted_ids_test(&client, &docs, false); + let err = res.await.expect_err("insert_many should fail"); + match *err.kind { + ErrorKind::BulkWrite(failure) => { + assert_eq!( + failure.inserted_ids.len(), + 100002, + "100002 documents should have been inserted" + ); + // docs at index 0 up to and including 100,000 should have been inserted + for (i, doc) in docs.iter().enumerate().take(100001) { + match failure.inserted_ids.get(&i) { + Some(doc_id) => { + let expected_id = doc.get("_id").unwrap(); + assert_eq!( + doc_id, expected_id, + "Doc at index {} did not have expected _id", + i + ); + } + None => panic!("Document at index {} should have been inserted", i), + } + } + // doc at index 100,001 generates a duplicate key error; doc at index + // 100,002 should be inserted + assert_eq!( + failure.inserted_ids.get(&100002).unwrap(), + docs[100002].get("_id").unwrap(), + "inserted_id for index 1000002 should be 100003", + ); + } + _ => panic!("Expected BulkWrite error, but got: {:?}", err), + } +}