From abdbb1ad5e0431f085d11c31df1029a4a3d3fae3 Mon Sep 17 00:00:00 2001 From: Alexey Ozeritskiy Date: Thu, 26 Dec 2024 14:15:23 +0300 Subject: [PATCH 1/9] Move client code to separate class --- examples/kv.cpp | 3 +- examples/kv.h | 3 +- examples/sql.cpp | 7 ++--- src/raft.cpp | 74 ++++++++++++++++++++++++++++++++++++++++++------ src/raft.h | 44 +++++++++++++++++++++++----- src/server.cpp | 5 +--- 6 files changed, 108 insertions(+), 28 deletions(-) diff --git a/examples/kv.cpp b/examples/kv.cpp index d5935aa..9dfed7a 100644 --- a/examples/kv.cpp +++ b/examples/kv.cpp @@ -66,11 +66,10 @@ TMessageHolder TKv::Write(TMessageHolder message, uint64_t return {}; } -TMessageHolder TKv::Prepare(TMessageHolder command, uint64_t term) { +TMessageHolder TKv::Prepare(TMessageHolder command) { auto dataSize = command->Len - sizeof(TCommandRequest); auto entry = NewHoldedMessage(sizeof(TLogEntry)+dataSize); memcpy(entry->Data, command->Data, dataSize); - entry->Term = term; return entry; } diff --git a/examples/kv.h b/examples/kv.h index 5df1bb7..5021ce3 100644 --- a/examples/kv.h +++ b/examples/kv.h @@ -9,9 +9,8 @@ class TKv: public IRsm { public: TMessageHolder Read(TMessageHolder message, uint64_t index) override; TMessageHolder Write(TMessageHolder message, uint64_t index) override; - TMessageHolder Prepare(TMessageHolder message, uint64_t term) override; + TMessageHolder Prepare(TMessageHolder message) override; private: - uint64_t LastAppliedIndex = 0; std::unordered_map H; }; diff --git a/examples/sql.cpp b/examples/sql.cpp index b447834..4dca764 100644 --- a/examples/sql.cpp +++ b/examples/sql.cpp @@ -50,7 +50,7 @@ class TSql: public IRsm { // insert, update, create TMessageHolder Write(TMessageHolder message, uint64_t index) override; // convert request to log message - TMessageHolder Prepare(TMessageHolder message, uint64_t term) override; + TMessageHolder Prepare(TMessageHolder message) override; private: bool Execute(const std::string& q); @@ -59,7 +59,6 @@ class TSql: public IRsm { TResult Result; std::string LastError; - uint64_t LastAppliedIndex = 0; sqlite3* Db = nullptr; }; @@ -185,12 +184,10 @@ TMessageHolder TSql::Reply(const std::string& ans, uint64_t index) return res; } -TMessageHolder TSql::Prepare(TMessageHolder command, uint64_t term) { +TMessageHolder TSql::Prepare(TMessageHolder command) { auto dataSize = command->Len - sizeof(TCommandRequest); - std::cerr << "Prepare entry of size: " << dataSize << ", in term: " << term << std::endl; auto entry = NewHoldedMessage(sizeof(TLogEntry)+dataSize); memcpy(entry->Data, command->Data, dataSize); - entry->Term = term; return entry; } diff --git a/src/raft.cpp b/src/raft.cpp index 7289bac..6941c2c 100644 --- a/src/raft.cpp +++ b/src/raft.cpp @@ -49,12 +49,11 @@ TMessageHolder TDummyRsm::Write(TMessageHolder message, uin return {}; } -TMessageHolder TDummyRsm::Prepare(TMessageHolder command, uint64_t term) +TMessageHolder TDummyRsm::Prepare(TMessageHolder command) { auto dataSize = command->Len - sizeof(TCommandRequest); auto entry = NewHoldedMessage(sizeof(TLogEntry)+dataSize); memcpy(entry->Data, command->Data, dataSize); - entry->Term = term; return entry; } @@ -280,8 +279,7 @@ void TRaft::OnCommandRequest(TMessageHolder command, const std: // TODO: move this logic to separate class if (StateName == EState::LEADER) { if (command->Flags & TCommandRequest::EWrite) { - auto entry = Rsm->Prepare(command, State->CurrentTerm); - State->Append(std::move(entry)); + Append(std::move(Rsm->Prepare(command))); } auto index = State->LastLogIndex; if (replyTo) { @@ -437,7 +435,7 @@ void TRaft::Process(ITimeSource::Time now, TMessageHolder message, con void TRaft::ProcessCommitted() { auto commitIndex = VolatileState->CommitIndex; - for (auto i = VolatileState->LastApplied+1; i <= commitIndex; i++) { + for (auto i = Rsm->LastAppliedIndex+1; i <= commitIndex; i++) { auto entry = State->Get(i-1); if (entry->Flags == TLogEntry::EStub) { continue; @@ -448,11 +446,10 @@ void TRaft::ProcessCommitted() { .Reply = reply ? reply : NewHoldedMessage(TCommandResponse {.Index = i}) }); } - VolatileState->LastApplied = commitIndex; } void TRaft::ProcessWaiting() { - auto lastApplied = VolatileState->LastApplied; + auto lastApplied = Rsm->LastAppliedIndex; while (!Waiting.empty() && Waiting.top().Index <= lastApplied) { auto w = Waiting.top(); Waiting.pop(); TMessageHolder reply; @@ -540,7 +537,6 @@ void TRaft::ProcessTimeout(ITimeSource::Time now) { auto nextVolatileState = std::make_unique(TVolatileState { .CommitIndex = VolatileState->CommitIndex, - .LastApplied = VolatileState->LastApplied, .NextIndex = nextIndex, .RpcDue = rpcDue, .ElectionDue = ITimeSource::Max, @@ -576,3 +572,65 @@ ITimeSource::Time TRaft::MakeElection(ITimeSource::Time now) { uint64_t delta = (uint64_t)((1.0 + (double)rand_(&Seed) / (double)UINT_MAX) * TTimeout::Election.count()); return now + std::chrono::milliseconds(delta); } + +void TRaft::Append(TMessageHolder entry) { + entry->Term = State->CurrentTerm; + State->Append(std::move(entry)); +} + +void TRequestProcessor::OnCommandRequest(TMessageHolder command, const std::shared_ptr& replyTo) { + auto stateName = Raft->CurrentStateName(); + auto index = Raft->GetState()->LastLogIndex; + auto leaderId = Raft->GetVolatileState()->LeaderId; + if (stateName == EState::LEADER) { + if (command->Flags & TCommandRequest::EWrite) { + Raft->Append(std::move(Rsm->Prepare(command))); + } + if (replyTo) { + Waiting.emplace(TWaiting{index, std::move(command), replyTo}); + } + } else if (stateName == EState::FOLLOWER && replyTo) { + if (command->Flags & TCommandRequest::EWrite) { + if (command->Cookie) { + // already forwarded + replyTo->Send(NewHoldedMessage(TCommandResponse{.Cookie = command->Cookie, .ErrorCode = 1})); + return; + } + + if (leaderId == 0) { + // TODO: wait for state change + replyTo->Send(NewHoldedMessage(TCommandResponse{.Cookie = command->Cookie, .ErrorCode = 1})); + return; + } + + assert(leaderId != Id); + // Forward + command->Cookie = std::max(1, ForwardCookie); + Nodes[leaderId]->Send(std::move(command)); + Forwarded[ForwardCookie] = replyTo; + ForwardCookie++; + } else { + Waiting.emplace(TWaiting{index, std::move(command), replyTo}); + } + } else if (stateName == EState::CANDIDATE && replyTo) { + // TODO: wait for state change + replyTo->Send(NewHoldedMessage(TCommandResponse{.Cookie = command->Cookie, .ErrorCode = 1})); + } +} + +void TRequestProcessor::ProcessCommitted() { + auto commitIndex = Raft->GetVolatileState()->CommitIndex; + auto& state = Raft->GetState(); + for (auto i = Rsm->LastAppliedIndex+1; i <= commitIndex; i++) { + auto entry = state->Get(i-1); + if (entry->Flags == TLogEntry::EStub) { + continue; + } + auto reply = Rsm->Write(entry, i); + WriteAnswers.emplace(TAnswer { + .Index = i, + .Reply = reply ? reply : NewHoldedMessage(TCommandResponse {.Index = i}) + }); + } +} + diff --git a/src/raft.h b/src/raft.h index 65b6c39..5fb2cad 100644 --- a/src/raft.h +++ b/src/raft.h @@ -23,16 +23,17 @@ struct IRsm { virtual ~IRsm() = default; virtual TMessageHolder Read(TMessageHolder message, uint64_t index) = 0; virtual TMessageHolder Write(TMessageHolder message, uint64_t index) = 0; - virtual TMessageHolder Prepare(TMessageHolder message, uint64_t term) = 0; + virtual TMessageHolder Prepare(TMessageHolder message) = 0; + + uint64_t LastAppliedIndex = 0; }; struct TDummyRsm: public IRsm { TMessageHolder Read(TMessageHolder message, uint64_t index) override; TMessageHolder Write(TMessageHolder message, uint64_t index) override; - TMessageHolder Prepare(TMessageHolder message, uint64_t term) override; + TMessageHolder Prepare(TMessageHolder message) override; private: - uint64_t LastAppliedIndex; std::vector> Log; }; @@ -40,7 +41,6 @@ using TNodeDict = std::unordered_map>; struct TVolatileState { uint64_t CommitIndex = 0; - uint64_t LastApplied = 0; uint32_t LeaderId = 0; std::unordered_map NextIndex; std::unordered_map MatchIndex; @@ -54,7 +54,6 @@ struct TVolatileState { std::vector Indices; TVolatileState& Vote(uint32_t id); - TVolatileState& SetLastApplied(int index); TVolatileState& CommitAdvance(int nservers, const IState& state); TVolatileState& SetCommitIndex(int index); TVolatileState& SetElectionDue(ITimeSource::Time); @@ -68,7 +67,6 @@ struct TVolatileState { bool operator==(const TVolatileState& other) const { return CommitIndex == other.CommitIndex && - LastApplied == other.LastApplied && NextIndex == other.NextIndex && MatchIndex == other.MatchIndex; } @@ -88,8 +86,10 @@ class TRaft { void Process(ITimeSource::Time now, TMessageHolder message, const std::shared_ptr& replyTo = {}); void ProcessTimeout(ITimeSource::Time now); + void Append(TMessageHolder entry); + // ut - const auto GetState() const { + const auto& GetState() const { return State; } @@ -173,3 +173,33 @@ class TRaft { uint32_t Seed = 31337; }; +class TRequestProcessor { +public: + void OnCommandRequest(TMessageHolder message, const std::shared_ptr& replyTo); + void OnCommandResponse(TMessageHolder message); + void ProcessCommitted(); + +private: + std::shared_ptr Raft; + std::shared_ptr Rsm; + TNodeDict Nodes; + + struct TWaiting { + uint64_t Index; + TMessageHolder Command; + std::shared_ptr ReplyTo; + bool operator< (const TWaiting& other) const { + return Index > other.Index; + } + }; + std::priority_queue Waiting; + + struct TAnswer { + uint64_t Index; + TMessageHolder Reply; + }; + std::queue WriteAnswers; + uint32_t ForwardCookie = 1; + std::unordered_map> Forwarded; +}; + diff --git a/src/server.cpp b/src/server.cpp index 0f56900..9a5dafe 100644 --- a/src/server.cpp +++ b/src/server.cpp @@ -198,8 +198,7 @@ void TRaftServer::DebugPrint() { std::cout << "Leader, " << "Term: " << state.CurrentTerm << ", " << "Index: " << state.LastLogIndex << ", " - << "CommitIndex: " << volatileState.CommitIndex << ", " - << "LastApplied: " << volatileState.LastApplied << ", "; + << "CommitIndex: " << volatileState.CommitIndex << ", "; std::cout << "Delay: "; for (auto [id, index] : volatileState.MatchIndex) { std::cout << id << ":" << (state.LastLogIndex - index) << " "; @@ -218,14 +217,12 @@ void TRaftServer::DebugPrint() { << "Term: " << state.CurrentTerm << ", " << "Index: " << state.LastLogIndex << ", " << "CommitIndex: " << volatileState.CommitIndex << ", " - << "LastApplied: " << volatileState.LastApplied << ", " << "\n"; } else if (Raft->CurrentStateName() == EState::FOLLOWER) { std::cout << "Follower, " << "Term: " << state.CurrentTerm << ", " << "Index: " << state.LastLogIndex << ", " << "CommitIndex: " << volatileState.CommitIndex << ", " - << "LastApplied: " << volatileState.LastApplied << ", " << "\n"; } PersistentFields = state; From b102fea5dcd85e0f945939c3da2750e2df819dc7 Mon Sep 17 00:00:00 2001 From: Alexey Ozeritskiy Date: Thu, 26 Dec 2024 19:25:10 +0300 Subject: [PATCH 2/9] Client code moved out of concencus module --- src/raft.cpp | 139 +++++++++++++------------------------------------ src/raft.h | 36 +++++-------- src/server.cpp | 18 +++++-- src/server.h | 2 + 4 files changed, 65 insertions(+), 130 deletions(-) diff --git a/src/raft.cpp b/src/raft.cpp index 6941c2c..38dcc12 100644 --- a/src/raft.cpp +++ b/src/raft.cpp @@ -135,7 +135,7 @@ TVolatileState& TVolatileState::SetCommitIndex(int index) } TRaft::TRaft(std::shared_ptr rsm, std::shared_ptr state, int node, const TNodeDict& nodes) - : Rsm(rsm) + : Rsm_(rsm) , Id(node) , Nodes(nodes) , MinVotes((nodes.size()+2+nodes.size()%2)/2) @@ -275,55 +275,6 @@ void TRaft::OnAppendEntries(TMessageHolder message) { } } -void TRaft::OnCommandRequest(TMessageHolder command, const std::shared_ptr& replyTo) { - // TODO: move this logic to separate class - if (StateName == EState::LEADER) { - if (command->Flags & TCommandRequest::EWrite) { - Append(std::move(Rsm->Prepare(command))); - } - auto index = State->LastLogIndex; - if (replyTo) { - Waiting.emplace(TWaiting{index, std::move(command), replyTo}); - } - } else if (StateName == EState::FOLLOWER && replyTo) { - if (command->Flags & TCommandRequest::EWrite) { - if (command->Cookie) { - // already forwarded - replyTo->Send(NewHoldedMessage(TCommandResponse{.Cookie = command->Cookie, .ErrorCode = 1})); - return; - } - - if (VolatileState->LeaderId == 0) { - // TODO: wait for state change - replyTo->Send(NewHoldedMessage(TCommandResponse{.Cookie = command->Cookie, .ErrorCode = 1})); - return; - } - - assert(VolatileState->LeaderId != Id); - // Forward - command->Cookie = std::max(1, ForwardCookie); - Nodes[VolatileState->LeaderId]->Send(std::move(command)); - Forwarded[ForwardCookie] = replyTo; - ForwardCookie++; - } else { - Waiting.emplace(TWaiting{State->LastLogIndex, std::move(command), replyTo}); - } - } else if (StateName == EState::CANDIDATE && replyTo) { - // TODO: wait for state change - replyTo->Send(NewHoldedMessage(TCommandResponse{.Cookie = command->Cookie, .ErrorCode = 1})); - } -} - -void TRaft::OnCommandResponse(TMessageHolder command) { - // forwarded - auto it = Forwarded.find(command->Cookie); - if (it == Forwarded.end()) { - return; - } - it->second->Send(std::move(command)); - Forwarded.erase(it); -} - TMessageHolder TRaft::CreateVote(uint32_t nodeId) { auto mes = NewHoldedMessage( TMessageEx {.Src = Id, .Dst = nodeId, .Term = State->CurrentTerm}, @@ -398,13 +349,6 @@ void TRaft::Become(EState newStateName) { } void TRaft::Process(ITimeSource::Time now, TMessageHolder message, const std::shared_ptr& replyTo) { - // client request - if (auto maybeCommandRequest = message.Maybe()) { - return OnCommandRequest(std::move(maybeCommandRequest.Cast()), replyTo); - } else if (auto maybeCommandResponse = message.Maybe()) { - return OnCommandResponse(std::move(maybeCommandResponse.Cast())); - } - if (message.IsEx()) { auto messageEx = message.Cast(); if (messageEx->Term > State->CurrentTerm) { @@ -433,49 +377,10 @@ void TRaft::Process(ITimeSource::Time now, TMessageHolder message, con } } -void TRaft::ProcessCommitted() { - auto commitIndex = VolatileState->CommitIndex; - for (auto i = Rsm->LastAppliedIndex+1; i <= commitIndex; i++) { - auto entry = State->Get(i-1); - if (entry->Flags == TLogEntry::EStub) { - continue; - } - auto reply = Rsm->Write(entry, i); - WriteAnswers.emplace(TAnswer { - .Index = i, - .Reply = reply ? reply : NewHoldedMessage(TCommandResponse {.Index = i}) - }); - } -} - -void TRaft::ProcessWaiting() { - auto lastApplied = Rsm->LastAppliedIndex; - while (!Waiting.empty() && Waiting.top().Index <= lastApplied) { - auto w = Waiting.top(); Waiting.pop(); - TMessageHolder reply; - if (w.Command->Flags & TCommandRequest::EWrite) { - while (!WriteAnswers.empty() && WriteAnswers.front().Index < w.Index) { - WriteAnswers.pop(); - } - assert(!WriteAnswers.empty()); - auto answer = std::move(WriteAnswers.front()); WriteAnswers.pop(); - assert(answer.Index == w.Index); - reply = std::move(answer.Reply.Cast()); - } else { - reply = Rsm->Read(std::move(w.Command), w.Index).Cast(); - } - reply->Cookie = w.Command->Cookie; - w.ReplyTo->Send(std::move(reply)); - } -} - void TRaft::FollowerTimeout(ITimeSource::Time now) { if (VolatileState->ElectionDue <= now) { Become(EState::CANDIDATE); } - - ProcessCommitted(); - ProcessWaiting(); // For forwarded requests } void TRaft::CandidateTimeout(ITimeSource::Time now) { @@ -502,9 +407,6 @@ void TRaft::LeaderTimeout(ITimeSource::Time now) { if (Nservers == 1) { VolatileState->CommitAdvance(Nservers, *State); } - - ProcessCommitted(); - ProcessWaiting(); } void TRaft::ProcessTimeout(ITimeSource::Time now) { @@ -580,14 +482,13 @@ void TRaft::Append(TMessageHolder entry) { void TRequestProcessor::OnCommandRequest(TMessageHolder command, const std::shared_ptr& replyTo) { auto stateName = Raft->CurrentStateName(); - auto index = Raft->GetState()->LastLogIndex; auto leaderId = Raft->GetVolatileState()->LeaderId; if (stateName == EState::LEADER) { if (command->Flags & TCommandRequest::EWrite) { Raft->Append(std::move(Rsm->Prepare(command))); } if (replyTo) { - Waiting.emplace(TWaiting{index, std::move(command), replyTo}); + Waiting.emplace(TWaiting{Raft->GetState()->LastLogIndex, std::move(command), replyTo}); } } else if (stateName == EState::FOLLOWER && replyTo) { if (command->Flags & TCommandRequest::EWrite) { @@ -603,14 +504,14 @@ void TRequestProcessor::OnCommandRequest(TMessageHolder command return; } - assert(leaderId != Id); + assert(leaderId != Raft->GetId()); // Forward command->Cookie = std::max(1, ForwardCookie); Nodes[leaderId]->Send(std::move(command)); Forwarded[ForwardCookie] = replyTo; ForwardCookie++; } else { - Waiting.emplace(TWaiting{index, std::move(command), replyTo}); + Waiting.emplace(TWaiting{Raft->GetState()->LastLogIndex, std::move(command), replyTo}); } } else if (stateName == EState::CANDIDATE && replyTo) { // TODO: wait for state change @@ -618,6 +519,16 @@ void TRequestProcessor::OnCommandRequest(TMessageHolder command } } +void TRequestProcessor::OnCommandResponse(TMessageHolder command) { + // forwarded + auto it = Forwarded.find(command->Cookie); + if (it == Forwarded.end()) { + return; + } + it->second->Send(std::move(command)); + Forwarded.erase(it); +} + void TRequestProcessor::ProcessCommitted() { auto commitIndex = Raft->GetVolatileState()->CommitIndex; auto& state = Raft->GetState(); @@ -632,5 +543,27 @@ void TRequestProcessor::ProcessCommitted() { .Reply = reply ? reply : NewHoldedMessage(TCommandResponse {.Index = i}) }); } + Rsm->LastAppliedIndex = commitIndex; +} + +void TRequestProcessor::ProcessWaiting() { + auto lastApplied = Rsm->LastAppliedIndex; + while (!Waiting.empty() && Waiting.top().Index <= lastApplied) { + auto w = Waiting.top(); Waiting.pop(); + TMessageHolder reply; + if (w.Command->Flags & TCommandRequest::EWrite) { + while (!WriteAnswers.empty() && WriteAnswers.front().Index < w.Index) { + WriteAnswers.pop(); + } + assert(!WriteAnswers.empty()); + auto answer = std::move(WriteAnswers.front()); WriteAnswers.pop(); + assert(answer.Index == w.Index); + reply = std::move(answer.Reply.Cast()); + } else { + reply = Rsm->Read(std::move(w.Command), w.Index).Cast(); + } + reply->Cookie = w.Command->Cookie; + w.ReplyTo->Send(std::move(reply)); + } } diff --git a/src/raft.h b/src/raft.h index 5fb2cad..65ec909 100644 --- a/src/raft.h +++ b/src/raft.h @@ -119,6 +119,10 @@ class TRaft { return Nservers; } + auto& GetRsm() { + return Rsm_; + } + private: void Candidate(ITimeSource::Time now, TMessageHolder message); void Follower(ITimeSource::Time now, TMessageHolder message); @@ -129,20 +133,15 @@ class TRaft { void OnAppendEntries(ITimeSource::Time now, TMessageHolder message); void OnAppendEntries(TMessageHolder message); - void OnCommandRequest(TMessageHolder message, const std::shared_ptr& replyTo); - void OnCommandResponse(TMessageHolder message); - void LeaderTimeout(ITimeSource::Time now); void CandidateTimeout(ITimeSource::Time now); void FollowerTimeout(ITimeSource::Time now); TMessageHolder CreateVote(uint32_t nodeId); TMessageHolder CreateAppendEntries(uint32_t nodeId); - void ProcessCommitted(); - void ProcessWaiting(); ITimeSource::Time MakeElection(ITimeSource::Time now); - std::shared_ptr Rsm; + std::shared_ptr Rsm_; uint32_t Id; TNodeDict Nodes; int MinVotes; @@ -151,33 +150,22 @@ class TRaft { std::shared_ptr State; std::unique_ptr VolatileState; - struct TWaiting { - uint64_t Index; - TMessageHolder Command; - std::shared_ptr ReplyTo; - bool operator< (const TWaiting& other) const { - return Index > other.Index; - } - }; - std::priority_queue Waiting; - - struct TAnswer { - uint64_t Index; - TMessageHolder Reply; - }; - std::queue WriteAnswers; - uint32_t ForwardCookie = 1; - std::unordered_map> Forwarded; - EState StateName; uint32_t Seed = 31337; }; class TRequestProcessor { public: + TRequestProcessor(std::shared_ptr raft, std::shared_ptr rsm, const TNodeDict& nodes) + : Raft(raft) + , Rsm(rsm) + , Nodes(nodes) + { } + void OnCommandRequest(TMessageHolder message, const std::shared_ptr& replyTo); void OnCommandResponse(TMessageHolder message); void ProcessCommitted(); + void ProcessWaiting(); private: std::shared_ptr Raft; diff --git a/src/server.cpp b/src/server.cpp index 9a5dafe..6e901bf 100644 --- a/src/server.cpp +++ b/src/server.cpp @@ -122,8 +122,17 @@ NNet::TVoidTask TRaftServer::InboundConnection(TSocket socket) { Nodes.insert(client); while (true) { auto mes = co_await TMessageReader(client->Sock()).Read(); - Raft->Process(TimeSource->Now(), std::move(mes), client); + // client request + if (auto maybeCommandRequest = mes.template Maybe()) { + RequestProcessor->OnCommandRequest(std::move(maybeCommandRequest.Cast()), client); + } else if (auto maybeCommandResponse = mes.template Maybe()) { + RequestProcessor->OnCommandResponse(std::move(maybeCommandResponse.Cast())); + } else { + Raft->Process(TimeSource->Now(), std::move(mes), client); + } Raft->ProcessTimeout(TimeSource->Now()); + RequestProcessor->ProcessCommitted(); + RequestProcessor->ProcessWaiting(); DrainNodes(); } } catch (const std::exception & ex) { @@ -163,8 +172,11 @@ NNet::TVoidTask TRaftServer::OutboundServe(std::shared_ptrSock()).Read(); // TODO: check message type // TODO: should be only TCommandResponse - Raft->Process(TimeSource->Now(), std::move(mes), nullptr); - DrainNodes(); + if (auto maybeCommandResponse = mes.template Maybe()) { + RequestProcessor->OnCommandResponse(std::move(maybeCommandResponse.Cast())); + RequestProcessor->ProcessWaiting(); + DrainNodes(); + } } catch (const std::exception& ex) { // wait for reconnection std::cerr << "Exception: " << ex.what() << "\n"; diff --git a/src/server.h b/src/server.h index 570f950..5430e4f 100644 --- a/src/server.h +++ b/src/server.h @@ -123,6 +123,7 @@ class TRaftServer { : Poller(poller) , Socket(std::move(socket)) , Raft(raft) + , RequestProcessor(std::make_shared(raft, raft->GetRsm(), nodes)) , TimeSource(ts) { for (const auto& [_, node] : nodes) { @@ -143,6 +144,7 @@ class TRaftServer { typename TSocket::TPoller& Poller; TSocket Socket; std::shared_ptr Raft; + std::shared_ptr RequestProcessor; std::unordered_set> Nodes; std::shared_ptr TimeSource; From c2670b0e4c97c649312637a38fc350b200252039 Mon Sep 17 00:00:00 2001 From: Alexey Ozeritskiy Date: Thu, 26 Dec 2024 20:50:24 +0300 Subject: [PATCH 3/9] Move rsm for request processor --- examples/kv.cpp | 4 ++-- examples/sql.cpp | 4 ++-- server/server.cpp | 6 +++--- src/raft.cpp | 5 ++--- src/raft.h | 7 +------ src/server.h | 3 ++- test/test_raft.cpp | 3 +-- 7 files changed, 13 insertions(+), 19 deletions(-) diff --git a/examples/kv.cpp b/examples/kv.cpp index 9dfed7a..581ff7f 100644 --- a/examples/kv.cpp +++ b/examples/kv.cpp @@ -201,11 +201,11 @@ int main(int argc, char** argv) { if (persist) { state = std::make_shared("state", myHost.Id); } - auto raft = std::make_shared(rsm, state, myHost.Id, nodes); + auto raft = std::make_shared(state, myHost.Id, nodes); TPoller::TSocket socket(NNet::TAddress{myHost.Address, myHost.Port}, loop.Poller()); socket.Bind(); socket.Listen(); - TRaftServer server(loop.Poller(), std::move(socket), raft, nodes, timeSource); + TRaftServer server(loop.Poller(), std::move(socket), raft, rsm, nodes, timeSource); server.Serve(); loop.Loop(); } else { diff --git a/examples/sql.cpp b/examples/sql.cpp index 4dca764..c12752d 100644 --- a/examples/sql.cpp +++ b/examples/sql.cpp @@ -299,11 +299,11 @@ int main(int argc, char** argv) std::shared_ptr rsm = std::make_shared("sql_file.db", myHost.Id); auto state = std::make_shared("sql_state", myHost.Id); - auto raft = std::make_shared(rsm, state, myHost.Id, nodes); + auto raft = std::make_shared(state, myHost.Id, nodes); TPoller::TSocket socket(NNet::TAddress{myHost.Address, myHost.Port}, loop.Poller()); socket.Bind(); socket.Listen(); - TRaftServer server(loop.Poller(), std::move(socket), raft, nodes, timeSource); + TRaftServer server(loop.Poller(), std::move(socket), raft, rsm, nodes, timeSource); server.Serve(); loop.Loop(); } else { diff --git a/server/server.cpp b/server/server.cpp index a1f9342..255e763 100644 --- a/server/server.cpp +++ b/server/server.cpp @@ -72,17 +72,17 @@ int main(int argc, char** argv) { } std::shared_ptr rsm = std::make_shared(); - auto raft = std::make_shared(rsm, std::make_shared(), myHost.Id, nodes); + auto raft = std::make_shared(std::make_shared(), myHost.Id, nodes); TPoller::TSocket socket(NNet::TAddress{myHost.Address, myHost.Port}, loop.Poller()); socket.Bind(); socket.Listen(); if (ssl) { auto sslSocket = NNet::TSslSocket(std::move(socket), *serverContext.get()); - TRaftServer server(loop.Poller(), std::move(sslSocket), raft, nodes, timeSource); + TRaftServer server(loop.Poller(), std::move(sslSocket), raft, rsm, nodes, timeSource); server.Serve(); loop.Loop(); } else { - TRaftServer server(loop.Poller(), std::move(socket), raft, nodes, timeSource); + TRaftServer server(loop.Poller(), std::move(socket), raft, rsm, nodes, timeSource); server.Serve(); loop.Loop(); } diff --git a/src/raft.cpp b/src/raft.cpp index 38dcc12..f63e539 100644 --- a/src/raft.cpp +++ b/src/raft.cpp @@ -134,9 +134,8 @@ TVolatileState& TVolatileState::SetCommitIndex(int index) return *this; } -TRaft::TRaft(std::shared_ptr rsm, std::shared_ptr state, int node, const TNodeDict& nodes) - : Rsm_(rsm) - , Id(node) +TRaft::TRaft(std::shared_ptr state, int node, const TNodeDict& nodes) + : Id(node) , Nodes(nodes) , MinVotes((nodes.size()+2+nodes.size()%2)/2) , Npeers(nodes.size()) diff --git a/src/raft.h b/src/raft.h index 65ec909..72a6c46 100644 --- a/src/raft.h +++ b/src/raft.h @@ -81,7 +81,7 @@ enum class EState: int { class TRaft { public: - TRaft(std::shared_ptr rsm, std::shared_ptr state, int node, const TNodeDict& nodes); + TRaft(std::shared_ptr state, int node, const TNodeDict& nodes); void Process(ITimeSource::Time now, TMessageHolder message, const std::shared_ptr& replyTo = {}); void ProcessTimeout(ITimeSource::Time now); @@ -119,10 +119,6 @@ class TRaft { return Nservers; } - auto& GetRsm() { - return Rsm_; - } - private: void Candidate(ITimeSource::Time now, TMessageHolder message); void Follower(ITimeSource::Time now, TMessageHolder message); @@ -141,7 +137,6 @@ class TRaft { TMessageHolder CreateAppendEntries(uint32_t nodeId); ITimeSource::Time MakeElection(ITimeSource::Time now); - std::shared_ptr Rsm_; uint32_t Id; TNodeDict Nodes; int MinVotes; diff --git a/src/server.h b/src/server.h index 5430e4f..7ed2703 100644 --- a/src/server.h +++ b/src/server.h @@ -118,12 +118,13 @@ class TRaftServer { typename TSocket::TPoller& poller, TSocket socket, const std::shared_ptr& raft, + const std::shared_ptr& rsm, const TNodeDict& nodes, const std::shared_ptr& ts) : Poller(poller) , Socket(std::move(socket)) , Raft(raft) - , RequestProcessor(std::make_shared(raft, raft->GetRsm(), nodes)) + , RequestProcessor(std::make_shared(raft, rsm, nodes)) , TimeSource(ts) { for (const auto& [_, node] : nodes) { diff --git a/test/test_raft.cpp b/test/test_raft.cpp index 7d3454e..e39ce82 100644 --- a/test/test_raft.cpp +++ b/test/test_raft.cpp @@ -66,13 +66,12 @@ std::shared_ptr MakeRaft( int count = 3, TState st = {}) { - std::shared_ptr rsm = std::make_shared(); TNodeDict nodes; for (int i = 2; i <= count; i++) { nodes[i] = std::make_shared(sendFunc); } std::shared_ptr state = std::make_shared(st); - return std::make_shared(std::move(rsm), std::move(state), 1, nodes); + return std::make_shared(std::move(state), 1, nodes); } template From f9bb963473363b2cd55638b2dca19561ab8136c5 Mon Sep 17 00:00:00 2001 From: Alexey Ozeritskiy Date: Thu, 26 Dec 2024 21:16:14 +0300 Subject: [PATCH 4/9] Simplify --- src/raft.cpp | 50 +++++++++++++++++++++++++++----------------------- 1 file changed, 27 insertions(+), 23 deletions(-) diff --git a/src/raft.cpp b/src/raft.cpp index f63e539..83fe04f 100644 --- a/src/raft.cpp +++ b/src/raft.cpp @@ -482,36 +482,40 @@ void TRaft::Append(TMessageHolder entry) { void TRequestProcessor::OnCommandRequest(TMessageHolder command, const std::shared_ptr& replyTo) { auto stateName = Raft->CurrentStateName(); auto leaderId = Raft->GetVolatileState()->LeaderId; - if (stateName == EState::LEADER) { - if (command->Flags & TCommandRequest::EWrite) { - Raft->Append(std::move(Rsm->Prepare(command))); + + // read request + if (! (command->Flags & TCommandRequest::EWrite)) { + if (replyTo) { + Waiting.emplace(TWaiting{Raft->GetState()->LastLogIndex, std::move(command), replyTo}); } + return; + } + + // write request + if (stateName == EState::LEADER) { + Raft->Append(std::move(Rsm->Prepare(command))); if (replyTo) { Waiting.emplace(TWaiting{Raft->GetState()->LastLogIndex, std::move(command), replyTo}); } } else if (stateName == EState::FOLLOWER && replyTo) { - if (command->Flags & TCommandRequest::EWrite) { - if (command->Cookie) { - // already forwarded - replyTo->Send(NewHoldedMessage(TCommandResponse{.Cookie = command->Cookie, .ErrorCode = 1})); - return; - } + if (command->Cookie) { + // already forwarded + replyTo->Send(NewHoldedMessage(TCommandResponse{.Cookie = command->Cookie, .ErrorCode = 1})); + return; + } - if (leaderId == 0) { - // TODO: wait for state change - replyTo->Send(NewHoldedMessage(TCommandResponse{.Cookie = command->Cookie, .ErrorCode = 1})); - return; - } - - assert(leaderId != Raft->GetId()); - // Forward - command->Cookie = std::max(1, ForwardCookie); - Nodes[leaderId]->Send(std::move(command)); - Forwarded[ForwardCookie] = replyTo; - ForwardCookie++; - } else { - Waiting.emplace(TWaiting{Raft->GetState()->LastLogIndex, std::move(command), replyTo}); + if (leaderId == 0) { + // TODO: wait for state change + replyTo->Send(NewHoldedMessage(TCommandResponse{.Cookie = command->Cookie, .ErrorCode = 1})); + return; } + + assert(leaderId != Raft->GetId()); + // Forward + command->Cookie = std::max(1, ForwardCookie); + Nodes[leaderId]->Send(std::move(command)); + Forwarded[ForwardCookie] = replyTo; + ForwardCookie++; } else if (stateName == EState::CANDIDATE && replyTo) { // TODO: wait for state change replyTo->Send(NewHoldedMessage(TCommandResponse{.Cookie = command->Cookie, .ErrorCode = 1})); From 80ccb70f45c8ed104592f892a19ad054341167c3 Mon Sep 17 00:00:00 2001 From: Alexey Ozeritskiy Date: Thu, 26 Dec 2024 22:06:27 +0300 Subject: [PATCH 5/9] Simplify --- src/raft.cpp | 13 +++++++++---- src/raft.h | 3 ++- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/raft.cpp b/src/raft.cpp index 83fe04f..c510c71 100644 --- a/src/raft.cpp +++ b/src/raft.cpp @@ -474,14 +474,19 @@ ITimeSource::Time TRaft::MakeElection(ITimeSource::Time now) { return now + std::chrono::milliseconds(delta); } -void TRaft::Append(TMessageHolder entry) { +uint64_t TRaft::Append(TMessageHolder entry) { entry->Term = State->CurrentTerm; State->Append(std::move(entry)); + return State->LastLogIndex; +} + +uint32_t TRaft::GetLeaderId() const { + return VolatileState->LeaderId; } void TRequestProcessor::OnCommandRequest(TMessageHolder command, const std::shared_ptr& replyTo) { auto stateName = Raft->CurrentStateName(); - auto leaderId = Raft->GetVolatileState()->LeaderId; + auto leaderId = Raft->GetLeaderId(); // read request if (! (command->Flags & TCommandRequest::EWrite)) { @@ -493,9 +498,9 @@ void TRequestProcessor::OnCommandRequest(TMessageHolder command // write request if (stateName == EState::LEADER) { - Raft->Append(std::move(Rsm->Prepare(command))); + auto index = Raft->Append(std::move(Rsm->Prepare(command))); if (replyTo) { - Waiting.emplace(TWaiting{Raft->GetState()->LastLogIndex, std::move(command), replyTo}); + Waiting.emplace(TWaiting{index, std::move(command), replyTo}); } } else if (stateName == EState::FOLLOWER && replyTo) { if (command->Cookie) { diff --git a/src/raft.h b/src/raft.h index 72a6c46..818d26f 100644 --- a/src/raft.h +++ b/src/raft.h @@ -86,7 +86,8 @@ class TRaft { void Process(ITimeSource::Time now, TMessageHolder message, const std::shared_ptr& replyTo = {}); void ProcessTimeout(ITimeSource::Time now); - void Append(TMessageHolder entry); + uint64_t Append(TMessageHolder entry); + uint32_t GetLeaderId() const; // ut const auto& GetState() const { From c5f8a9729d065185cc5f04d65a307aede01fd33a Mon Sep 17 00:00:00 2001 From: Alexey Ozeritskiy Date: Thu, 26 Dec 2024 22:12:24 +0300 Subject: [PATCH 6/9] Simplify --- src/raft.cpp | 37 +++++++++++++++++++++++-------------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/src/raft.cpp b/src/raft.cpp index c510c71..ad79272 100644 --- a/src/raft.cpp +++ b/src/raft.cpp @@ -502,29 +502,38 @@ void TRequestProcessor::OnCommandRequest(TMessageHolder command if (replyTo) { Waiting.emplace(TWaiting{index, std::move(command), replyTo}); } - } else if (stateName == EState::FOLLOWER && replyTo) { - if (command->Cookie) { - // already forwarded - replyTo->Send(NewHoldedMessage(TCommandResponse{.Cookie = command->Cookie, .ErrorCode = 1})); - return; - } + return; + } + + // forwarding write request + if (!replyTo) { + // nothing to forward + return; + } + + if (command->Cookie) { + // already forwarded + replyTo->Send(NewHoldedMessage(TCommandResponse{.Cookie = command->Cookie, .ErrorCode = 1})); + return; + } - if (leaderId == 0) { - // TODO: wait for state change - replyTo->Send(NewHoldedMessage(TCommandResponse{.Cookie = command->Cookie, .ErrorCode = 1})); - return; - } + if (stateName == EState::CANDIDATE || leaderId == 0) { + // TODO: wait for state change + replyTo->Send(NewHoldedMessage(TCommandResponse{.Cookie = command->Cookie, .ErrorCode = 1})); + return; + } + if (stateName == EState::FOLLOWER) { assert(leaderId != Raft->GetId()); // Forward command->Cookie = std::max(1, ForwardCookie); Nodes[leaderId]->Send(std::move(command)); Forwarded[ForwardCookie] = replyTo; ForwardCookie++; - } else if (stateName == EState::CANDIDATE && replyTo) { - // TODO: wait for state change - replyTo->Send(NewHoldedMessage(TCommandResponse{.Cookie = command->Cookie, .ErrorCode = 1})); + return; } + + assert(false && "Wrong state"); } void TRequestProcessor::OnCommandResponse(TMessageHolder command) { From a29ae86b22dcbc9edd79bbec472e21a4e09fcd1b Mon Sep 17 00:00:00 2001 From: Alexey Ozeritskiy Date: Thu, 26 Dec 2024 22:37:00 +0300 Subject: [PATCH 7/9] Simplify --- src/raft.cpp | 12 +++++++++--- src/raft.h | 6 ++---- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/src/raft.cpp b/src/raft.cpp index ad79272..c6db9fb 100644 --- a/src/raft.cpp +++ b/src/raft.cpp @@ -484,6 +484,10 @@ uint32_t TRaft::GetLeaderId() const { return VolatileState->LeaderId; } +uint64_t TRaft::GetLastIndex() const { + return State->LastLogIndex; +} + void TRequestProcessor::OnCommandRequest(TMessageHolder command, const std::shared_ptr& replyTo) { auto stateName = Raft->CurrentStateName(); auto leaderId = Raft->GetLeaderId(); @@ -491,7 +495,8 @@ void TRequestProcessor::OnCommandRequest(TMessageHolder command // read request if (! (command->Flags & TCommandRequest::EWrite)) { if (replyTo) { - Waiting.emplace(TWaiting{Raft->GetState()->LastLogIndex, std::move(command), replyTo}); + assert(Waiting.empty() || Waiting.back().Index <= Raft->GetLastIndex()); + Waiting.emplace(TWaiting{Raft->GetLastIndex(), std::move(command), replyTo}); } return; } @@ -500,6 +505,7 @@ void TRequestProcessor::OnCommandRequest(TMessageHolder command if (stateName == EState::LEADER) { auto index = Raft->Append(std::move(Rsm->Prepare(command))); if (replyTo) { + assert(Waiting.empty() || Waiting.back().Index <= index); Waiting.emplace(TWaiting{index, std::move(command), replyTo}); } return; @@ -565,8 +571,8 @@ void TRequestProcessor::ProcessCommitted() { void TRequestProcessor::ProcessWaiting() { auto lastApplied = Rsm->LastAppliedIndex; - while (!Waiting.empty() && Waiting.top().Index <= lastApplied) { - auto w = Waiting.top(); Waiting.pop(); + while (!Waiting.empty() && Waiting.back().Index <= lastApplied) { + auto w = Waiting.back(); Waiting.pop(); TMessageHolder reply; if (w.Command->Flags & TCommandRequest::EWrite) { while (!WriteAnswers.empty() && WriteAnswers.front().Index < w.Index) { diff --git a/src/raft.h b/src/raft.h index 818d26f..eba0f62 100644 --- a/src/raft.h +++ b/src/raft.h @@ -88,6 +88,7 @@ class TRaft { uint64_t Append(TMessageHolder entry); uint32_t GetLeaderId() const; + uint64_t GetLastIndex() const; // ut const auto& GetState() const { @@ -172,11 +173,8 @@ class TRequestProcessor { uint64_t Index; TMessageHolder Command; std::shared_ptr ReplyTo; - bool operator< (const TWaiting& other) const { - return Index > other.Index; - } }; - std::priority_queue Waiting; + std::queue Waiting; struct TAnswer { uint64_t Index; From 6976ca84f4049309daa00c92d873284fb0c13a67 Mon Sep 17 00:00:00 2001 From: Alexey Ozeritskiy Date: Thu, 26 Dec 2024 23:01:42 +0300 Subject: [PATCH 8/9] Simplify --- src/raft.cpp | 23 +++++++++++++++++++++-- src/raft.h | 2 ++ src/server.cpp | 9 +++++---- 3 files changed, 28 insertions(+), 6 deletions(-) diff --git a/src/raft.cpp b/src/raft.cpp index c6db9fb..c213f83 100644 --- a/src/raft.cpp +++ b/src/raft.cpp @@ -488,6 +488,26 @@ uint64_t TRaft::GetLastIndex() const { return State->LastLogIndex; } +void TRequestProcessor::CheckStateChange() { + if (WaitingStateChange.empty()) { + return; + } + + auto stateName = Raft->CurrentStateName(); + auto leaderId = Raft->GetLeaderId(); + + if (stateName == EState::CANDIDATE || leaderId == 0) { + return; + } + + std::queue apply; + WaitingStateChange.swap(apply); + while (!apply.empty()) { + auto w = std::move(apply.front()); apply.pop(); + OnCommandRequest(std::move(w.Command), w.ReplyTo); + } +} + void TRequestProcessor::OnCommandRequest(TMessageHolder command, const std::shared_ptr& replyTo) { auto stateName = Raft->CurrentStateName(); auto leaderId = Raft->GetLeaderId(); @@ -524,8 +544,7 @@ void TRequestProcessor::OnCommandRequest(TMessageHolder command } if (stateName == EState::CANDIDATE || leaderId == 0) { - // TODO: wait for state change - replyTo->Send(NewHoldedMessage(TCommandResponse{.Cookie = command->Cookie, .ErrorCode = 1})); + WaitingStateChange.emplace(TWaiting{0, std::move(command), replyTo}); return; } diff --git a/src/raft.h b/src/raft.h index eba0f62..8fb0dab 100644 --- a/src/raft.h +++ b/src/raft.h @@ -159,6 +159,7 @@ class TRequestProcessor { , Nodes(nodes) { } + void CheckStateChange(); void OnCommandRequest(TMessageHolder message, const std::shared_ptr& replyTo); void OnCommandResponse(TMessageHolder message); void ProcessCommitted(); @@ -175,6 +176,7 @@ class TRequestProcessor { std::shared_ptr ReplyTo; }; std::queue Waiting; + std::queue WaitingStateChange; struct TAnswer { uint64_t Index; diff --git a/src/server.cpp b/src/server.cpp index 6e901bf..ce490c5 100644 --- a/src/server.cpp +++ b/src/server.cpp @@ -122,7 +122,7 @@ NNet::TVoidTask TRaftServer::InboundConnection(TSocket socket) { Nodes.insert(client); while (true) { auto mes = co_await TMessageReader(client->Sock()).Read(); - // client request + // client request if (auto maybeCommandRequest = mes.template Maybe()) { RequestProcessor->OnCommandRequest(std::move(maybeCommandRequest.Cast()), client); } else if (auto maybeCommandResponse = mes.template Maybe()) { @@ -131,6 +131,7 @@ NNet::TVoidTask TRaftServer::InboundConnection(TSocket socket) { Raft->Process(TimeSource->Now(), std::move(mes), client); } Raft->ProcessTimeout(TimeSource->Now()); + RequestProcessor->CheckStateChange(); RequestProcessor->ProcessCommitted(); RequestProcessor->ProcessWaiting(); DrainNodes(); @@ -170,16 +171,16 @@ NNet::TVoidTask TRaftServer::OutboundServe(std::shared_ptrSock()).Read(); - // TODO: check message type - // TODO: should be only TCommandResponse if (auto maybeCommandResponse = mes.template Maybe()) { RequestProcessor->OnCommandResponse(std::move(maybeCommandResponse.Cast())); RequestProcessor->ProcessWaiting(); DrainNodes(); + } else { + std::cerr << "Wrong message type: " << mes->Type << std::endl; } } catch (const std::exception& ex) { // wait for reconnection - std::cerr << "Exception: " << ex.what() << "\n"; + std::cerr << "Exception: " << ex.what() << std::endl; error = true; } if (error) { From 7009a8168e065a337d226714bada24a327d0a5ba Mon Sep 17 00:00:00 2001 From: Alexey Ozeritskiy Date: Thu, 26 Dec 2024 23:14:51 +0300 Subject: [PATCH 9/9] Cleanup clients --- src/raft.cpp | 25 +++++++++++++++++++++---- src/raft.h | 4 +++- src/server.cpp | 2 +- 3 files changed, 25 insertions(+), 6 deletions(-) diff --git a/src/raft.cpp b/src/raft.cpp index c213f83..69b0581 100644 --- a/src/raft.cpp +++ b/src/raft.cpp @@ -553,7 +553,8 @@ void TRequestProcessor::OnCommandRequest(TMessageHolder command // Forward command->Cookie = std::max(1, ForwardCookie); Nodes[leaderId]->Send(std::move(command)); - Forwarded[ForwardCookie] = replyTo; + Cookie2Client[ForwardCookie] = replyTo; + Client2Cookie[replyTo].emplace(ForwardCookie); ForwardCookie++; return; } @@ -563,12 +564,28 @@ void TRequestProcessor::OnCommandRequest(TMessageHolder command void TRequestProcessor::OnCommandResponse(TMessageHolder command) { // forwarded - auto it = Forwarded.find(command->Cookie); - if (it == Forwarded.end()) { + auto it = Cookie2Client.find(command->Cookie); + if (it == Cookie2Client.end()) { return; } it->second->Send(std::move(command)); - Forwarded.erase(it); + auto jt = Client2Cookie.find(it->second); + jt->second.erase(command->Cookie); + if (jt->second.empty()) { + Client2Cookie.erase(jt); + } + Cookie2Client.erase(it); +} + +void TRequestProcessor::CleanUp(const std::shared_ptr& replyTo) { + auto jt = Client2Cookie.find(replyTo); + if (jt == Client2Cookie.end()) { + return; + } + for (auto cookie : jt->second) { + Cookie2Client.erase(cookie); + } + Client2Cookie.erase(jt); } void TRequestProcessor::ProcessCommitted() { diff --git a/src/raft.h b/src/raft.h index 8fb0dab..17e6ae9 100644 --- a/src/raft.h +++ b/src/raft.h @@ -164,6 +164,7 @@ class TRequestProcessor { void OnCommandResponse(TMessageHolder message); void ProcessCommitted(); void ProcessWaiting(); + void CleanUp(const std::shared_ptr& replyTo); private: std::shared_ptr Raft; @@ -184,6 +185,7 @@ class TRequestProcessor { }; std::queue WriteAnswers; uint32_t ForwardCookie = 1; - std::unordered_map> Forwarded; + std::unordered_map> Cookie2Client; + std::unordered_map, std::unordered_set> Client2Cookie; }; diff --git a/src/server.cpp b/src/server.cpp index ce490c5..c79307b 100644 --- a/src/server.cpp +++ b/src/server.cpp @@ -139,8 +139,8 @@ NNet::TVoidTask TRaftServer::InboundConnection(TSocket socket) { } catch (const std::exception & ex) { std::cerr << "Exception: " << ex.what() << "\n"; } - // TODO: erase also from Forwarded Nodes.erase(client); + RequestProcessor->CleanUp(client); co_return; }