diff --git a/src/server/common.cc b/src/server/common.cc index 62046a5c4c6b..7b7d97eb25cd 100644 --- a/src/server/common.cc +++ b/src/server/common.cc @@ -9,6 +9,8 @@ #include #include +#include + extern "C" { #include "redis/object.h" #include "redis/rdb.h" @@ -250,4 +252,8 @@ std::string GenericError::Format() const { return absl::StrCat(ec_.message(), ":", details_); } +void Context::Cancel() { + Error(std::make_error_code(errc::operation_canceled), "Context cancelled"); +} + } // namespace dfly diff --git a/src/server/common.h b/src/server/common.h index e0e74bdfbfb5..e5bc83fa3857 100644 --- a/src/server/common.h +++ b/src/server/common.h @@ -246,7 +246,7 @@ using AggregateGenericError = AggregateValue; // Contest combines Cancellation and AggregateGenericError in one class. // Allows setting an error_handler to run on errors. -class Context : public Cancellation { +class Context : private Cancellation { public: // The error handler should return false if this error is ignored. using ErrHandler = std::function; @@ -255,6 +255,24 @@ class Context : public Cancellation { Context(ErrHandler err_handler) : Cancellation{}, err_handler_{std::move(err_handler)} { } + operator GenericError() { + return err_; + } + + operator std::error_code() { + return err_.GetError(); + } + + // Cancelling the internal context is only possible through the context directly, + // because it needs to emit an cancellation error. + operator const Cancellation*() { + return this; + } + + using Cancellation::IsCancelled; + + void Cancel(); + template void Error(T... ts) { std::lock_guard lk{mu_}; if (err_) @@ -263,7 +281,7 @@ class Context : public Cancellation { GenericError new_err{std::forward(ts)...}; if (!err_handler_ || err_handler_(new_err)) { err_ = std::move(new_err); - Cancel(); + Cancellation::Cancel(); } } diff --git a/src/server/dflycmd.cc b/src/server/dflycmd.cc index 544119203fb6..078ac2c0b9d4 100644 --- a/src/server/dflycmd.cc +++ b/src/server/dflycmd.cc @@ -250,7 +250,7 @@ void DflyCmd::Sync(CmdArgList args, ConnectionContext* cntx) { TransactionGuard tg{cntx->transaction}; AggregateStatus status; - auto cb = [this, &status, replica_ptr](unsigned index, auto*) { + auto cb = [this, &status, replica_ptr = replica_ptr](unsigned index, auto*) { status = StartFullSyncInThread(&replica_ptr->flows[index], &replica_ptr->cntx, EngineShard::tlocal()); }; @@ -283,7 +283,7 @@ void DflyCmd::StartStable(CmdArgList args, ConnectionContext* cntx) { TransactionGuard tg{cntx->transaction}; AggregateStatus status; - auto cb = [this, &status, replica_ptr](unsigned index, auto*) { + auto cb = [this, &status, replica_ptr = replica_ptr](unsigned index, auto*) { EngineShard* shard = EngineShard::tlocal(); FlowInfo* flow = &replica_ptr->flows[index]; @@ -325,7 +325,7 @@ OpStatus DflyCmd::StartFullSyncInThread(FlowInfo* flow, Context* cntx, EngineSha // Shard can be null for io thread. if (shard != nullptr) { CHECK(!sf_->journal()->OpenInThread(false, ""sv)); // can only happen in persistent mode. - flow->saver->StartSnapshotInShard(true, cntx, shard); + flow->saver->StartSnapshotInShard(true, *cntx, shard); } flow->full_sync_fb = ::boost::fibers::fiber(&DflyCmd::FullSyncFb, this, flow, cntx); @@ -383,7 +383,7 @@ void DflyCmd::FullSyncFb(FlowInfo* flow, Context* cntx) { return cntx->Error(ec); } - if (ec = saver->SaveBody(cntx, nullptr); ec) { + if (ec = saver->SaveBody(*cntx, nullptr); ec) { return cntx->Error(ec); } diff --git a/src/server/replica.cc b/src/server/replica.cc index c5562ea7c325..1eae2b0671a3 100644 --- a/src/server/replica.cc +++ b/src/server/replica.cc @@ -137,29 +137,39 @@ bool Replica::Start(ConnectionContext* cntx) { // 3. Spawn main coordination fiber. sync_fb_ = ::boost::fibers::fiber(&Replica::MainReplicationFb, this); + // 4. Init basic context, which is used only for cancellation. + // Full and stable sync context will be used for error propagation as well. + cntx_.emplace([this](const auto& ge) { + CloseAllSockets(); + return true; + }); + + // Lock the running latch. + running_latch_mu_.lock(); + (*cntx)->SendOk(); return true; } void Replica::Stop() { + // Mark disabled, prevent from retrying. if (sock_) { sock_->proactor()->Await([this] { state_mask_ = 0; // Specifically ~R_ENABLED. - auto ec = sock_->Shutdown(SHUT_RDWR); - LOG_IF(ERROR, ec) << "Could not shutdown socket " << ec; }); } - // Close sub flows. - auto partition = Partition(num_df_flows_); - shard_set->pool()->AwaitFiberOnAll([&](unsigned index, auto*) { - for (auto id : partition[index]) { - shard_flows_[id]->Stop(); - } - }); + // The context if fully reposible for cleaning up on cancellation. + { + lock_guard lk(cntx_mu_); + CHECK(cntx_); + cntx_->Cancel(); + } - if (sync_fb_.joinable()) - sync_fb_.join(); + // Make sure the replica fully stopped and did all cleanups, + // so we can freely release resources (connections) + running_latch_mu_.lock(); + running_latch_mu_.unlock(); } void Replica::Pause(bool pause) { @@ -236,6 +246,7 @@ void Replica::MainReplicationFb() { } VLOG(1) << "Main replication fiber finished"; + running_latch_mu_.unlock(); } error_code Replica::ConnectSocket() { @@ -431,43 +442,62 @@ error_code Replica::InitiateDflySync() { shard_flows_[i].reset(new Replica(master_context_, i, &service_)); } - SyncBlock sb{num_df_flows_}; + // Allocate shared, because the error handler might outlive the whole task. + auto sync_block = std::make_shared(num_df_flows_); + + // Init context. + // From full sync and onwards cleanup is managed by the underlying currently running task. + // Context responsibility is always transferred atomically. + { + lock_guard lk(cntx_mu_); + cntx_.emplace([this, sync_block](const auto& ge) { + // Prevent blocking fiber that issued error by starting another fiber. + ::boost::fibers::fiber{[this, sync_block]() { + // Unblock sync block. + { + lock_guard lk(sync_block->mu_); + sync_block->flows_left = 0; + } + sync_block->cv_.notify_all(); + }}.detach(); + + // Close sockets to unblock flows in case of cancellation. + CloseAllSockets(); + return true; + }); + } - AggregateError ec; + // Start full sync flows. auto partition = Partition(num_df_flows_); shard_set->pool()->AwaitFiberOnAll([&](unsigned index, auto*) { for (auto id : partition[index]) { - if ((ec = shard_flows_[id]->StartFullSyncFlow(&sb))) - break; + auto ec = shard_flows_[id]->StartFullSyncFlow(sync_block.get(), &*cntx_); + if (ec) + cntx_->Error(ec); } }); - RETURN_ON_ERR(*ec); - - ReqSerializer serializer{sock_.get()}; - - // Master waits for this command in order to start sending replication stream. - RETURN_ON_ERR(SendCommand(StrCat("DFLY SYNC ", master_context_.dfly_session_id), &serializer)); - - base::IoBuf io_buf{128}; - unsigned consumed = 0; - RETURN_ON_ERR(ReadRespReply(&io_buf, &consumed)); - if (!CheckRespIsSimpleReply("OK")) { - LOG(ERROR) << "Sync failed " << ToSV(io_buf.InputBuffer()); - return make_error_code(errc::bad_message); + // Send DFLY SYNC. + if (auto ec = SendNextPhaseRequest(); ec) { + cntx_->Error(ec); } // Wait for all flows to receive full sync cut. + // In case of an error, this is unblocked by the error handler. { VLOG(1) << "Blocking before full sync cut"; - std::unique_lock lk(sb.mu_); - sb.cv_.wait(lk, [&]() { return sb.flows_left == 0; }); + std::unique_lock lk(sync_block->mu_); + sync_block->cv_.wait(lk, [&]() { return sync_block->flows_left == 0; }); } - LOG(INFO) << "Full sync finished"; - state_mask_ |= R_SYNC_OK; + if (!cntx_->IsCancelled()) { + LOG(INFO) << "Full sync finished"; + state_mask_ |= R_SYNC_OK; + } else { + JoinAllFlows(); + } - return error_code{}; + return *cntx_; } error_code Replica::ConsumeRedisStream() { @@ -517,43 +547,88 @@ error_code Replica::ConsumeRedisStream() { } error_code Replica::ConsumeDflyStream() { - // Request master to transition to stable sync. - { - ReqSerializer serializer{sock_.get()}; - serializer.SendCommand(StrCat("DFLY STARTSTABLE ", master_context_.dfly_session_id)); - RETURN_ON_ERR(serializer.ec()); + DCHECK(cntx_); // Context from full sync. + + // Send DFLY STARTSTABLE. + if (auto ec = SendNextPhaseRequest(); ec) { + cntx_->Error(ec); } // Wait for all flows to finish full sync. - for (auto& sub_repl : shard_flows_) - sub_repl->sync_fb_.join(); + JoinAllFlows(); + + // Transfer cleanup responsibility to stable state task. + { + lock_guard lk(cntx_mu_); + + // We were cancelled during transition, fast fail. + if (cntx_->IsCancelled()) { + return *cntx_; + } + + cntx_.emplace([this](const auto& ge) { + CloseAllSockets(); + return true; + }); + } - AggregateError all_ec; vector> partition = Partition(num_df_flows_); shard_set->pool()->AwaitFiberOnAll([&](unsigned index, auto*) { const auto& local_ids = partition[index]; for (unsigned id : local_ids) { - all_ec = shard_flows_[id]->StartStableSyncFlow(); - if (all_ec) - break; + auto ec = shard_flows_[id]->StartStableSyncFlow(&*cntx_); + if (ec) + cntx_->Error(ec); } }); - RETURN_ON_ERR(*all_ec); + // Wait for all shard flows to join. + JoinAllFlows(); - base::IoBuf io_buf(16_KB); - std::error_code ec; - while (!ec) { - io::MutableBytes buf = io_buf.AppendBuffer(); - io::Result size_res = sock_->Recv(buf); - if (!size_res) - return size_res.error(); + return *cntx_; // TODO replace with generic error +} + +void Replica::CloseAllSockets() { + if (sock_) { + sock_->proactor()->Await([this] { + auto ec = sock_->Shutdown(SHUT_RDWR); + LOG_IF(ERROR, ec) << "Could not shutdown socket " << ec; + }); } - return error_code{}; + for (auto& flow : shard_flows_) { + flow->CloseAllSockets(); + } } -error_code Replica::StartFullSyncFlow(SyncBlock* sb) { +void Replica::JoinAllFlows() { + for (auto& flow : shard_flows_) { + if (flow->sync_fb_.joinable()) { + flow->sync_fb_.join(); + } + } +} + +error_code Replica::SendNextPhaseRequest() { + ReqSerializer serializer{sock_.get()}; + + // Ask master to start sending replication stream + string request = (state_mask_ & R_SYNC_OK) ? "STARTSTABLE" : "SYNC"; + RETURN_ON_ERR( + SendCommand(StrCat("DFLY ", request, " ", master_context_.dfly_session_id), &serializer)); + + base::IoBuf io_buf{128}; + unsigned consumed = 0; + RETURN_ON_ERR(ReadRespReply(&io_buf, &consumed)); + if (!CheckRespIsSimpleReply("OK")) { + LOG(ERROR) << "Phase transition failed " << ToSV(io_buf.InputBuffer()); + return make_error_code(errc::bad_message); + } + + return std::error_code{}; +} + +error_code Replica::StartFullSyncFlow(SyncBlock* sb, Context* cntx) { CHECK(!sock_); DCHECK(!master_context_.master_repl_id.empty() && !master_context_.dfly_session_id.empty()); @@ -595,12 +670,12 @@ error_code Replica::StartFullSyncFlow(SyncBlock* sb) { // We can not discard io_buf because it may contain data // besides the response we parsed. Therefore we pass it further to ReplicateDFFb. - sync_fb_ = ::boost::fibers::fiber(&Replica::FullSyncDflyFb, this, sb, move(eof_token)); + sync_fb_ = ::boost::fibers::fiber(&Replica::FullSyncDflyFb, this, move(eof_token), sb, cntx); return error_code{}; } -error_code Replica::StartStableSyncFlow() { +error_code Replica::StartStableSyncFlow(Context* cntx) { DCHECK(!master_context_.master_repl_id.empty() && !master_context_.dfly_session_id.empty()); ProactorBase* mythread = ProactorBase::me(); CHECK(mythread); @@ -608,12 +683,12 @@ error_code Replica::StartStableSyncFlow() { CHECK(sock_->IsOpen()); // sock_.reset(mythread->CreateSocket()); // RETURN_ON_ERR(sock_->Connect(master_context_.master_ep)); - sync_fb_ = ::boost::fibers::fiber(&Replica::StableSyncDflyFb, this); + sync_fb_ = ::boost::fibers::fiber(&Replica::StableSyncDflyFb, this, cntx); return std::error_code{}; } -void Replica::FullSyncDflyFb(SyncBlock* sb, string eof_token) { +void Replica::FullSyncDflyFb(string eof_token, SyncBlock* sb, Context* cntx) { DCHECK(leftover_buf_); SocketSource ss{sock_.get()}; io::PrefixSource ps{leftover_buf_->InputBuffer(), &ss}; @@ -627,7 +702,10 @@ void Replica::FullSyncDflyFb(SyncBlock* sb, string eof_token) { sb->cv_.notify_all(); } }); - loader.Load(&ps); + + // Load incoming rdb stream. + if (std::error_code ec = loader.Load(&ps); ec) + return cntx->Error(ec, "Error loading rdb format"); // Try finding eof token. io::PrefixSource chained_tail{loader.Leftover(), &ps}; @@ -638,7 +716,8 @@ void Replica::FullSyncDflyFb(SyncBlock* sb, string eof_token) { chained_tail.ReadAtLeast(io::MutableBytes{buf.get(), eof_token.size()}, eof_token.size()); if (!res || *res != eof_token.size()) { - LOG(ERROR) << "Error finding eof token in the stream"; + return cntx->Error(std::make_error_code(errc::protocol_error), + "Error finding eof token in stream"); } } @@ -656,7 +735,7 @@ void Replica::FullSyncDflyFb(SyncBlock* sb, string eof_token) { VLOG(1) << "FullSyncDflyFb finished after reading " << loader.bytes_read() << " bytes"; } -void Replica::StableSyncDflyFb() { +void Replica::StableSyncDflyFb(Context* cntx) { base::IoBuf io_buf(16_KB); parser_.reset(new RedisParser); @@ -668,24 +747,22 @@ void Replica::StableSyncDflyFb() { leftover_buf_.reset(); } - error_code ec; string ack_cmd; - while (!ec) { + while (!cntx->IsCancelled()) { io::MutableBytes buf = io_buf.AppendBuffer(); io::Result size_res = sock_->Recv(buf); if (!size_res) - return; + return cntx->Error(size_res.error()); last_io_time_ = sock_->proactor()->GetMonotonicTimeNs(); io_buf.CommitWrite(*size_res); repl_offs_ += *size_res; - ec = ParseAndExecute(&io_buf); + if (auto ec = ParseAndExecute(&io_buf); ec) + return cntx->Error(ec); } - - return; } error_code Replica::ReadRespReply(base::IoBuf* io_buf, uint32_t* consumed) { diff --git a/src/server/replica.h b/src/server/replica.h index 6c2500306e45..914f0fe4e212 100644 --- a/src/server/replica.h +++ b/src/server/replica.h @@ -11,6 +11,7 @@ #include "base/io_buf.h" #include "facade/facade_types.h" #include "facade/redis_parser.h" +#include "server/common.h" #include "util/fiber_socket_base.h" namespace facade { @@ -82,21 +83,26 @@ class Replica { std::error_code ConsumeRedisStream(); // Redis stable state. std::error_code ConsumeDflyStream(); // Dragonfly stable state. + void CloseAllSockets(); // Close all sockets. + void JoinAllFlows(); // Join all flows if possible. + + std::error_code SendNextPhaseRequest(); // Send DFLY SYNC or DFLY STARTSTABLE. + private: /* Main dlfly flow mode functions */ // Initialize as single dfly flow. Replica(const MasterContext& context, uint32_t dfly_flow_id, Service* service); // Start replica initialized as dfly flow. - std::error_code StartFullSyncFlow(SyncBlock* block); + std::error_code StartFullSyncFlow(SyncBlock* block, Context* cntx); // Transition into stable state mode as dfly flow. - std::error_code StartStableSyncFlow(); + std::error_code StartStableSyncFlow(Context* cntx); // Single flow full sync fiber spawned by StartFullSyncFlow. - void FullSyncDflyFb(SyncBlock* block, std::string eof_token); + void FullSyncDflyFb(std::string eof_token, SyncBlock* block, Context* cntx); // Single flow stable state sync fiber spawned by StartStableSyncFlow. - void StableSyncDflyFb(); + void StableSyncDflyFb(Context* cntx); private: /* Utility */ struct PSyncResponse { @@ -165,6 +171,10 @@ class Replica { facade::RespVec resp_args_; facade::CmdArgVec cmd_str_args_; + std::optional cntx_; // context for tasks in replica. + ::boost::fibers::mutex cntx_mu_; // guards updating cntx_ with new instances. + ::boost::fibers::mutex running_latch_mu_; // locked when running. + // repl_offs - till what offset we've already read from the master. // ack_offs_ last acknowledged offset. size_t repl_offs_ = 0, ack_offs_ = 0; diff --git a/tests/dragonfly/replication_test.py b/tests/dragonfly/replication_test.py index b0eab05276d8..520d4a830455 100644 --- a/tests/dragonfly/replication_test.py +++ b/tests/dragonfly/replication_test.py @@ -226,3 +226,57 @@ async def disconnect(replica, c_replica, crash_type): # Check master survived all disconnects assert await c_master.ping() + + +""" +Test crashing master and letting replicas re-connect to it. +""" + +master_crash_cases = [ + (8, [8], 100), + (6, [6, 6, 6], 500), + (4, [2] * 8, 500), +] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("t_master, t_replicas, n_keys", master_crash_cases) +async def test_master_crash(df_local_factory, t_master, t_replicas, n_keys): + master = df_local_factory.create(port=1111, proactor_threads=t_master) + replicas = [ + df_local_factory.create(port=BASE_PORT+i+1, proactor_threads=t) + for i, t in enumerate(t_replicas) + ] + + master.start() + for replica in replicas: + replica.start() + + c_master = aioredis.Redis(port=master.port) + c_replicas = [aioredis.Redis(port=replica.port) for replica in replicas] + + # Fill master with test data + await batch_fill_data_async(c_master, gen_test_data(n_keys, seed=0)) + + # Do full sync + async def full_sync(c_replica): + await c_replica.execute_command("REPLICAOF localhost " + str(master.port)) + await wait_available_async(c_replica) + + await asyncio.gather(*(full_sync(c) for c in c_replicas)) + + # Crash master + master.stop(kill=True) + + # Start master + master.start() + + # Fill data with new seed\ + c_master = aioredis.Redis(port=master.port) + await batch_fill_data_async(c_master, gen_test_data(n_keys, seed=1)) + + # Check replicas received it + await asyncio.sleep(0.5 * len(replicas)) # this takes, time, really + for c_replica in c_replicas: + await wait_available_async(c_replica) + await batch_check_data_async(c_replica, gen_test_data(n_keys, seed=1))