diff --git a/src/binding.cc b/src/binding.cc index b42e570..26c8fed 100644 --- a/src/binding.cc +++ b/src/binding.cc @@ -1677,7 +1677,8 @@ struct Batch { Batch (Database* database) : database_(database), batch_(new leveldb::WriteBatch()), - hasData_(false) {} + hasData_(false), + isShared_(false) {} ~Batch () { delete batch_; @@ -1704,9 +1705,22 @@ struct Batch { return database_->WriteBatch(options, batch_); } + bool IsShared () { + return isShared_; + } + + void Share () { + isShared_ = true; + } + + void Unshare () { + isShared_ = false; + } + Database* database_; leveldb::WriteBatch* batch_; bool hasData_; + bool isShared_; }; /** @@ -1741,11 +1755,15 @@ NAPI_METHOD(batch_put) { NAPI_ARGV(3); NAPI_BATCH_CONTEXT(); - leveldb::Slice key = ToSlice(env, argv[1]); - leveldb::Slice value = ToSlice(env, argv[2]); - batch->Put(key, value); - DisposeSliceBuffer(key); - DisposeSliceBuffer(value); + if (!batch->IsShared()) { + leveldb::Slice key = ToSlice(env, argv[1]); + leveldb::Slice value = ToSlice(env, argv[2]); + batch->Put(key, value); + DisposeSliceBuffer(key); + DisposeSliceBuffer(value); + } else { + napi_throw_error(env, 0, "Unsafe batch put."); + } NAPI_RETURN_UNDEFINED(); } @@ -1757,9 +1775,13 @@ NAPI_METHOD(batch_del) { NAPI_ARGV(2); NAPI_BATCH_CONTEXT(); - leveldb::Slice key = ToSlice(env, argv[1]); - batch->Del(key); - DisposeSliceBuffer(key); + if (!batch->IsShared()) { + leveldb::Slice key = ToSlice(env, argv[1]); + batch->Del(key); + DisposeSliceBuffer(key); + } else { + napi_throw_error(env, 0, "Unsafe batch del."); + } NAPI_RETURN_UNDEFINED(); } @@ -1771,7 +1793,11 @@ NAPI_METHOD(batch_clear) { NAPI_ARGV(1); NAPI_BATCH_CONTEXT(); - batch->Clear(); + if (!batch->IsShared()) { + batch->Clear(); + } else { + napi_throw_error(env, 0, "Unsafe batch clear."); + } NAPI_RETURN_UNDEFINED(); } @@ -1788,6 +1814,9 @@ struct BatchWriteWorker final : public PriorityWorker { : PriorityWorker(env, batch->database_, callback, "leveldown.batch.write"), batch_(batch), sync_(sync) { + // For thread saftey, consider BatchWrite as shared. + batch->Share(); + // Prevent GC of batch object before we execute NAPI_STATUS_THROWS_VOID(napi_create_reference(env_, context, 1, &contextRef_)); } @@ -1802,6 +1831,11 @@ struct BatchWriteWorker final : public PriorityWorker { } } + void DoFinally () override { + database_->DecrementPriorityWork(); + batch_->Unshare(); + } + Batch* batch_; bool sync_; @@ -1816,12 +1850,18 @@ NAPI_METHOD(batch_write) { NAPI_ARGV(3); NAPI_BATCH_CONTEXT(); - napi_value options = argv[1]; - bool sync = BooleanProperty(env, options, "sync", false); napi_value callback = argv[2]; - BatchWriteWorker* worker = new BatchWriteWorker(env, argv[0], batch, callback, sync); - worker->Queue(); + if (!batch->IsShared()) { + napi_value options = argv[1]; + bool sync = BooleanProperty(env, options, "sync", false); + + BatchWriteWorker* worker = new BatchWriteWorker(env, argv[0], batch, callback, sync); + worker->Queue(); + } else { + napi_value argv = CreateError(env, "Unsafe batch write."); + CallFunction(env, callback, 1, &argv); + } NAPI_RETURN_UNDEFINED(); } diff --git a/test/bdb-test.js b/test/bdb-test.js index 1639de5..f86f849 100644 --- a/test/bdb-test.js +++ b/test/bdb-test.js @@ -161,4 +161,56 @@ describe('BDB', function() { } }); }); + + describe('thread safety', function() { + async function checkError(method, message) { + const batch = db.batch(); + const hash = Buffer.alloc(20, 0x11); + + const value = Buffer.alloc(1024 * 1024); + const key = tkey.encode(hash, 12); + + batch.put(key, value); + batch.write(); + + let err = null; + + try { + switch (method) { + case 'clear': + batch.clear(); + break; + case 'put': + batch.put(key, value); + break; + case 'del': + batch.del(key); + break; + case 'write': + await batch.write(); + break; + } + + } catch (e) { + err = e; + } + + assert(err); + assert.equal(err.message, message); + await new Promise((r) => setTimeout(r, 200)); + } + + const methods = { + 'clear': 'Unsafe batch clear.', + 'put': 'Unsafe batch put.', + 'del': 'Unsafe batch del.', + 'write': 'Unsafe batch write.', + }; + + for (const [method, message] of Object.entries(methods)) { + it(`will check safety of ${method}`, async () => { + await checkError(method, message); + }); + } + }); });