diff --git a/.github/import_generation.txt b/.github/import_generation.txt index 8f92bfdd497..7facc89938b 100644 --- a/.github/import_generation.txt +++ b/.github/import_generation.txt @@ -1 +1 @@ -35 +36 diff --git a/.github/last_commit.txt b/.github/last_commit.txt index e06b8762c3b..1604fb18241 100644 --- a/.github/last_commit.txt +++ b/.github/last_commit.txt @@ -1 +1 @@ -7f34dd9500921f4c8501e75e62488659cb276fb4 +3b235ed1f2fc3977cfc6f99a74123c0097ef9795 diff --git a/.github/scripts/copy_sources.sh b/.github/scripts/copy_sources.sh index b5076f0d34d..3d648a7776c 100755 --- a/.github/scripts/copy_sources.sh +++ b/.github/scripts/copy_sources.sh @@ -7,7 +7,6 @@ echo "Copying sources..." cp -r $1/ydb/public/sdk/cpp/* $tmp_dir echo "tmp_dir: $tmp_dir" -rm -r $tmp_dir/client rm -r $tmp_dir/src/client/arrow rm -r $tmp_dir/src/client/cms rm -r $tmp_dir/src/client/config @@ -23,8 +22,10 @@ rm -r $tmp_dir/include/ydb-cpp-sdk/client/draft rm -r $tmp_dir/tests/unit/client/draft mkdir -p $tmp_dir/src/api/client/yc_private +mkdir -p $tmp_dir/src/api/client/yc_private/accessservice mkdir -p $tmp_dir/src/api/client/yc_public +cp -r $1/ydb/public/api/client/yc_private/accessservice/sensitive.proto $tmp_dir/src/api/client/yc_private/accessservice/sensitive.proto cp -r $1/ydb/public/api/client/yc_private/iam $tmp_dir/src/api/client/yc_private cp -r $1/ydb/public/api/client/yc_private/operation $tmp_dir/src/api/client/yc_private cp -r $1/ydb/public/api/client/yc_public/common $tmp_dir/src/api/client/yc_public @@ -33,7 +34,7 @@ cp -r $1/ydb/public/api/grpc $tmp_dir/src/api cp -r $1/ydb/public/api/protos $tmp_dir/src/api rm -r $tmp_dir/src/api/protos/out -rm $tmp_dir/include/ydb-cpp-sdk/type_switcher.h $tmp_dir/include/ydb-cpp-sdk/client/proto/private.h $tmp_dir/src/version.h +rm $tmp_dir/include/ydb-cpp-sdk/type_switcher.h $tmp_dir/src/version.h cp -r $2/util $tmp_dir cp -r $2/library $tmp_dir @@ -57,7 +58,6 @@ cp $2/tests/slo_workloads/.dockerignore $tmp_dir/tests/slo_workloads cp $2/tests/slo_workloads/Dockerfile $tmp_dir/tests/slo_workloads cp $2/include/ydb-cpp-sdk/type_switcher.h $tmp_dir/include/ydb-cpp-sdk/type_switcher.h -cp $2/include/ydb-cpp-sdk/client/proto/private.h $tmp_dir/include/ydb-cpp-sdk/client/proto/private.h cp $2/src/version.h $tmp_dir/src/version.h cd $2 diff --git a/include/ydb-cpp-sdk/client/discovery/discovery.h b/include/ydb-cpp-sdk/client/discovery/discovery.h index 91be5831f25..870d8adc5b2 100644 --- a/include/ydb-cpp-sdk/client/discovery/discovery.h +++ b/include/ydb-cpp-sdk/client/discovery/discovery.h @@ -90,9 +90,21 @@ class TWhoAmIResult : public TStatus { TWhoAmIResult(TStatus&& status, const Ydb::Discovery::WhoAmIResult& proto); const std::string& GetUserName() const; const std::vector& GetGroups() const; + bool IsAdministrationAllowed() const; + bool IsMonitoringAllowed() const; + bool IsViewerAllowed() const; + bool IsDatabaseAllowed() const; + bool IsRegisterNodeAllowed() const; + bool IsBootstrapAllowed() const; private: std::string UserName_; std::vector Groups_; + bool IsAdministrationAllowed_ = false; + bool IsMonitoringAllowed_ = false; + bool IsViewerAllowed_ = false; + bool IsDatabaseAllowed_ = false; + bool IsRegisterNodeAllowed_ = false; + bool IsBootstrapAllowed_ = false; }; using TAsyncWhoAmIResult = NThreading::TFuture; diff --git a/include/ydb-cpp-sdk/client/driver/driver.h b/include/ydb-cpp-sdk/client/driver/driver.h index 72aa008ccca..41efe5c31e1 100644 --- a/include/ydb-cpp-sdk/client/driver/driver.h +++ b/include/ydb-cpp-sdk/client/driver/driver.h @@ -103,6 +103,13 @@ class TDriverConfig { //! default: true, 30, 5, 10 for linux, and true and OS default for others POSIX TDriverConfig& SetTcpKeepAliveSettings(bool enable, size_t idle, size_t count, size_t interval); + //! Set TCP_NODELAY socket option + //! enable - if true TCP_NODELAY is enabled (default, no Nagle algorithm, low latency, packet fragmentation) + //! - if false TCP_NODELAY is disabled (Nagle algorithm enabled, reduced packet fragmentation) + //! NOTE: This affects network performance. Disable only if you want to reduce packet fragmentation. + //! default: true + TDriverConfig& SetTcpNoDelay(bool enable); + //! Enable or disable drain of client logic (e.g. session pool drain) during dtor call TDriverConfig& SetDrainOnDtors(bool allowed); diff --git a/include/ydb-cpp-sdk/client/federated_topic/federated_topic.h b/include/ydb-cpp-sdk/client/federated_topic/federated_topic.h index 52af62a9aad..1e5b8fffe78 100644 --- a/include/ydb-cpp-sdk/client/federated_topic/federated_topic.h +++ b/include/ydb-cpp-sdk/client/federated_topic/federated_topic.h @@ -517,7 +517,7 @@ class TFederatedTopicClient { std::shared_ptr CreateReadSession(const TFederatedReadSessionSettings& settings); //! Create write session. - // std::shared_ptr CreateSimpleBlockingWriteSession(const TFederatedWriteSessionSettings& settings); + std::shared_ptr CreateSimpleBlockingWriteSession(const TFederatedWriteSessionSettings& settings); std::shared_ptr CreateWriteSession(const TFederatedWriteSessionSettings& settings); struct TClusterInfo { diff --git a/include/ydb-cpp-sdk/client/topic/client.h b/include/ydb-cpp-sdk/client/topic/client.h index 83d324f7e48..88d50cca684 100644 --- a/include/ydb-cpp-sdk/client/topic/client.h +++ b/include/ydb-cpp-sdk/client/topic/client.h @@ -48,6 +48,14 @@ class TTopicClient { //! Create write session. std::shared_ptr CreateSimpleBlockingWriteSession(const TWriteSessionSettings& settings); + + //! Create simple blocking keyed write session. Experimental feature. DO NOT USE IN PRODUCTION. + std::shared_ptr CreateSimpleBlockingKeyedWriteSession(const TKeyedWriteSessionSettings& settings); + + //! Create keyed write session. Experimental feature. DO NOT USE IN PRODUCTION. + std::shared_ptr CreateKeyedWriteSession(const TKeyedWriteSessionSettings& settings); + + //! Create write session. std::shared_ptr CreateWriteSession(const TWriteSessionSettings& settings); // Commit offset diff --git a/include/ydb-cpp-sdk/client/topic/read_session.h b/include/ydb-cpp-sdk/client/topic/read_session.h index 1a7d4be36cc..f79ba7e8f9b 100644 --- a/include/ydb-cpp-sdk/client/topic/read_session.h +++ b/include/ydb-cpp-sdk/client/topic/read_session.h @@ -198,6 +198,9 @@ struct TReadSessionSettings: public TRequestSettings { //! Log. FLUENT_SETTING_OPTIONAL(TLog, Log); + + //! InFlightMemoryController. + FLUENT_SETTING_OPTIONAL(std::uint64_t, PartitionMaxInFlightBytes); }; struct TReadSessionGetEventSettings : public TCommonClientSettingsBase { diff --git a/include/ydb-cpp-sdk/client/topic/write_session.h b/include/ydb-cpp-sdk/client/topic/write_session.h index e0c4d4618e2..b3bba9904d8 100644 --- a/include/ydb-cpp-sdk/client/topic/write_session.h +++ b/include/ydb-cpp-sdk/client/topic/write_session.h @@ -144,6 +144,45 @@ struct TWriteSessionSettings : public TRequestSettings { FLUENT_SETTING_DEFAULT(bool, ValidateSeqNo, true); }; +struct TKeyedWriteSessionSettings : public TWriteSessionSettings { + using TSelf = TKeyedWriteSessionSettings; + + enum class EPartitionChooserStrategy { + Bound, + Hash, + }; + + TKeyedWriteSessionSettings() = default; + TKeyedWriteSessionSettings(const TKeyedWriteSessionSettings&) = default; + TKeyedWriteSessionSettings(TKeyedWriteSessionSettings&&) = default; + + TKeyedWriteSessionSettings& operator=(const TKeyedWriteSessionSettings&) = default; + TKeyedWriteSessionSettings& operator=(TKeyedWriteSessionSettings&&) = default; + + //! Session lifetime. + FLUENT_SETTING_DEFAULT(TDuration, SubSessionIdleTimeout, TDuration::Seconds(30)); + + //! Partition chooser strategy. + FLUENT_SETTING_DEFAULT(EPartitionChooserStrategy, PartitionChooserStrategy, EPartitionChooserStrategy::Bound); + + //! Hasher function. + FLUENT_SETTING_DEFAULT(std::function, PartitioningKeyHasher, DefaultPartitioningKeyHasher); + + //! Default partitioning key hasher. + //! Uses MurmurHash. + static std::string DefaultPartitioningKeyHasher(const std::string_view key); + + //! ProducerId prefix to use. + //! ProducerId is generated as ProducerIdPrefix + partition id. + FLUENT_SETTING(std::string, ProducerIdPrefix); + + //! SessionID to use. + FLUENT_SETTING_DEFAULT(std::string, SessionId, ""); + +private: + using TWriteSessionSettings::ProducerId; +}; + //! Contains the message to write and all the options. struct TWriteMessage { using TSelf = TWriteMessage; @@ -276,4 +315,48 @@ class IWriteSession { virtual ~IWriteSession() = default; }; +//! Keyed write session. Experimental SDK. DO NOT USE IN PRODUCTION. +class IKeyedWriteSession { +public: + //! Write single message. + //! continuationToken - a token earlier provided to client with ReadyToAccept event. + virtual void Write(TContinuationToken&& continuationToken, const std::string& key, TWriteMessage&& message, + TTransactionBase* tx = nullptr) = 0; + + //! Future that is set when next event is available. + virtual NThreading::TFuture WaitEvent() = 0; + + //! Wait and return next event. Use WaitEvent() for non-blocking wait. + virtual std::optional GetEvent(bool block = false) = 0; + + //! Get several events in one call. + //! If blocking = false, instantly returns up to maxEventsCount available events. + //! If blocking = true, blocks till maxEventsCount events are available. + //! If maxEventsCount is unset, write session decides the count to return itself. + virtual std::vector GetEvents(bool block = false, std::optional maxEventsCount = std::nullopt) = 0; + + virtual bool Close(TDuration closeTimeout = TDuration::Max()) = 0; + virtual TWriterCounters::TPtr GetCounters() = 0; + virtual ~IKeyedWriteSession() = default; +}; + +//! Simple blocking keyed write session. Experimental SDK. DO NOT USE IN PRODUCTION. +class ISimpleBlockingKeyedWriteSession { +public: + //! Write single message. + //! continuationToken - a token earlier provided to client with ReadyToAccept event. + virtual bool Write(const std::string& key, TWriteMessage&& message, TTransactionBase* tx = nullptr, + TDuration blockTimeout = TDuration::Max()) = 0; + + //! Wait for all writes to complete (no more that closeTimeout()), then close. + //! Return true if all writes were completed and acked, false if timeout was reached and some writes were aborted. + virtual bool Close(TDuration closeTimeout = TDuration::Max()) = 0; + + //! Writer counters with different stats (see TWriterConuters). + virtual TWriterCounters::TPtr GetCounters() = 0; + + //! Close() with timeout = 0 and destroy everything instantly. + virtual ~ISimpleBlockingKeyedWriteSession() = default; +}; + } // namespace NYdb::NTopic diff --git a/library/cpp/threading/future/CMakeLists.txt b/library/cpp/threading/future/CMakeLists.txt index 3ae66ecf4dc..c34885dea26 100644 --- a/library/cpp/threading/future/CMakeLists.txt +++ b/library/cpp/threading/future/CMakeLists.txt @@ -48,3 +48,5 @@ if (YDB_SDK_TESTS) unit ) endif() + +add_subdirectory(subscription) diff --git a/library/cpp/threading/future/subscription/CMakeLists.txt b/library/cpp/threading/future/subscription/CMakeLists.txt new file mode 100644 index 00000000000..972ecd40b81 --- /dev/null +++ b/library/cpp/threading/future/subscription/CMakeLists.txt @@ -0,0 +1,31 @@ +_ydb_sdk_add_library(threading-future-subscription) + +target_link_libraries(threading-future-subscription PUBLIC + yutil + threading-future +) + +target_sources(threading-future-subscription + PRIVATE + subscription.cpp + wait_all.cpp + wait_all_or_exception.cpp + wait_any.cpp +) + +_ydb_sdk_install_targets(TARGETS threading-future-subscription) + +if (YDB_SDK_TESTS) + add_ydb_test(NAME future-subscription-ut + SOURCES + subscription_ut.cpp + wait_all_ut.cpp + wait_all_or_exception_ut.cpp + wait_any_ut.cpp + wait_ut_common.cpp + LINK_LIBRARIES + threading-future-subscription + LABELS + unit + ) +endif() diff --git a/library/cpp/threading/future/subscription/README.md b/library/cpp/threading/future/subscription/README.md new file mode 100644 index 00000000000..6f547926854 --- /dev/null +++ b/library/cpp/threading/future/subscription/README.md @@ -0,0 +1,104 @@ +Subscriptions manager and wait primitives library +================================================= + +Wait primitives +--------------- + +All wait primitives are futures those being signaled when some or all of theirs dependencies are signaled. +Wait privimitives could be constructed either from an initializer_list or from a standard container of futures. + +1. WaitAll is signaled when all its dependencies are signaled: + + ```C++ + #include + + auto w = NWait::WaitAll({ future1, future2, ..., futureN }); + ... + w.Wait(); // wait for all futures + ``` + +2. WaitAny is signaled when any of its dependencies is signaled: + + ```C++ + #include + + auto w = NWait::WaitAny(TVector>{ future1, future2, ..., futureN }); + ... + w.Wait(); // wait for any future + ``` + +3. WaitAllOrException is signaled when all its dependencies are signaled with values or any dependency is signaled with an exception: + + ```C++ + #include + + auto w = NWait::WaitAllOrException(TVector>{ future1, future2, ..., futureN }); + ... + w.Wait(); // wait for all values or for an exception + ``` + +Subscriptions manager +--------------------- + +The subscription manager can manage multiple links beetween futures and callbacks. Multiple managed subscriptions to a single future shares just a single underlying subscription to the future. That allows dynamic creation and deletion of subscriptions and efficient implementation of different wait primitives. +The subscription manager could be used in the following way: + +1. Subscribe to a single future: + + ```C++ + #include + + TFuture LongOperation(); + + ... + auto future = LongRunnigOperation(); + auto m = MakeSubsriptionManager(); + auto id = m->Subscribe(future, [](TFuture const& f) { + try { + auto value = f.GetValue(); + ... + } catch (...) { + ... // handle exception + } + }); + if (id.has_value()) { + ... // Callback will run asynchronously + } else { + ... // Future has been signaled already. The callback has been invoked synchronously + } + ``` + + Note that a callback could be invoked synchronously during a Subscribe call. In this case the returned optional will have no value. + +2. Unsubscribe from a single future: + + ```C++ + // id holds the subscription id from a previous Subscribe call + m->Unsubscribe(id.value()); + ``` + + There is no need to call Unsubscribe if the callback has been called. In this case Unsubscribe will do nothing. And it is safe to call Unsubscribe with the same id multiple times. + +3. Subscribe a single callback to multiple futures: + + ```C++ + auto ids = m->Subscribe({ future1, future2, ..., futureN }, [](auto&& f) { ... }); + ... + ``` + + Futures could be passed to Subscribe method either via an initializer_list or via a standard container like vector or list. Subscribe method accept an optional boolean parameter revertOnSignaled. If the parameter is false (default) then all subscriptions will be performed regardless of the futures states and the returned vector will have a subscription id for each future (even if callback has been executed synchronously for some futures). Otherwise the method will stop on the first signaled future (the callback will be synchronously called for it), no subscriptions will be created and an empty vector will be returned. + +4. Unsubscribe multiple subscriptions: + + ```C++ + // ids is the vector or subscription ids + m->Unsubscribe(ids); + ``` + + The vector of IDs could be a result of a previous Subscribe call or an arbitrary set of IDs of previously created subscriptions. + +5. If you do not want to instantiate a new instance of the subscription manager it is possible to use the default instance: + + ```C++ + auto m = TSubscriptionManager::Default(); + ``` diff --git a/library/cpp/threading/future/subscription/subscription-inl.h b/library/cpp/threading/future/subscription/subscription-inl.h new file mode 100644 index 00000000000..a9d3b3114c7 --- /dev/null +++ b/library/cpp/threading/future/subscription/subscription-inl.h @@ -0,0 +1,118 @@ +#pragma once + +#if !defined(INCLUDE_LIBRARY_THREADING_FUTURE_SUBSCRIPTION_INL_H) +#error "you should never include subscription-inl.h directly" +#endif + +namespace NThreading { + +namespace NPrivate { + +template +TFutureStateId CheckedStateId(TFuture const& future) { + auto const id = future.StateId(); + if (id.Defined()) { + return *id; + } + ythrow TFutureException() << "Future state should be initialized"; +} + +} + +template +inline TSubscriptionManager::TSubscription::TSubscription(TFuture future, F&& callback, TCallbackExecutor&& executor) + : Callback( + [future = std::move(future), callback = std::forward(callback), executor = std::forward(executor)]() mutable { + executor(std::as_const(future), callback); + }) +{ +} + +template +inline std::optional TSubscriptionManager::Subscribe(TFuture const& future, F&& callback, TCallbackExecutor&& executor) { + auto stateId = NPrivate::CheckedStateId(future); + with_lock(Lock) { + auto const status = TrySubscribe(future, std::forward(callback), stateId, std::forward(executor)); + switch (status) { + case ECallbackStatus::Subscribed: + return TSubscriptionId(stateId, Revision); + case ECallbackStatus::ExecutedSynchronously: + return {}; + default: + Y_ABORT("Unexpected callback status"); + } + } +} + +template +inline TVector TSubscriptionManager::Subscribe(TFutures const& futures, F&& callback, bool revertOnSignaled + , TCallbackExecutor&& executor) +{ + return SubscribeImpl(futures, std::forward(callback), revertOnSignaled, std::forward(executor)); +} + +template +inline TVector TSubscriptionManager::Subscribe(std::initializer_list const> futures, F&& callback + , bool revertOnSignaled, TCallbackExecutor&& executor) +{ + return SubscribeImpl(futures, std::forward(callback), revertOnSignaled, std::forward(executor)); +} + +template +inline TSubscriptionManager::ECallbackStatus TSubscriptionManager::TrySubscribe(TFuture const& future, F&& callback, TFutureStateId stateId + , TCallbackExecutor&& executor) +{ + TSubscription subscription(future, std::forward(callback), std::forward(executor)); + auto const it = Subscriptions.find(stateId); + auto const revision = ++Revision; + if (it == std::end(Subscriptions)) { + auto const success = Subscriptions.emplace(stateId, THashMap{ { revision, std::move(subscription) } }).second; + Y_ABORT_UNLESS(success); + auto self = TSubscriptionManagerPtr(this); + future.Subscribe([self, stateId](TFuture const&) { self->OnCallback(stateId); }); + if (Subscriptions.find(stateId) == std::end(Subscriptions)) { + return ECallbackStatus::ExecutedSynchronously; + } + } else { + Y_ABORT_UNLESS(it->second.emplace(revision, std::move(subscription)).second); + } + return ECallbackStatus::Subscribed; +} + +template +inline TVector TSubscriptionManager::SubscribeImpl(TFutures const& futures, F const& callback, bool revertOnSignaled + , TCallbackExecutor const& executor) +{ + TVector results; + results.reserve(std::size(futures)); + // resolve all state ids to minimize processing under the lock + for (auto const& f : futures) { + results.push_back(TSubscriptionId(NPrivate::CheckedStateId(f), 0)); + } + with_lock(Lock) { + size_t i = 0; + for (auto const& f : futures) { + auto& r = results[i]; + auto const status = TrySubscribe(f, callback, r.StateId(), executor); + switch (status) { + case ECallbackStatus::Subscribed: + break; + case ECallbackStatus::ExecutedSynchronously: + if (revertOnSignaled) { + // revert + results.crop(i); + UnsubscribeImpl(results); + return {}; + } + break; + default: + Y_ABORT("Unexpected callback status"); + } + r.SetSubId(Revision); + ++i; + } + } + return results; +} + +} diff --git a/library/cpp/threading/future/subscription/subscription.cpp b/library/cpp/threading/future/subscription/subscription.cpp new file mode 100644 index 00000000000..e3cb3052c8d --- /dev/null +++ b/library/cpp/threading/future/subscription/subscription.cpp @@ -0,0 +1,65 @@ +#include "subscription.h" + +namespace NThreading { + +bool operator==(TSubscriptionId const& l, TSubscriptionId const& r) noexcept { + return l.StateId() == r.StateId() && l.SubId() == r.SubId(); +} + +bool operator!=(TSubscriptionId const& l, TSubscriptionId const& r) noexcept { + return !(l == r); +} + +void TSubscriptionManager::TSubscription::operator()() { + Callback(); +} + +TSubscriptionManagerPtr TSubscriptionManager::NewInstance() { + return new TSubscriptionManager(); +} + +TSubscriptionManagerPtr TSubscriptionManager::Default() { + static auto instance = NewInstance(); + return instance; +} + +void TSubscriptionManager::Unsubscribe(TSubscriptionId id) { + with_lock(Lock) { + UnsubscribeImpl(id); + } +} + +void TSubscriptionManager::Unsubscribe(TVector const& ids) { + with_lock(Lock) { + UnsubscribeImpl(ids); + } +} + +void TSubscriptionManager::OnCallback(TFutureStateId stateId) noexcept { + THashMap subscriptions; + with_lock(Lock) { + auto const it = Subscriptions.find(stateId); + Y_ABORT_UNLESS(it != Subscriptions.end(), "The callback has been triggered more than once"); + subscriptions.swap(it->second); + Subscriptions.erase(it); + } + for (auto& [_, subscription] : subscriptions) { + subscription(); + } +} + +void TSubscriptionManager::UnsubscribeImpl(TSubscriptionId id) { + auto const it = Subscriptions.find(id.StateId()); + if (it == std::end(Subscriptions)) { + return; + } + it->second.erase(id.SubId()); +} + +void TSubscriptionManager::UnsubscribeImpl(TVector const& ids) { + for (auto const& id : ids) { + UnsubscribeImpl(id); + } +} + +} diff --git a/library/cpp/threading/future/subscription/subscription.h b/library/cpp/threading/future/subscription/subscription.h new file mode 100644 index 00000000000..afe5eda7111 --- /dev/null +++ b/library/cpp/threading/future/subscription/subscription.h @@ -0,0 +1,186 @@ +#pragma once + +#include + +#include +#include +#include +#include + +#include +#include +#include + +namespace NThreading { + +namespace NPrivate { + +struct TNoexceptExecutor { + template + void operator()(TFuture const& future, F&& callee) const noexcept { + return callee(future); + } +}; + +} + +class TSubscriptionManager; + +using TSubscriptionManagerPtr = TIntrusivePtr; + +//! A subscription id +class TSubscriptionId { +private: + TFutureStateId StateId_; + ui64 SubId_; // Secondary id to make the whole subscription id unique + + friend class TSubscriptionManager; + +public: + TFutureStateId StateId() const noexcept { + return StateId_; + } + + ui64 SubId() const noexcept { + return SubId_; + } + +private: + TSubscriptionId(TFutureStateId stateId, ui64 subId) + : StateId_(stateId) + , SubId_(subId) + { + } + + void SetSubId(ui64 subId) noexcept { + SubId_ = subId; + } +}; + +bool operator==(TSubscriptionId const& l, TSubscriptionId const& r) noexcept; +bool operator!=(TSubscriptionId const& l, TSubscriptionId const& r) noexcept; + +//! The subscription manager manages subscriptions to futures +/** It provides an ability to create (and drop) multiple subscriptions to any future + with just a single underlying subscription per future. + + When a future is signaled all its subscriptions are removed. + So, there no need to call Unsubscribe for subscriptions to already signaled futures. + + Warning!!! For correct operation this class imposes the following requirement to futures/promises: + Any used future must be signaled (value or exception set) before the future state destruction. + Otherwise subscriptions and futures may happen. + Current future design does not provide the required guarantee. But that should be fixed soon. +**/ +class TSubscriptionManager final : public TAtomicRefCount { +private: + //! A single subscription + class TSubscription { + private: + std::function Callback; + + public: + template + TSubscription(TFuture future, F&& callback, TCallbackExecutor&& executor); + + void operator()(); + }; + + struct TFutureStateIdHash { + size_t operator()(TFutureStateId const id) const noexcept { + auto const value = id.Value(); + return ::hash()(value); + } + }; + +private: + THashMap, TFutureStateIdHash> Subscriptions; + ui64 Revision = 0; + TMutex Lock; + +public: + //! Creates a new subscription manager instance + static TSubscriptionManagerPtr NewInstance(); + + //! The default subscription manager instance + static TSubscriptionManagerPtr Default(); + + //! Attempts to subscribe the callback to the future + /** Subscription should succeed if the future is not signaled yet. + Otherwise the callback will be called synchronously and nullopt will be returned + + @param future - The future to subscribe to + @param callback - The callback to attach + @return The subscription id on success, nullopt if the future has been signaled already + **/ + template + std::optional Subscribe(TFuture const& future, F&& callback + , TCallbackExecutor&& executor = NPrivate::TNoexceptExecutor()); + + //! Drops the subscription with the given id + /** @param id - The subscription id + **/ + void Unsubscribe(TSubscriptionId id); + + //! Attempts to subscribe the callback to the set of futures + /** @param futures - The futures to subscribe to + @param callback - The callback to attach + @param revertOnSignaled - Shows whether to stop and revert the subscription process if one of the futures is in signaled state + @return The vector of subscription ids if no revert happened or an empty vector otherwise + A subscription id will be valid even if a corresponding future has been signaled + **/ + template + TVector Subscribe(TFutures const& futures, F&& callback, bool revertOnSignaled = false + , TCallbackExecutor&& executor = NPrivate::TNoexceptExecutor()); + + //! Attempts to subscribe the callback to the set of futures + /** @param futures - The futures to subscribe to + @param callback - The callback to attach + @param revertOnSignaled - Shows whether to stop and revert the subscription process if one of the futures is in signaled state + @return The vector of subscription ids if no revert happened or an empty vector otherwise + A subscription id will be valid even if a corresponding future has been signaled + **/ + template + TVector Subscribe(std::initializer_list const> futures, F&& callback, bool revertOnSignaled = false + , TCallbackExecutor&& executor = NPrivate::TNoexceptExecutor()); + + //! Drops the subscriptions with the given ids + /** @param ids - The subscription ids + **/ + void Unsubscribe(TVector const& ids); + +private: + enum class ECallbackStatus { + Subscribed, //! A subscription has been created. The callback will be called asynchronously. + ExecutedSynchronously //! A callback has been called synchronously. No subscription has been created + }; + +private: + //! .ctor + TSubscriptionManager() = default; + //! Processes a callback from a future + void OnCallback(TFutureStateId stateId) noexcept; + //! Attempts to create a subscription + /** This method should be called under the lock + **/ + template + ECallbackStatus TrySubscribe(TFuture const& future, F&& callback, TFutureStateId stateId, TCallbackExecutor&& executor); + //! Batch subscribe implementation + template + TVector SubscribeImpl(TFutures const& futures, F const& callback, bool revertOnSignaled + , TCallbackExecutor const& executor); + //! Unsubscribe implementation + /** This method should be called under the lock + **/ + void UnsubscribeImpl(TSubscriptionId id); + //! Batch unsubscribe implementation + /** This method should be called under the lock + **/ + void UnsubscribeImpl(TVector const& ids); +}; + +} + +#define INCLUDE_LIBRARY_THREADING_FUTURE_SUBSCRIPTION_INL_H +#include "subscription-inl.h" +#undef INCLUDE_LIBRARY_THREADING_FUTURE_SUBSCRIPTION_INL_H diff --git a/library/cpp/threading/future/subscription/subscription_ut.cpp b/library/cpp/threading/future/subscription/subscription_ut.cpp new file mode 100644 index 00000000000..d018ea15cc2 --- /dev/null +++ b/library/cpp/threading/future/subscription/subscription_ut.cpp @@ -0,0 +1,432 @@ +#include "subscription.h" + +#include + +using namespace NThreading; + +Y_UNIT_TEST_SUITE(TSubscriptionManagerTest) { + + Y_UNIT_TEST(TestSubscribeUnsignaled) { + auto m = TSubscriptionManager::NewInstance(); + auto p = NewPromise(); + + size_t callCount = 0; + auto id = m->Subscribe(p.GetFuture(), [&callCount](auto&&) { ++callCount; } ); + UNIT_ASSERT(id.has_value()); + UNIT_ASSERT_EQUAL(callCount, 0); + + p.SetValue(); + UNIT_ASSERT_EQUAL(callCount, 1); + } + + Y_UNIT_TEST(TestSubscribeSignaled) { + auto m = TSubscriptionManager::NewInstance(); + auto f = MakeFuture(); + + size_t callCount = 0; + auto id = m->Subscribe(f, [&callCount](auto&&) { ++callCount; } ); + UNIT_ASSERT(!id.has_value()); + UNIT_ASSERT_EQUAL(callCount, 1); + } + + Y_UNIT_TEST(TestSubscribeUnsignaledAndSignaled) { + auto m = TSubscriptionManager::NewInstance(); + auto p = NewPromise(); + + size_t callCount1 = 0; + auto id1 = m->Subscribe(p.GetFuture(), [&callCount1](auto&&) { ++callCount1; } ); + UNIT_ASSERT(id1.has_value()); + UNIT_ASSERT_EQUAL(callCount1, 0); + + p.SetValue(); + UNIT_ASSERT_EQUAL(callCount1, 1); + + size_t callCount2 = 0; + auto id2 = m->Subscribe(p.GetFuture(), [&callCount2](auto&&) { ++callCount2; } ); + UNIT_ASSERT(!id2.has_value()); + UNIT_ASSERT_EQUAL(callCount2, 1); + UNIT_ASSERT_EQUAL(callCount1, 1); + } + + Y_UNIT_TEST(TestSubscribeUnsubscribeUnsignaled) { + auto m = TSubscriptionManager::NewInstance(); + auto p = NewPromise(); + + size_t callCount = 0; + auto id = m->Subscribe(p.GetFuture(), [&callCount](auto&&) { ++callCount; } ); + UNIT_ASSERT(id.has_value()); + UNIT_ASSERT_EQUAL(callCount, 0); + + m->Unsubscribe(id.value()); + + p.SetValue(); + UNIT_ASSERT_EQUAL(callCount, 0); + } + + Y_UNIT_TEST(TestSubscribeUnsignaledUnsubscribeSignaled) { + auto m = TSubscriptionManager::NewInstance(); + auto p = NewPromise(); + + size_t callCount = 0; + auto id = m->Subscribe(p.GetFuture(), [&callCount](auto&&) { ++callCount; } ); + UNIT_ASSERT(id.has_value()); + UNIT_ASSERT_EQUAL(callCount, 0); + + p.SetValue(); + UNIT_ASSERT_EQUAL(callCount, 1); + + m->Unsubscribe(id.value()); + UNIT_ASSERT_EQUAL(callCount, 1); + } + + Y_UNIT_TEST(TestUnsubscribeTwice) { + auto m = TSubscriptionManager::NewInstance(); + auto p = NewPromise(); + + size_t callCount = 0; + auto id = m->Subscribe(p.GetFuture(), [&callCount](auto&&) { ++callCount; } ); + UNIT_ASSERT(id.has_value()); + UNIT_ASSERT_EQUAL(callCount, 0); + + m->Unsubscribe(id.value()); + UNIT_ASSERT_EQUAL(callCount, 0); + m->Unsubscribe(id.value()); + UNIT_ASSERT_EQUAL(callCount, 0); + + p.SetValue(); + UNIT_ASSERT_EQUAL(callCount, 0); + } + + Y_UNIT_TEST(TestSubscribeOneUnsignaledManyTimes) { + auto m = TSubscriptionManager::NewInstance(); + auto p = NewPromise(); + + size_t callCount1 = 0; + auto id1 = m->Subscribe(p.GetFuture(), [&callCount1](auto&&) { ++callCount1; } ); + size_t callCount2 = 0; + auto id2 = m->Subscribe(p.GetFuture(), [&callCount2](auto&&) { ++callCount2; } ); + size_t callCount3 = 0; + auto id3 = m->Subscribe(p.GetFuture(), [&callCount3](auto&&) { ++callCount3; } ); + + UNIT_ASSERT(id1.has_value()); + UNIT_ASSERT(id2.has_value()); + UNIT_ASSERT(id3.has_value()); + UNIT_ASSERT_UNEQUAL(id1.value(), id2.value()); + UNIT_ASSERT_UNEQUAL(id2.value(), id3.value()); + UNIT_ASSERT_UNEQUAL(id3.value(), id1.value()); + UNIT_ASSERT_EQUAL(callCount1, 0); + UNIT_ASSERT_EQUAL(callCount2, 0); + UNIT_ASSERT_EQUAL(callCount3, 0); + + p.SetValue(); + UNIT_ASSERT_EQUAL(callCount1, 1); + UNIT_ASSERT_EQUAL(callCount2, 1); + UNIT_ASSERT_EQUAL(callCount3, 1); + } + + Y_UNIT_TEST(TestSubscribeOneSignaledManyTimes) { + auto m = TSubscriptionManager::NewInstance(); + auto f = MakeFuture(); + + size_t callCount1 = 0; + auto id1 = m->Subscribe(f, [&callCount1](auto&&) { ++callCount1; } ); + size_t callCount2 = 0; + auto id2 = m->Subscribe(f, [&callCount2](auto&&) { ++callCount2; } ); + size_t callCount3 = 0; + auto id3 = m->Subscribe(f, [&callCount3](auto&&) { ++callCount3; } ); + + UNIT_ASSERT(!id1.has_value()); + UNIT_ASSERT(!id2.has_value()); + UNIT_ASSERT(!id3.has_value()); + UNIT_ASSERT_EQUAL(callCount1, 1); + UNIT_ASSERT_EQUAL(callCount2, 1); + UNIT_ASSERT_EQUAL(callCount3, 1); + } + + Y_UNIT_TEST(TestSubscribeUnsubscribeOneUnsignaledManyTimes) { + auto m = TSubscriptionManager::NewInstance(); + auto p = NewPromise(); + + size_t callCount1 = 0; + auto id1 = m->Subscribe(p.GetFuture(), [&callCount1](auto&&) { ++callCount1; } ); + size_t callCount2 = 0; + auto id2 = m->Subscribe(p.GetFuture(), [&callCount2](auto&&) { ++callCount2; } ); + size_t callCount3 = 0; + auto id3 = m->Subscribe(p.GetFuture(), [&callCount3](auto&&) { ++callCount3; } ); + size_t callCount4 = 0; + auto id4 = m->Subscribe(p.GetFuture(), [&callCount4](auto&&) { ++callCount4; } ); + + UNIT_ASSERT(id1.has_value()); + UNIT_ASSERT(id2.has_value()); + UNIT_ASSERT(id3.has_value()); + UNIT_ASSERT(id4.has_value()); + UNIT_ASSERT_EQUAL(callCount1, 0); + UNIT_ASSERT_EQUAL(callCount2, 0); + UNIT_ASSERT_EQUAL(callCount3, 0); + UNIT_ASSERT_EQUAL(callCount4, 0); + + m->Unsubscribe(id3.value()); + m->Unsubscribe(id1.value()); + UNIT_ASSERT_EQUAL(callCount1, 0); + UNIT_ASSERT_EQUAL(callCount2, 0); + UNIT_ASSERT_EQUAL(callCount3, 0); + UNIT_ASSERT_EQUAL(callCount4, 0); + + p.SetValue(); + UNIT_ASSERT_EQUAL(callCount1, 0); + UNIT_ASSERT_EQUAL(callCount2, 1); + UNIT_ASSERT_EQUAL(callCount3, 0); + UNIT_ASSERT_EQUAL(callCount4, 1); + } + + Y_UNIT_TEST(TestSubscribeManyUnsignaled) { + auto m = TSubscriptionManager::NewInstance(); + auto p1 = NewPromise(); + auto p2 = NewPromise(); + + size_t callCount1 = 0; + auto id1 = m->Subscribe(p1.GetFuture(), [&callCount1](auto&&) { ++callCount1; } ); + size_t callCount2 = 0; + auto id2 = m->Subscribe(p2.GetFuture(), [&callCount2](auto&&) { ++callCount2; } ); + size_t callCount3 = 0; + auto id3 = m->Subscribe(p1.GetFuture(), [&callCount3](auto&&) { ++callCount3; } ); + + UNIT_ASSERT(id1.has_value()); + UNIT_ASSERT(id2.has_value()); + UNIT_ASSERT(id3.has_value()); + UNIT_ASSERT_UNEQUAL(id1.value(), id2.value()); + UNIT_ASSERT_UNEQUAL(id2.value(), id3.value()); + UNIT_ASSERT_UNEQUAL(id3.value(), id1.value()); + UNIT_ASSERT_EQUAL(callCount1, 0); + UNIT_ASSERT_EQUAL(callCount2, 0); + UNIT_ASSERT_EQUAL(callCount3, 0); + + p1.SetValue(33); + UNIT_ASSERT_EQUAL(callCount1, 1); + UNIT_ASSERT_EQUAL(callCount2, 0); + UNIT_ASSERT_EQUAL(callCount3, 1); + + p2.SetValue(111); + UNIT_ASSERT_EQUAL(callCount1, 1); + UNIT_ASSERT_EQUAL(callCount2, 1); + UNIT_ASSERT_EQUAL(callCount3, 1); + } + + Y_UNIT_TEST(TestSubscribeManySignaled) { + auto m = TSubscriptionManager::NewInstance(); + auto f1 = MakeFuture(0); + auto f2 = MakeFuture(1); + + size_t callCount1 = 0; + auto id1 = m->Subscribe(f1, [&callCount1](auto&&) { ++callCount1; } ); + size_t callCount2 = 0; + auto id2 = m->Subscribe(f2, [&callCount2](auto&&) { ++callCount2; } ); + size_t callCount3 = 0; + auto id3 = m->Subscribe(f2, [&callCount3](auto&&) { ++callCount3; } ); + + UNIT_ASSERT(!id1.has_value()); + UNIT_ASSERT(!id2.has_value()); + UNIT_ASSERT(!id3.has_value()); + UNIT_ASSERT_EQUAL(callCount1, 1); + UNIT_ASSERT_EQUAL(callCount2, 1); + UNIT_ASSERT_EQUAL(callCount3, 1); + } + + Y_UNIT_TEST(TestSubscribeManyMixed) { + auto m = TSubscriptionManager::NewInstance(); + auto p1 = NewPromise(); + auto p2 = NewPromise(); + auto f = MakeFuture(42); + + size_t callCount1 = 0; + auto id1 = m->Subscribe(p1.GetFuture(), [&callCount1](auto&&) { ++callCount1; } ); + size_t callCount2 = 0; + auto id2 = m->Subscribe(p2.GetFuture(), [&callCount2](auto&&) { ++callCount2; } ); + size_t callCount3 = 0; + auto id3 = m->Subscribe(f, [&callCount3](auto&&) { ++callCount3; } ); + + UNIT_ASSERT(id1.has_value()); + UNIT_ASSERT(id2.has_value()); + UNIT_ASSERT(!id3.has_value()); + UNIT_ASSERT_UNEQUAL(id1.value(), id2.value()); + UNIT_ASSERT_EQUAL(callCount1, 0); + UNIT_ASSERT_EQUAL(callCount2, 0); + UNIT_ASSERT_EQUAL(callCount3, 1); + + p1.SetValue(45); + UNIT_ASSERT_EQUAL(callCount1, 1); + UNIT_ASSERT_EQUAL(callCount2, 0); + UNIT_ASSERT_EQUAL(callCount3, 1); + + p2.SetValue(-7); + UNIT_ASSERT_EQUAL(callCount1, 1); + UNIT_ASSERT_EQUAL(callCount2, 1); + UNIT_ASSERT_EQUAL(callCount3, 1); + } + + Y_UNIT_TEST(TestSubscribeUnsubscribeMany) { + auto m = TSubscriptionManager::NewInstance(); + auto p1 = NewPromise(); + auto p2 = NewPromise(); + auto p3 = NewPromise(); + + size_t callCount1 = 0; + auto id1 = m->Subscribe(p1.GetFuture(), [&callCount1](auto&&) { ++callCount1; } ); + size_t callCount2 = 0; + auto id2 = m->Subscribe(p2.GetFuture(), [&callCount2](auto&&) { ++callCount2; } ); + size_t callCount3 = 0; + auto id3 = m->Subscribe(p3.GetFuture(), [&callCount3](auto&&) { ++callCount3; } ); + size_t callCount4 = 0; + auto id4 = m->Subscribe(p2.GetFuture(), [&callCount4](auto&&) { ++callCount4; } ); + size_t callCount5 = 0; + auto id5 = m->Subscribe(p1.GetFuture(), [&callCount5](auto&&) { ++callCount5; } ); + + UNIT_ASSERT(id1.has_value()); + UNIT_ASSERT(id2.has_value()); + UNIT_ASSERT(id3.has_value()); + UNIT_ASSERT(id4.has_value()); + UNIT_ASSERT(id5.has_value()); + UNIT_ASSERT_EQUAL(callCount1, 0); + UNIT_ASSERT_EQUAL(callCount2, 0); + UNIT_ASSERT_EQUAL(callCount3, 0); + UNIT_ASSERT_EQUAL(callCount4, 0); + UNIT_ASSERT_EQUAL(callCount5, 0); + + m->Unsubscribe(id1.value()); + p1.SetValue(-1); + UNIT_ASSERT_EQUAL(callCount1, 0); + UNIT_ASSERT_EQUAL(callCount2, 0); + UNIT_ASSERT_EQUAL(callCount3, 0); + UNIT_ASSERT_EQUAL(callCount4, 0); + UNIT_ASSERT_EQUAL(callCount5, 1); + + m->Unsubscribe(id4.value()); + p2.SetValue(23); + UNIT_ASSERT_EQUAL(callCount1, 0); + UNIT_ASSERT_EQUAL(callCount2, 1); + UNIT_ASSERT_EQUAL(callCount3, 0); + UNIT_ASSERT_EQUAL(callCount4, 0); + UNIT_ASSERT_EQUAL(callCount5, 1); + + p3.SetValue(100500); + UNIT_ASSERT_EQUAL(callCount1, 0); + UNIT_ASSERT_EQUAL(callCount2, 1); + UNIT_ASSERT_EQUAL(callCount3, 1); + UNIT_ASSERT_EQUAL(callCount4, 0); + UNIT_ASSERT_EQUAL(callCount5, 1); + } + + Y_UNIT_TEST(TestBulkSubscribeManyUnsignaled) { + auto m = TSubscriptionManager::NewInstance(); + auto p1 = NewPromise(); + auto p2 = NewPromise(); + + size_t callCount = 0; + auto ids = m->Subscribe({ p1.GetFuture(), p2.GetFuture(), p1.GetFuture() }, [&callCount](auto&&) { ++callCount; }); + + UNIT_ASSERT_EQUAL(ids.size(), 3); + UNIT_ASSERT_UNEQUAL(ids[0], ids[1]); + UNIT_ASSERT_UNEQUAL(ids[1], ids[2]); + UNIT_ASSERT_UNEQUAL(ids[2], ids[0]); + UNIT_ASSERT_EQUAL(callCount, 0); + + p1.SetValue(33); + UNIT_ASSERT_EQUAL(callCount, 2); + + p2.SetValue(111); + UNIT_ASSERT_EQUAL(callCount, 3); + } + + Y_UNIT_TEST(TestBulkSubscribeManySignaledNoRevert) { + auto m = TSubscriptionManager::NewInstance(); + auto f1 = MakeFuture(0); + auto f2 = MakeFuture(1); + + size_t callCount = 0; + auto ids = m->Subscribe({ f1, f2, f1 }, [&callCount](auto&&) { ++callCount; }); + + UNIT_ASSERT_EQUAL(ids.size(), 3); + UNIT_ASSERT_UNEQUAL(ids[0], ids[1]); + UNIT_ASSERT_UNEQUAL(ids[1], ids[2]); + UNIT_ASSERT_UNEQUAL(ids[2], ids[0]); + UNIT_ASSERT_EQUAL(callCount, 3); + } + + Y_UNIT_TEST(TestBulkSubscribeManySignaledRevert) { + auto m = TSubscriptionManager::NewInstance(); + auto f1 = MakeFuture(0); + auto f2 = MakeFuture(1); + + size_t callCount = 0; + auto ids = m->Subscribe({ f1, f2, f1 }, [&callCount](auto&&) { ++callCount; }, true); + + UNIT_ASSERT(ids.empty()); + UNIT_ASSERT_EQUAL(callCount, 1); + } + + Y_UNIT_TEST(TestBulkSubscribeManyMixedNoRevert) { + auto m = TSubscriptionManager::NewInstance(); + auto p1 = NewPromise(); + auto p2 = NewPromise(); + auto f = MakeFuture(42); + + size_t callCount = 0; + auto ids = m->Subscribe({ p1.GetFuture(), p2.GetFuture(), f }, [&callCount](auto&&) { ++callCount; } ); + + UNIT_ASSERT_EQUAL(ids.size(), 3); + UNIT_ASSERT_UNEQUAL(ids[0], ids[1]); + UNIT_ASSERT_UNEQUAL(ids[1], ids[2]); + UNIT_ASSERT_UNEQUAL(ids[2], ids[0]); + UNIT_ASSERT_EQUAL(callCount, 1); + + p1.SetValue(45); + UNIT_ASSERT_EQUAL(callCount, 2); + + p2.SetValue(-7); + UNIT_ASSERT_EQUAL(callCount, 3); + } + + Y_UNIT_TEST(TestBulkSubscribeManyMixedRevert) { + auto m = TSubscriptionManager::NewInstance(); + auto p1 = NewPromise(); + auto p2 = NewPromise(); + auto f = MakeFuture(); + + size_t callCount = 0; + auto ids = m->Subscribe({ p1.GetFuture(), f, p2.GetFuture() }, [&callCount](auto&&) { ++callCount; }, true); + + UNIT_ASSERT(ids.empty()); + UNIT_ASSERT_EQUAL(callCount, 1); + + p1.SetValue(); + p2.SetValue(); + UNIT_ASSERT_EQUAL(callCount, 1); + } + + Y_UNIT_TEST(TestBulkSubscribeUnsubscribeMany) { + auto m = TSubscriptionManager::NewInstance(); + auto p1 = NewPromise(); + auto p2 = NewPromise(); + auto p3 = NewPromise(); + + size_t callCount = 0; + auto ids = m->Subscribe( + TVector>{ p1.GetFuture(), p2.GetFuture(), p3.GetFuture(), p2.GetFuture(), p1.GetFuture() } + , [&callCount](auto&&) { ++callCount; } ); + + UNIT_ASSERT_EQUAL(ids.size(), 5); + UNIT_ASSERT_EQUAL(callCount, 0); + + m->Unsubscribe(TVector{ ids[0], ids[3] }); + UNIT_ASSERT_EQUAL(callCount, 0); + + p1.SetValue(-1); + UNIT_ASSERT_EQUAL(callCount, 1); + + p2.SetValue(23); + UNIT_ASSERT_EQUAL(callCount, 2); + + p3.SetValue(100500); + UNIT_ASSERT_EQUAL(callCount, 3); + } +} diff --git a/library/cpp/threading/future/subscription/ut/ya.make b/library/cpp/threading/future/subscription/ut/ya.make new file mode 100644 index 00000000000..9b7e371509b --- /dev/null +++ b/library/cpp/threading/future/subscription/ut/ya.make @@ -0,0 +1,11 @@ +UNITTEST_FOR(library/cpp/threading/future/subscription) + +SRCS( + subscription_ut.cpp + wait_all_ut.cpp + wait_all_or_exception_ut.cpp + wait_any_ut.cpp + wait_ut_common.cpp +) + +END() diff --git a/library/cpp/threading/future/subscription/wait.h b/library/cpp/threading/future/subscription/wait.h new file mode 100644 index 00000000000..533bab9d8d9 --- /dev/null +++ b/library/cpp/threading/future/subscription/wait.h @@ -0,0 +1,119 @@ +#pragma once + +#include "subscription.h" + +#include +#include +#include + + +#include + +namespace NThreading::NPrivate { + +template +class TWait : public TThrRefBase { +private: + TSubscriptionManagerPtr Manager; + TVector Subscriptions; + bool Unsubscribed = false; + +protected: + TAdaptiveLock Lock; + TPromise Promise; + +public: + template + static TFuture Make(TFutures const& futures, TSubscriptionManagerPtr manager, TCallbackExecutor&& executor) { + TIntrusivePtr w(new TDerived(std::move(manager))); + w->Subscribe(futures, std::forward(executor)); + return w->Promise.GetFuture(); + } + +protected: + TWait(TSubscriptionManagerPtr manager) + : Manager(std::move(manager)) + , Subscriptions() + , Unsubscribed(false) + , Lock() + , Promise(NewPromise()) + { + Y_ENSURE(Manager != nullptr); + } + +protected: + //! Unsubscribes all existing subscriptions + /** Lock should be acquired! + **/ + void Unsubscribe() noexcept { + if (Unsubscribed) { + return; + } + Unsubscribe(Subscriptions); + Subscriptions.clear(); + } + +private: + //! Performs a subscription to the given futures + /** Lock should not be acquired! + @param future - The futures to subscribe to + @param callback - The callback to call for each future + **/ + template + void Subscribe(TFutures const& futures, TCallbackExecutor&& executor) { + auto self = TIntrusivePtr(static_cast(this)); + self->BeforeSubscribe(futures); + auto callback = [self = std::move(self)](const auto& future) mutable { + self->Set(future); + }; + auto subscriptions = Manager->Subscribe(futures, callback, TDerived::RevertOnSignaled, std::forward(executor)); + if (subscriptions.empty()) { + return; + } + with_lock (Lock) { + if (Unsubscribed) { + Unsubscribe(subscriptions); + } else { + Subscriptions = std::move(subscriptions); + } + } + } + + void Unsubscribe(TVector& subscriptions) noexcept { + Manager->Unsubscribe(subscriptions); + Unsubscribed = true; + } +}; + +template +TFuture Wait(TFutures const& futures, TSubscriptionManagerPtr manager, TCallbackExecutor&& executor) { + switch (std::size(futures)) { + case 0: + return MakeFuture(); + case 1: + return std::begin(futures)->IgnoreResult(); + default: + return TWaiter::Make(futures, std::move(manager), std::forward(executor)); + } +} + +template +TFuture Wait(std::initializer_list const> futures, TSubscriptionManagerPtr manager, TCallbackExecutor&& executor) { + switch (std::size(futures)) { + case 0: + return MakeFuture(); + case 1: + return std::begin(futures)->IgnoreResult(); + default: + return TWaiter::Make(futures, std::move(manager), std::forward(executor)); + } +} + + +template +TFuture Wait(TFuture const& future1, TFuture const& future2, TSubscriptionManagerPtr manager, TCallbackExecutor&& executor) { + return TWaiter::Make(std::initializer_list const>({ future1, future2 }), std::move(manager) + , std::forward(executor)); +} + +} diff --git a/library/cpp/threading/future/subscription/wait_all.cpp b/library/cpp/threading/future/subscription/wait_all.cpp new file mode 100644 index 00000000000..10e7ee75984 --- /dev/null +++ b/library/cpp/threading/future/subscription/wait_all.cpp @@ -0,0 +1 @@ +#include "wait_all.h" diff --git a/library/cpp/threading/future/subscription/wait_all.h b/library/cpp/threading/future/subscription/wait_all.h new file mode 100644 index 00000000000..8c1e6fea3a5 --- /dev/null +++ b/library/cpp/threading/future/subscription/wait_all.h @@ -0,0 +1,26 @@ +#pragma once + +#include "wait.h" + +namespace NThreading::NWait { + +template +[[nodiscard("This method creates TFuture, wait for it")]] +TFuture WaitAll(TFutures const& futures, TSubscriptionManagerPtr manager = TSubscriptionManager::Default() + , TCallbackExecutor&& executor = TCallbackExecutor()); + +template +[[nodiscard("This method creates TFuture, wait for it")]] +TFuture WaitAll(std::initializer_list const> futures, TSubscriptionManagerPtr manager = TSubscriptionManager::Default() + , TCallbackExecutor&& executor = TCallbackExecutor()); + +template +[[nodiscard("This method creates TFuture, wait for it")]] +TFuture WaitAll(TFuture const& future1, TFuture const& future2, TSubscriptionManagerPtr manager = TSubscriptionManager::Default() + , TCallbackExecutor&& executor = TCallbackExecutor()); + +} + +#define INCLUDE_LIBRARY_THREADING_FUTURE_WAIT_ALL_INL_H +#include "wait_all_inl.h" +#undef INCLUDE_LIBRARY_THREADING_FUTURE_WAIT_ALL_INL_H diff --git a/library/cpp/threading/future/subscription/wait_all_inl.h b/library/cpp/threading/future/subscription/wait_all_inl.h new file mode 100644 index 00000000000..a3b665f6427 --- /dev/null +++ b/library/cpp/threading/future/subscription/wait_all_inl.h @@ -0,0 +1,80 @@ +#pragma once + +#if !defined(INCLUDE_LIBRARY_THREADING_FUTURE_WAIT_ALL_INL_H) +#error "you should never include wait_all_inl.h directly" +#endif + +#include "subscription.h" + +#include + +namespace NThreading::NWait { + +namespace NPrivate { + +class TWaitAll final : public NThreading::NPrivate::TWait { +private: + size_t Count = 0; + std::exception_ptr Exception; + + static constexpr bool RevertOnSignaled = false; + + using TBase = NThreading::NPrivate::TWait; + friend TBase; + +private: + TWaitAll(TSubscriptionManagerPtr manager) + : TBase(std::move(manager)) + , Count(0) + , Exception() + { + } + + template + void BeforeSubscribe(TFutures const& futures) { + Count = std::size(futures); + Y_ENSURE(Count > 0, "It is meaningless to use this class with empty futures set"); + } + + template + void Set(TFuture const& future) { + with_lock (TBase::Lock) { + if (!Exception) { + try { + future.TryRethrow(); + } catch (...) { + Exception = std::current_exception(); + } + } + + if (--Count == 0) { + // there is no need to call Unsubscribe here since all futures are signaled + Y_ASSERT(!TBase::Promise.HasValue() && !TBase::Promise.HasException()); + if (Exception) { + TBase::Promise.SetException(std::move(Exception)); + } else { + TBase::Promise.SetValue(); + } + } + } + } +}; + +} + +template +TFuture WaitAll(TFutures const& futures, TSubscriptionManagerPtr manager, TCallbackExecutor&& executor) { + return NThreading::NPrivate::Wait(futures, std::move(manager), std::forward(executor)); +} + +template +TFuture WaitAll(std::initializer_list const> futures, TSubscriptionManagerPtr manager, TCallbackExecutor&& executor) { + return NThreading::NPrivate::Wait(futures, std::move(manager), std::forward(executor)); +} + +template +TFuture WaitAll(TFuture const& future1, TFuture const& future2, TSubscriptionManagerPtr manager, TCallbackExecutor&& executor) { + return NThreading::NPrivate::Wait(future1, future2, std::move(manager), std::forward(executor)); +} + +} diff --git a/library/cpp/threading/future/subscription/wait_all_or_exception.cpp b/library/cpp/threading/future/subscription/wait_all_or_exception.cpp new file mode 100644 index 00000000000..0c73ddeb84a --- /dev/null +++ b/library/cpp/threading/future/subscription/wait_all_or_exception.cpp @@ -0,0 +1 @@ +#include "wait_all_or_exception.h" diff --git a/library/cpp/threading/future/subscription/wait_all_or_exception.h b/library/cpp/threading/future/subscription/wait_all_or_exception.h new file mode 100644 index 00000000000..10bba2bffa7 --- /dev/null +++ b/library/cpp/threading/future/subscription/wait_all_or_exception.h @@ -0,0 +1,28 @@ +#pragma once + +#include "wait.h" + +namespace NThreading::NWait { + +template +[[nodiscard("This method creates TFuture, wait for it")]] +TFuture WaitAllOrException(TFutures const& futures, TSubscriptionManagerPtr manager = TSubscriptionManager::Default() + , TCallbackExecutor&& executor = TCallbackExecutor()); + +template +[[nodiscard("This method creates TFuture, wait for it")]] +TFuture WaitAllOrException(std::initializer_list const> futures + , TSubscriptionManagerPtr manager = TSubscriptionManager::Default() + , TCallbackExecutor&& executor = TCallbackExecutor()); + +template +[[nodiscard("This method creates TFuture, wait for it")]] +TFuture WaitAllOrException(TFuture const& future1, TFuture const& future2 + , TSubscriptionManagerPtr manager = TSubscriptionManager::Default() + , TCallbackExecutor&& executor = TCallbackExecutor()); + +} + +#define INCLUDE_LIBRARY_THREADING_FUTURE_WAIT_ALL_OR_EXCEPTION_INL_H +#include "wait_all_or_exception_inl.h" +#undef INCLUDE_LIBRARY_THREADING_FUTURE_WAIT_ALL_OR_EXCEPTION_INL_H diff --git a/library/cpp/threading/future/subscription/wait_all_or_exception_inl.h b/library/cpp/threading/future/subscription/wait_all_or_exception_inl.h new file mode 100644 index 00000000000..fcd9782d543 --- /dev/null +++ b/library/cpp/threading/future/subscription/wait_all_or_exception_inl.h @@ -0,0 +1,79 @@ +#pragma once + +#if !defined(INCLUDE_LIBRARY_THREADING_FUTURE_WAIT_ALL_OR_EXCEPTION_INL_H) +#error "you should never include wait_all_or_exception_inl.h directly" +#endif + +#include "subscription.h" + +#include + +namespace NThreading::NWait { + +namespace NPrivate { + +class TWaitAllOrException final : public NThreading::NPrivate::TWait +{ +private: + size_t Count = 0; + + static constexpr bool RevertOnSignaled = false; + + using TBase = NThreading::NPrivate::TWait; + friend TBase; + +private: + TWaitAllOrException(TSubscriptionManagerPtr manager) + : TBase(std::move(manager)) + , Count(0) + { + } + + template + void BeforeSubscribe(TFutures const& futures) { + Count = std::size(futures); + Y_ENSURE(Count > 0, "It is meaningless to use this class with empty futures set"); + } + + template + void Set(TFuture const& future) { + with_lock (TBase::Lock) { + try { + future.TryRethrow(); + if (--Count == 0) { + // there is no need to call Unsubscribe here since all futures are signaled + TBase::Promise.SetValue(); + } + } catch (...) { + Y_ASSERT(!TBase::Promise.HasValue()); + TBase::Unsubscribe(); + if (!TBase::Promise.HasException()) { + TBase::Promise.SetException(std::current_exception()); + } + } + } + } +}; + +} + +template +TFuture WaitAllOrException(TFutures const& futures, TSubscriptionManagerPtr manager, TCallbackExecutor&& executor) { + return NThreading::NPrivate::Wait(futures, std::move(manager), std::forward(executor)); +} + +template +TFuture WaitAllOrException(std::initializer_list const> futures, TSubscriptionManagerPtr manager + , TCallbackExecutor&& executor) +{ + return NThreading::NPrivate::Wait(futures, std::move(manager), std::forward(executor)); +} +template +TFuture WaitAllOrException(TFuture const& future1, TFuture const& future2, TSubscriptionManagerPtr manager + , TCallbackExecutor&& executor) +{ + return NThreading::NPrivate::Wait(future1, future2, std::move(manager) + , std::forward(executor)); +} + +} diff --git a/library/cpp/threading/future/subscription/wait_all_or_exception_ut.cpp b/library/cpp/threading/future/subscription/wait_all_or_exception_ut.cpp new file mode 100644 index 00000000000..34ae9edb4e6 --- /dev/null +++ b/library/cpp/threading/future/subscription/wait_all_or_exception_ut.cpp @@ -0,0 +1,167 @@ +#include "wait_all_or_exception.h" +#include "wait_ut_common.h" + +#include +#include + +#include +#include + +using namespace NThreading; + +Y_UNIT_TEST_SUITE(TWaitAllOrExceptionTest) { + + Y_UNIT_TEST(TestTwoUnsignaled) { + auto p1 = NewPromise(); + auto p2 = NewPromise(); + auto w = NWait::WaitAllOrException(p1.GetFuture(), p2.GetFuture()); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + p1.SetValue(10); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + p2.SetValue(1); + UNIT_ASSERT(w.HasValue()); + } + + Y_UNIT_TEST(TestTwoUnsignaledWithException) { + auto p1 = NewPromise(); + auto p2 = NewPromise(); + auto w = NWait::WaitAllOrException(p1.GetFuture(), p2.GetFuture()); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + constexpr TStringBuf message = "Test exception"; + p1.SetException(std::make_exception_ptr(yexception() << message)); + UNIT_ASSERT_EXCEPTION_SATISFIES(w.TryRethrow(), yexception, [message](auto const& e) { + return message == e.what(); + }); + + p2.SetValue(-11); + } + + Y_UNIT_TEST(TestOneUnsignaledOneSignaled) { + auto p = NewPromise(); + auto f = MakeFuture(); + auto w = NWait::WaitAllOrException(p.GetFuture(), f); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + p.SetValue(); + UNIT_ASSERT(w.HasValue()); + } + + Y_UNIT_TEST(TestOneUnsignaledOneSignaledWithException) { + auto p = NewPromise(); + auto f = MakeFuture(); + auto w = NWait::WaitAllOrException(f, p.GetFuture()); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + constexpr TStringBuf message = "Test exception 2"; + p.SetException(std::make_exception_ptr(yexception() << message)); + UNIT_ASSERT_EXCEPTION_SATISFIES(w.TryRethrow(), yexception, [message](auto const& e) { + return message == e.what(); + }); + } + + Y_UNIT_TEST(TestEmptyInitializer) { + auto w = NWait::WaitAllOrException(std::initializer_list const>({})); + UNIT_ASSERT(w.HasValue()); + } + + Y_UNIT_TEST(TestEmptyVector) { + auto w = NWait::WaitAllOrException(TVector>()); + UNIT_ASSERT(w.HasValue()); + } + + Y_UNIT_TEST(TestOneUnsignaledWithInitializer) { + auto p = NewPromise(); + auto w = NWait::WaitAllOrException({ p.GetFuture() }); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + p.SetValue(1); + UNIT_ASSERT(w.HasValue()); + } + + Y_UNIT_TEST(TestOneUnsignaledWithVector) { + auto p = NewPromise(); + auto w = NWait::WaitAllOrException(TVector>{ p.GetFuture() }); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + constexpr TStringBuf message = "Test exception 3"; + p.SetException(std::make_exception_ptr(yexception() << message)); + UNIT_ASSERT_EXCEPTION_SATISFIES(w.TryRethrow(), yexception, [message](auto const& e) { + return message == e.what(); + }); + } + + Y_UNIT_TEST(TestManyWithInitializer) { + auto p1 = NewPromise(); + auto p2 = NewPromise(); + auto f = MakeFuture(42); + auto w = NWait::WaitAllOrException({ p1.GetFuture(), f, p2.GetFuture() }); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + p1.SetValue(10); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + p2.SetValue(-3); + UNIT_ASSERT(w.HasValue()); + } + + Y_UNIT_TEST(TestManyWithVector) { + auto p1 = NewPromise(); + auto p2 = NewPromise(); + auto f = MakeFuture(42); + auto w = NWait::WaitAllOrException(TVector>{ p1.GetFuture(), f, p2.GetFuture() }); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + constexpr TStringBuf message = "Test exception 4"; + p1.SetException(std::make_exception_ptr(yexception() << message)); + UNIT_ASSERT_EXCEPTION_SATISFIES(w.TryRethrow(), yexception, [message](auto const& e) { + return message == e.what(); + }); + + p2.SetValue(34); + } + + Y_UNIT_TEST(TestManyWithVectorAndIntialError) { + auto p1 = NewPromise(); + auto p2 = NewPromise(); + constexpr TStringBuf message = "Test exception 5"; + auto f = MakeErrorFuture(std::make_exception_ptr(yexception() << message)); + auto w = NWait::WaitAllOrException(TVector>{ p1.GetFuture(), p2.GetFuture(), f }); + UNIT_ASSERT_EXCEPTION_SATISFIES(w.TryRethrow(), yexception, [message](auto const& e) { + return message == e.what(); + }); + + p1.SetValue(); + p2.SetValue(); + } + + Y_UNIT_TEST(TestManyStress) { + NTest::TestManyStress([](auto&& futures) { return NWait::WaitAllOrException(futures); } + , [](size_t) { + return [](auto&& p) { p.SetValue(); }; + } + , [](auto&& waiter) { UNIT_ASSERT(waiter.HasValue()); }); + + NTest::TestManyStress([](auto&& futures) { return NWait::WaitAllOrException(futures); } + , [](size_t) { + return [](auto&& p) { p.SetValue(22); }; + } + , [](auto&& waiter) { UNIT_ASSERT(waiter.HasValue()); }); + auto e = std::make_exception_ptr(yexception() << "Test exception 6"); + std::atomic index = 0; + NTest::TestManyStress([](auto&& futures) { return NWait::WaitAllOrException(futures); } + , [e, &index](size_t size) { + auto exceptionIndex = size / 2; + index = 0; + return [e, exceptionIndex, &index](auto&& p) { + if (index++ == exceptionIndex) { + p.SetException(e); + } else { + p.SetValue(); + } + }; + } + , [](auto&& waiter) { UNIT_ASSERT(waiter.HasException()); }); + } + +} diff --git a/library/cpp/threading/future/subscription/wait_all_ut.cpp b/library/cpp/threading/future/subscription/wait_all_ut.cpp new file mode 100644 index 00000000000..3bc9762671c --- /dev/null +++ b/library/cpp/threading/future/subscription/wait_all_ut.cpp @@ -0,0 +1,161 @@ +#include "wait_all.h" +#include "wait_ut_common.h" + +#include +#include + +#include +#include + +using namespace NThreading; + +Y_UNIT_TEST_SUITE(TWaitAllTest) { + + Y_UNIT_TEST(TestTwoUnsignaled) { + auto p1 = NewPromise(); + auto p2 = NewPromise(); + auto w = NWait::WaitAll(p1.GetFuture(), p2.GetFuture()); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + p1.SetValue(10); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + p2.SetValue(1); + UNIT_ASSERT(w.HasValue()); + } + + Y_UNIT_TEST(TestTwoUnsignaledWithException) { + auto p1 = NewPromise(); + auto p2 = NewPromise(); + auto w = NWait::WaitAll(p1.GetFuture(), p2.GetFuture()); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + constexpr TStringBuf message = "Test exception"; + p1.SetException(std::make_exception_ptr(yexception() << message)); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + p2.SetValue(-11); + UNIT_ASSERT_EXCEPTION_SATISFIES(w.TryRethrow(), yexception, [message](auto const& e) { + return message == e.what(); + }); + } + + Y_UNIT_TEST(TestOneUnsignaledOneSignaled) { + auto p = NewPromise(); + auto f = MakeFuture(); + auto w = NWait::WaitAll(p.GetFuture(), f); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + p.SetValue(); + UNIT_ASSERT(w.HasValue()); + } + + Y_UNIT_TEST(TestOneUnsignaledOneSignaledWithException) { + auto p = NewPromise(); + auto f = MakeFuture(); + auto w = NWait::WaitAll(f, p.GetFuture()); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + constexpr TStringBuf message = "Test exception 2"; + p.SetException(std::make_exception_ptr(yexception() << message)); + UNIT_ASSERT_EXCEPTION_SATISFIES(w.TryRethrow(), yexception, [message](auto const& e) { + return message == e.what(); + }); + } + + Y_UNIT_TEST(TestEmptyInitializer) { + auto w = NWait::WaitAll(std::initializer_list const>({})); + UNIT_ASSERT(w.HasValue()); + } + + Y_UNIT_TEST(TestEmptyVector) { + auto w = NWait::WaitAll(TVector>()); + UNIT_ASSERT(w.HasValue()); + } + + Y_UNIT_TEST(TestOneUnsignaledWithInitializer) { + auto p = NewPromise(); + auto w = NWait::WaitAll({ p.GetFuture() }); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + p.SetValue(1); + UNIT_ASSERT(w.HasValue()); + } + + Y_UNIT_TEST(TestOneUnsignaledWithVector) { + auto p = NewPromise(); + auto w = NWait::WaitAll(TVector>{ p.GetFuture() }); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + constexpr TStringBuf message = "Test exception 3"; + p.SetException(std::make_exception_ptr(yexception() << message)); + UNIT_ASSERT_EXCEPTION_SATISFIES(w.TryRethrow(), yexception, [message](auto const& e) { + return message == e.what(); + }); + } + + Y_UNIT_TEST(TestManyWithInitializer) { + auto p1 = NewPromise(); + auto p2 = NewPromise(); + auto f = MakeFuture(42); + auto w = NWait::WaitAll({ p1.GetFuture(), f, p2.GetFuture() }); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + p1.SetValue(10); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + p2.SetValue(-3); + UNIT_ASSERT(w.HasValue()); + } + + Y_UNIT_TEST(TestManyWithVector) { + auto p1 = NewPromise(); + auto p2 = NewPromise(); + auto f = MakeFuture(42); + auto w = NWait::WaitAll(TVector>{ p1.GetFuture(), f, p2.GetFuture() }); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + constexpr TStringBuf message = "Test exception 4"; + p1.SetException(std::make_exception_ptr(yexception() << message)); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + p2.SetValue(34); + UNIT_ASSERT_EXCEPTION_SATISFIES(w.TryRethrow(), yexception, [message](auto const& e) { + return message == e.what(); + }); + } + + Y_UNIT_TEST(TestManyStress) { + NTest::TestManyStress([](auto&& futures) { return NWait::WaitAll(futures); } + , [](size_t) { + return [](auto&& p) { p.SetValue(42); }; + } + , [](auto&& waiter) { UNIT_ASSERT(waiter.HasValue()); }); + + NTest::TestManyStress([](auto&& futures) { return NWait::WaitAll(futures); } + , [](size_t) { + return [](auto&& p) { p.SetValue(); }; + } + , [](auto&& waiter) { UNIT_ASSERT(waiter.HasValue()); }); + auto e = std::make_exception_ptr(yexception() << "Test exception 5"); + NTest::TestManyStress([](auto&& futures) { return NWait::WaitAll(futures); } + , [e](size_t) { + return [e](auto&& p) { p.SetException(e); }; + } + , [](auto&& waiter) { UNIT_ASSERT(waiter.HasException()); }); + e = std::make_exception_ptr(yexception() << "Test exception 6"); + std::atomic index = 0; + NTest::TestManyStress([](auto&& futures) { return NWait::WaitAll(futures); } + , [e, &index](size_t size) { + auto exceptionIndex = size / 2; + index = 0; + return [e, exceptionIndex, &index](auto&& p) { + if (index++ == exceptionIndex) { + p.SetException(e); + } else { + p.SetValue(index); + } + }; + } + , [](auto&& waiter) { UNIT_ASSERT(waiter.HasException()); }); + } + +} diff --git a/library/cpp/threading/future/subscription/wait_any.cpp b/library/cpp/threading/future/subscription/wait_any.cpp new file mode 100644 index 00000000000..57cc1b2c253 --- /dev/null +++ b/library/cpp/threading/future/subscription/wait_any.cpp @@ -0,0 +1 @@ +#include "wait_any.h" diff --git a/library/cpp/threading/future/subscription/wait_any.h b/library/cpp/threading/future/subscription/wait_any.h new file mode 100644 index 00000000000..969e307a897 --- /dev/null +++ b/library/cpp/threading/future/subscription/wait_any.h @@ -0,0 +1,26 @@ +#pragma once + +#include "wait.h" + +namespace NThreading::NWait { + +template +[[nodiscard("This method creates TFuture, wait for it")]] +TFuture WaitAny(TFutures const& futures, TSubscriptionManagerPtr manager = TSubscriptionManager::Default() + , TCallbackExecutor&& executor = TCallbackExecutor()); + +template +[[nodiscard("This method creates TFuture, wait for it")]] +TFuture WaitAny(std::initializer_list const> futures, TSubscriptionManagerPtr manager = TSubscriptionManager::Default() + , TCallbackExecutor&& executor = TCallbackExecutor()); + +template +[[nodiscard("This method creates TFuture, wait for it")]] +TFuture WaitAny(TFuture const& future1, TFuture const& future2, TSubscriptionManagerPtr manager = TSubscriptionManager::Default() + , TCallbackExecutor&& executor = TCallbackExecutor()); + +} + +#define INCLUDE_LIBRARY_THREADING_FUTURE_WAIT_ANY_INL_H +#include "wait_any_inl.h" +#undef INCLUDE_LIBRARY_THREADING_FUTURE_WAIT_ANY_INL_H diff --git a/library/cpp/threading/future/subscription/wait_any_inl.h b/library/cpp/threading/future/subscription/wait_any_inl.h new file mode 100644 index 00000000000..e80822bfc9c --- /dev/null +++ b/library/cpp/threading/future/subscription/wait_any_inl.h @@ -0,0 +1,64 @@ +#pragma once + +#if !defined(INCLUDE_LIBRARY_THREADING_FUTURE_WAIT_ANY_INL_H) +#error "you should never include wait_any_inl.h directly" +#endif + +#include "subscription.h" + +#include + +namespace NThreading::NWait { + +namespace NPrivate { + +class TWaitAny final : public NThreading::NPrivate::TWait { +private: + static constexpr bool RevertOnSignaled = true; + + using TBase = NThreading::NPrivate::TWait; + friend TBase; + +private: + TWaitAny(TSubscriptionManagerPtr manager) + : TBase(std::move(manager)) + { + } + + template + void BeforeSubscribe(TFutures const& futures) { + Y_ENSURE(std::size(futures) > 0, "Futures set cannot be empty"); + } + + template + void Set(TFuture const& future) { + with_lock (TBase::Lock) { + TBase::Unsubscribe(); + try { + future.TryRethrow(); + TBase::Promise.TrySetValue(); + } catch (...) { + TBase::Promise.TrySetException(std::current_exception()); + } + } + } +}; + +} + +template +TFuture WaitAny(TFutures const& futures, TSubscriptionManagerPtr manager, TCallbackExecutor&& executor) { + return NThreading::NPrivate::Wait(futures, std::move(manager), std::forward(executor)); +} + +template +TFuture WaitAny(std::initializer_list const> futures, TSubscriptionManagerPtr manager, TCallbackExecutor&& executor) { + return NThreading::NPrivate::Wait(futures, std::move(manager), std::forward(executor)); +} + +template +TFuture WaitAny(TFuture const& future1, TFuture const& future2, TSubscriptionManagerPtr manager, TCallbackExecutor&& executor) { + return NThreading::NPrivate::Wait(future1, future2, std::move(manager), std::forward(executor)); +} + +} diff --git a/library/cpp/threading/future/subscription/wait_any_ut.cpp b/library/cpp/threading/future/subscription/wait_any_ut.cpp new file mode 100644 index 00000000000..262080e8d12 --- /dev/null +++ b/library/cpp/threading/future/subscription/wait_any_ut.cpp @@ -0,0 +1,166 @@ +#include "wait_any.h" +#include "wait_ut_common.h" + +#include +#include + +#include + +using namespace NThreading; + +Y_UNIT_TEST_SUITE(TWaitAnyTest) { + + Y_UNIT_TEST(TestTwoUnsignaled) { + auto p1 = NewPromise(); + auto p2 = NewPromise(); + auto w = NWait::WaitAny(p1.GetFuture(), p2.GetFuture()); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + p1.SetValue(10); + UNIT_ASSERT(w.HasValue()); + p2.SetValue(1); + } + + Y_UNIT_TEST(TestTwoUnsignaledWithException) { + auto p1 = NewPromise(); + auto p2 = NewPromise(); + auto w = NWait::WaitAny(p1.GetFuture(), p2.GetFuture()); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + constexpr TStringBuf message = "Test exception"; + p2.SetException(std::make_exception_ptr(yexception() << message)); + UNIT_ASSERT_EXCEPTION_SATISFIES(w.TryRethrow(), yexception, [message](auto const& e) { + return message == e.what(); + }); + + p1.SetValue(-11); + } + + Y_UNIT_TEST(TestOneUnsignaledOneSignaled) { + auto p = NewPromise(); + auto f = MakeFuture(); + auto w = NWait::WaitAny(p.GetFuture(), f); + UNIT_ASSERT(w.HasValue()); + + p.SetValue(); + } + + Y_UNIT_TEST(TestOneUnsignaledOneSignaledWithException) { + auto p = NewPromise(); + constexpr TStringBuf message = "Test exception 2"; + auto f = MakeErrorFuture(std::make_exception_ptr(yexception() << message)); + auto w = NWait::WaitAny(f, p.GetFuture()); + UNIT_ASSERT_EXCEPTION_SATISFIES(w.TryRethrow(), yexception, [message](auto const& e) { + return message == e.what(); + }); + + p.SetValue(); + } + + Y_UNIT_TEST(TestEmptyInitializer) { + auto w = NWait::WaitAny(std::initializer_list const>({})); + UNIT_ASSERT(w.HasValue()); + } + + Y_UNIT_TEST(TestEmptyVector) { + auto w = NWait::WaitAny(TVector>()); + UNIT_ASSERT(w.HasValue()); + } + + Y_UNIT_TEST(TestOneUnsignaledWithInitializer) { + auto p = NewPromise(); + auto w = NWait::WaitAny({ p.GetFuture() }); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + p.SetValue(1); + UNIT_ASSERT(w.HasValue()); + } + + Y_UNIT_TEST(TestOneUnsignaledWithVector) { + auto p = NewPromise(); + auto w = NWait::WaitAny(TVector>{ p.GetFuture() }); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + constexpr TStringBuf message = "Test exception 3"; + p.SetException(std::make_exception_ptr(yexception() << message)); + UNIT_ASSERT_EXCEPTION_SATISFIES(w.TryRethrow(), yexception, [message](auto const& e) { + return message == e.what(); + }); + } + + Y_UNIT_TEST(TestManyUnsignaledWithInitializer) { + auto p1 = NewPromise(); + auto p2 = NewPromise(); + auto p3 = NewPromise(); + auto w = NWait::WaitAny({ p1.GetFuture(), p2.GetFuture(), p3.GetFuture() }); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + p1.SetValue(42); + UNIT_ASSERT(w.HasValue()); + + p2.SetValue(-3); + p3.SetValue(12); + } + + Y_UNIT_TEST(TestManyMixedWithInitializer) { + auto p1 = NewPromise(); + auto p2 = NewPromise(); + auto f = MakeFuture(42); + auto w = NWait::WaitAny({ p1.GetFuture(), f, p2.GetFuture() }); + UNIT_ASSERT(w.HasValue()); + + p1.SetValue(10); + p2.SetValue(-3); + } + + + Y_UNIT_TEST(TestManyUnsignaledWithVector) { + auto p1 = NewPromise(); + auto p2 = NewPromise(); + auto p3 = NewPromise(); + auto w = NWait::WaitAny(TVector>{ p1.GetFuture(), p2.GetFuture(), p3.GetFuture() }); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + constexpr TStringBuf message = "Test exception 4"; + p2.SetException(std::make_exception_ptr(yexception() << message)); + UNIT_ASSERT_EXCEPTION_SATISFIES(w.TryRethrow(), yexception, [message](auto const& e) { + return message == e.what(); + }); + + p1.SetValue(); + p3.SetValue(); + } + + + Y_UNIT_TEST(TestManyMixedWithVector) { + auto p1 = NewPromise(); + auto p2 = NewPromise(); + auto f = MakeFuture(); + auto w = NWait::WaitAny(TVector>{ p1.GetFuture(), p2.GetFuture(), f }); + UNIT_ASSERT(w.HasValue()); + + p1.SetValue(); + p2.SetValue(); + } + + Y_UNIT_TEST(TestManyStress) { + NTest::TestManyStress([](auto&& futures) { return NWait::WaitAny(futures); } + , [](size_t) { + return [](auto&& p) { p.SetValue(); }; + } + , [](auto&& waiter) { UNIT_ASSERT(waiter.HasValue()); }); + + NTest::TestManyStress([](auto&& futures) { return NWait::WaitAny(futures); } + , [](size_t) { + return [](auto&& p) { p.SetValue(22); }; + } + , [](auto&& waiter) { UNIT_ASSERT(waiter.HasValue()); }); + auto e = std::make_exception_ptr(yexception() << "Test exception 5"); + NTest::TestManyStress([](auto&& futures) { return NWait::WaitAny(futures); } + , [e](size_t) { + return [e](auto&& p) { p.SetException(e); }; + } + , [](auto&& waiter) { UNIT_ASSERT(waiter.HasException()); }); + } + +} diff --git a/library/cpp/threading/future/subscription/wait_ut_common.cpp b/library/cpp/threading/future/subscription/wait_ut_common.cpp new file mode 100644 index 00000000000..9f961e73036 --- /dev/null +++ b/library/cpp/threading/future/subscription/wait_ut_common.cpp @@ -0,0 +1,26 @@ +#include "wait_ut_common.h" + +#include +#include +#include + +namespace NThreading::NTest::NPrivate { + +void ExecuteAndWait(TVector> jobs, TFuture waiter, size_t threads) { + Y_ENSURE(threads > 0); + Shuffle(jobs.begin(), jobs.end()); + auto pool = CreateThreadPool(threads); + TManualEvent start; + for (auto& j : jobs) { + pool->SafeAddFunc( + [&start, job = std::move(j)]() { + start.WaitI(); + job(); + }); + } + start.Signal(); + waiter.Wait(); + pool->Stop(); +} + +} diff --git a/library/cpp/threading/future/subscription/wait_ut_common.h b/library/cpp/threading/future/subscription/wait_ut_common.h new file mode 100644 index 00000000000..99530dd1f67 --- /dev/null +++ b/library/cpp/threading/future/subscription/wait_ut_common.h @@ -0,0 +1,56 @@ +#pragma once + +#include +#include + +#include + +#include +#include + +namespace NThreading::NTest { + +namespace NPrivate { + +void ExecuteAndWait(TVector> jobs, TFuture waiter, size_t threads); + +template +void SetConcurrentAndWait(TPromises&& promises, FSetter&& setter, TFuture waiter, size_t threads = 8) { + TVector> jobs; + jobs.reserve(std::size(promises)); + for (auto& p : promises) { + jobs.push_back([p, setter]() mutable {setter(p); }); + } + ExecuteAndWait(std::move(jobs), std::move(waiter), threads); +} + +template +auto MakePromise() { + if constexpr (std::is_same_v) { + return NewPromise(); + } + return NewPromise(); +} + +} + +template +void TestManyStress(FWaiterFactory&& waiterFactory, FSetterFactory&& setterFactory, FChecker&& checker) { + for (size_t i : { 1, 2, 4, 8, 16, 32, 64, 128, 256 }) { + TVector> promises; + TVector> futures; + promises.reserve(i); + futures.reserve(i); + for (size_t j = 0; j < i; ++j) { + auto promise = NPrivate::MakePromise(); + futures.push_back(promise.GetFuture()); + promises.push_back(std::move(promise)); + } + auto waiter = waiterFactory(futures); + NPrivate::SetConcurrentAndWait(std::move(promises), [valueSetter = setterFactory(i)](auto&& p) { valueSetter(p); } + , waiter); + checker(waiter); + } +} + +} diff --git a/library/cpp/threading/future/subscription/ya.make b/library/cpp/threading/future/subscription/ya.make new file mode 100644 index 00000000000..759c80f3394 --- /dev/null +++ b/library/cpp/threading/future/subscription/ya.make @@ -0,0 +1,18 @@ +LIBRARY() + +SRCS( + subscription.cpp + wait_all.cpp + wait_all_or_exception.cpp + wait_any.cpp +) + +PEERDIR( + library/cpp/threading/future +) + +END() + +RECURSE_FOR_TESTS( + ut +) diff --git a/src/api/client/yc_private/accessservice/sensitive.proto b/src/api/client/yc_private/accessservice/sensitive.proto new file mode 100644 index 00000000000..009620bd991 --- /dev/null +++ b/src/api/client/yc_private/accessservice/sensitive.proto @@ -0,0 +1,29 @@ +syntax = "proto3"; + +// Based on: +// https://bb.yandexcloud.net/projects/CLOUD/repos/cloud-go/browse/private-api/yandex/cloud/priv/sensitive.proto + +package yandex.cloud; + +import "google/protobuf/descriptor.proto"; + +option go_package = "cloud/proto_extensions"; + +enum SensitiveType { + SENSITIVE_TYPE_UNSPECIFIED = 0; + SENSITIVE_CRC = 1; + SENSITIVE_IAM_TOKEN = 2; + SENSITIVE_REMOVE = 3; + SENSITIVE_YANDEX_PASSPORT_OAUTH_TOKEN = 4; + SENSITIVE_IAM_COOKIE = 5; + SENSITIVE_REFRESH_TOKEN = 6; + SENSITIVE_SESSION_TOKEN = 7; +} + +extend google.protobuf.FieldOptions { + // novikoff: + // Sensitive fields are hidden in logs + // For now could be applied only to string fields + bool sensitive = 110601; + SensitiveType sensitive_type = 110602; +} diff --git a/src/api/client/yc_private/iam/iam_token.proto b/src/api/client/yc_private/iam/iam_token.proto index 3d7f41f6eb8..47b89f82792 100644 --- a/src/api/client/yc_private/iam/iam_token.proto +++ b/src/api/client/yc_private/iam/iam_token.proto @@ -3,8 +3,9 @@ syntax = "proto3"; package yandex.cloud.priv.iam.v1; import "google/protobuf/timestamp.proto"; +import "src/api/client/yc_private/accessservice/sensitive.proto"; message IamToken { - string iam_token = 1; + string iam_token = 1 [(sensitive) = true]; google.protobuf.Timestamp expires_at = 2; } diff --git a/src/api/client/yc_private/iam/iam_token_service.proto b/src/api/client/yc_private/iam/iam_token_service.proto index d900dff629f..b94bd9da098 100644 --- a/src/api/client/yc_private/iam/iam_token_service.proto +++ b/src/api/client/yc_private/iam/iam_token_service.proto @@ -4,6 +4,7 @@ package yandex.cloud.priv.iam.v1; import "google/api/annotations.proto"; import "google/protobuf/timestamp.proto"; +import "src/api/client/yc_private/accessservice/sensitive.proto"; import "src/api/client/yc_private/iam/iam_token_service_subject.proto"; import "src/api/client/yc_private/iam/yandex_passport_cookie.proto"; import "src/api/client/yc_private/iam/oauth_request.proto"; @@ -35,8 +36,8 @@ service IamTokenService { message CreateIamTokenRequest { oneof identity { - string yandex_passport_oauth_token = 1; - string jwt = 2; + string yandex_passport_oauth_token = 1 [(sensitive) = true]; + string jwt = 2 [(sensitive) = true]; string iam_cookie = 3; YandexPassportCookies yandex_passport_cookies = 4; } @@ -67,7 +68,7 @@ message CreateIamTokenForComputeInstanceRequest { } message CreateIamTokenResponse { - string iam_token = 1; + string iam_token = 1 [(sensitive) = true]; google.protobuf.Timestamp issued_at = 4; google.protobuf.Timestamp expires_at = 2; ts.Subject subject = 3; diff --git a/src/api/protos/draft/ydb_maintenance.proto b/src/api/protos/draft/ydb_maintenance.proto index 2d3af6457fd..17f96b9847b 100644 --- a/src/api/protos/draft/ydb_maintenance.proto +++ b/src/api/protos/draft/ydb_maintenance.proto @@ -80,6 +80,12 @@ enum AvailabilityMode { // Ignore any storage group & state storage checks. // Using this mode might cause data unavailability. AVAILABILITY_MODE_FORCE = 3; + + // In this mode: + // - attempts to apply AVAILABILITY_MODE_STRONG; + // - if strong constraints cannot be satisfied, falls back to AVAILABILITY_MODE_WEAK; + // - never escalates to AVAILABILITY_MODE_FORCE. + AVAILABILITY_MODE_SMART = 4; } message MaintenanceTaskOptions { diff --git a/src/api/protos/ydb_discovery.proto b/src/api/protos/ydb_discovery.proto index ea6580f85a1..2779becca22 100644 --- a/src/api/protos/ydb_discovery.proto +++ b/src/api/protos/ydb_discovery.proto @@ -63,6 +63,18 @@ message WhoAmIResult { string user = 1; // List of group SIDs (Security IDs) for the user repeated string groups = 2; + // Whether user is allowed to perform administration operations + bool is_administration_allowed = 3; + // Whether user is allowed to perform monitoring operations + bool is_monitoring_allowed = 4; + // Whether user is allowed to view data + bool is_viewer_allowed = 5; + // Whether user is allowed to access database + bool is_database_allowed = 6; + // Whether user is allowed to register dynamic node + bool is_register_node_allowed = 7; + // Whether user is allowed to bootstrap + bool is_bootstrap_allowed = 8; } message WhoAmIResponse { diff --git a/src/api/protos/ydb_persqueue_v1.proto b/src/api/protos/ydb_persqueue_v1.proto index 04655c3852d..3eac02b2d89 100644 --- a/src/api/protos/ydb_persqueue_v1.proto +++ b/src/api/protos/ydb_persqueue_v1.proto @@ -385,7 +385,7 @@ message StreamingReadClientMessage { int64 partition_session_id = 1; } - // Signal for server that client is not ready to recieve more data from this partition. + // Signal for server that client is not ready to receive more data from this partition. message PauseReadRequest { repeated int64 partition_session_ids = 1; } @@ -490,7 +490,7 @@ message StreamingReadServerMessage { } // Command to create and start a partition session. - // Client must react on this signal by sending StartRead when ready recieve data from this partition. + // Client must react on this signal by sending StartRead when ready receive data from this partition. message StartPartitionSessionRequest { // Partition partition stream description. PartitionSession partition_session = 1; diff --git a/src/api/protos/ydb_topic.proto b/src/api/protos/ydb_topic.proto index 40f2c31ea4f..f3fdf7c7a8d 100644 --- a/src/api/protos/ydb_topic.proto +++ b/src/api/protos/ydb_topic.proto @@ -327,6 +327,8 @@ message StreamReadMessage { bool direct_read = 4; // Indicates that the SDK supports auto partitioning. bool auto_partitioning_support = 5; + // Max in flight bytes per partition + uint64 partition_max_in_flight_bytes = 6; message TopicReadSettings { // Topic path. diff --git a/src/client/discovery/discovery.cpp b/src/client/discovery/discovery.cpp index 2da501ead72..21e0d6102b8 100644 --- a/src/client/discovery/discovery.cpp +++ b/src/client/discovery/discovery.cpp @@ -60,6 +60,12 @@ TWhoAmIResult::TWhoAmIResult(TStatus&& status, const Ydb::Discovery::WhoAmIResul for (const auto& group : groups) { Groups_.emplace_back(group); } + IsAdministrationAllowed_ = proto.is_administration_allowed(); + IsMonitoringAllowed_ = proto.is_monitoring_allowed(); + IsViewerAllowed_ = proto.is_viewer_allowed(); + IsDatabaseAllowed_ = proto.is_database_allowed(); + IsRegisterNodeAllowed_ = proto.is_register_node_allowed(); + IsBootstrapAllowed_ = proto.is_bootstrap_allowed(); } const std::string& TWhoAmIResult::GetUserName() const { @@ -70,6 +76,30 @@ const std::vector& TWhoAmIResult::GetGroups() const { return Groups_; } +bool TWhoAmIResult::IsAdministrationAllowed() const { + return IsAdministrationAllowed_; +} + +bool TWhoAmIResult::IsMonitoringAllowed() const { + return IsMonitoringAllowed_; +} + +bool TWhoAmIResult::IsViewerAllowed() const { + return IsViewerAllowed_; +} + +bool TWhoAmIResult::IsDatabaseAllowed() const { + return IsDatabaseAllowed_; +} + +bool TWhoAmIResult::IsRegisterNodeAllowed() const { + return IsRegisterNodeAllowed_; +} + +bool TWhoAmIResult::IsBootstrapAllowed() const { + return IsBootstrapAllowed_; +} + TNodeLocation::TNodeLocation(const Ydb::Discovery::NodeLocation& location) : DataCenterNum(location.has_data_center_num() ? std::make_optional(location.data_center_num()) : std::nullopt) , RoomNum(location.has_room_num() ? std::make_optional(location.room_num()) : std::nullopt) diff --git a/src/client/driver/driver.cpp b/src/client/driver/driver.cpp index 207c67b6d5f..f9e2af436c4 100644 --- a/src/client/driver/driver.cpp +++ b/src/client/driver/driver.cpp @@ -40,6 +40,7 @@ class TDriverConfig::TImpl : public IConnectionsParams { EDiscoveryMode GetDiscoveryMode() const override { return DiscoveryMode; } size_t GetMaxQueuedRequests() const override { return MaxQueuedRequests; } TTcpKeepAliveSettings GetTcpKeepAliveSettings() const override { return TcpKeepAliveSettings; } + bool GetTcpNoDelay() const override { return TcpNoDelay; } bool GetDrinOnDtors() const override { return DrainOnDtors; } TBalancingPolicy::TImpl GetBalancingSettings() const override { return BalancingSettings; } TDuration GetGRpcKeepAliveTimeout() const override { return GRpcKeepAliveTimeout; } @@ -69,6 +70,7 @@ class TDriverConfig::TImpl : public IConnectionsParams { TCP_KEEPALIVE_COUNT, TCP_KEEPALIVE_INTERVAL }; + bool TcpNoDelay = true; bool DrainOnDtors = true; TBalancingPolicy::TImpl BalancingSettings = TBalancingPolicy::TImpl::UsePreferableLocation(std::nullopt); TDuration GRpcKeepAliveTimeout = TDuration::Seconds(10); @@ -170,6 +172,11 @@ TDriverConfig& TDriverConfig::SetTcpKeepAliveSettings(bool enable, size_t idle, return *this; } +TDriverConfig& TDriverConfig::SetTcpNoDelay(bool enable) { + Impl_->TcpNoDelay = enable; + return *this; +} + TDriverConfig& TDriverConfig::SetGrpcMemoryQuota(uint64_t bytes) { Impl_->MemoryQuota = bytes; return *this; @@ -271,6 +278,7 @@ TDriverConfig TDriver::GetConfig() const { Impl_->TcpKeepAliveSettings_.Count, Impl_->TcpKeepAliveSettings_.Interval ); + config.SetTcpNoDelay(Impl_->TcpNoDelay_); config.SetDrainOnDtors(Impl_->DrainOnDtors_); config.SetBalancingPolicy(std::make_unique(Impl_->BalancingSettings_)); config.SetGRpcKeepAliveTimeout(std::chrono::duration_cast(Impl_->GRpcKeepAliveTimeout_)); diff --git a/src/client/federated_topic/impl/federated_topic.cpp b/src/client/federated_topic/impl/federated_topic.cpp index 5b0f7f2b19b..33f39ec23a1 100644 --- a/src/client/federated_topic/impl/federated_topic.cpp +++ b/src/client/federated_topic/impl/federated_topic.cpp @@ -73,10 +73,10 @@ std::shared_ptr TFederatedTopicClient::CreateReadSession( return Impl_->CreateReadSession(settings); } -// std::shared_ptr TFederatedTopicClient::CreateSimpleBlockingWriteSession( -// const TFederatedWriteSessionSettings& settings) { -// return Impl_->CreateSimpleBlockingWriteSession(settings); -// } +std::shared_ptr TFederatedTopicClient::CreateSimpleBlockingWriteSession( + const TFederatedWriteSessionSettings& settings) { + return Impl_->CreateSimpleBlockingWriteSession(settings); +} std::shared_ptr TFederatedTopicClient::CreateWriteSession(const TFederatedWriteSessionSettings& settings) { return Impl_->CreateWriteSession(settings); diff --git a/src/client/federated_topic/impl/federated_topic_impl.cpp b/src/client/federated_topic/impl/federated_topic_impl.cpp index 45fc189643c..4da1c013a83 100644 --- a/src/client/federated_topic/impl/federated_topic_impl.cpp +++ b/src/client/federated_topic/impl/federated_topic_impl.cpp @@ -13,14 +13,23 @@ TFederatedTopicClient::TImpl::CreateReadSession(const TFederatedReadSessionSetti return std::move(session); } -// std::shared_ptr -// TFederatedTopicClient::TImpl::CreateSimpleBlockingWriteSession(const TFederatedWriteSessionSettings& settings) { -// InitObserver(); -// auto session = std::make_shared(settings, Connections, ClientSettings, GetObserver()); -// session->Start(); -// return std::move(session); +std::shared_ptr +TFederatedTopicClient::TImpl::CreateSimpleBlockingWriteSession(const TFederatedWriteSessionSettings& settings) { + // Split settings.MaxMemoryUsage_ by two. + // One half goes to subsession. Other half goes to federated session internal buffer. + const ui64 splitSize = (settings.MaxMemoryUsage_ + 1) / 2; + TFederatedWriteSessionSettings splitSettings = settings; + splitSettings.MaxMemoryUsage(splitSize); + InitObserver(); -// } + with_lock(Lock) { + if (!splitSettings.EventHandlers_.HandlersExecutor_) { + splitSettings.EventHandlers_.HandlersExecutor(ClientSettings.DefaultHandlersExecutor_); + } + } + return std::make_shared( + splitSettings, Connections, ClientSettings, GetObserver(), ProvidedCodecs, GetSubsessionHandlersExecutor()); +} std::shared_ptr TFederatedTopicClient::TImpl::CreateWriteSession(const TFederatedWriteSessionSettings& settings) { diff --git a/src/client/federated_topic/impl/federated_write_session.cpp b/src/client/federated_topic/impl/federated_write_session.cpp index 8f559f6076a..a6063ea208e 100644 --- a/src/client/federated_topic/impl/federated_write_session.cpp +++ b/src/client/federated_topic/impl/federated_write_session.cpp @@ -1,6 +1,7 @@ #include "federated_write_session.h" #include +#include #include #define INCLUDE_YDB_INTERNAL_H @@ -467,4 +468,89 @@ bool TFederatedWriteSessionImpl::Close(TDuration timeout) { } } +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// TSimpleBlockingFederatedWriteSession + +TSimpleBlockingFederatedWriteSession::TSimpleBlockingFederatedWriteSession( + const TFederatedWriteSessionSettings& settings, + std::shared_ptr connections, + const TFederatedTopicClientSettings& clientSettings, + std::shared_ptr observer, + std::shared_ptr>> codecs, + IExecutor::TPtr subsessionHandlersExecutor +) { + TFederatedWriteSessionSettings subSettings = settings; + auto& log = connections->GetLog(); + if (settings.EventHandlers_.AcksHandler_) { + LOG_LAZY(log, TLOG_WARNING, "TSimpleBlockingFederatedWriteSession: Cannot use AcksHandler, resetting."); + subSettings.EventHandlers_.AcksHandler({}); + } + if (settings.EventHandlers_.ReadyToAcceptHandler_) { + LOG_LAZY(log, TLOG_WARNING, "TSimpleBlockingFederatedWriteSession: Cannot use ReadyToAcceptHandler, resetting."); + subSettings.EventHandlers_.ReadyToAcceptHandler({}); + } + if (settings.EventHandlers_.SessionClosedHandler_) { + LOG_LAZY(log, TLOG_WARNING, "TSimpleBlockingFederatedWriteSession: Cannot use SessionClosedHandler, resetting."); + subSettings.EventHandlers_.SessionClosedHandler({}); + } + if (settings.EventHandlers_.CommonHandler_) { + LOG_LAZY(log, TLOG_WARNING, "TSimpleBlockingFederatedWriteSession: Cannot use CommonHandler, resetting."); + subSettings.EventHandlers_.CommonHandler({}); + } + + Writer = std::make_shared( + subSettings, std::move(connections), clientSettings, std::move(observer), std::move(codecs), std::move(subsessionHandlersExecutor)); + Writer->Start(); +} + +uint64_t TSimpleBlockingFederatedWriteSession::GetInitSeqNo() { + return Writer->GetInitSeqNo().GetValueSync(); +} + +bool TSimpleBlockingFederatedWriteSession::Write( + std::string_view data, std::optional seqNo, std::optional createTimestamp, const TDuration& blockTimeout +) { + auto message = NTopic::TWriteMessage(std::move(data)) + .SeqNo(seqNo) + .CreateTimestamp(createTimestamp); + return Write(std::move(message), nullptr, blockTimeout); +} + +bool TSimpleBlockingFederatedWriteSession::Write( + NTopic::TWriteMessage&& message, TTransactionBase* tx, const TDuration& blockTimeout +) { + if (tx || message.GetTxPtr()) { + ythrow yexception() << "transactions are not supported"; + } + auto continuationToken = WaitForToken(blockTimeout); + if (continuationToken.has_value()) { + Writer->Write(std::move(*continuationToken), std::move(message)); + return true; + } + return false; +} + +std::optional TSimpleBlockingFederatedWriteSession::WaitForToken(const TDuration& timeout) { + return NTopic::NDetail::WaitForToken(*Writer, Closed, timeout); +} + +NTopic::TWriterCounters::TPtr TSimpleBlockingFederatedWriteSession::GetCounters() { + ythrow yexception() << "GetCounters is not yet implemented for federated write sessions"; +} + +bool TSimpleBlockingFederatedWriteSession::IsAlive() const { + return !Closed.load(); +} + +bool TSimpleBlockingFederatedWriteSession::Close(TDuration closeTimeout) { + Closed.store(true); + return Writer->Close(closeTimeout); +} + +TSimpleBlockingFederatedWriteSession::~TSimpleBlockingFederatedWriteSession() { + if (!Closed.load()) { + Close(TDuration::Zero()); + } +} + } // namespace NYdb::NFederatedTopic diff --git a/src/client/federated_topic/impl/federated_write_session.h b/src/client/federated_topic/impl/federated_write_session.h index 77376740e90..8e8e3e2dd38 100644 --- a/src/client/federated_topic/impl/federated_write_session.h +++ b/src/client/federated_topic/impl/federated_write_session.h @@ -150,6 +150,7 @@ class TFederatedWriteSessionImpl : public NTopic::TContinuationTokenIssuer, class TFederatedWriteSession : public NTopic::IWriteSession, public NTopic::TContextOwner { friend class TFederatedTopicClient::TImpl; + friend class TSimpleBlockingFederatedWriteSession; public: @@ -174,13 +175,13 @@ class TFederatedWriteSession : public NTopic::IWriteSession, return TryGetImpl()->GetInitSeqNo(); } void Write(NTopic::TContinuationToken&& continuationToken, NTopic::TWriteMessage&& message, TTransactionBase* tx = nullptr) override { - if (tx) { + if (tx || message.GetTxPtr()) { ythrow yexception() << "transactions are not supported"; } TryGetImpl()->Write(std::move(continuationToken), std::move(message)); } void WriteEncoded(NTopic::TContinuationToken&& continuationToken, NTopic::TWriteMessage&& params, TTransactionBase* tx = nullptr) override { - if (tx) { + if (tx || params.GetTxPtr()) { ythrow yexception() << "transactions are not supported"; } TryGetImpl()->WriteEncoded(std::move(continuationToken), std::move(params)); @@ -206,4 +207,40 @@ class TFederatedWriteSession : public NTopic::IWriteSession, } }; +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// TSimpleBlockingFederatedWriteSession + +class TSimpleBlockingFederatedWriteSession : public NTopic::ISimpleBlockingWriteSession { +public: + TSimpleBlockingFederatedWriteSession( + const TFederatedWriteSessionSettings& settings, + std::shared_ptr connections, + const TFederatedTopicClientSettings& clientSettings, + std::shared_ptr observer, + std::shared_ptr>> codecs, + IExecutor::TPtr subsessionHandlersExecutor); + + bool Write(std::string_view data, std::optional seqNo = std::nullopt, std::optional createTimestamp = std::nullopt, + const TDuration& blockTimeout = TDuration::Max()) override; + + bool Write(NTopic::TWriteMessage&& message, + TTransactionBase* tx = nullptr, + const TDuration& blockTimeout = TDuration::Max()) override; + + uint64_t GetInitSeqNo() override; + + bool Close(TDuration closeTimeout = TDuration::Max()) override; + + ~TSimpleBlockingFederatedWriteSession(); + bool IsAlive() const override; + + NTopic::TWriterCounters::TPtr GetCounters() override; + +private: + std::optional WaitForToken(const TDuration& timeout); + + std::shared_ptr Writer; + std::atomic_bool Closed = false; +}; + } // namespace NYdb::NFederatedTopic diff --git a/src/client/federated_topic/ut/fds_mock/fds_mock.h b/src/client/federated_topic/ut/fds_mock/fds_mock.h index 9f9df0ceb2b..dc0e787f1bd 100644 --- a/src/client/federated_topic/ut/fds_mock/fds_mock.h +++ b/src/client/federated_topic/ut/fds_mock/fds_mock.h @@ -57,11 +57,27 @@ class TFederationDiscoveryServiceMock: public Ydb::FederationDiscovery::V1::Fede } while (true); } + void SetAutoRespondSingleDatabase(bool enable) { + with_lock(Lock) { + AutoRespondSingleDatabase = enable; + } + } + virtual grpc::Status ListFederationDatabases(grpc::ServerContext*, const TRequest* request, TResponse* response) override { Y_UNUSED(request); + // Check if auto-respond mode is enabled + with_lock(Lock) { + if (AutoRespondSingleDatabase) { + auto result = ComposeOkResultSingleDatabase(); + Cerr << ">>> Auto-responding with single database" << Endl; + *response = std::move(result.Response); + return result.Status; + } + } + auto p = NThreading::NewPromise(); auto f = p.GetFuture(); @@ -125,6 +141,32 @@ class TFederationDiscoveryServiceMock: public Ydb::FederationDiscovery::V1::Fede return ComposeOkResult(::Ydb::FederationDiscovery::DatabaseInfo::Status::DatabaseInfo_Status_AVAILABLE); } + // Returns a single database response to avoid partition competition between multiple sub-sessions + TGrpcResult ComposeOkResultSingleDatabase() { + Ydb::FederationDiscovery::ListFederationDatabasesResponse okResponse; + + auto op = okResponse.mutable_operation(); + op->set_status(Ydb::StatusIds::SUCCESS); + okResponse.mutable_operation()->set_ready(true); + okResponse.mutable_operation()->set_id("12345"); + + Ydb::FederationDiscovery::ListFederationDatabasesResult mockResult; + mockResult.set_control_plane_endpoint("cp.logbroker-federation:2135"); + mockResult.set_self_location("dc1"); + auto c1 = mockResult.add_federation_databases(); + c1->set_name("dc1"); + c1->set_path("/Root"); + c1->set_id("account-dc1"); + c1->set_endpoint("localhost:" + ToString(Port)); + c1->set_location("dc1"); + c1->set_status(::Ydb::FederationDiscovery::DatabaseInfo::Status::DatabaseInfo_Status_AVAILABLE); + c1->set_weight(1000); + + op->mutable_result()->PackFrom(mockResult); + + return {okResponse, grpc::Status::OK}; + } + TGrpcResult ComposeOkResultUnavailableDatabases() { return ComposeOkResult(::Ydb::FederationDiscovery::DatabaseInfo::Status::DatabaseInfo_Status_UNAVAILABLE); } @@ -173,6 +215,7 @@ class TFederationDiscoveryServiceMock: public Ydb::FederationDiscovery::V1::Fede ui16 Port; std::deque PendingRequests; TAdaptiveLock Lock; + bool AutoRespondSingleDatabase = false; }; } // namespace NYdb::NFederatedTopic::NTests diff --git a/src/client/federated_topic/ut/simple_blocking_write_session_ut.cpp b/src/client/federated_topic/ut/simple_blocking_write_session_ut.cpp new file mode 100644 index 00000000000..51260f74b12 --- /dev/null +++ b/src/client/federated_topic/ut/simple_blocking_write_session_ut.cpp @@ -0,0 +1,402 @@ +#include +#include + +#include +#include + +#include +#include + +#include + +namespace NYdb::NFederatedTopic::NTests { + +// Test fixture providing common setup for federated topic tests +class TSimpleBlockingWriteSessionTestFixture { +public: + explicit TSimpleBlockingWriteSessionTestFixture(const TString& testName) + : Setup(std::make_shared( + testName, false, ::NPersQueue::TTestServer::LOGGED_SERVICES, NActors::NLog::PRI_DEBUG, 1)) + , ThreadPool(CreateThreadPool(2)) + { + Setup->Start(true, true); + + FdsMock.Port = Setup->GetGrpcPort(); + ServicePort = Setup->GetPortManager()->GetPort(4285); + GrpcServer = Setup->StartGrpcService(ServicePort, &FdsMock); + + NYdb::TDriverConfig cfg; + cfg.SetEndpoint(TStringBuilder() << "localhost:" << ServicePort); + cfg.SetDatabase("/Root"); + cfg.SetLog(std::unique_ptr(CreateLogBackend("cerr", ELogPriority::TLOG_DEBUG).Release())); + + Driver = std::make_unique(cfg); + TopicClient = std::make_unique(*Driver); + } + + // Creates a write session and responds to the FDS request with available databases + // The session creation triggers FDS discovery, so we start creation async and then respond + std::shared_ptr CreateWriteSessionWithFdsResponse( + const TFederatedWriteSessionSettings& settings + ) { + // Start session creation asynchronously + auto sessionFuture = NThreading::Async([this, settings]() { + return TopicClient->CreateSimpleBlockingWriteSession(settings); + }, *ThreadPool); + + // Wait for FDS request and respond + auto fdsRequest = FdsMock.WaitNextPendingRequest(); + fdsRequest.Result.SetValue(FdsMock.ComposeOkResultAvailableDatabases()); + + // Wait for session creation to complete + return sessionFuture.GetValueSync(); + } + + std::shared_ptr CreateWriteSessionWithFdsResponse() { + return CreateWriteSessionWithFdsResponse(DefaultWriteSettings()); + } + + // Creates a write session and responds with one unavailable database + std::shared_ptr CreateWriteSessionWithUnavailableDatabaseResponse( + const TFederatedWriteSessionSettings& settings, + int unavailableDbIndex + ) { + auto sessionFuture = NThreading::Async([this, settings]() { + return TopicClient->CreateSimpleBlockingWriteSession(settings); + }, *ThreadPool); + + auto fdsRequest = FdsMock.WaitNextPendingRequest(); + fdsRequest.Result.SetValue(FdsMock.ComposeOkResultWithUnavailableDatabase(unavailableDbIndex)); + + return sessionFuture.GetValueSync(); + } + + TFederatedWriteSessionSettings DefaultWriteSettings() { + return TFederatedWriteSessionSettings() + .Path(Setup->GetTestTopicPath()) + .MessageGroupId("src_id"); + } + + std::shared_ptr CreateReadSession() { + TFederatedReadSessionSettings settings; + settings + .ConsumerName("shared/user") + .MaxMemoryUsageBytes(1_MB) + .AppendTopics(std::string(Setup->GetTestTopicPath())); + return TopicClient->CreateReadSession(settings); + } + + // Creates a read session and responds to FDS + std::shared_ptr CreateReadSessionWithFdsResponse() { + auto session = CreateReadSession(); + + auto fdsRequest = FdsMock.WaitNextPendingRequest(); + fdsRequest.Result.SetValue(FdsMock.ComposeOkResultAvailableDatabases()); + + return session; + } + + // Creates a write session with single-database FDS response (avoids partition competition) + std::shared_ptr CreateWriteSessionWithSingleDatabaseFdsResponse( + const TFederatedWriteSessionSettings& settings + ) { + auto sessionFuture = NThreading::Async([this, settings]() { + return TopicClient->CreateSimpleBlockingWriteSession(settings); + }, *ThreadPool); + + auto fdsRequest = FdsMock.WaitNextPendingRequest(); + fdsRequest.Result.SetValue(FdsMock.ComposeOkResultSingleDatabase()); + + return sessionFuture.GetValueSync(); + } + + std::shared_ptr CreateWriteSessionWithSingleDatabaseFdsResponse() { + return CreateWriteSessionWithSingleDatabaseFdsResponse(DefaultWriteSettings()); + } + + // Creates a read session with single-database FDS response (avoids partition competition) + std::shared_ptr CreateReadSessionWithSingleDatabaseFdsResponse() { + auto session = CreateReadSession(); + + auto fdsRequest = FdsMock.WaitNextPendingRequest(); + fdsRequest.Result.SetValue(FdsMock.ComposeOkResultSingleDatabase()); + + return session; + } + + // Creates a read session with unavailable database response + std::shared_ptr CreateReadSessionWithUnavailableDatabaseResponse(int unavailableDbIndex) { + auto session = CreateReadSession(); + + auto fdsRequest = FdsMock.WaitNextPendingRequest(); + fdsRequest.Result.SetValue(FdsMock.ComposeOkResultWithUnavailableDatabase(unavailableDbIndex)); + + return session; + } + + void ConfirmPartitionStart(std::shared_ptr& session, TDuration timeout = TDuration::Seconds(60)) { + TInstant deadline = TInstant::Now() + timeout; + while (TInstant::Now() < deadline) { + session->WaitEvent().Wait(TDuration::MilliSeconds(100)); + auto event = session->GetEvent(false); + if (!event.has_value()) { + continue; + } + if (auto* startEvent = std::get_if(&*event)) { + startEvent->Confirm(); + return; + } + } + UNIT_FAIL("Timeout waiting for TStartPartitionSessionEvent"); + } + + std::optional ReadOneMessage( + std::shared_ptr& session, + TDuration timeout = TDuration::Seconds(60) + ) { + TInstant deadline = TInstant::Now() + timeout; + while (TInstant::Now() < deadline) { + session->WaitEvent().Wait(TDuration::MilliSeconds(100)); + auto events = session->GetEvents(false, 1); + for (auto& event : events) { + if (auto* dataEvent = std::get_if(&event)) { + auto& messages = dataEvent->GetMessages(); + if (!messages.empty()) { + auto msg = messages[0]; + dataEvent->Commit(); + return msg; + } + } + } + } + return std::nullopt; + } + + TString GetTestTopic() const { return Setup->GetTestTopic(); } + TFederationDiscoveryServiceMock& GetFdsMock() { return FdsMock; } + + // Helper to verify that messages were written correctly by reading them back + void VerifyMessagesWritten(const TVector& expectedMessages) { + // Create read session with single database FDS response + auto readSession = CreateReadSessionWithSingleDatabaseFdsResponse(); + ConfirmPartitionStart(readSession); + + // Enable auto-respond for any subsequent FDS discovery refresh requests + FdsMock.SetAutoRespondSingleDatabase(true); + + // Read and verify all messages + for (size_t i = 0; i < expectedMessages.size(); ++i) { + auto msg = ReadOneMessage(readSession); + UNIT_ASSERT_C(msg.has_value(), "Did not receive message " << i); + UNIT_ASSERT_VALUES_EQUAL_C(msg->GetData(), expectedMessages[i], + "Message " << i << " content mismatch"); + } + + readSession->Close(); + } + + TFederatedTopicClient* GetTopicClient() { return TopicClient.get(); } + IThreadPool* GetThreadPool() { return ThreadPool.Get(); } + +private: + std::shared_ptr Setup; + TFederationDiscoveryServiceMock FdsMock; + ui16 ServicePort; + std::unique_ptr GrpcServer; + std::unique_ptr Driver; + std::unique_ptr TopicClient; + THolder ThreadPool; +}; + +Y_UNIT_TEST_SUITE(SimpleBlockingFederatedWriteSession) { + + Y_UNIT_TEST(BasicWriteAndClose) { + TSimpleBlockingWriteSessionTestFixture fixture(TEST_CASE_NAME); + + auto session = fixture.CreateWriteSessionWithSingleDatabaseFdsResponse(); + UNIT_ASSERT(session); + UNIT_ASSERT(session->IsAlive()); + + TString message = "test message"; + UNIT_ASSERT(session->Write(message)); + session->Close(); + UNIT_ASSERT(!session->IsAlive()); + + // Verify the message was written correctly + fixture.VerifyMessagesWritten({message}); + } + + Y_UNIT_TEST(WriteMultipleMessages) { + TSimpleBlockingWriteSessionTestFixture fixture(TEST_CASE_NAME); + + auto session = fixture.CreateWriteSessionWithSingleDatabaseFdsResponse(); + UNIT_ASSERT(session); + + // Write and verify a single message (multi-message read verification needs more infrastructure) + TString message = "message-0"; + UNIT_ASSERT_C(session->Write(message), "Failed to write message"); + + session->Close(); + + // Verify message was written correctly + fixture.VerifyMessagesWritten({message}); + } + + Y_UNIT_TEST(WriteWithSeqNoAndTimestamp) { + TSimpleBlockingWriteSessionTestFixture fixture(TEST_CASE_NAME); + + auto session = fixture.CreateWriteSessionWithSingleDatabaseFdsResponse(); + UNIT_ASSERT(session); + + TString message = "message with explicit params"; + UNIT_ASSERT(session->Write(message, 1, TInstant::Now())); + session->Close(); + + // Verify the message was written correctly + fixture.VerifyMessagesWritten({message}); + } + + Y_UNIT_TEST(WriteAndReadBack) { + TSimpleBlockingWriteSessionTestFixture fixture(TEST_CASE_NAME); + + // Create write session first (which responds to first FDS with single database) + auto writeSession = fixture.CreateWriteSessionWithSingleDatabaseFdsResponse(); + UNIT_ASSERT(writeSession); + + // Create read session (which triggers second FDS request with single database) + auto readSession = fixture.CreateReadSessionWithSingleDatabaseFdsResponse(); + fixture.ConfirmPartitionStart(readSession); + + // Enable auto-respond for any subsequent FDS discovery refresh requests + fixture.GetFdsMock().SetAutoRespondSingleDatabase(true); + + // Write single message for now to verify flow works + UNIT_ASSERT(writeSession->Write("test-message")); + + auto msg = fixture.ReadOneMessage(readSession); + UNIT_ASSERT_C(msg.has_value(), "Did not receive message"); + UNIT_ASSERT_VALUES_EQUAL(msg->GetData(), "test-message"); + + writeSession->Close(); + readSession->Close(); + } + + Y_UNIT_TEST(CloseEmptySession) { + TSimpleBlockingWriteSessionTestFixture fixture(TEST_CASE_NAME); + + auto session = fixture.CreateWriteSessionWithFdsResponse(); + UNIT_ASSERT(session); + session->Close(); + } + + Y_UNIT_TEST(WriteWithPreferredDatabase) { + TSimpleBlockingWriteSessionTestFixture fixture(TEST_CASE_NAME); + + // Use single-database FDS to avoid multi-database complications + auto session = fixture.CreateWriteSessionWithSingleDatabaseFdsResponse(); + UNIT_ASSERT(session); + + TString message = "message to preferred database"; + UNIT_ASSERT(session->Write(message)); + session->Close(); + + // Verify the message was written correctly + fixture.VerifyMessagesWritten({message}); + } + + Y_UNIT_TEST(WriteWithPreferredDatabaseUnavailableAndFallback) { + // TODO: Enable after fixing test environment to support multiple federated databases/topics + /* + TSimpleBlockingWriteSessionTestFixture fixture(TEST_CASE_NAME); + + auto settings = fixture.DefaultWriteSettings() + .PreferredDatabase("dc1") + .AllowFallback(true); + + // dc1 is unavailable, but fallback is allowed + auto writeSession = fixture.CreateWriteSessionWithUnavailableDatabaseResponse(settings, 1); + UNIT_ASSERT(writeSession); + + // Create read session to verify where message was written + auto readSession = fixture.CreateReadSessionWithUnavailableDatabaseResponse(1); + fixture.ConfirmPartitionStart(readSession); + + TString message = "message with fallback"; + UNIT_ASSERT(writeSession->Write(message)); + + // Verify message was written to fallback database (not dc1) + auto msg = fixture.ReadOneMessage(readSession); + UNIT_ASSERT(msg.has_value()); + UNIT_ASSERT_VALUES_EQUAL(msg->GetData(), message); + + auto dbName = msg->GetFederatedPartitionSession()->GetDatabaseName(); + UNIT_ASSERT_C(dbName != "dc1", "Message should not be written to unavailable dc1"); + UNIT_ASSERT_C(dbName == "dc2" || dbName == "dc3", "Expected dc2 or dc3, got: " << dbName); + + writeSession->Close(); + readSession->Close(); + */ + } + + Y_UNIT_TEST(WriteLargeMessages) { + TSimpleBlockingWriteSessionTestFixture fixture(TEST_CASE_NAME); + + auto session = fixture.CreateWriteSessionWithSingleDatabaseFdsResponse(); + UNIT_ASSERT(session); + + // Write single large message for verification + TString largeMessage(10 * 1024, 'x'); // 10KB + UNIT_ASSERT_C(session->Write(largeMessage), "Failed to write large message"); + + session->Close(); + + // Verify the message was written correctly + fixture.VerifyMessagesWritten({largeMessage}); + } + + Y_UNIT_TEST(WriteWithTWriteMessage) { + TSimpleBlockingWriteSessionTestFixture fixture(TEST_CASE_NAME); + + auto session = fixture.CreateWriteSessionWithSingleDatabaseFdsResponse(); + UNIT_ASSERT(session); + + TString message = "message via TWriteMessage"; + NTopic::TWriteMessage writeMessage(message); + writeMessage.CreateTimestamp(TInstant::Now()); + UNIT_ASSERT(session->Write(std::move(writeMessage))); + + session->Close(); + + // Verify the message was written correctly + fixture.VerifyMessagesWritten({message}); + } + + Y_UNIT_TEST(IsAliveAfterClose) { + TSimpleBlockingWriteSessionTestFixture fixture(TEST_CASE_NAME); + + auto session = fixture.CreateWriteSessionWithSingleDatabaseFdsResponse(); + UNIT_ASSERT(session); + + // Session should be alive after creation + UNIT_ASSERT(session->IsAlive()); + + // Write a message to verify session is working + UNIT_ASSERT(session->Write("test message")); + + // Session still alive after write + UNIT_ASSERT(session->IsAlive()); + + // Close the session + session->Close(); + + // After close, session should no longer be alive + UNIT_ASSERT(!session->IsAlive()); + + // Write should fail after session is closed + UNIT_ASSERT_C(!session->Write("after close", std::nullopt, std::nullopt, TDuration::MilliSeconds(100)), + "Write should fail after session is closed"); + } + +} + +} // namespace NYdb::NFederatedTopic::NTests diff --git a/src/client/impl/internal/grpc_connections/grpc_connections.cpp b/src/client/impl/internal/grpc_connections/grpc_connections.cpp index 48e170d28c6..7a7ac927d79 100644 --- a/src/client/impl/internal/grpc_connections/grpc_connections.cpp +++ b/src/client/impl/internal/grpc_connections/grpc_connections.cpp @@ -163,9 +163,10 @@ TGRpcConnectionsImpl::TGRpcConnectionsImpl(std::shared_ptr p , MaxMessageSize_(params->GetMaxMessageSize()) , QueuedRequests_(0) , TcpKeepAliveSettings_(params->GetTcpKeepAliveSettings()) + , TcpNoDelay_(params->GetTcpNoDelay()) , SocketIdleTimeout_(TDeadline::SafeDurationCast(params->GetSocketIdleTimeout())) #ifndef YDB_GRPC_BYPASS_CHANNEL_POOL - , ChannelPool_(TcpKeepAliveSettings_, params->GetSocketIdleTimeout()) + , ChannelPool_(TcpKeepAliveSettings_, params->GetSocketIdleTimeout(), TcpNoDelay_) #endif , NetworkThreadsNum_(params->GetNetworkThreadsNum()) , UsePerChannelTcpConnection_(params->GetUsePerChannelTcpConnection()) diff --git a/src/client/impl/internal/grpc_connections/grpc_connections.h b/src/client/impl/internal/grpc_connections/grpc_connections.h index 756d2f0d957..0467b35b9fe 100644 --- a/src/client/impl/internal/grpc_connections/grpc_connections.h +++ b/src/client/impl/internal/grpc_connections/grpc_connections.h @@ -705,6 +705,7 @@ class TGRpcConnectionsImpl std::atomic_int64_t QueuedRequests_; const NYdbGrpc::TTcpKeepAliveSettings TcpKeepAliveSettings_; + const bool TcpNoDelay_; const TDeadline::Duration SocketIdleTimeout_; #ifndef YDB_GRPC_BYPASS_CHANNEL_POOL NYdbGrpc::TChannelPool ChannelPool_; diff --git a/src/client/impl/internal/grpc_connections/params.h b/src/client/impl/internal/grpc_connections/params.h index 2bc9f4567c5..e1378768872 100644 --- a/src/client/impl/internal/grpc_connections/params.h +++ b/src/client/impl/internal/grpc_connections/params.h @@ -25,6 +25,7 @@ class IConnectionsParams { virtual EDiscoveryMode GetDiscoveryMode() const = 0; virtual size_t GetMaxQueuedRequests() const = 0; virtual NYdbGrpc::TTcpKeepAliveSettings GetTcpKeepAliveSettings() const = 0; + virtual bool GetTcpNoDelay() const = 0; virtual bool GetDrinOnDtors() const = 0; virtual TBalancingPolicy::TImpl GetBalancingSettings() const = 0; virtual TDuration GetGRpcKeepAliveTimeout() const = 0; diff --git a/src/client/persqueue_public/impl/write_session_impl.cpp b/src/client/persqueue_public/impl/write_session_impl.cpp index e090309c60b..44915db0586 100644 --- a/src/client/persqueue_public/impl/write_session_impl.cpp +++ b/src/client/persqueue_public/impl/write_session_impl.cpp @@ -661,7 +661,7 @@ void TWriteSessionImpl::OnReadDone(NYdbGrpc::TGrpcStatus&& grpcStatus, size_t co } else { processResult = ProcessServerMessageImpl(); needSetValue = !InitSeqNoSetDone && processResult.InitSeqNo.has_value() && (InitSeqNoSetDone = true); - if (errorStatus.Ok() && processResult.Ok) { + if (errorStatus.Ok() && processResult.Ok && !processResult.HandleResult.DoRestart) { doRead = true; } } @@ -672,10 +672,8 @@ void TWriteSessionImpl::OnReadDone(NYdbGrpc::TGrpcStatus&& grpcStatus, size_t co { std::lock_guard guard(Lock); - if (!errorStatus.Ok()) { - if (processResult.Ok) { // Otherwise, OnError was already called - processResult.HandleResult = RestartImpl(errorStatus); - } + if ((!errorStatus.Ok() && processResult.Ok) || processResult.HandleResult.DoRestart) { // Otherwise, OnError was already called + processResult.HandleResult = RestartImpl(errorStatus); } if (processResult.HandleResult.DoStop) { CloseImpl(std::move(errorStatus)); @@ -785,9 +783,13 @@ TWriteSessionImpl::TProcessSrvMessageResult TWriteSessionImpl::ProcessServerMess writeStat, }); - if (CleanupOnAcknowledged(GetIdImpl(sequenceNumber))) { + if (CleanupOnAcknowledged(GetIdImpl(sequenceNumber), result)) { result.Events.emplace_back(TWriteSessionEvent::TReadyToAcceptEvent{IssueContinuationToken()}); } + + if (result.HandleResult.DoRestart) { + return result; + } } //EventsQueue->PushEvent(std::move(acksEvent)); result.Events.emplace_back(std::move(acksEvent)); @@ -803,16 +805,17 @@ TWriteSessionImpl::TProcessSrvMessageResult TWriteSessionImpl::ProcessServerMess return result; } -bool TWriteSessionImpl::CleanupOnAcknowledged(ui64 id) { +bool TWriteSessionImpl::CleanupOnAcknowledged(ui64 id, TProcessSrvMessageResult& processResult) { bool result = false; LOG_LAZY(DbDriverState->Log, TLOG_DEBUG, LogPrefix() << "Write session: acknoledged message " << id); UpdateTimedCountersImpl(); if (SentOriginalMessages.empty() || SentOriginalMessages.front().Id != id){ - std::cerr << "State before restart was:\n" << StateStr << "\n\n"; + LOG_LAZY(DbDriverState->Log, TLOG_ERR, LogPrefix() << "State before restart was: " << StateStr); DumpState(); - std::cerr << "State on ack with id " << id << " is:\n"; - std::cerr << StateStr << "\n\n"; - Y_ABORT("got unknown ack"); + LOG_LAZY(DbDriverState->Log, TLOG_ERR, LogPrefix() << "State on ack with id " << id << " is:\n" << StateStr); + LOG_LAZY(DbDriverState->Log, TLOG_ERR, LogPrefix() << "Write session: got unknown ack " << id); + processResult.HandleResult.DoRestart = true; + return false; } const auto& sentFront = SentOriginalMessages.front(); diff --git a/src/client/persqueue_public/impl/write_session_impl.h b/src/client/persqueue_public/impl/write_session_impl.h index 483570c944c..86be4fc8d3a 100644 --- a/src/client/persqueue_public/impl/write_session_impl.h +++ b/src/client/persqueue_public/impl/write_session_impl.h @@ -356,7 +356,7 @@ class TWriteSessionImpl : public TContinuationTokenIssuer, //std::string GetDebugIdentity() const; Ydb::PersQueue::V1::StreamingWriteClientMessage GetInitClientMessage(); - bool CleanupOnAcknowledged(ui64 id); + bool CleanupOnAcknowledged(ui64 id, TProcessSrvMessageResult& processResult); bool IsReadyToSendNextImpl(); void DumpState(); ui64 GetNextIdImpl(const std::optional& seqNo); diff --git a/src/client/topic/common/simple_blocking_helpers.h b/src/client/topic/common/simple_blocking_helpers.h new file mode 100644 index 00000000000..ddddb94056a --- /dev/null +++ b/src/client/topic/common/simple_blocking_helpers.h @@ -0,0 +1,48 @@ +#pragma once + +#include + +#include +#include + +namespace NYdb::inline V3::NTopic::NDetail { + +// Common helper for blocking write sessions to wait for a continuation token. +// Used by both TSimpleBlockingWriteSession and TSimpleBlockingFederatedWriteSession. +template +std::optional WaitForToken( + TWriter& writer, + std::atomic_bool& closed, + const TDuration& timeout +) { + TInstant startTime = TInstant::Now(); + TDuration remainingTime = timeout; + + std::optional token; + + while (!closed.load() && remainingTime > TDuration::Zero()) { + writer.WaitEvent().Wait(remainingTime); + + for (auto event : writer.GetEvents(false, std::nullopt)) { + if (auto* readyEvent = std::get_if(&event)) { + Y_ABORT_UNLESS(!token.has_value()); + token = std::move(readyEvent->ContinuationToken); + } else if (std::get_if(&event)) { + // discard (maybe log?) + } else if (std::get_if(&event)) { + closed.store(true); + return std::nullopt; + } + } + + if (token.has_value()) { + return token; + } + + remainingTime = timeout - (TInstant::Now() - startTime); + } + + return std::nullopt; +} + +} // namespace NYdb::NTopic::NDetail diff --git a/src/client/topic/impl/CMakeLists.txt b/src/client/topic/impl/CMakeLists.txt index 6f3bc99989b..96fc3b6ba8b 100644 --- a/src/client/topic/impl/CMakeLists.txt +++ b/src/client/topic/impl/CMakeLists.txt @@ -9,6 +9,7 @@ target_link_libraries(client-ydb_topic-impl PUBLIC persqueue-obfuscate api-grpc-draft api-grpc + threading-future-subscription impl-internal-make_request client-ydb_common_client-impl client-ydb_driver diff --git a/src/client/topic/impl/direct_reader.cpp b/src/client/topic/impl/direct_reader.cpp index 3c85ec02025..4be57ef89c9 100644 --- a/src/client/topic/impl/direct_reader.cpp +++ b/src/client/topic/impl/direct_reader.cpp @@ -517,23 +517,23 @@ void TDirectReadSession::OnReadDone(NYdbGrpc::TGrpcStatus&& grpcStatus, size_t c if (!IsErrorMessage(*ServerMessage)) { if (ServerMessage->server_message_case() != TDirectReadServerMessage::kDirectReadResponse) { - LOG_LAZY(Log, TLOG_DEBUG, GetLogPrefix() << "XXXXX subsession got message = " << ServerMessage->ShortDebugString()); + LOG_LAZY(Log, TLOG_DEBUG, GetLogPrefix() << "subsession got message = " << ServerMessage->ShortDebugString()); } else { const auto& data = ServerMessage->direct_read_response().partition_data(); const auto partitionSessionId = ServerMessage->direct_read_response().partition_session_id(); auto partitionSessionIt = PartitionSessions.find(partitionSessionId); if (partitionSessionIt == PartitionSessions.end()) { - LOG_LAZY(Log, TLOG_DEBUG, GetLogPrefix() << "XXXXX subsession got message = DirectReadResponse partitionSessionId=" << partitionSessionId << " not found"); + LOG_LAZY(Log, TLOG_DEBUG, GetLogPrefix() << "subsession got message = DirectReadResponse partitionSessionId=" << partitionSessionId << " not found"); } if (data.batches_size() == 0) { - LOG_LAZY(Log, TLOG_DEBUG, GetLogPrefix() << "XXXXX subsession got message = DirectReadResponse EMPTY"); + LOG_LAZY(Log, TLOG_DEBUG, GetLogPrefix() << "subsession got message = DirectReadResponse EMPTY"); } else { const auto& firstBatch = data.batches(0); const auto firstOffset = firstBatch.message_data(0).offset(); const auto& lastBatch = data.batches(data.batches_size() - 1); const auto lastOffset = lastBatch.message_data(lastBatch.message_data_size() - 1).offset(); auto partitionId = partitionSessionIt == PartitionSessions.end() ? -1 : partitionSessionIt->second.PartitionId; - LOG_LAZY(Log, TLOG_DEBUG, GetLogPrefix() << "XXXXX subsession got message = DirectReadResponse" + LOG_LAZY(Log, TLOG_DEBUG, GetLogPrefix() << "subsession got message = DirectReadResponse" << " partitionSessionId = " << partitionSessionId << " partitionId = " << partitionId << " directReadId = " << ServerMessage->direct_read_response().direct_read_id() @@ -788,7 +788,7 @@ void TDirectReadSession::WriteToProcessorImpl(TDirectReadClientMessage&& req) { Y_ABORT_UNLESS(Lock.IsLocked()); if (Processor) { - LOG_LAZY(Log, TLOG_DEBUG, GetLogPrefix() << "XXXXX subsession send message = " << req.ShortDebugString()); + LOG_LAZY(Log, TLOG_DEBUG, GetLogPrefix() << "subsession send message = " << req.ShortDebugString()); Processor->Write(std::move(req)); } } diff --git a/src/client/topic/impl/read_session_impl.ipp b/src/client/topic/impl/read_session_impl.ipp index f9a4a67d873..e49e76b4f03 100644 --- a/src/client/topic/impl/read_session_impl.ipp +++ b/src/client/topic/impl/read_session_impl.ipp @@ -530,6 +530,10 @@ inline void TSingleClusterReadSessionImpl::InitImpl(TDeferredActionsset_path(TStringType{topic.Path_}); diff --git a/src/client/topic/impl/topic.cpp b/src/client/topic/impl/topic.cpp index 610a4b921a6..cf8c77e0aac 100644 --- a/src/client/topic/impl/topic.cpp +++ b/src/client/topic/impl/topic.cpp @@ -591,6 +591,16 @@ std::shared_ptr TTopicClient::CreateSimpleBlockingW return Impl_->CreateSimpleWriteSession(settings); } +std::shared_ptr TTopicClient::CreateSimpleBlockingKeyedWriteSession( + const TKeyedWriteSessionSettings& settings) { + return Impl_->CreateSimpleKeyedWriteSession(settings); +} + +std::shared_ptr TTopicClient::CreateKeyedWriteSession( + const TKeyedWriteSessionSettings& settings) { + return Impl_->CreateKeyedWriteSession(settings); +} + std::shared_ptr TTopicClient::CreateWriteSession(const TWriteSessionSettings& settings) { return Impl_->CreateWriteSession(settings); } diff --git a/src/client/topic/impl/topic_impl.cpp b/src/client/topic/impl/topic_impl.cpp index 6fe5e604458..acfe6bde587 100644 --- a/src/client/topic/impl/topic_impl.cpp +++ b/src/client/topic/impl/topic_impl.cpp @@ -61,6 +61,43 @@ std::shared_ptr TTopicClient::TImpl::CreateSimpleWr return std::move(session); } +std::shared_ptr TTopicClient::TImpl::CreateSimpleKeyedWriteSession(const TKeyedWriteSessionSettings& settings) { + auto alteredSettings = settings; + { + std::lock_guard guard(Lock); + if (!settings.CompressionExecutor_) { + alteredSettings.CompressionExecutor(Settings.DefaultCompressionExecutor_); + } + + if (!settings.EventHandlers_.HandlersExecutor_) { + alteredSettings.EventHandlers_.HandlersExecutor(Settings.DefaultHandlersExecutor_); + } + } + + auto session = std::make_shared( + alteredSettings, shared_from_this(), Connections_, DbDriverState_ + ); + return session; +} + +std::shared_ptr TTopicClient::TImpl::CreateKeyedWriteSession(const TKeyedWriteSessionSettings& settings) { + auto alteredSettings = settings; + { + std::lock_guard guard(Lock); + if (!settings.CompressionExecutor_) { + alteredSettings.CompressionExecutor(Settings.DefaultCompressionExecutor_); + } + + if (!settings.EventHandlers_.HandlersExecutor_) { + alteredSettings.EventHandlers_.HandlersExecutor(Settings.DefaultHandlersExecutor_); + } + } + + return std::make_shared( + alteredSettings, shared_from_this(), Connections_, DbDriverState_ + ); +} + std::shared_ptr TTopicClient::TImpl::CreateReadSessionConnectionProcessorFactory() { using TService = Ydb::Topic::V1::TopicService; using TRequest = Ydb::Topic::StreamReadMessage::FromClient; diff --git a/src/client/topic/impl/topic_impl.h b/src/client/topic/impl/topic_impl.h index 1983d4734d6..669ae12c39f 100644 --- a/src/client/topic/impl/topic_impl.h +++ b/src/client/topic/impl/topic_impl.h @@ -316,6 +316,8 @@ class TTopicClient::TImpl : public TClientImplCommon { // Runtime API. std::shared_ptr CreateReadSession(const TReadSessionSettings& settings); std::shared_ptr CreateSimpleWriteSession(const TWriteSessionSettings& settings); + std::shared_ptr CreateSimpleKeyedWriteSession(const TKeyedWriteSessionSettings& settings); + std::shared_ptr CreateKeyedWriteSession(const TKeyedWriteSessionSettings& settings); std::shared_ptr CreateWriteSession(const TWriteSessionSettings& settings); using IReadSessionConnectionProcessorFactory = diff --git a/src/client/topic/impl/write_session.cpp b/src/client/topic/impl/write_session.cpp index 3b51c48e723..dd93430b12d 100644 --- a/src/client/topic/impl/write_session.cpp +++ b/src/client/topic/impl/write_session.cpp @@ -1,6 +1,14 @@ #include "write_session.h" +#include #include +#include +#include + +#include +#include +#include +#include namespace NYdb::inline V3::NTopic { @@ -8,11 +16,12 @@ namespace NYdb::inline V3::NTopic { // TWriteSession TWriteSession::TWriteSession( - const TWriteSessionSettings& settings, - std::shared_ptr client, - std::shared_ptr connections, - TDbDriverStatePtr dbDriverState) - : TContextOwner(settings, std::move(client), std::move(connections), std::move(dbDriverState)) { + const TWriteSessionSettings& settings, + std::shared_ptr client, + std::shared_ptr connections, + TDbDriverStatePtr dbDriverState) + : TContextOwner(settings, std::move(client), std::move(connections), std::move(dbDriverState)) +{ } void TWriteSession::Start(const TDuration& delay) { @@ -36,17 +45,19 @@ NThreading::TFuture TWriteSession::WaitEvent() { } void TWriteSession::WriteEncoded(TContinuationToken&& token, std::string_view data, ECodec codec, ui32 originalSize, - std::optional seqNo, std::optional createTimestamp) { + std::optional seqNo, std::optional createTimestamp) { auto message = TWriteMessage::CompressedMessage(std::move(data), codec, originalSize); - if (seqNo.has_value()) + if (seqNo.has_value()) { message.SeqNo(*seqNo); - if (createTimestamp.has_value()) + } + if (createTimestamp.has_value()) { message.CreateTimestamp(*createTimestamp); + } TryGetImpl()->WriteInternal(std::move(token), std::move(message)); } void TWriteSession::WriteEncoded(TContinuationToken&& token, TWriteMessage&& message, - TTransactionBase* tx) + TTransactionBase* tx) { if (tx) { message.Tx(*tx); @@ -55,17 +66,19 @@ void TWriteSession::WriteEncoded(TContinuationToken&& token, TWriteMessage&& mes } void TWriteSession::Write(TContinuationToken&& token, std::string_view data, std::optional seqNo, - std::optional createTimestamp) { + std::optional createTimestamp) { TWriteMessage message{std::move(data)}; - if (seqNo.has_value()) + if (seqNo.has_value()) { message.SeqNo(*seqNo); - if (createTimestamp.has_value()) + } + if (createTimestamp.has_value()) { message.CreateTimestamp(*createTimestamp); + } TryGetImpl()->WriteInternal(std::move(token), std::move(message)); } void TWriteSession::Write(TContinuationToken&& token, TWriteMessage&& message, - TTransactionBase* tx) { + TTransactionBase* tx) { if (tx) { message.Tx(*tx); } @@ -81,101 +94,1794 @@ TWriteSession::~TWriteSession() { } //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -// TSimpleBlockingWriteSession +// TKeyedWriteSession -TSimpleBlockingWriteSession::TSimpleBlockingWriteSession( - const TWriteSessionSettings& settings, - std::shared_ptr client, - std::shared_ptr connections, - TDbDriverStatePtr dbDriverState -) { - auto subSettings = settings; - if (settings.EventHandlers_.AcksHandler_) { - LOG_LAZY(dbDriverState->Log, TLOG_WARNING, "TSimpleBlockingWriteSession: Cannot use AcksHandler, resetting."); - subSettings.EventHandlers_.AcksHandler({}); +// TKeyedWriteSessionSettings + +std::string TKeyedWriteSessionSettings::DefaultPartitioningKeyHasher(const std::string_view key) { + const std::uint64_t lo = MurmurHash(key.data(), key.size(), std::uint64_t{0}); + const std::uint64_t hi = MurmurHash(key.data(), key.size(), std::uint64_t{0x9E3779B97F4A7C15ull}); // fixed seed + + const std::uint64_t hiBe = InetToHost(hi); + const std::uint64_t loBe = InetToHost(lo); + + std::string out; + out.resize(16); + memcpy(out.data() + 0, &hiBe, 8); + memcpy(out.data() + 8, &loBe, 8); + return out; // 16 bytes +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// TKeyedWriteSession::TPartitionInfo + +bool TKeyedWriteSession::TPartitionInfo::InRange(const std::string_view key) const { + if (FromBound_ > key) { + return false; } - if (settings.EventHandlers_.ReadyToAcceptHandler_) { - LOG_LAZY(dbDriverState->Log, TLOG_WARNING, "TSimpleBlockingWriteSession: Cannot use ReadyToAcceptHandler, resetting."); - subSettings.EventHandlers_.ReadyToAcceptHandler({}); + if (ToBound_.has_value() && *ToBound_ <= key) { + return false; } - if (settings.EventHandlers_.SessionClosedHandler_) { - LOG_LAZY(dbDriverState->Log, TLOG_WARNING, "TSimpleBlockingWriteSession: Cannot use SessionClosedHandler, resetting."); - subSettings.EventHandlers_.SessionClosedHandler({}); + return true; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// TKeyedWriteSession::TMessageInfo + +TKeyedWriteSession::TMessageInfo::TMessageInfo(const std::string& key, TWriteMessage&& message, std::uint32_t partition, TTransactionBase* tx) + : Key(key) + , Data(message.Data) + , Codec(message.Codec) + , OriginalSize(message.OriginalSize) + , SeqNo(message.SeqNo_) + , CreateTimestamp(message.CreateTimestamp_) + , TxInMessage(message.Tx_) + , Tx(tx) + , Partition(partition) +{ + for (const auto& [key, value] : message.MessageMeta_) { + MessageMeta.Fields.emplace_back(key, value); } - if (settings.EventHandlers_.CommonHandler_) { - LOG_LAZY(dbDriverState->Log, TLOG_WARNING, "TSimpleBlockingWriteSession: Cannot use CommonHandler, resetting."); - subSettings.EventHandlers_.CommonHandler({}); +} + +TWriteMessage TKeyedWriteSession::TMessageInfo::BuildMessage() const { + TWriteMessage message(Data); + message.Codec = Codec; + message.OriginalSize = OriginalSize; + message.SeqNo(SeqNo); + message.CreateTimestamp(CreateTimestamp); + for (const auto& [key, value] : MessageMeta.Fields) { + message.MessageMeta_.emplace_back(key, value); } - Writer = std::make_shared(subSettings, client, connections, dbDriverState); - Writer->Start(TDuration::Zero()); + message.Tx(TxInMessage); + return message; } +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// TKeyedWriteSession::TWriteSessionWrapper -uint64_t TSimpleBlockingWriteSession::GetInitSeqNo() { - return Writer->GetInitSeqNo().GetValueSync(); +TKeyedWriteSession::TWriteSessionWrapper::TWriteSessionWrapper(WriteSessionPtr session, std::uint32_t partition) + : Session(std::move(session)) + , Partition(partition) + , QueueSize(0) +{} + +bool TKeyedWriteSession::TWriteSessionWrapper::IsQueueEmpty() const { + return QueueSize == 0; } -bool TSimpleBlockingWriteSession::Write( - std::string_view data, std::optional seqNo, std::optional createTimestamp, const TDuration& blockTimeout -) { - auto message = TWriteMessage(std::move(data)) - .SeqNo(seqNo) - .CreateTimestamp(createTimestamp); - return Write(std::move(message), nullptr, blockTimeout); +bool TKeyedWriteSession::TWriteSessionWrapper::AddToQueue(std::uint64_t delta) { + bool idle = QueueSize == 0; + QueueSize += delta; + return idle; } -bool TSimpleBlockingWriteSession::Write( - TWriteMessage&& message, TTransactionBase* tx, const TDuration& blockTimeout -) { - auto continuationToken = WaitForToken(blockTimeout); - if (continuationToken.has_value()) { - Writer->Write(std::move(*continuationToken), std::move(message), tx); - return true; +bool TKeyedWriteSession::TWriteSessionWrapper::RemoveFromQueue(std::uint64_t delta) { + Y_ABORT_UNLESS(QueueSize >= delta, "RemoveFromQueue: underflow"); + QueueSize -= delta; + return QueueSize == 0; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// TKeyedWriteSession::TIdleSession + +bool TKeyedWriteSession::TIdleSession::Less(const std::shared_ptr& other) const { + if (EmptySince == other->EmptySince) { + return Session->Partition < other->Session->Partition; + } + + return EmptySince < other->EmptySince; +} + +bool TKeyedWriteSession::TIdleSession::Comparator::operator()( + const std::shared_ptr& first, + const std::shared_ptr& second) const { + return first->Less(second); +} + +bool TKeyedWriteSession::TIdleSession::IsExpired() const { + return TInstant::Now() - EmptySince > IdleTimeout; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// TKeyedWriteSession::TSplittedPartitionWorker + +TKeyedWriteSession::TSplittedPartitionWorker::TSplittedPartitionWorker(TKeyedWriteSession* session, std::uint32_t partitionId) + : Session(session) + , PartitionId(partitionId) +{ + LOG_LAZY(Session->DbDriverState->Log, TLOG_INFO, Session->LogPrefix() << "Creating splitted partition worker for partition " << PartitionId); +} + +std::string TKeyedWriteSession::TSplittedPartitionWorker::GetStateName() const { + switch (State) { + case EState::Init: + return "Init"; + case EState::PendingDescribe: + return "PendingDescribe"; + case EState::GotDescribe: + return "GotDescribe"; + case EState::PendingMaxSeqNo: + return "PendingMaxSeqNo"; + case EState::Done: + return "Done"; + case EState::GotMaxSeqNo: + return "GotMaxSeqNo"; + } +} + +void TKeyedWriteSession::TSplittedPartitionWorker::DoWork() { + std::unique_lock lock(Lock); + std::weak_ptr session = Session->shared_from_this(); + switch (State) { + case EState::Init: + DescribeTopicFuture = Session->Client->DescribeTopic(Session->Settings.Path_, TDescribeTopicSettings()); + lock.unlock(); + DescribeTopicFuture.Subscribe([this, session](const NThreading::TFuture&) { + auto sessionPtr = session.lock(); + if (!sessionPtr) { + return; + } + + { + std::lock_guard lock(Lock); + MoveTo(EState::GotDescribe); + } + + sessionPtr->RunMainWorker(); + }); + lock.lock(); + if (State == EState::Init) { + MoveTo(EState::PendingDescribe); + } + break; + case EState::GotDescribe: + HandleDescribeResult(); + if (State != EState::GotDescribe) { + break; + } + + LaunchGetMaxSeqNoFutures(lock); + if (State == EState::GotDescribe) { + MoveTo(EState::PendingMaxSeqNo); + } + break; + case EState::PendingDescribe: + case EState::PendingMaxSeqNo: + case EState::Done: + break; + case EState::GotMaxSeqNo: + Session->MessagesWorker->RebuildPendingMessagesIndex(PartitionId); + Session->MessagesWorker->ScheduleResendMessages(PartitionId, MaxSeqNo); + for (const auto& child : Session->Partitions[PartitionId].Children_) { + Session->Partitions[child].Locked(false); + } + Session->Partitions[PartitionId].Locked_ = false; + MoveTo(EState::Done); + break; + } +} + +void TKeyedWriteSession::TSplittedPartitionWorker::MoveTo(EState state) { + State = state; + LOG_LAZY(Session->DbDriverState->Log, TLOG_INFO, Session->LogPrefix() << "Moving splitted partition worker for partition " << PartitionId << " to state " << GetStateName()); +} + +void TKeyedWriteSession::TSplittedPartitionWorker::UpdateMaxSeqNo(std::uint64_t maxSeqNo) { + MaxSeqNo = std::max(MaxSeqNo, maxSeqNo); +} + +bool TKeyedWriteSession::TSplittedPartitionWorker::IsDone() { + std::lock_guard lock(Lock); + return State == EState::Done; +} + +bool TKeyedWriteSession::TSplittedPartitionWorker::IsInit() { + std::lock_guard lock(Lock); + return State == EState::Init; +} + +void TKeyedWriteSession::TSplittedPartitionWorker::HandleDescribeResult() { + std::vector newPartitionsIds; + const auto& partitions = DescribeTopicFuture.GetValue().GetTopicDescription().GetPartitions(); + for (const auto& partition : partitions) { + if (partition.GetPartitionId() != PartitionId) { + continue; + } + + LOG_LAZY(Session->DbDriverState->Log, TLOG_ERR, Session->LogPrefix() << "Found partition " << partition.GetPartitionId() << " for partition " << PartitionId << " children: " << partition.GetChildPartitionIds().size()); + for (const auto& childPartitionId : partition.GetChildPartitionIds()) { + newPartitionsIds.push_back(childPartitionId); + } + break; + } + + if (newPartitionsIds.empty()) { + // describe response is incomplete, we need to resend describe request + MoveTo(EState::Init); + Y_ABORT_UNLESS(++Retries < 40, "Too many retries for partition %u", PartitionId); + LOG_LAZY(Session->DbDriverState->Log, TLOG_ERR, Session->LogPrefix() << "Describe response is incomplete, we need to resend describe request for partition " << PartitionId); + return; + } + + std::vector children; + const auto& splittedPartition = Session->Partitions[PartitionId]; + Session->PartitionsIndex.erase(splittedPartition.FromBound_); + + for (const auto& newPartitionId : newPartitionsIds) { + auto partitionDescribeInfo = std::find_if(partitions.begin(), partitions.end(), [newPartitionId](const auto& partition) { + return partition.GetPartitionId() == newPartitionId; + }); + Y_ABORT_UNLESS(partitionDescribeInfo != partitions.end(), "Partition describe info not found"); + Session->PartitionsIndex[partitionDescribeInfo->GetFromBound().value_or("")] = newPartitionId; + Session->Partitions[newPartitionId] = TPartitionInfo() + .PartitionId(newPartitionId) + .FromBound(partitionDescribeInfo->GetFromBound().value_or("")) + .ToBound(partitionDescribeInfo->GetToBound()) + .Locked(true); + children.push_back(newPartitionId); + } + + Session->Partitions[PartitionId].Children(children); +} + +void TKeyedWriteSession::TSplittedPartitionWorker::LaunchGetMaxSeqNoFutures(std::unique_lock& lock) { + Y_ABORT_UNLESS(DescribeTopicFuture.IsReady(), "DescribeTopicFuture is not ready yet"); + + std::unordered_map partitionIdToParentId; + const auto& partitions = DescribeTopicFuture.GetValue().GetTopicDescription().GetPartitions(); + for (const auto& partition : partitions) { + auto parentPartitions = partition.GetParentPartitionIds(); + if (parentPartitions.empty()) { + continue; + } + + // we consider here that each partition has only one parent partition + partitionIdToParentId[partition.GetPartitionId()] = parentPartitions.front(); + } + + std::vector ancestors; + std::uint32_t currentPartitionId = PartitionId; + while (true) { + ancestors.push_back(currentPartitionId); + + auto parentPartitionId = partitionIdToParentId.find(currentPartitionId); + if (parentPartitionId == partitionIdToParentId.end()) { + break; + } + currentPartitionId = parentPartitionId->second; + } + + NotReadyFutures = ancestors.size(); + for (const auto& ancestor : ancestors) { + auto wrappedSession = Session->SessionsWorker->GetWriteSession(ancestor, false); + Y_ABORT_UNLESS(wrappedSession, "Write session not found"); + WriteSessions.push_back(wrappedSession); + + auto future = wrappedSession->Session->GetInitSeqNo(); + std::weak_ptr session = Session->shared_from_this(); + lock.unlock(); + future.Subscribe([this, session, wrappedSession, ancestor](const NThreading::TFuture& result) { + auto sessionPtr = session.lock(); + if (!sessionPtr) { + return; + } + + if (IsDone()) { + return; + } + + bool gotMaxSeqNo = false; + { + std::lock_guard lock(Lock); + if (result.HasException()) { + LOG_LAZY(sessionPtr->DbDriverState->Log, TLOG_ERR, sessionPtr->LogPrefix() << "Failed to get max seq no for partition " << ancestor << " for splitted partition " << PartitionId); + TSessionClosedEvent sessionClosedEvent(EStatus::INTERNAL_ERROR, {}); + sessionPtr->GetSessionClosedEventAndDie(wrappedSession, std::move(sessionClosedEvent)); + MoveTo(EState::Done); + return; + } + + UpdateMaxSeqNo(result.GetValue()); + if (--NotReadyFutures == 0) { + MoveTo(EState::GotMaxSeqNo); + gotMaxSeqNo = true; + } + } + + if (gotMaxSeqNo) { + sessionPtr->RunMainWorker(); + } + }); + lock.lock(); + GetMaxSeqNoFutures.push_back(future); + } + + if (ancestors.empty()) { + LOG_LAZY(Session->DbDriverState->Log, TLOG_INFO, Session->LogPrefix() << "No ancestors found for partition " << PartitionId); + MoveTo(EState::Init); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// TKeyedWriteSession::TEventsWorkerWrapper + +TKeyedWriteSession::TEventsWorker::TEventsWorker(TKeyedWriteSession* session) + : Session(session) +{ + EventsPromise = NThreading::NewPromise(); + EventsFuture = EventsPromise.GetFuture(); + + AddReadyToAcceptEvent(); +} + +void TKeyedWriteSession::TEventsWorker::HandleAcksEvent(std::uint64_t partition, TWriteSessionEvent::TAcksEvent&& event) { + auto [queueIt, _] = PartitionsEventQueues.try_emplace(partition); + queueIt->second.push_back(TWriteSessionEvent::TEvent(std::move(event))); +} + +void TKeyedWriteSession::TEventsWorker::HandleReadyToAcceptEvent(std::uint32_t partition, TWriteSessionEvent::TReadyToAcceptEvent&& event) { + Session->MessagesWorker->HandleContinuationToken(partition, std::move(event.ContinuationToken)); +} + +void TKeyedWriteSession::TEventsWorker::HandleSessionClosedEvent(TSessionClosedEvent&& event, std::uint32_t partition) { + if (event.IsSuccess()) { + return; + } + + Session->Partitions[partition].Locked_ = true; + + if (event.GetStatus() == EStatus::OVERLOADED) { + Session->HandleAutoPartitioning(partition); + return; + } + + if (!CloseEvent.has_value()) { + CloseEvent = std::move(event); + } + Session->NonBlockingClose(); +} + +bool TKeyedWriteSession::TEventsWorker::RunEventLoop(WrappedWriteSessionPtr wrappedSession, std::uint32_t partition) { + while (true) { + auto event = wrappedSession->Session->GetEvent(false); + if (!event) { + break; + } + + if (auto sessionClosedEvent = std::get_if(&*event); sessionClosedEvent) { + HandleSessionClosedEvent(std::move(*sessionClosedEvent), partition); + return true; + } + + if (auto readyToAcceptEvent = std::get_if(&*event)) { + HandleReadyToAcceptEvent(partition, std::move(*readyToAcceptEvent)); + continue; + } + + if (auto acksEvent = std::get_if(&*event)) { + Session->SessionsWorker->OnReadFromSession(wrappedSession); + HandleAcksEvent(partition, std::move(*acksEvent)); + continue; + } } + return false; } -std::optional TSimpleBlockingWriteSession::WaitForToken(const TDuration& timeout) { - TInstant startTime = TInstant::Now(); - TDuration remainingTime = timeout; +std::optional> TKeyedWriteSession::TEventsWorker::DoWork() { + std::unique_lock lock(Lock); + + while (!ReadyFutures.empty()) { + auto idx = *ReadyFutures.begin(); + ReadyFutures.erase(idx); + lock.unlock(); + // RunEventLoop without Lock: sub-session's WaitEvent() completion may run the Subscribe + // callback (ReadyFutures.insert) synchronously; that callback takes Lock -> same-thread deadlock. + auto isSessionClosed = RunEventLoop(Session->SessionsWorker->GetWriteSession(idx), idx); + if (!isSessionClosed) { + SubscribeToPartition(idx); + } else { + UnsubscribeFromPartition(idx); + } + lock.lock(); + } + + if (!Session->Done.load() && TransferEventsToOutputQueue()) { + return EventsPromise; + } + + return std::nullopt; +} + +void TKeyedWriteSession::TEventsWorker::SubscribeToPartition(std::uint32_t partition) { + if (auto it = Session->SplittedPartitionWorkers.find(partition); it != Session->SplittedPartitionWorkers.end()) { + Session->Partitions[partition].Future(NThreading::MakeFuture()); + return; + } + + auto wrappedSession = Session->SessionsWorker->GetWriteSession(partition); + auto newFuture = wrappedSession->Session->WaitEvent(); + std::weak_ptr session = Session->shared_from_this(); + std::weak_ptr self = shared_from_this(); + + newFuture.Subscribe([self, session, partition](const NThreading::TFuture&) { + auto sessionPtr = session.lock(); + if (!sessionPtr) { + return; + } + + auto selfPtr = self.lock(); + if (!selfPtr) { + return; + } + + { + std::lock_guard lock(selfPtr->Lock); + selfPtr->ReadyFutures.insert(partition); + } + sessionPtr->RunMainWorker(); + }); + Session->Partitions[partition].Future(newFuture); +} + +std::optional> TKeyedWriteSession::TEventsWorker::HandleNewMessage() { + std::lock_guard lock(Lock); + if (Session->MessagesWorker->IsMemoryUsageOK()) { + AddReadyToAcceptEvent(); + return EventsPromise; + } + + return std::nullopt; +} + +void TKeyedWriteSession::TEventsWorker::AddReadyToAcceptEvent() { + EventsOutputQueue.push_back(TWriteSessionEvent::TReadyToAcceptEvent(IssueContinuationToken())); +} + +bool TKeyedWriteSession::TEventsWorker::AddSessionClosedIfNeeded() { + if (!Session->Closed.load()) { + return false; + } + + if (!CloseEvent.has_value()) { + CloseEvent = TSessionClosedEvent(EStatus::SUCCESS, {}); + } + + if (EventsOutputQueue.empty() && (Session->MessagesWorker->IsQueueEmpty() || Session->Done.load())) { + EventsOutputQueue.push_back(*CloseEvent); + return true; + } - std::optional token = std::nullopt; + return false; +} - while (IsAlive() && remainingTime > TDuration::Zero()) { - Writer->WaitEvent().Wait(remainingTime); +bool TKeyedWriteSession::TEventsWorker::TransferEventsToOutputQueue() { + bool eventsTransferred = false; + bool shouldAddReadyToAcceptEvent = false; + std::unordered_map> acks; - for (auto event : Writer->GetEvents()) { - if (auto* readyEvent = std::get_if(&event)) { - Y_ABORT_UNLESS(!token.has_value()); - token = std::move(readyEvent->ContinuationToken); - } else if (std::get_if(&event)) { - // discard - } else if (std::get_if(&event)) { - Closed.store(true); - return std::nullopt; + auto messagesWorker = Session->MessagesWorker; + auto buildOutputAckEvent = [&](std::deque& acksQueue, std::uint64_t partition, std::optional expectedSeqNo) -> TWriteSessionEvent::TAcksEvent { + TWriteSessionEvent::TAcksEvent ackEvent; + + if (expectedSeqNo.has_value()) { + if (acksQueue.front().SeqNo != expectedSeqNo.value()) { + LOG_LAZY(Session->DbDriverState->Log, TLOG_ERR, Session->LogPrefix() << "Expected seqNo=" << expectedSeqNo.value() << " but got " << acksQueue.front().SeqNo << " for partition " << partition); } + Y_ENSURE(acksQueue.front().SeqNo == expectedSeqNo.value(), TStringBuilder() << "Expected seqNo=" << expectedSeqNo.value() << " but got " << acksQueue.front().SeqNo << " for partition " << Session->Partitions[partition].PartitionId_); + } + + auto ack = std::move(acksQueue.front()); + ackEvent.Acks.push_back(std::move(ack)); + acksQueue.pop_front(); + return ackEvent; + }; + auto finishWithAck = [messagesWorker, &shouldAddReadyToAcceptEvent]() { + bool wasMemoryUsageOk = messagesWorker->IsMemoryUsageOK(); + messagesWorker->HandleAck(); + if (messagesWorker->IsMemoryUsageOK() && !wasMemoryUsageOk) { + shouldAddReadyToAcceptEvent = true; + } + }; + + while (messagesWorker->HasInFlightMessages()) { + const auto& head = messagesWorker->GetFrontInFlightMessage(); + + auto remainingAcks = acks.find(head.Partition); + if (remainingAcks != acks.end() && remainingAcks->second.size() > 0) { + EventsOutputQueue.push_back(buildOutputAckEvent(remainingAcks->second, head.Partition, head.SeqNo)); + finishWithAck(); + continue; + } + + auto eventsQueueIt = PartitionsEventQueues.find(head.Partition); + if (eventsQueueIt == PartitionsEventQueues.end() || eventsQueueIt->second.empty()) { + // No events for this message yet, stop processing (preserve order) + break; + } + + auto event = std::move(eventsQueueIt->second.front()); + auto acksEvent = std::get_if(&event); + Y_ABORT_UNLESS(acksEvent, "Expected AcksEvent only in PartitionsEventQueues"); + + std::deque acksQueue; + std::copy(acksEvent->Acks.begin(), acksEvent->Acks.end(), std::back_inserter(acksQueue)); + EventsOutputQueue.push_back(buildOutputAckEvent(acksQueue, head.Partition, head.SeqNo)); + acks[head.Partition] = std::move(acksQueue); + eventsQueueIt->second.pop_front(); + eventsTransferred = true; + + finishWithAck(); + } + + // this case handles situation: + // 1st message is written to partition 0 + // 2nd message is written to partition 1 + // 3rd message is written to partition 0 + // 4th message is written to partition 1 + // but AcksEvent for partition 0 looks like: + // [ack1, ack3] + // In this case we can not just forget about ack3, because 3rd message is in-flight + // so we will push 'AcksEvent' back to the queue for partition 0 + for (auto& [partition, acksQueue] : acks) { + if (acksQueue.size() > 0) { + TWriteSessionEvent::TAcksEvent ackEvent; + std::copy(acksQueue.begin(), acksQueue.end(), std::back_inserter(ackEvent.Acks)); + PartitionsEventQueues[partition].push_front(std::move(ackEvent)); } + } - if (token.has_value()) { - return token; + if (shouldAddReadyToAcceptEvent) { + AddReadyToAcceptEvent(); + } + + return eventsTransferred; +} + +std::list::iterator TKeyedWriteSession::TEventsWorker::AckQueueBegin(std::uint32_t partition) { + auto [queueIt, _] = PartitionsEventQueues.try_emplace(partition); + return queueIt->second.begin(); +} + +std::list::iterator TKeyedWriteSession::TEventsWorker::AckQueueEnd(std::uint32_t partition) { + auto [queueIt, _] = PartitionsEventQueues.try_emplace(partition); + return queueIt->second.end(); +} + +TKeyedWriteSession::TEventsWorker::EEventType TKeyedWriteSession::TEventsWorker::GetEventType(const TWriteSessionEvent::TEvent& event) { + if (std::holds_alternative(event)) { + return EEventType::SessionClosed; + } else if (std::holds_alternative(event)) { + return EEventType::ReadyToAccept; + } else if (std::holds_alternative(event)) { + return EEventType::Ack; + } + + Y_ABORT_UNLESS(false, "Unexpected event type"); +} + +std::optional TKeyedWriteSession::TEventsWorker::GetEventImpl(bool block, const std::vector& eventTypes) { + std::unique_lock lock(Lock); + if (EventsOutputQueue.empty() && block) { + lock.unlock(); + WaitEvent().Wait(); + lock.lock(); + } + + if (!EventsOutputQueue.empty()) { + if (!eventTypes.empty() && std::find(eventTypes.begin(), eventTypes.end(), GetEventType(EventsOutputQueue.front())) == eventTypes.end()) { + return std::nullopt; } - remainingTime = timeout - (TInstant::Now() - startTime); + auto event = std::move(EventsOutputQueue.front()); + EventsOutputQueue.pop_front(); + return event; } return std::nullopt; } -TWriterCounters::TPtr TSimpleBlockingWriteSession::GetCounters() { - return Writer->GetCounters(); +std::optional TKeyedWriteSession::TEventsWorker::GetEvent(bool block, const std::vector& eventTypes) { + { + std::unique_lock lock(Lock); + AddSessionClosedIfNeeded(); + } + auto event = GetEventImpl(block, eventTypes); + + return event; } -bool TSimpleBlockingWriteSession::IsAlive() const { - return !Closed.load(); +std::vector TKeyedWriteSession::TEventsWorker::GetEvents(bool block, std::optional maxEventsCount, const std::vector& eventTypes) { + if (maxEventsCount.has_value() && maxEventsCount.value() == 0) { + return {}; + } + + { + std::unique_lock lock(Lock); + AddSessionClosedIfNeeded(); + } + + std::vector events; + while (true) { + auto event = GetEventImpl(block, eventTypes); + if (!event) { + break; + } + + events.push_back(std::move(*event)); + if (maxEventsCount.has_value() && events.size() >= maxEventsCount.value()) { + break; + } + } + + return events; } -bool TSimpleBlockingWriteSession::Close(TDuration closeTimeout) { - Closed.store(true); - return Writer->Close(std::move(closeTimeout)); +NThreading::TFuture TKeyedWriteSession::TEventsWorker::WaitEvent() { + std::unique_lock lock(Lock); + + AddSessionClosedIfNeeded(); + if (!EventsOutputQueue.empty()) { + return NThreading::MakeFuture(); + } + + if (EventsFuture.IsReady() && !Session->Closed.load()) { + EventsPromise = NThreading::NewPromise(); + EventsFuture = EventsPromise.GetFuture(); + } + + return EventsFuture; +} + +void TKeyedWriteSession::TEventsWorker::UnsubscribeFromPartition(std::uint32_t partition) { + ReadyFutures.erase(partition); + Session->Partitions[partition].Future(NThreading::MakeFuture()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// TKeyedWriteSession::TSessionsWorker + +TKeyedWriteSession::TSessionsWorker::TSessionsWorker(TKeyedWriteSession* session) + : Session(session) +{} + +TKeyedWriteSession::WrappedWriteSessionPtr TKeyedWriteSession::TSessionsWorker::GetWriteSession(std::uint32_t partition, bool directToPartition) { + auto sessionIter = SessionsIndex.find(partition); + if (sessionIter == SessionsIndex.end() || !directToPartition) { + return CreateWriteSession(partition, directToPartition); + } + + return sessionIter->second; +} + +std::string TKeyedWriteSession::TSessionsWorker::GetProducerId(std::uint32_t partitionId) { + return std::format("{}_{}", Session->Settings.ProducerIdPrefix_, partitionId); +} + +TKeyedWriteSession::WrappedWriteSessionPtr TKeyedWriteSession::TSessionsWorker::CreateWriteSession(std::uint32_t partition, bool directToPartition) { + auto partitionId = Session->Partitions[partition].PartitionId_; + auto producerId = GetProducerId(partitionId); + TWriteSessionSettings alteredSettings = Session->Settings; + + alteredSettings + .ProducerId(producerId) + .MessageGroupId(producerId) + .MaxMemoryUsage(std::numeric_limits::max()) + .RetryPolicy(Session->RetryPolicy) + .EventHandlers(TWriteSessionSettings::TEventHandlers() + .ReadyToAcceptHandler({}) + .AcksHandler({}) + .SessionClosedHandler({})); + + if (directToPartition) { + alteredSettings.DirectWriteToPartition(true); + alteredSettings.PartitionId(partitionId); + } + auto writeSession = std::make_shared( + Session->Client->CreateWriteSession(alteredSettings), + partition); + + if (directToPartition) { + SessionsIndex.emplace(partition, writeSession); + Session->EventsWorker->SubscribeToPartition(partition); + } + return writeSession; +} + +void TKeyedWriteSession::TSessionsWorker::DestroyWriteSession(TSessionsIndexIterator& it, TDuration closeTimeout, bool mustBeEmpty) { + if (it == SessionsIndex.end() || !it->second) { + return; + } + + auto closeResult = it->second->Session->Close(closeTimeout); + Y_ABORT_UNLESS(!mustBeEmpty || closeResult, "There are still messages in flight"); + const auto partition = it->second->Partition; + it = SessionsIndex.erase(it); + Session->EventsWorker->UnsubscribeFromPartition(partition); +} + +void TKeyedWriteSession::TSessionsWorker::OnReadFromSession(WrappedWriteSessionPtr wrappedSession) { + if (wrappedSession->RemoveFromQueue(1)) { + Y_ABORT_UNLESS(!wrappedSession->IdleSession, "IdleSession is already set"); + auto idleSessionPtr = std::make_shared(wrappedSession.get(), TInstant::Now(), Session->Settings.SubSessionIdleTimeout_); + auto [itIdle, inserted] = IdlerSessions.insert(idleSessionPtr); + Y_ABORT_UNLESS(inserted, "Duplicate idle session for partition"); + IdlerSessionsIndex[wrappedSession->Partition] = itIdle; + wrappedSession->IdleSession = idleSessionPtr; + } +} + +void TKeyedWriteSession::TSessionsWorker::OnWriteToSession(WrappedWriteSessionPtr wrappedSession) { + if (wrappedSession->AddToQueue(1) && wrappedSession->IdleSession) { + auto itIdle = IdlerSessionsIndex.find(wrappedSession->Partition); + if (itIdle != IdlerSessionsIndex.end()) { + IdlerSessions.erase(itIdle->second); + IdlerSessionsIndex.erase(itIdle); + } + wrappedSession->IdleSession.reset(); + } +} + +void TKeyedWriteSession::TSessionsWorker::DoWork() { + while (!IdlerSessions.empty()) { + auto it = IdlerSessions.begin(); + if (!(*it)->IsExpired()) { + break; + } + + const auto partition = (*it)->Session->Partition; + + // Remove idle tracking first to keep containers consistent even if the session + // is already absent from SessionsIndex. + IdlerSessions.erase(it); + IdlerSessionsIndex.erase(partition); + + auto sessionIter = SessionsIndex.find(partition); + if (sessionIter != SessionsIndex.end()) { + sessionIter->second->IdleSession.reset(); + DestroyWriteSession(sessionIter, TDuration::Zero()); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// TKeyedWriteSession::TMessagesWorker + +TKeyedWriteSession::TMessagesWorker::TMessagesWorker(TKeyedWriteSession* session) + : Session(session) +{ +} + +void TKeyedWriteSession::TMessagesWorker::RechoosePartitionIfNeeded(MessageIter message) { + const auto& partitionInfo = Session->Partitions[message->Partition]; + if (partitionInfo.Children_.empty()) { + return; + } + + // this case means that partition was split, so we need to rechoose the partition for the message + auto newPartition = Session->PartitionChooser->ChoosePartition(message->Key); + message->Partition = newPartition; +} + +void TKeyedWriteSession::TMessagesWorker::DoWork() { + auto sessionsWorker = Session->SessionsWorker; + + auto iterateMessagesIndex = [&](std::unordered_map>& messagesIndex, auto stopCondition) { + std::vector partitionsProcessed; + for (auto& [partition, messages] : messagesIndex) { + while (!messages.empty()) { + auto head = messages.front(); + if (stopCondition(head)) { + break; + } + + auto wrappedSession = sessionsWorker->GetWriteSession(head->Partition); + if (!SendMessage(wrappedSession, *head)) { + break; + } + + Session->Metrics.AddWriteLag((TInstant::Now() - head->CreateTimestamp.value_or(TInstant::Now())).MilliSeconds()); + head->Sent = true; + sessionsWorker->OnWriteToSession(wrappedSession); + messages.pop_front(); + } + + if (messages.empty()) { + partitionsProcessed.push_back(partition); + } + } + + for (const auto& partition : partitionsProcessed) { + messagesIndex.erase(partition); + } + }; + + iterateMessagesIndex( + MessagesToResendIndex, + [](MessageIter) { + return false; + } + ); + + iterateMessagesIndex( + PendingMessagesIndex, + [this](MessageIter head) { + return Session->Partitions[head->Partition].Locked_ || + MessagesToResendIndex.contains(head->Partition); + } + ); +} + +bool TKeyedWriteSession::TMessagesWorker::SendMessage(WrappedWriteSessionPtr wrappedSession, const TMessageInfo& message) { + if (!wrappedSession) { + return false; + } + + auto continuationToken = GetContinuationToken(message.Partition); + if (!continuationToken) { + return false; + } + + wrappedSession->Session->Write(std::move(*continuationToken), message.BuildMessage(), message.Tx); + return true; +} + +void TKeyedWriteSession::TMessagesWorker::PushInFlightMessage(std::uint32_t partition, TMessageInfo&& message) { + auto iter = InFlightMessages.insert(InFlightMessages.end(), std::move(message)); + auto [inFlightMessagesIndexIt, _] = InFlightMessagesIndex.try_emplace(partition); + inFlightMessagesIndexIt->second.push_back(iter); + + auto [pendingMessagesIndexIt, __] = PendingMessagesIndex.try_emplace(partition); + pendingMessagesIndexIt->second.push_back(iter); +} + +void TKeyedWriteSession::TMessagesWorker::HandleAck() { + PopInFlightMessage(); +} + +void TKeyedWriteSession::TMessagesWorker::PopInFlightMessage() { + Y_ABORT_UNLESS(!InFlightMessages.empty()); + const std::uint64_t partition = InFlightMessages.front().Partition; + const auto it = InFlightMessages.begin(); + + auto mapIt = InFlightMessagesIndex.find(partition); + if (mapIt != InFlightMessagesIndex.end()) { + auto& list = mapIt->second; + for (auto listIt = list.begin(); listIt != list.end(); ++listIt) { + if (*listIt == it) { + list.erase(listIt); + break; + } + } + if (list.empty()) { + InFlightMessagesIndex.erase(mapIt); + } + } + + Y_ABORT_UNLESS(it->Data.size() <= MemoryUsage, "MemoryUsage is less than the size of the message"); + MemoryUsage -= it->Data.size(); + InFlightMessages.pop_front(); +} + +bool TKeyedWriteSession::TMessagesWorker::IsMemoryUsageOK() const { + return MemoryUsage <= Session->Settings.MaxMemoryUsage_ / 2; +} + +void TKeyedWriteSession::TMessagesWorker::AddMessage(const std::string& key, TWriteMessage&& message, std::uint32_t partition, TTransactionBase* tx) { + MemoryUsage += message.Data.size(); + PushInFlightMessage(partition, TMessageInfo(key, std::move(message), partition, tx)); +} + +std::optional TKeyedWriteSession::TMessagesWorker::GetContinuationToken(std::uint32_t partition) { + auto it = ContinuationTokens.find(partition); + if (it != ContinuationTokens.end() && !it->second.empty()) { + auto token = std::move(it->second.front()); + it->second.pop_front(); + if (it->second.empty()) { + ContinuationTokens.erase(it); + } + return token; + } + + return std::nullopt; +} + +void TKeyedWriteSession::TMessagesWorker::HandleContinuationToken(std::uint32_t partition, TContinuationToken&& continuationToken) { + auto [it, _] = ContinuationTokens.try_emplace(partition); + it->second.push_back(std::move(continuationToken)); +} + +bool TKeyedWriteSession::TMessagesWorker::IsQueueEmpty() const { + return InFlightMessages.empty(); +} + +const TKeyedWriteSession::TMessageInfo& TKeyedWriteSession::TMessagesWorker::GetFrontInFlightMessage() const { + Y_ABORT_UNLESS(!InFlightMessages.empty()); + return InFlightMessages.front(); +} + +bool TKeyedWriteSession::TMessagesWorker::HasInFlightMessages() const { + return !InFlightMessages.empty(); +} + +void TKeyedWriteSession::TMessagesWorker::ScheduleResendMessages(std::uint32_t partition, std::uint64_t afterSeqNo) { + auto it = InFlightMessagesIndex.find(partition); + if (it == InFlightMessagesIndex.end()) { + return; + } + + auto& list = it->second; + auto resendIt = list.begin(); + auto ackQueueIt = Session->EventsWorker->AckQueueBegin(partition); + size_t ackIdx = 0; + auto ackQueueEnd = Session->EventsWorker->AckQueueEnd(partition); + std::vector acksToSend; + + while (resendIt != list.end()) { + if (!(*resendIt)->SeqNo.has_value() || (*resendIt)->SeqNo.value() > afterSeqNo) { + break; + } + + auto seqNo = (*resendIt)->SeqNo.value(); + if (ackQueueIt == ackQueueEnd) { + // this case can happen if the message was sent, but session was closed before the ack was received + TWriteSessionEvent::TWriteAck ack; + ack.SeqNo = seqNo; + acksToSend.push_back(std::move(ack)); + } else { + auto acksEvent = std::get_if(&*ackQueueIt); + if (ackIdx == acksEvent->Acks.size()) { + ++ackQueueIt; + ackIdx = 0; + continue; + } + + if (acksEvent->Acks[ackIdx].SeqNo > seqNo) { + // this case can happen if the message was sent, but session was closed before the ack was received + TWriteSessionEvent::TWriteAck ack; + ack.SeqNo = seqNo; + acksEvent->Acks.insert(acksEvent->Acks.begin() + ackIdx, std::move(ack)); + } + ++ackIdx; + } + ++resendIt; + } + + if (!acksToSend.empty()) { + TWriteSessionEvent::TAcksEvent event; + event.Acks = std::move(acksToSend); + Session->EventsWorker->HandleAcksEvent(partition, std::move(event)); + } + + // IMPORTANT: do not mutate InFlightMessagesIndex while holding references/iterators to its elements. + // try_emplace()/rehash may invalidate 'it' and 'list' -> use-after-free and segfaults. + std::vector> messagesFromOldPartition; + messagesFromOldPartition.reserve(std::distance(resendIt, list.end())); + auto currentSeqNo = resendIt != list.end() ? (*resendIt)->SeqNo.value_or(0) : 0; + for (auto iter = resendIt; iter != list.end(); ++iter) { + if (iter != resendIt && currentSeqNo != 0) { + Y_ABORT_UNLESS((*iter)->SeqNo.value_or(0) > currentSeqNo, "SeqNo is not increasing for partition %d", partition); + } + + auto newPartition = Session->PartitionChooser->ChoosePartition((*iter)->Key); + (*iter)->Partition = newPartition; + messagesFromOldPartition.emplace_back(newPartition, *iter); + + currentSeqNo = (*iter)->SeqNo.value_or(0); + } + + list.erase(resendIt, list.end()); + for (const auto& [newPartition, msgIt] : messagesFromOldPartition) { + auto [inFlightMessagesIndexChainIt, _] = InFlightMessagesIndex.try_emplace(newPartition); + inFlightMessagesIndexChainIt->second.push_back(msgIt); + + if (msgIt->Sent) { + auto [messagesToResendChainIt, __] = MessagesToResendIndex.try_emplace(newPartition); + messagesToResendChainIt->second.push_back(msgIt); + } + } + + InFlightMessagesIndex.erase(partition); +} + +void TKeyedWriteSession::TMessagesWorker::RebuildPendingMessagesIndex(std::uint32_t partition) { + auto [oldPendingMessagesIndexChainIt, __] = PendingMessagesIndex.try_emplace(partition); + std::unordered_map> pendingMessagesForNewPartitions; + for (auto it = oldPendingMessagesIndexChainIt->second.begin(); it != oldPendingMessagesIndexChainIt->second.end(); ++it) { + auto newPartition = Session->PartitionChooser->ChoosePartition((*it)->Key); + auto [pendingMessagesForNewPartitionsIt, __] = pendingMessagesForNewPartitions.try_emplace(newPartition); + pendingMessagesForNewPartitionsIt->second.push_back(*it); + } + + for (const auto& [newPartition, pendingMessagesForNewPartition] : pendingMessagesForNewPartitions) { + auto [pendingMessagesIndexChainIt, __] = PendingMessagesIndex.try_emplace(newPartition); + for (auto reverseIt = pendingMessagesForNewPartition.rbegin(); reverseIt != pendingMessagesForNewPartition.rend(); ++reverseIt) { + pendingMessagesIndexChainIt->second.push_front(*reverseIt); + } + } + + PendingMessagesIndex.erase(partition); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// TKeyedWriteSession::TKeyedWriteSessionRetryPolicy + +TKeyedWriteSession::TKeyedWriteSessionRetryPolicy::TKeyedWriteSessionRetryPolicy(TKeyedWriteSession* session) + : Session(session) +{} + +typename TKeyedWriteSession::TKeyedWriteSessionRetryPolicy::IRetryState::TPtr TKeyedWriteSession::TKeyedWriteSessionRetryPolicy::CreateRetryState() const { + struct TRetryState : public IRetryState { + TRetryState(TKeyedWriteSession* session) + : Session(session) + {} + ~TRetryState() = default; + TMaybe GetNextRetryDelay(EStatus status) override { + if (status == EStatus::OVERLOADED) { + return Nothing(); + } + + if (!UserRetryState) { + auto policy = Session->Settings.RetryPolicy_ ? Session->Settings.RetryPolicy_ : NYdb::NTopic::IRetryPolicy::GetDefaultPolicy(); + UserRetryState = policy->CreateRetryState(); + } + + return UserRetryState->GetNextRetryDelay(status); + } + + private: + TKeyedWriteSession* Session; + IRetryState::TPtr UserRetryState; + }; + + return std::make_unique(Session); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// TKeyedWriteSession::Metrics + +void TKeyedWriteSession::TMetricGauge::Add(std::uint64_t value) { + Sum += value; + MetricCount++; +} + +void TKeyedWriteSession::TMetricGauge::Clear() { + Sum = 0; + MetricCount = 0; +} + +long double TKeyedWriteSession::TMetricGauge::Average() { + if (MetricCount == 0) { + return 0; + } + + return (long double)Sum / (long double)MetricCount; +} + +TKeyedWriteSession::TMetrics::TMetrics(TKeyedWriteSession* session): Session(session) {} + +void TKeyedWriteSession::TMetrics::AddMainWorkerTime(std::uint64_t ms) { + std::lock_guard lock(Lock); + MainWorkerTimeMs.Add(ms); +} + +void TKeyedWriteSession::TMetrics::AddCycleTime(std::uint64_t ms) { + std::lock_guard lock(Lock); + CycleTimeMs.Add(ms); +} + +void TKeyedWriteSession::TMetrics::AddWriteLag(std::uint64_t lagMs) { + std::lock_guard lock(Lock); + WriteLagMs.Add(lagMs); +} + +void TKeyedWriteSession::TMetrics::PrintMetrics() { + std::lock_guard lock(Lock); + LOG_LAZY(Session->DbDriverState->Log, TLOG_ERR, Session->LogPrefix() << "METRICS: MainWorkerTimeMs: " << MainWorkerTimeMs.Average() << " ms, CycleTimeMs: " << CycleTimeMs.Average() << " ms, WriteLagMs: " << WriteLagMs.Average() << " ms"); + MainWorkerTimeMs.Clear(); + CycleTimeMs.Clear(); + WriteLagMs.Clear(); +} + + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// TKeyedWriteSession + +TKeyedWriteSession::TKeyedWriteSession( + const TKeyedWriteSessionSettings& settings, + std::shared_ptr client, + std::shared_ptr connections, + TDbDriverStatePtr dbDriverState) + : Connections(connections), + Client(client), + DbDriverState(dbDriverState), + Metrics(this), + Settings(settings) +{ + if (settings.ProducerIdPrefix_.empty()) { + ythrow TContractViolation("ProducerIdPrefix is required for KeyedWriteSession"); + } + + if (!settings.ProducerId_.empty()) { + ythrow TContractViolation("ProducerId should be empty for KeyedWriteSession, use ProducerIdPrefix instead"); + } + + if (!settings.MessageGroupId_.empty()) { + ythrow TContractViolation("MessageGroupId should be empty for KeyedWriteSession"); + } + + TDescribeTopicSettings describeTopicSettings; + auto topicConfig = client->DescribeTopic(settings.Path_, describeTopicSettings).GetValueSync(); + auto partitions = topicConfig.GetTopicDescription().GetPartitions(); + std::sort(partitions.begin(), partitions.end(), [](const auto& a, const auto& b) -> bool { + return a.GetPartitionId() < b.GetPartitionId(); + }); + + auto partitionChooserStrategy = settings.PartitionChooserStrategy_; + auto strategy = topicConfig.GetTopicDescription().GetPartitioningSettings().GetAutoPartitioningSettings().GetStrategy(); + auto autoPartitioningEnabled = (strategy != EAutoPartitioningStrategy::Disabled && + strategy != EAutoPartitioningStrategy::Unspecified); + + for (const auto& partition : partitions) { + auto partitionId = partition.GetPartitionId(); + auto fromBound = partition.GetFromBound().value_or(""); + auto toBound = partition.GetToBound(); + LOG_LAZY(DbDriverState->Log, TLOG_ERR, LogPrefix() << "Adding partition " << partitionId << " from bound " << fromBound << " to bound " << (toBound.has_value() ? toBound.value() : "null")); + Partitions[partitionId] = TPartitionInfo() + .PartitionId(partitionId) + .FromBound(fromBound) + .ToBound(toBound); + } + + for (const auto& partition : partitions) { + auto children = partition.GetChildPartitionIds(); + + std::vector childrenIndices; + childrenIndices.reserve(children.size()); + for (auto child : children) { + childrenIndices.push_back(child); + } + Partitions[partition.GetPartitionId()].Children(childrenIndices); + } + + if (Settings.EventHandlers_.CommonHandler_) { + EventTypesWithHandlers.push_back(TEventsWorker::EEventType::SessionClosed); + EventTypesWithHandlers.push_back(TEventsWorker::EEventType::ReadyToAccept); + EventTypesWithHandlers.push_back(TEventsWorker::EEventType::Ack); + } else { + if (Settings.EventHandlers_.SessionClosedHandler_) { + EventTypesWithHandlers.push_back(TEventsWorker::EEventType::SessionClosed); + } + if (Settings.EventHandlers_.ReadyToAcceptHandler_) { + EventTypesWithHandlers.push_back(TEventsWorker::EEventType::ReadyToAccept); + } + if (Settings.EventHandlers_.AcksHandler_) { + EventTypesWithHandlers.push_back(TEventsWorker::EEventType::Ack); + } + } + + switch (partitionChooserStrategy) { + case TKeyedWriteSessionSettings::EPartitionChooserStrategy::Bound: + PartitioningKeyHasher = settings.PartitioningKeyHasher_; + PartitionChooser = std::make_unique(this); + for (size_t i = 0; i < Partitions.size(); ++i) { + if (i > 0 && Partitions[i].FromBound_.empty() && !Partitions[i].ToBound_.has_value()) { + ythrow TContractViolation("Unbounded partition is not supported for Bound partition chooser strategy"); + } + + if (!Partitions[i].Children_.empty()) { + continue; + } + + PartitionsIndex[Partitions[i].FromBound_] = Partitions[i].PartitionId_; + } + break; + case TKeyedWriteSessionSettings::EPartitionChooserStrategy::Hash: + if (autoPartitioningEnabled) { + ythrow TContractViolation("Hash partition chooser strategy is not supported for topic with auto partitioning"); + } + + std::vector partitionsIds; + partitionsIds.reserve(partitions.size()); + for (const auto& partition : partitions) { + partitionsIds.push_back(partition.GetPartitionId()); + } + + PartitionChooser = std::make_unique(std::move(partitionsIds)); + break; + } + + ClosePromise = NThreading::NewPromise(); + CloseFuture = ClosePromise.GetFuture(); + ShutdownPromise = NThreading::NewPromise(); + ShutdownFuture = ShutdownPromise.GetFuture(); + + SessionsWorker = std::make_shared(this); + MessagesWorker = std::make_shared(this); + EventsWorker = std::make_shared(this); + RetryPolicy = std::make_shared(this); + + // Start handlers executor for user callbacks (Acks/ReadyToAccept/SessionClosed/Common). + Settings.EventHandlers_.HandlersExecutor_->Start(); + + CloseFuture.Subscribe([this](const NThreading::TFuture&) { + RunMainWorker(); + }); + + RunMainWorker(); + + LOG_LAZY(DbDriverState->Log, TLOG_INFO, LogPrefix() << "Keyed write session created"); +} + +std::vector TKeyedWriteSession::GetPartitions() const { + std::vector partitions; + partitions.reserve(Partitions.size()); + for (const auto& [partitionId, partitionInfo] : Partitions) { + partitions.push_back(partitionInfo); + } + return partitions; +} + +void TKeyedWriteSession::Write(TContinuationToken&&, const std::string& key, TWriteMessage&& message, TTransactionBase* tx) { + std::optional> eventsPromise; + { + std::lock_guard lock(GlobalLock); + if (Closed.load()) { + return; + } + + if ((message.SeqNo_.has_value() && SeqNoStrategy == ESeqNoStrategy::WithoutSeqNo) + || (!message.SeqNo_.has_value() && SeqNoStrategy == ESeqNoStrategy::WithSeqNo)) { + ythrow TContractViolation("Can not mix messages with and without seqNo"); + } + + if (SeqNoStrategy == ESeqNoStrategy::NotInitialized) { + SeqNoStrategy = message.SeqNo_.has_value() ? ESeqNoStrategy::WithSeqNo : ESeqNoStrategy::WithoutSeqNo; + } + + auto partition = PartitionChooser->ChoosePartition(key); + MessagesWorker->AddMessage(key, std::move(message), partition, tx); + eventsPromise = EventsWorker->HandleNewMessage(); + RunUserEventLoop(); + } + + RunMainWorker(); + if (eventsPromise) { + eventsPromise->TrySetValue(); + } +} + +bool TKeyedWriteSession::Close(TDuration closeTimeout) { + if (Closed.exchange(true)) { + std::lock_guard lock(GlobalLock); + return MessagesWorker->IsQueueEmpty(); + } + + SetCloseDeadline(closeTimeout); + + ClosePromise.TrySetValue(); + ShutdownFuture.Wait(CloseDeadline); + RunUserEventLoop(); + Done.store(true); + + // No need to lock here, because we are waiting for the shutdown future and it will block until the main worker is done + return MessagesWorker->IsQueueEmpty(); +} + +void TKeyedWriteSession::NonBlockingClose() { + Closed.store(true); + Done.store(true); +} + +void TKeyedWriteSession::SetCloseDeadline(const TDuration& closeTimeout) { + std::lock_guard lock(GlobalLock); + CloseDeadline = TInstant::Now() + closeTimeout; +} + +TKeyedWriteSession::~TKeyedWriteSession() { + Close(TDuration::Zero()); + Settings.EventHandlers_.HandlersExecutor_->Stop(); + ShutdownFuture.Wait(); +} + +NThreading::TFuture TKeyedWriteSession::WaitEvent() { + return EventsWorker->WaitEvent(); +} + +std::optional TKeyedWriteSession::GetEvent(bool block) { + if (Settings.EventHandlers_.CommonHandler_) { + return std::nullopt; + } + + return EventsWorker->GetEvent(block); +} + +std::vector TKeyedWriteSession::GetEvents(bool block, std::optional maxEventsCount) { + if (Settings.EventHandlers_.CommonHandler_) { + return {}; + } + + return EventsWorker->GetEvents(block, maxEventsCount); +} + +TDuration TKeyedWriteSession::GetCloseTimeout() { + std::lock_guard lock(GlobalLock); + auto now = TInstant::Now(); + if (CloseDeadline <= now) { + return TDuration::Zero(); + } + return CloseDeadline - now; +} + +bool TKeyedWriteSession::RunSplittedPartitionWorkers() { + if (SplittedPartitionWorkers.empty()) { + return false; + } + + bool needRerun = false; + for (const auto& [partition, splittedPartitionWorker] : SplittedPartitionWorkers) { + if (splittedPartitionWorker->IsDone()) { + continue; + } + + splittedPartitionWorker->DoWork(); + needRerun = needRerun || splittedPartitionWorker->IsInit(); + needRerun = needRerun || splittedPartitionWorker->IsDone(); + } + + return needRerun; +} + +void TKeyedWriteSession::RunUserEventLoop() { + if (!Settings.EventHandlers_.AcksHandler_ && + !Settings.EventHandlers_.ReadyToAcceptHandler_ && + !Settings.EventHandlers_.SessionClosedHandler_ && + !Settings.EventHandlers_.CommonHandler_) { + return; + } + + auto handlersExecutor = Settings.EventHandlers_.HandlersExecutor_; + if (!handlersExecutor) { + return; + } + + while (true) { + auto event = EventsWorker->GetEvent(false, EventTypesWithHandlers); + if (!event) { + break; + } + + if (auto* readyToAcceptEvent = std::get_if(&*event)) { + if (Settings.EventHandlers_.ReadyToAcceptHandler_) { + handlersExecutor->Post( + [this, ev = std::move(*readyToAcceptEvent)]() mutable { + Settings.EventHandlers_.ReadyToAcceptHandler_(ev); + }); + } else if (Settings.EventHandlers_.CommonHandler_) { + handlersExecutor->Post( + [this, ev = std::move(*event)]() mutable { + Settings.EventHandlers_.CommonHandler_(ev); + }); + } + continue; + } + + if (auto* acksEvent = std::get_if(&*event)) { + if (Settings.EventHandlers_.AcksHandler_) { + handlersExecutor->Post( + [this, ev = std::move(*acksEvent)]() mutable { + Settings.EventHandlers_.AcksHandler_(ev); + }); + } else if (Settings.EventHandlers_.CommonHandler_) { + handlersExecutor->Post( + [this, ev = std::move(*event)]() mutable { + Settings.EventHandlers_.CommonHandler_(ev); + }); + } + continue; + } + + if (auto* sessionClosedEvent = std::get_if(&*event)) { + if (Settings.EventHandlers_.SessionClosedHandler_) { + handlersExecutor->Post( + [this, ev = std::move(*sessionClosedEvent)]() mutable { + Settings.EventHandlers_.SessionClosedHandler_(ev); + }); + } else if (Settings.EventHandlers_.CommonHandler_) { + handlersExecutor->Post( + [this, ev = std::move(*event)]() mutable { + Settings.EventHandlers_.CommonHandler_(ev); + }); + } + break; + } + } +} + +void TKeyedWriteSession::GetSessionClosedEventAndDie(WrappedWriteSessionPtr wrappedSession, std::optional sessionClosedEvent) { + std::optional receivedSessionClosedEvent; + while (true) { + auto event = wrappedSession->Session->GetEvent(false); + if (!event) { + break; + } + + if (auto* closedEvent = std::get_if(&*event)) { + receivedSessionClosedEvent = std::move(*closedEvent); + break; + } + } + + if (!receivedSessionClosedEvent || receivedSessionClosedEvent->GetStatus() == EStatus::SUCCESS || receivedSessionClosedEvent->GetStatus() == EStatus::OVERLOADED) { + LOG_LAZY(DbDriverState->Log, TLOG_ERR, LogPrefix() << "Failed to get session closed event"); + EventsWorker->HandleSessionClosedEvent(std::move(*sessionClosedEvent), wrappedSession->Partition); + } else { + EventsWorker->HandleSessionClosedEvent(std::move(*receivedSessionClosedEvent), wrappedSession->Partition); + } +} + +TStringBuilder TKeyedWriteSession::LogPrefix() { + return TStringBuilder() << " SessionId: " << Settings.SessionId_ << " Epoch: " << Epoch.load() << " "; +} + +void TKeyedWriteSession::NextEpoch() { + auto maxEpoch = MAX_EPOCH - 1; + if (Epoch.compare_exchange_weak(maxEpoch, 0)) { + LOG_LAZY(DbDriverState->Log, TLOG_INFO, LogPrefix() << "Epoch overflow, resetting to 0"); + return; + } + + Epoch.fetch_add(1); +} + +void TKeyedWriteSession::RunMainWorker() { + // This function is both "request to run" and the runner itself. + // We must handle two properties: + // - TFuture::Subscribe may call back synchronously when future is already ready. + // - A callback may race with the runner trying to go idle (avoid lost wakeups). + enum : std::uint8_t { + Idle = 0, + Running = 1, + Rerun = 2, + }; + + // Try to become the runner. If already running, just request a rerun. + std::uint8_t state = MainWorkerState.load(std::memory_order_acquire); + for (;;) { + if (state & Running) { + if (MainWorkerState.compare_exchange_weak(state, std::uint8_t(state | Rerun), + std::memory_order_acq_rel, + std::memory_order_acquire)) { + return; + } + continue; + } else { + if (MainWorkerState.compare_exchange_weak(state, Running, + std::memory_order_acq_rel, + std::memory_order_acquire)) { + break; // we are the runner now + } + continue; + } + } + + NextEpoch(); + + auto startWorkerTime = TInstant::Now(); + // Runner loop: process, arm subscription, then either go idle or loop again. + for (;;) { + auto startIter = TInstant::Now(); + // Clear rerun request for this iteration. + MainWorkerState.fetch_and(std::uint8_t(~Rerun), std::memory_order_acq_rel); + bool needRerun = false; + std::optional> eventsPromise; + + { + std::unique_lock lock(GlobalLock); + eventsPromise = EventsWorker->DoWork(); + RunUserEventLoop(); + needRerun = RunSplittedPartitionWorkers(); + if (!Done.load()) { + SessionsWorker->DoWork(); + MessagesWorker->DoWork(); + } + } + + if (eventsPromise) { + eventsPromise->TrySetValue(); + } + + const auto isClosed = Closed.load(); + const auto closeTimeout = GetCloseTimeout(); + if (isClosed && (Done.load() || MessagesWorker->IsQueueEmpty() || closeTimeout == TDuration::Zero())) { + ShutdownPromise.TrySetValue(); + EventsWorker->EventsPromise.TrySetValue(); + ClosePromise.TrySetValue(); + MainWorkerState.store(Idle, std::memory_order_release); + return; + } + + if (needRerun) { + // we need this case to start resending messages if there are any + Metrics.AddCycleTime((TInstant::Now() - startIter).MilliSeconds()); + continue; + } + + // Try to go idle. If someone requested rerun concurrently, keep running. + std::uint8_t cur = MainWorkerState.load(std::memory_order_acquire); + for (;;) { + if (cur & Rerun) { + Metrics.AddCycleTime((TInstant::Now() - startIter).MilliSeconds()); + break; // continue outer loop + } + if (MainWorkerState.compare_exchange_weak(cur, Idle, + std::memory_order_acq_rel, + std::memory_order_acquire)) { + auto workerFinished = TInstant::Now(); + Metrics.AddCycleTime((workerFinished - startIter).MilliSeconds()); + Metrics.AddMainWorkerTime((workerFinished - startWorkerTime).MilliSeconds()); + return; // successfully went idle + } + } + // Rerun was requested; continue the loop without recursion. + } +} + +TInstant TKeyedWriteSession::GetCloseDeadline() { + std::lock_guard lock(GlobalLock); + return CloseDeadline; +} + +void TKeyedWriteSession::HandleAutoPartitioning(std::uint32_t partition) { + LOG_LAZY(DbDriverState->Log, TLOG_ERR, LogPrefix() << "HandleAutoPartitioning: " << partition); + auto splittedPartitionWorker = std::make_shared(this, partition); + SplittedPartitionWorkers.try_emplace(partition, splittedPartitionWorker); +} + +std::string TKeyedWriteSession::GetProducerId(std::uint32_t partition) { + return std::format("{}_{}", Settings.ProducerIdPrefix_, partition); +} + +TWriterCounters::TPtr TKeyedWriteSession::GetCounters() { + return nullptr; +} + +TKeyedWriteSession::TBoundPartitionChooser::TBoundPartitionChooser(TKeyedWriteSession* session) + : Session(session) +{} + +std::uint32_t TKeyedWriteSession::TBoundPartitionChooser::ChoosePartition(const std::string_view key) { + auto hashedKey = Session->PartitioningKeyHasher(key); + + auto lowerBound = Session->PartitionsIndex.lower_bound(hashedKey); + if (lowerBound != Session->PartitionsIndex.end() && lowerBound->first == hashedKey) { + return lowerBound->second; + } + + Y_ABORT_IF(lowerBound == Session->PartitionsIndex.begin(), "Lower bound is the first element"); + return std::prev(lowerBound)->second; +} + +TKeyedWriteSession::THashPartitionChooser::THashPartitionChooser(std::vector&& partitions) + : Partitions(std::move(partitions)) +{} + +std::uint32_t TKeyedWriteSession::THashPartitionChooser::ChoosePartition(const std::string_view key) { + auto hash = MurmurHash(key.data(), key.size()); + return Partitions[hash % Partitions.size()]; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// TSimpleBlockingWriteSession + +TSimpleBlockingWriteSession::TSimpleBlockingWriteSession( + const TWriteSessionSettings& settings, + std::shared_ptr client, + std::shared_ptr connections, + TDbDriverStatePtr dbDriverState) { + auto subSettings = settings; + if (settings.EventHandlers_.AcksHandler_) { + LOG_LAZY(dbDriverState->Log, TLOG_WARNING, "TSimpleBlockingWriteSession: Cannot use AcksHandler, resetting."); + subSettings.EventHandlers_.AcksHandler({}); + } + if (settings.EventHandlers_.ReadyToAcceptHandler_) { + LOG_LAZY(dbDriverState->Log, TLOG_WARNING, "TSimpleBlockingWriteSession: Cannot use ReadyToAcceptHandler, resetting."); + subSettings.EventHandlers_.ReadyToAcceptHandler({}); + } + if (settings.EventHandlers_.SessionClosedHandler_) { + LOG_LAZY(dbDriverState->Log, TLOG_WARNING, "TSimpleBlockingWriteSession: Cannot use SessionClosedHandler, resetting."); + subSettings.EventHandlers_.SessionClosedHandler({}); + } + if (settings.EventHandlers_.CommonHandler_) { + LOG_LAZY(dbDriverState->Log, TLOG_WARNING, "TSimpleBlockingWriteSession: Cannot use CommonHandler, resetting."); + subSettings.EventHandlers_.CommonHandler({}); + } + Writer = std::make_shared(subSettings, client, connections, dbDriverState); + Writer->Start(TDuration::Zero()); +} + +uint64_t TSimpleBlockingWriteSession::GetInitSeqNo() { + return Writer->GetInitSeqNo().GetValueSync(); +} + +bool TSimpleBlockingWriteSession::Write( + std::string_view data, std::optional seqNo, std::optional createTimestamp, const TDuration& blockTimeout) { + auto message = TWriteMessage(std::move(data)) + .SeqNo(seqNo) + .CreateTimestamp(createTimestamp); + return Write(std::move(message), nullptr, blockTimeout); +} + +bool TSimpleBlockingWriteSession::Write( + TWriteMessage&& message, TTransactionBase* tx, const TDuration& blockTimeout) { + auto continuationToken = WaitForToken(blockTimeout); + if (continuationToken.has_value()) { + Writer->Write(std::move(*continuationToken), std::move(message), tx); + return true; + } + return false; +} + +std::optional TSimpleBlockingWriteSession::WaitForToken(const TDuration& timeout) { + return NDetail::WaitForToken(*Writer, Closed, timeout); +} + +TWriterCounters::TPtr TSimpleBlockingWriteSession::GetCounters() { + return Writer->GetCounters(); +} + +bool TSimpleBlockingWriteSession::IsAlive() const { + return !Closed.load(); +} + +bool TSimpleBlockingWriteSession::Close(TDuration closeTimeout) { + Closed.store(true); + return Writer->Close(std::move(closeTimeout)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// TSimpleBlockingKeyedWriteSession + +TSimpleBlockingKeyedWriteSession::TSimpleBlockingKeyedWriteSession( + const TKeyedWriteSessionSettings& settings, + std::shared_ptr client, + std::shared_ptr connections, + TDbDriverStatePtr dbDriverState) + : Writer(std::make_shared(settings, client, connections, dbDriverState)) +{ + ClosePromise = NThreading::NewPromise(); + CloseFuture = ClosePromise.GetFuture(); +} + +void TSimpleBlockingKeyedWriteSession::RunEventLoop() { + while (true) { + auto event = Writer->GetEvent(false); + if (!event) { + break; + } + + if (auto readyToAcceptEvent = std::get_if(&*event)) { + ContinuationTokensQueue.push(std::move(readyToAcceptEvent->ContinuationToken)); + continue; + } + if (std::get_if(&*event)) { + Closed.store(true); + return; + } + if (auto acksEvent = std::get_if(&*event)) { + HandleAcksEvent(std::move(*acksEvent)); + } + } +} + +void TSimpleBlockingKeyedWriteSession::HandleAcksEvent(const TWriteSessionEvent::TAcksEvent& acksEvent) { + for (auto ack : acksEvent.Acks) { + AckedSeqNos.insert(ack.SeqNo); + } +} + +template +bool TSimpleBlockingKeyedWriteSession::Wait(const TDuration& timeout, F&& stopFunc) { + std::unique_lock lock(Lock); + + auto deadline = TInstant::Now() + timeout; + while (true) { + if (TInstant::Now() > deadline) { + return false; + } + + RunEventLoop(); + + if (stopFunc()) { + return true; + } + + if (Closed.load()) { + return false; + } + + std::vector> futures; + futures.push_back(CloseFuture); + futures.push_back(Writer->WaitEvent()); + lock.unlock(); + NThreading::NWait::WaitAny(futures).Wait(deadline); + lock.lock(); + } +} + +std::optional TSimpleBlockingKeyedWriteSession::GetContinuationToken(TDuration timeout) { + std::optional token; + + Wait(timeout, [&]() { + if (!ContinuationTokensQueue.empty()) { + token = std::move(ContinuationTokensQueue.front()); + ContinuationTokensQueue.pop(); + return true; + } + return false; + }); + + return token; +} + +bool TSimpleBlockingKeyedWriteSession::WaitForAck(std::optional seqNo, TDuration timeout) { + return Wait(timeout, [&]() { + if (!seqNo.has_value()) { + if (AckedSeqNos.empty()) { + return false; + } + + AckedSeqNos.erase(AckedSeqNos.begin()); + return true; + } + + if (AckedSeqNos.contains(*seqNo)) { + AckedSeqNos.erase(*seqNo); + return true; + } + return false; + }); +} + +bool TSimpleBlockingKeyedWriteSession::Write(const std::string& key, TWriteMessage&& message, TTransactionBase* tx, TDuration blockTimeout) { + auto continuationToken = GetContinuationToken(blockTimeout); + if (!continuationToken) { + return false; + } + + auto seqNo = message.SeqNo_; + Writer->Write(std::move(*continuationToken), std::move(key), std::move(message), tx); + return WaitForAck(seqNo, blockTimeout); +} + +bool TSimpleBlockingKeyedWriteSession::Close(TDuration closeTimeout) { + Closed.store(true); + ClosePromise.TrySetValue(); + return Writer->Close(closeTimeout); +} + +TWriterCounters::TPtr TSimpleBlockingKeyedWriteSession::GetCounters() { + return nullptr; } -} // namespace NYdb::NTopic +} // namespace NYdb::inline V3::NTopic \ No newline at end of file diff --git a/src/client/topic/impl/write_session.h b/src/client/topic/impl/write_session.h index 3112fd1361f..441d675a508 100644 --- a/src/client/topic/impl/write_session.h +++ b/src/client/topic/impl/write_session.h @@ -1,12 +1,19 @@ #pragma once +#include +#include #include #include #include +#include + #include #include +#include +#include +#include namespace NYdb::inline V3::NTopic { @@ -55,6 +62,392 @@ class TWriteSession : public IWriteSession, void Start(const TDuration& delay); }; +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// TKeyedWriteSession + +class TKeyedWriteSession : public IKeyedWriteSession, + public TContinuationTokenIssuer, + public std::enable_shared_from_this { +private: + using WriteSessionPtr = std::shared_ptr; + + struct TPartitionInfo { + using TSelf = TPartitionInfo; + + bool InRange(const std::string_view key) const; + + FLUENT_SETTING(std::string, FromBound); + FLUENT_SETTING(std::optional, ToBound); + FLUENT_SETTING(std::uint32_t, PartitionId); + FLUENT_SETTING(std::vector, Children); + FLUENT_SETTING_DEFAULT(bool, Locked, false); + FLUENT_SETTING_DEFAULT(NThreading::TFuture, Future, NThreading::MakeFuture()); + }; + + struct TMessageInfo { + TMessageInfo(const std::string& key, TWriteMessage&& message, std::uint32_t partition, TTransactionBase* tx); + + std::string Key; + std::string Data; + std::optional Codec; + uint32_t OriginalSize = 0; + std::optional SeqNo; + std::optional CreateTimestamp; + TMessageMeta MessageMeta; + std::optional> TxInMessage; + TTransactionBase* Tx; + std::uint32_t Partition; + bool Sent = false; + + TWriteMessage BuildMessage() const; + }; + + struct TIdleSession; + + struct TWriteSessionWrapper { + WriteSessionPtr Session; + const std::uint32_t Partition; + std::uint64_t QueueSize = 0; + std::shared_ptr IdleSession = nullptr; + + TWriteSessionWrapper(WriteSessionPtr session, std::uint32_t partition); + + bool IsQueueEmpty() const; + bool AddToQueue(std::uint64_t delta); + bool RemoveFromQueue(std::uint64_t delta); + }; + + using WrappedWriteSessionPtr = std::shared_ptr; + + struct TIdleSession { + TIdleSession(TWriteSessionWrapper* session, TInstant emptySince, TDuration idleTimeout) + : Session(session) + , EmptySince(emptySince) + , IdleTimeout(idleTimeout) + {} + + const TWriteSessionWrapper* Session; + const TInstant EmptySince; + const TDuration IdleTimeout; + + bool Less(const std::shared_ptr& other) const; + bool IsExpired() const; + + struct Comparator { + bool operator()(const std::shared_ptr& first, const std::shared_ptr& second) const; + }; + }; + + using IdleSessionPtr = std::shared_ptr; + + enum class ESeqNoStrategy { + NotInitialized, + WithoutSeqNo, + WithSeqNo, + }; + + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // Custom retry policy + + struct TKeyedWriteSessionRetryPolicy : public ::IRetryPolicy { + using TSelf = TKeyedWriteSessionRetryPolicy; + using TPtr = std::shared_ptr; + + TKeyedWriteSessionRetryPolicy(TKeyedWriteSession* session); + ~TKeyedWriteSessionRetryPolicy() = default; + typename IRetryState::TPtr CreateRetryState() const override; + + private: + TKeyedWriteSession* Session; + }; + + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // Workers + + struct TEventsWorker; + + struct TSessionsWorker { + TSessionsWorker(TKeyedWriteSession* session); + WrappedWriteSessionPtr GetWriteSession(std::uint32_t partition, bool directToPartition = true); + void OnReadFromSession(WrappedWriteSessionPtr wrappedSession); + void OnWriteToSession(WrappedWriteSessionPtr wrappedSession); + void DoWork(); + + private: + void AddIdleSession(WrappedWriteSessionPtr wrappedSession, TInstant emptySince, TDuration idleTimeout); + void RemoveIdleSession(std::uint32_t partition); + WrappedWriteSessionPtr CreateWriteSession(std::uint32_t partition, bool directToPartition = true); + + using TSessionsIndexIterator = std::unordered_map::iterator; + void DestroyWriteSession(TSessionsIndexIterator& it, TDuration closeTimeout, bool mustBeEmpty = true); + + std::string GetProducerId(std::uint32_t partitionId); + + TKeyedWriteSession* Session; + std::set IdlerSessions; + using IdlerSessionsIterator = std::set::iterator; + std::unordered_map IdlerSessionsIndex; + std::unordered_map SessionsIndex; + }; + + struct TMessagesWorker { + TMessagesWorker(TKeyedWriteSession* session); + + void DoWork(); + + void AddMessage(const std::string& key, TWriteMessage&& message, std::uint32_t partition, TTransactionBase* tx); + void ScheduleResendMessages(std::uint32_t partition, std::uint64_t afterSeqNo); + void RebuildPendingMessagesIndex(std::uint32_t partition); + void HandleAck(); + void HandleContinuationToken(std::uint32_t partition, TContinuationToken&& continuationToken); + bool IsMemoryUsageOK() const; + bool IsQueueEmpty() const; + bool HasInFlightMessages() const; + const TMessageInfo& GetFrontInFlightMessage() const; + + private: + using MessageIter = std::list::iterator; + + void PushInFlightMessage(std::uint32_t partition, TMessageInfo&& message); + void PopInFlightMessage(); + bool SendMessage(WrappedWriteSessionPtr wrappedSession, const TMessageInfo& message); + std::optional GetContinuationToken(std::uint32_t partition); + void RechoosePartitionIfNeeded(MessageIter message); + + TKeyedWriteSession* Session; + + std::list InFlightMessages; + std::unordered_map> InFlightMessagesIndex; + std::unordered_map> PendingMessagesIndex; + std::unordered_map> MessagesToResendIndex; + std::unordered_map> ContinuationTokens; + + std::uint64_t MemoryUsage = 0; + + friend class TKeyedWriteSession; + }; + + struct TSplittedPartitionWorker : public std::enable_shared_from_this { + private: + enum class EState { + Init = 0, + PendingDescribe = 1, + GotDescribe = 2, + PendingMaxSeqNo = 3, + GotMaxSeqNo = 4, + Done = 5, + }; + + void MoveTo(EState state); + void UpdateMaxSeqNo(uint64_t maxSeqNo); + void LaunchGetMaxSeqNoFutures(std::unique_lock& lock); + void HandleDescribeResult(); + + public: + TSplittedPartitionWorker(TKeyedWriteSession* session, std::uint32_t partitionId); + void DoWork(); + bool IsDone(); + bool IsInit(); + std::string GetStateName() const; + + private: + TKeyedWriteSession* Session; + NThreading::TFuture DescribeTopicFuture; + EState State = EState::Init; + std::uint32_t PartitionId; + std::uint64_t MaxSeqNo = 0; + std::vector WriteSessions; + std::vector> GetMaxSeqNoFutures; + std::mutex Lock; + std::uint64_t NotReadyFutures = 0; + size_t Retries = 0; + }; + + struct TEventsWorker : public std::enable_shared_from_this { + enum class EEventType { + SessionClosed = 0, + ReadyToAccept = 1, + Ack = 2, + }; + + TEventsWorker(TKeyedWriteSession* session); + + std::optional> DoWork(); + NThreading::TFuture WaitEvent(); + void UnsubscribeFromPartition(std::uint32_t partition); + void SubscribeToPartition(std::uint32_t partition); + std::optional> HandleNewMessage(); + void HandleAcksEvent(std::uint64_t partition, TWriteSessionEvent::TAcksEvent&& event); + std::optional GetEvent(bool block, const std::vector& eventTypes = {}); + std::vector GetEvents(bool block, std::optional maxEventsCount = std::nullopt, const std::vector& eventTypes = {}); + std::list::iterator AckQueueBegin(std::uint32_t partition); + std::list::iterator AckQueueEnd(std::uint32_t partition); + + private: + void HandleSessionClosedEvent(TSessionClosedEvent&& event, std::uint32_t partition); + void HandleReadyToAcceptEvent(std::uint32_t partition, TWriteSessionEvent::TReadyToAcceptEvent&& event); + bool RunEventLoop(WrappedWriteSessionPtr wrappedSession, std::uint32_t partition); + bool TransferEventsToOutputQueue(); + void AddReadyToAcceptEvent(); + bool AddSessionClosedIfNeeded(); + std::optional GetEventImpl(bool block, const std::vector& eventTypes = {}); + EEventType GetEventType(const TWriteSessionEvent::TEvent& event); + + TKeyedWriteSession* Session; + + std::unordered_set ReadyFutures; + std::unordered_map> PartitionsEventQueues; + std::list EventsOutputQueue; + std::mutex Lock; + + NThreading::TPromise EventsPromise; + NThreading::TFuture EventsFuture; + + std::optional CloseEvent; + + friend class TKeyedWriteSession; + }; + + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // Partition chooser + + struct IPartitionChooser { + virtual std::uint32_t ChoosePartition(const std::string_view key) = 0; + virtual ~IPartitionChooser() = default; + }; + + struct TBoundPartitionChooser : IPartitionChooser { + TBoundPartitionChooser(TKeyedWriteSession* session); + std::uint32_t ChoosePartition(const std::string_view key) override; + private: + TKeyedWriteSession* Session; + }; + + struct THashPartitionChooser : IPartitionChooser { + THashPartitionChooser(std::vector&& partitions); + std::uint32_t ChoosePartition(const std::string_view key) override; + private: + std::vector Partitions; + }; + + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + + struct TMetricGauge { + std::uint64_t MetricCount = 0; + std::uint64_t Sum = 0; + + long double Average(); + void Add(std::uint64_t value); + void Clear(); + }; + + struct TMetrics { + TMetrics(TKeyedWriteSession* session); + + TMetricGauge MainWorkerTimeMs; + TMetricGauge CycleTimeMs; + TMetricGauge WriteLagMs; + std::mutex Lock; + TKeyedWriteSession* Session; + + void AddMainWorkerTime(std::uint64_t ms); + void AddCycleTime(std::uint64_t ms); + void AddWriteLag(std::uint64_t lagMs); + void PrintMetrics(); + }; + + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + + void RunMainWorker(); + + void NonBlockingClose(); + + void SetCloseDeadline(const TDuration& closeTimeout); + + TDuration GetCloseTimeout(); + + std::string GetProducerId(std::uint32_t partition); + + void HandleAutoPartitioning(std::uint32_t partition); + + bool RunSplittedPartitionWorkers(); + + void RunUserEventLoop(); + + TInstant GetCloseDeadline(); + + void GetSessionClosedEventAndDie(WrappedWriteSessionPtr wrappedSession, std::optional sessionClosedEvent = std::nullopt); + + TStringBuilder LogPrefix(); + + void NextEpoch(); + +public: + TKeyedWriteSession(const TKeyedWriteSessionSettings& settings, + std::shared_ptr client, + std::shared_ptr connections, + TDbDriverStatePtr dbDriverState); + + void Write(TContinuationToken&& continuationToken, const std::string& key, TWriteMessage&& message, + TTransactionBase* tx = nullptr) override; + + NThreading::TFuture WaitEvent() override; + + std::optional GetEvent(bool block = false) override; + + std::vector GetEvents(bool block = false, std::optional maxEventsCount = std::nullopt) override; + + bool Close(TDuration closeTimeout = TDuration::Max()) override; + + TWriterCounters::TPtr GetCounters() override; + + std::vector GetPartitions() const; + + ~TKeyedWriteSession(); + +private: + std::shared_ptr Connections; + std::shared_ptr Client; + TDbDriverStatePtr DbDriverState; + + TMetrics Metrics; + + std::unordered_map Partitions; + std::map PartitionsIndex; + + TKeyedWriteSessionSettings Settings; + ESeqNoStrategy SeqNoStrategy = ESeqNoStrategy::NotInitialized; + + NThreading::TPromise ClosePromise; + NThreading::TFuture CloseFuture; + NThreading::TPromise ShutdownPromise; + NThreading::TFuture ShutdownFuture; + + std::mutex GlobalLock; + std::atomic_bool Closed = false; + std::atomic_bool Done = false; + TInstant CloseDeadline = TInstant::Now(); + + std::unique_ptr PartitionChooser; + + std::function PartitioningKeyHasher; + + std::shared_ptr EventsWorker; + std::shared_ptr SessionsWorker; + std::unordered_map> SplittedPartitionWorkers; + std::shared_ptr MessagesWorker; + std::shared_ptr RetryPolicy; + + // TFuture::Subscribe may invoke callback synchronously when the future is already ready. + // Also, callbacks may arrive concurrently with the attempt to go idle. + // Use a small state machine to avoid re-entrancy and lost wakeups. + std::atomic MainWorkerState = 0; + std::atomic Epoch = 0; + static constexpr size_t MAX_EPOCH = 1'000'000'000; + + std::vector EventTypesWithHandlers; +}; + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // TSimpleBlockingWriteSession @@ -89,5 +482,47 @@ class TSimpleBlockingWriteSession : public ISimpleBlockingWriteSession { std::atomic_bool Closed = false; }; +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// TSimpleBlockingKeyedWriteSession + +class TSimpleBlockingKeyedWriteSession : public ISimpleBlockingKeyedWriteSession { +private: + std::optional GetContinuationToken(TDuration timeout); + + void HandleAcksEvent(const TWriteSessionEvent::TAcksEvent& acksEvent); + + bool WaitForAck(std::optional seqNo, TDuration timeout); + + template + bool Wait(const TDuration& timeout, F&& stopFunc); + + void RunEventLoop(); + +public: + TSimpleBlockingKeyedWriteSession( + const TKeyedWriteSessionSettings& settings, + std::shared_ptr client, + std::shared_ptr connections, + TDbDriverStatePtr dbDriverState); + + + bool Write(const std::string& key, TWriteMessage&& message, TTransactionBase* tx = nullptr, + TDuration blockTimeout = TDuration::Max()) override; + + bool Close(TDuration closeTimeout = TDuration::Max()) override; + + TWriterCounters::TPtr GetCounters() override; + +protected: + std::shared_ptr Writer; + std::unordered_set AckedSeqNos; + std::queue ContinuationTokensQueue; + + NThreading::TPromise ClosePromise; + NThreading::TFuture CloseFuture; + + std::mutex Lock; + std::atomic_bool Closed = false; +}; } // namespace NYdb::NTopic diff --git a/src/client/topic/ut/basic_usage_ut.cpp b/src/client/topic/ut/basic_usage_ut.cpp index 478a88927a4..c43f79a8697 100644 --- a/src/client/topic/ut/basic_usage_ut.cpp +++ b/src/client/topic/ut/basic_usage_ut.cpp @@ -1,5 +1,6 @@ #include "ut_utils/topic_sdk_test_setup.h" +#include #include #include @@ -12,9 +13,11 @@ #include #include #include +#include #include #include +#include #include #include @@ -25,8 +28,13 @@ #include #include +#include #include +#include +#include +#include +#include using namespace std::chrono_literals; @@ -129,6 +137,33 @@ void WriteBinaryProducerIdWithDirectTabletWrite(TTopicSdkTestSetup& setup, UNIT_ASSERT_VALUES_EQUAL(result->Record.GetPartitionResponse().CmdWriteResultSize(), 1); } +static std::string FindKeyForBucket(size_t bucket, size_t bucketsCount) { + for (size_t i = 0; i < 1'000'000; ++i) { + std::string key = "key-" + ToString(i); + if (MurmurHash(key.data(), key.size()) % bucketsCount == bucket) { + return key; + } + } + UNIT_FAIL("Failed to find a key for bucket"); + return {}; +} + +void CreateTopicWithAutoPartitioning(TTopicClient& client) { + TCreateTopicSettings createSettings; + createSettings + .BeginConfigurePartitioningSettings() + .MinActivePartitions(2) + .MaxActivePartitions(100) + .BeginConfigureAutoPartitioningSettings() + .UpUtilizationPercent(2) + .DownUtilizationPercent(1) + .StabilizationWindow(TDuration::Seconds(2)) + .Strategy(EAutoPartitioningStrategy::ScaleUp) + .EndConfigureAutoPartitioningSettings() + .EndConfigurePartitioningSettings(); + client.CreateTopic(TEST_TOPIC, createSettings).Wait(); +} + void WriteAndReadToEndWithRestarts(TReadSessionSettings readSettings, TWriteSessionSettings writeSettings, const std::string& message, std::uint32_t count, TTopicSdkTestSetup& setup, std::shared_ptr decompressor) { auto client = setup.MakeClient(); auto session = client.CreateSimpleBlockingWriteSession(writeSettings); @@ -259,6 +294,16 @@ Y_UNIT_TEST_SUITE(BasicUsage) { UNIT_ASSERT_VALUES_EQUAL(gotMetaProducerId, expectedEncoded); UNIT_ASSERT_VALUES_EQUAL(Base64Decode(expectedEncoded), binaryProducerId); } + + Y_UNIT_TEST(CreateTopicWithManyPartitions) { + TTopicSdkTestSetup setup{TEST_CASE_NAME, TTopicSdkTestSetup::MakeServerSettings(), false}; + const TString name = "test-topic-" + ToString(TInstant::Now().Seconds()); + setup.CreateTopic(name, TEST_CONSUMER, 100); + + auto describe = setup.MakeClient().DescribeTopic(name).GetValueSync(); + UNIT_ASSERT_C(describe.IsSuccess(), describe.GetIssues().ToOneLineString()); + UNIT_ASSERT_VALUES_EQUAL(describe.GetTopicDescription().GetPartitions().size(), 100); + } Y_UNIT_TEST(CreateTopicWithStreamingConsumer) { TTopicSdkTestSetup setup{TEST_CASE_NAME, TTopicSdkTestSetup::MakeServerSettings(), false}; @@ -794,10 +839,6 @@ Y_UNIT_TEST_SUITE(BasicUsage) { } Y_UNIT_TEST(ReadWithoutConsumerWithRestarts) { - if (EnableDirectRead) { - // TODO(qyryq) Enable the test when LOGBROKER-9364 is done. - return; - } TTopicSdkTestSetup setup(TEST_CASE_NAME); auto compressor = std::make_shared(); auto decompressor = CreateThreadPoolManagedExecutor(1); @@ -810,7 +851,7 @@ Y_UNIT_TEST_SUITE(BasicUsage) { .MaxMemoryUsageBytes(1_MB) .DecompressionExecutor(decompressor) .AppendTopics(topic) - // .DirectRead(EnableDirectRead) + .DirectRead(EnableDirectRead) ; TWriteSessionSettings writeSettings; @@ -893,6 +934,951 @@ Y_UNIT_TEST_SUITE(BasicUsage) { } + Y_UNIT_TEST(KeyedWriteSession_UserEventHandlers) { + TTopicSdkTestSetup setup{TEST_CASE_NAME, TTopicSdkTestSetup::MakeServerSettings(), false}; + setup.CreateTopic(TEST_TOPIC, TEST_CONSUMER, 2); + + auto client = setup.MakeClient(); + + TKeyedWriteSessionSettings writeSettings; + writeSettings + .Path(setup.GetTopicPath(TEST_TOPIC)) + .Codec(ECodec::RAW); + writeSettings.ProducerIdPrefix(CreateGuidAsString()); + writeSettings.PartitionChooserStrategy(TKeyedWriteSessionSettings::EPartitionChooserStrategy::Hash); + writeSettings.SubSessionIdleTimeout(TDuration::Seconds(30)); + + std::atomic readyCount{0}; + std::atomic acksCount{0}; + std::atomic closedCount{0}; + std::atomic commonCount{0}; + + std::mutex tokensMutex; + std::condition_variable tokensCv; + std::deque readyTokens; + + writeSettings.EventHandlers_.HandlersExecutor(std::make_shared()); + + writeSettings.EventHandlers_.ReadyToAcceptHandler( + [&](TWriteSessionEvent::TReadyToAcceptEvent& ev) { + readyCount.fetch_add(1); + { + std::lock_guard lock(tokensMutex); + readyTokens.emplace_back(std::move(ev.ContinuationToken)); + } + tokensCv.notify_one(); + }); + + writeSettings.EventHandlers_.AcksHandler( + [&](TWriteSessionEvent::TAcksEvent& ev) { + Y_UNUSED(ev); + acksCount.fetch_add(1); + }); + + writeSettings.EventHandlers_.SessionClosedHandler( + [&](const TSessionClosedEvent& ev) { + Y_UNUSED(ev); + closedCount.fetch_add(1); + }); + + writeSettings.EventHandlers_.CommonHandler( + [&](TWriteSessionEvent::TEvent& ev) { + Y_UNUSED(ev); + commonCount.fetch_add(1); + }); + + auto getReadyToken = [&]() -> std::optional { + std::unique_lock lock(tokensMutex); + tokensCv.wait_for(lock, std::chrono::seconds(30), [&]() { return !readyTokens.empty(); }); + if (readyTokens.empty()) { + return std::nullopt; + } + auto token = std::move(readyTokens.front()); + readyTokens.pop_front(); + return token; + }; + + auto session = client.CreateKeyedWriteSession(writeSettings); + + const ui64 messages = 5; + for (ui64 i = 0; i < messages; ++i) { + auto token = getReadyToken(); + UNIT_ASSERT_C(token, "Timed out waiting for ReadyToAcceptEvent"); + std::string payload = "payload"; + TWriteMessage msg(payload); + msg.SeqNo(i + 1); + session->Write(std::move(*token), "key-" + ToString(i), std::move(msg)); + } + + UNIT_ASSERT_C(session->Close(TDuration::Seconds(30)), "Failed to close keyed write session"); + + UNIT_ASSERT_C(readyCount.load() > 0, "ReadyToAcceptHandler was not called"); + UNIT_ASSERT_C(acksCount.load() == messages, "AcksHandler does not work properly"); + UNIT_ASSERT_C(closedCount.load() > 0, "SessionClosedHandler was not called"); + UNIT_ASSERT_C(commonCount.load() == 0, "CommonHandler should not be called when type-specific handlers are set"); + } + + Y_UNIT_TEST(KeyedWriteSession_ProducerIdPrefixRequired) { + TTopicSdkTestSetup setup{TEST_CASE_NAME, TTopicSdkTestSetup::MakeServerSettings(), false}; + setup.CreateTopic(TEST_TOPIC, TEST_CONSUMER, 1); + + TKeyedWriteSessionSettings writeSettings; + writeSettings + .Path(setup.GetTopicPath(TEST_TOPIC)) + .Codec(ECodec::RAW); + writeSettings.PartitionChooserStrategy(TKeyedWriteSessionSettings::EPartitionChooserStrategy::Hash); + writeSettings.SubSessionIdleTimeout(TDuration::Seconds(30)); + + UNIT_ASSERT_EXCEPTION(setup.MakeClient().CreateKeyedWriteSession(writeSettings), TContractViolation); + } + + Y_UNIT_TEST(KeyedWriteSession_SessionClosedDueToUserError) { + TTopicSdkTestSetup setup{TEST_CASE_NAME, TTopicSdkTestSetup::MakeServerSettings(), false}; + setup.CreateTopic(TEST_TOPIC, TEST_CONSUMER, 2); + auto publicClient = setup.MakeClient(); + + TKeyedWriteSessionSettings writeSettings; + writeSettings + .Path(setup.GetTopicPath(TEST_TOPIC)) + .Codec(ECodec::RAW); + writeSettings.ProducerIdPrefix(CreateGuidAsString()); + writeSettings.PartitionChooserStrategy(TKeyedWriteSessionSettings::EPartitionChooserStrategy::Hash); + writeSettings.SubSessionIdleTimeout(TDuration::Seconds(30)); + + auto session = publicClient.CreateKeyedWriteSession(writeSettings); + TKeyedWriteSessionEventLoop eventLoop(session); + auto token = eventLoop.GetContinuationToken(TDuration::Seconds(30)); + + std::string payload = "msg0"; + TWriteMessage msg(payload); + msg.SeqNo(0); + session->Write(std::move(*token), "key", std::move(msg)); + + auto readyToAcceptEvent = session->GetEvent(false); + UNIT_ASSERT_C(std::holds_alternative(*readyToAcceptEvent), "ReadyToAcceptEvent is not received"); + + UNIT_ASSERT_C(session->WaitEvent().Wait(TDuration::Seconds(1000)), "Timed out waiting for event"); + auto event = session->GetEvent(false); + UNIT_ASSERT_C(event, "Event is not received"); + auto sessionClosedEvent = std::get_if(&*event); + UNIT_ASSERT_C(sessionClosedEvent, "SessionClosedEvent is not received"); + UNIT_ASSERT_C(sessionClosedEvent->GetStatus() == EStatus::BAD_REQUEST, "Status is not BAD_REQUEST"); + UNIT_ASSERT(!session->Close(TDuration::Seconds(10))); + } + + Y_UNIT_TEST(KeyedWriteSession_NoAutoPartitioning_HashPartitionChooser) { + TTopicSdkTestSetup setup{TEST_CASE_NAME, TTopicSdkTestSetup::MakeServerSettings(), false}; + setup.CreateTopic(TEST_TOPIC, TEST_CONSUMER, 2); + + // Capture partition ids in the same order as DescribeTopic returns them + // (the keyed session uses the same DescribeTopic ordering to map hash bucket -> partition id). + auto publicClient = setup.MakeClient(); + auto describeTopicSettings = TDescribeTopicSettings().IncludeStats(true); + auto before = publicClient.DescribeTopic(setup.GetTopicPath(TEST_TOPIC), describeTopicSettings).GetValueSync(); + UNIT_ASSERT_C(before.IsSuccess(), before.GetIssues().ToOneLineString()); + const auto& beforePartitions = before.GetTopicDescription().GetPartitions(); + UNIT_ASSERT_VALUES_EQUAL(beforePartitions.size(), 2); + const ui64 partitionId0 = beforePartitions[0].GetPartitionId(); + const ui64 partitionId1 = beforePartitions[1].GetPartitionId(); + + TKeyedWriteSessionSettings writeSettings; + writeSettings + .Path(setup.GetTopicPath(TEST_TOPIC)) + .Codec(ECodec::RAW); + writeSettings.ProducerIdPrefix(CreateGuidAsString()); + writeSettings.PartitionChooserStrategy(TKeyedWriteSessionSettings::EPartitionChooserStrategy::Hash); + writeSettings.SubSessionIdleTimeout(TDuration::Seconds(30)); + + auto session = publicClient.CreateKeyedWriteSession(writeSettings); + + const std::string key0 = FindKeyForBucket(0, 2); + const std::string key1 = FindKeyForBucket(1, 2); + + const ui64 count0 = 7; + const ui64 count1 = 11; + + TKeyedWriteSessionEventLoop eventLoop(session); + + auto seqNo = 1; + for (ui64 i = 0; i < count0; ++i) { + auto token = eventLoop.GetContinuationToken(TDuration::Seconds(30)); + UNIT_ASSERT_C(token, "Timed out waiting for ReadyToAcceptEvent"); + + std::string payload = "msg0"; + TWriteMessage msg(payload); + msg.SeqNo(seqNo++); + session->Write(std::move(*token), key0, std::move(msg)); + } + for (ui64 i = 0; i < count1; ++i) { + auto token = eventLoop.GetContinuationToken(TDuration::Seconds(30)); + UNIT_ASSERT_C(token, "Timed out waiting for ReadyToAcceptEvent"); + std::string payload = "msg1"; + TWriteMessage msg(payload); + msg.SeqNo(seqNo++); + session->Write(std::move(*token), key1, std::move(msg)); + } + + UNIT_ASSERT(session->Close(TDuration::Seconds(10))); + + auto after = publicClient.DescribeTopic(setup.GetTopicPath(TEST_TOPIC), describeTopicSettings).GetValueSync(); + UNIT_ASSERT_C(after.IsSuccess(), after.GetIssues().ToOneLineString()); + const auto& afterPartitions = after.GetTopicDescription().GetPartitions(); + UNIT_ASSERT_VALUES_EQUAL(afterPartitions.size(), 2); + + std::unordered_map endOffsets; + for (const auto& p : afterPartitions) { + auto stats = p.GetPartitionStats(); + UNIT_ASSERT(stats.has_value()); + endOffsets[p.GetPartitionId()] = stats->GetEndOffset(); + } + + auto it0 = endOffsets.find(partitionId0); + auto it1 = endOffsets.find(partitionId1); + UNIT_ASSERT(it0 != endOffsets.end()); + UNIT_ASSERT(it1 != endOffsets.end()); + + const ui64 endOffset0 = it0->second; + const ui64 endOffset1 = it1->second; + + // Partition ordering in DescribeTopic is not a part of public API contract, so allow swapping. + UNIT_ASSERT_VALUES_EQUAL(endOffset0 + endOffset1, count0 + count1); + UNIT_ASSERT_C( + (endOffset0 == count0 && endOffset1 == count1) || (endOffset0 == count1 && endOffset1 == count0), + TStringBuilder() << "Unexpected end offsets distribution: " + << "partitionId0=" << partitionId0 << " endOffset0=" << endOffset0 << ", " + << "partitionId1=" << partitionId1 << " endOffset1=" << endOffset1 << ", " + << "expected (" << count0 << "," << count1 << ") in any order" + ); + } + + Y_UNIT_TEST(KeyedWriteSession_NoAutoPartitioning_BoundPartitionChooser) { + TTopicSdkTestSetup setup{TEST_CASE_NAME, TTopicSdkTestSetup::MakeServerSettings(), false}; + setup.CreateTopicWithAutoscale(TEST_TOPIC, TEST_CONSUMER, 5, 10); + + auto publicClient = setup.MakeClient(); + auto describeTopicSettings = TDescribeTopicSettings().IncludeStats(true); + auto before = publicClient.DescribeTopic(setup.GetTopicPath(TEST_TOPIC), describeTopicSettings).GetValueSync(); + + UNIT_ASSERT_C(before.IsSuccess(), before.GetIssues().ToOneLineString()); + const auto& beforePartitions = before.GetTopicDescription().GetPartitions(); + UNIT_ASSERT_VALUES_EQUAL(beforePartitions.size(), 5); + + TKeyedWriteSessionSettings writeSettings; + writeSettings + .Path(setup.GetTopicPath(TEST_TOPIC)) + .Codec(ECodec::RAW); + writeSettings.ProducerIdPrefix(CreateGuidAsString()); + writeSettings.PartitionChooserStrategy(TKeyedWriteSessionSettings::EPartitionChooserStrategy::Bound); + writeSettings.SubSessionIdleTimeout(TDuration::Seconds(30)); + writeSettings.PartitioningKeyHasher([](const std::string_view key) -> std::string { + return std::string{key}; + }); + + auto session = publicClient.CreateKeyedWriteSession(writeSettings); + auto keyedSession = std::dynamic_pointer_cast(session); + const auto& partitions = keyedSession->GetPartitions(); + + TKeyedWriteSessionEventLoop eventLoop(session); + + std::unordered_map keysCount; + for (const auto& p : partitions) { + keysCount[p.PartitionId_] = 0; + } + + for (size_t i = 0; i < 100; ++i) { + auto key = CreateGuidAsString(); + for (const auto& p : partitions) { + if (p.InRange(key)) { + keysCount[p.PartitionId_]++; + break; + } + } + + auto token = eventLoop.GetContinuationToken(TDuration::Seconds(30)); + UNIT_ASSERT_C(token, "Timed out waiting for ReadyToAcceptEvent"); + std::string payload = "msg"; + TWriteMessage msg(payload); + msg.SeqNo(i + 1); + session->Write(std::move(*token), key, std::move(msg)); + } + + UNIT_ASSERT(session->Close(TDuration::Seconds(10))); + + auto after = publicClient.DescribeTopic(setup.GetTopicPath(TEST_TOPIC), describeTopicSettings).GetValueSync(); + UNIT_ASSERT_C(after.IsSuccess(), after.GetIssues().ToOneLineString()); + const auto& afterPartitions = after.GetTopicDescription().GetPartitions(); + + std::unordered_map endOffsets; + for (const auto& p : afterPartitions) { + auto stats = p.GetPartitionStats(); + UNIT_ASSERT(stats.has_value()); + endOffsets[p.GetPartitionId()] = stats->GetEndOffset(); + } + + for (const auto& p : partitions) { + auto sb = TStringBuilder() << "partitionId=" << p.PartitionId_ << " endOffset=" << endOffsets[p.PartitionId_] << " keysCount=" << keysCount[p.PartitionId_]; + UNIT_ASSERT_VALUES_EQUAL_C(endOffsets[p.PartitionId_], keysCount[p.PartitionId_], sb.c_str()); + } + } + + Y_UNIT_TEST(KeyedWriteSession_EventLoop_Acks) { + TTopicSdkTestSetup setup{TEST_CASE_NAME, TTopicSdkTestSetup::MakeServerSettings(), false}; + setup.CreateTopic(TEST_TOPIC, TEST_CONSUMER, 4); + + auto client = setup.MakeClient(); + + TKeyedWriteSessionSettings writeSettings; + writeSettings + .Path(setup.GetTopicPath(TEST_TOPIC)) + .Codec(ECodec::RAW); + writeSettings.ProducerIdPrefix(CreateGuidAsString()); + writeSettings.SubSessionIdleTimeout(TDuration::Seconds(10)); + writeSettings.PartitionChooserStrategy(TKeyedWriteSessionSettings::EPartitionChooserStrategy::Hash); + + auto session = client.CreateKeyedWriteSession(writeSettings); + TKeyedWriteSessionEventLoop eventLoop(session); + + const ui64 count = 3000; + for (ui64 i = 1; i <= count; ++i) { + auto key = CreateGuidAsString(); + auto token = eventLoop.GetContinuationToken(TDuration::Seconds(30)); + UNIT_ASSERT_C(token, "Timed out waiting for ReadyToAcceptEvent"); + std::string payload = "data"; + TWriteMessage msg(payload); + msg.SeqNo(i); + session->Write(std::move(*token), key, std::move(msg)); + } + + UNIT_ASSERT(eventLoop.WaitForAcks(count, TDuration::Seconds(60))); + eventLoop.CheckAcksOrder(); + UNIT_ASSERT(session->Close(TDuration::Seconds(10))); + } + + Y_UNIT_TEST(KeyedWriteSession_MultiThreadedWrite_Acks) { + TTopicSdkTestSetup setup{TEST_CASE_NAME, TTopicSdkTestSetup::MakeServerSettings(), false}; + setup.CreateTopic(TEST_TOPIC, TEST_CONSUMER, 3); + + auto client = setup.MakeClient(); + + TKeyedWriteSessionSettings writeSettings; + writeSettings + .Path(setup.GetTopicPath(TEST_TOPIC)) + .Codec(ECodec::RAW); + writeSettings.ProducerIdPrefix(CreateGuidAsString()); + writeSettings.SubSessionIdleTimeout(TDuration::Seconds(30)); + writeSettings.PartitionChooserStrategy(TKeyedWriteSessionSettings::EPartitionChooserStrategy::Hash); + + auto session = client.CreateKeyedWriteSession(writeSettings); + + constexpr ui64 threadsCount = 4; + constexpr ui64 perThread = 25; + constexpr ui64 total = threadsCount * perThread; + + std::atomic nextSeqNo{1}; + std::vector threads; + threads.reserve(threadsCount); + + TKeyedWriteSessionEventLoop eventLoop(session); + + for (ui64 t = 0; t < threadsCount; ++t) { + threads.emplace_back([&, t]() { + auto key = TStringBuilder() << "key-" << t; + for (ui64 i = 0; i < perThread; ++i) { + std::cout << "thread " << t << " writing message " << i << std::endl; + auto token = eventLoop.GetContinuationToken(TDuration::Seconds(30)); + UNIT_ASSERT_C(token, "Timed out waiting for ReadyToAcceptEvent"); + const ui64 seqNo = nextSeqNo.fetch_add(1); + std::string payload = "data"; + TWriteMessage msg(payload); + msg.SeqNo(seqNo); + session->Write(std::move(*token), key, std::move(msg)); + } + }); + } + + UNIT_ASSERT(eventLoop.WaitForAcks(total, TDuration::Seconds(60))); + UNIT_ASSERT(session->Close(TDuration::Seconds(10))); + } + + Y_UNIT_TEST(KeyedWriteSession_IdleSessionsTimeout) { + TTopicSdkTestSetup setup{TEST_CASE_NAME, TTopicSdkTestSetup::MakeServerSettings(), false}; + setup.CreateTopic(TEST_TOPIC, TEST_CONSUMER, 3); + + auto client = setup.MakeClient(); + + TKeyedWriteSessionSettings writeSettings; + writeSettings + .Path(setup.GetTopicPath(TEST_TOPIC)) + .Codec(ECodec::RAW); + writeSettings.ProducerIdPrefix(CreateGuidAsString()); + writeSettings.SubSessionIdleTimeout(TDuration::Seconds(5)); + writeSettings.PartitionChooserStrategy(TKeyedWriteSessionSettings::EPartitionChooserStrategy::Hash); + + auto session = client.CreateKeyedWriteSession(writeSettings); + + TKeyedWriteSessionEventLoop eventLoop(session); + constexpr ui64 messages = 100; + ui64 seqNo = 1; + + for (ui64 i = 0; i < messages; ++i) { + auto key = CreateGuidAsString(); + auto token = eventLoop.GetContinuationToken(TDuration::Seconds(30)); + UNIT_ASSERT_C(token, "Timed out waiting for ReadyToAcceptEvent"); + std::string payload = "data"; + TWriteMessage msg(payload); + msg.SeqNo(seqNo++); + session->Write(std::move(*token), key, std::move(msg)); + } + + UNIT_ASSERT(eventLoop.WaitForAcks(messages, TDuration::Seconds(60))); + eventLoop.CheckAcksOrder(); + + Sleep(TDuration::Seconds(6)); + + for (ui64 i = 0; i < messages; ++i) { + auto key = CreateGuidAsString(); + auto token = eventLoop.GetContinuationToken(TDuration::Seconds(30)); + UNIT_ASSERT_C(token, "Timed out waiting for ReadyToAcceptEvent"); + std::string payload = "data"; + TWriteMessage msg(payload); + msg.SeqNo(seqNo++); + session->Write(std::move(*token), key, std::move(msg)); + } + + UNIT_ASSERT(eventLoop.WaitForAcks(messages * 2, TDuration::Seconds(60))); + eventLoop.CheckAcksOrder(); + } + + Y_UNIT_TEST(KeyedWriteSession_BoundPartitionChooser_SplitPartition_MultiThreadedAcksOrder) { + NKikimr::NPQ::NTest::TTopicSdkTestSetup setup = NKikimr::NPQ::NTest::CreateSetup(); + setup.CreateTopicWithAutoscale(TEST_TOPIC, TEST_CONSUMER, 1, 100); + + auto client = setup.MakeClient(); + + TKeyedWriteSessionSettings writeSettings; + writeSettings + .Path(setup.GetTopicPath(TEST_TOPIC)) + .Codec(ECodec::RAW); + writeSettings.ProducerIdPrefix(CreateGuidAsString()); + writeSettings.SubSessionIdleTimeout(TDuration::Seconds(30)); + writeSettings.PartitionChooserStrategy(TKeyedWriteSessionSettings::EPartitionChooserStrategy::Bound); + writeSettings.PartitioningKeyHasher([](const std::string_view key) -> std::string { + return std::string{key}; + }); + + auto session = client.CreateKeyedWriteSession(writeSettings); + + constexpr ui64 messages = 1000; + TKeyedWriteSessionEventLoop eventLoop(session); + + std::jthread writer([&]() { + for (ui64 i = 1; i <= messages; ++i) { + auto token = eventLoop.GetContinuationToken(TDuration::Seconds(30)); + UNIT_ASSERT_C(token, "Timed out waiting for ReadyToAcceptEvent"); + auto key = CreateGuidAsString(); + std::string payload = "data"; + TWriteMessage msg(payload); + msg.SeqNo(i); + session->Write(std::move(*token), key, std::move(msg)); + } + }); + + std::jthread splitter([&]() { + Sleep(TDuration::Seconds(1)); + ui64 txId = 1006; + NKikimr::NPQ::NTest::SplitPartition(setup, ++txId, 0, "a"); + }); + + writer.join(); + splitter.join(); + + UNIT_ASSERT(eventLoop.WaitForAcks(messages, TDuration::Seconds(60))); + eventLoop.CheckAcksOrder(); + UNIT_ASSERT(session->Close(TDuration::Seconds(30))); + } + + Y_UNIT_TEST(SimpleBlockingKeyedWriteSession_BasicWrite) { + TTopicSdkTestSetup setup{TEST_CASE_NAME, TTopicSdkTestSetup::MakeServerSettings(), false}; + setup.CreateTopic(TEST_TOPIC, TEST_CONSUMER, 5); + + auto client = setup.MakeClient(); + + TKeyedWriteSessionSettings writeSettings; + writeSettings + .Path(setup.GetTopicPath(TEST_TOPIC)) + .Codec(ECodec::RAW); + writeSettings.ProducerIdPrefix(CreateGuidAsString()); + writeSettings.SubSessionIdleTimeout(TDuration::Seconds(30)); + writeSettings.PartitionChooserStrategy(TKeyedWriteSessionSettings::EPartitionChooserStrategy::Hash); + + auto session = client.CreateSimpleBlockingKeyedWriteSession(writeSettings); + + const std::string key1 = "key1"; + const std::string key2 = "key2"; + + // Write several messages with different keys + size_t seqNo = 1; + for (int i = 0; i < 5; ++i) { + std::string payload = "message1-" + ToString(i); + TWriteMessage msg(payload); + msg.SeqNo(seqNo++); + bool res = session->Write(key1, std::move(msg)); + UNIT_ASSERT(res); + } + + for (int i = 0; i < 5; ++i) { + std::string payload = "message2-" + ToString(i); + TWriteMessage msg(payload); + msg.SeqNo(seqNo++); + bool res = session->Write(key2, std::move(msg)); + UNIT_ASSERT(res); + } + + UNIT_ASSERT(session->Close(TDuration::Seconds(10))); + } + + Y_UNIT_TEST(SimpleBlockingKeyedWriteSession_NoSeqNo) { + TTopicSdkTestSetup setup{TEST_CASE_NAME, TTopicSdkTestSetup::MakeServerSettings(), false}; + setup.CreateTopic(TEST_TOPIC, TEST_CONSUMER, 3); + + auto client = setup.MakeClient(); + + TKeyedWriteSessionSettings writeSettings; + writeSettings + .Path(setup.GetTopicPath(TEST_TOPIC)) + .Codec(ECodec::RAW); + writeSettings.ProducerIdPrefix(CreateGuidAsString()); + writeSettings.SubSessionIdleTimeout(TDuration::Seconds(30)); + writeSettings.PartitionChooserStrategy(TKeyedWriteSessionSettings::EPartitionChooserStrategy::Hash); + + auto session = client.CreateSimpleBlockingKeyedWriteSession(writeSettings); + + const ui64 messages = 10; + for (ui64 i = 0; i < messages; ++i) { + std::string payload = "payload-" + ToString(i); + TWriteMessage msg(payload); + bool res = session->Write("key-" + ToString(i % 3), std::move(msg)); + UNIT_ASSERT(res); + } + + bool closeRes = session->Close(TDuration::Seconds(30)); + UNIT_ASSERT(closeRes); + } + + Y_UNIT_TEST(SimpleBlockingKeyedWriteSession_ManyMessages) { + TTopicSdkTestSetup setup{TEST_CASE_NAME, TTopicSdkTestSetup::MakeServerSettings(), false}; + setup.CreateTopic(TEST_TOPIC, TEST_CONSUMER, 4); + + auto client = setup.MakeClient(); + + TKeyedWriteSessionSettings writeSettings; + writeSettings + .Path(setup.GetTopicPath(TEST_TOPIC)) + .Codec(ECodec::RAW); + writeSettings.ProducerIdPrefix(CreateGuidAsString()); + writeSettings.SubSessionIdleTimeout(TDuration::Seconds(30)); + writeSettings.PartitionChooserStrategy(TKeyedWriteSessionSettings::EPartitionChooserStrategy::Hash); + + auto session = client.CreateSimpleBlockingKeyedWriteSession(writeSettings); + + ui64 seqNo = 1; + + for (ui64 i = 0; i < 1000; ++i) { + auto key = CreateGuidAsString(); + std::string payload = "payload-" + ToString(seqNo); + TWriteMessage msg(payload); + msg.SeqNo(seqNo++); + bool res = session->Write(key, std::move(msg)); + UNIT_ASSERT(res); + } + + bool closeRes = session->Close(TDuration::Seconds(60)); + UNIT_ASSERT(closeRes); + } + + Y_UNIT_TEST(KeyedWriteSession_CloseTimeout) { + TTopicSdkTestSetup setup{TEST_CASE_NAME, TTopicSdkTestSetup::MakeServerSettings(), false}; + setup.CreateTopic(TEST_TOPIC, TEST_CONSUMER, 3); + + auto client = setup.MakeClient(); + + TKeyedWriteSessionSettings writeSettings; + writeSettings + .Path(setup.GetTopicPath(TEST_TOPIC)) + .Codec(ECodec::RAW); + writeSettings.ProducerIdPrefix(CreateGuidAsString()); + writeSettings.PartitionChooserStrategy(TKeyedWriteSessionSettings::EPartitionChooserStrategy::Hash); + writeSettings.SubSessionIdleTimeout(TDuration::Seconds(30)); + + auto session = client.CreateKeyedWriteSession(writeSettings); + + TKeyedWriteSessionEventLoop eventLoop(session); + + for (int i = 0; i < 1000; ++i) { + auto token = eventLoop.GetContinuationToken(TDuration::Seconds(10)); + UNIT_ASSERT_C(token, "Timed out waiting for ReadyToAcceptEvent"); + std::string payload = "message-" + ToString(i); + TWriteMessage msg(payload); + msg.SeqNo(i + 1); + session->Write(std::move(*token), "key1", std::move(msg)); + } + + // Test Close timeout + const TDuration closeTimeout = TDuration::Seconds(2); + const TInstant startTime = TInstant::Now(); + session->Close(closeTimeout); + const TDuration actualDuration = TInstant::Now() - startTime; + + // Verify that Close didn't block longer than timeout (with some tolerance) + const TDuration maxExpectedDuration = closeTimeout + TDuration::MilliSeconds(100) + closeTimeout / 10; + UNIT_ASSERT_C( + actualDuration <= maxExpectedDuration + maxExpectedDuration / 10, + TStringBuilder() << "Close() took " << actualDuration << " but timeout was " << closeTimeout + ); + + int attempts = 0; + constexpr int maxAttempts = 1100; + for (attempts = 0; attempts < maxAttempts; ++attempts) { + auto event = session->GetEvent(false); + if (!event) { + break; + } + + auto sessionClosedEvent = std::get_if(&*event); + if (!sessionClosedEvent) { + continue; + } + + UNIT_ASSERT(sessionClosedEvent->IsSuccess()); + break; + } + + UNIT_ASSERT(attempts < maxAttempts); + } + + Y_UNIT_TEST(AutoPartitioning_KeyedWriteSession) { + auto settings = TTopicSdkTestSetup::MakeServerSettings(); + settings.PQConfig.SetUseSrcIdMetaMappingInFirstClass(true); + TTopicSdkTestSetup setup{TEST_CASE_NAME, settings, false}; + TTopicClient client = setup.MakeClient(); + + std::queue readyTokens1; + std::queue readyTokens2; + std::optional sessionClosedEvent; + std::unordered_set ackedSeqNos; + bool closed = false; + + auto createMessage = [](std::string_view payload, ui64 seqNo) -> TWriteMessage { + TWriteMessage msg(payload); + msg.SeqNo(seqNo); + return msg; + }; + + TCreateTopicSettings createSettings; + createSettings + .BeginConfigurePartitioningSettings() + .MinActivePartitions(2) + .MaxActivePartitions(100) + .BeginConfigureAutoPartitioningSettings() + .UpUtilizationPercent(2) + .DownUtilizationPercent(1) + .StabilizationWindow(TDuration::Seconds(2)) + .Strategy(EAutoPartitioningStrategy::ScaleUp) + .EndConfigureAutoPartitioningSettings() + .EndConfigurePartitioningSettings(); + client.CreateTopic(TEST_TOPIC, createSettings).Wait(); + + auto describe = client.DescribeTopic(TEST_TOPIC).GetValueSync(); + UNIT_ASSERT_EQUAL(describe.GetTopicDescription().GetPartitions().size(), 2); + + TKeyedWriteSessionSettings writeSettings1; + writeSettings1 + .Path(setup.GetTopicPath(TEST_TOPIC)) + .Codec(ECodec::RAW); + writeSettings1.ProducerIdPrefix("autopartitioning_keyed_1"); + writeSettings1.PartitionChooserStrategy(TKeyedWriteSessionSettings::EPartitionChooserStrategy::Bound); + writeSettings1.SubSessionIdleTimeout(TDuration::Seconds(30)); + + TKeyedWriteSessionSettings writeSettings2 = writeSettings1; + writeSettings2.ProducerIdPrefix("autopartitioning_keyed_2"); + + auto session1 = client.CreateKeyedWriteSession(writeSettings1); + auto session2 = client.CreateKeyedWriteSession(writeSettings2); + auto msgData = TString(1_MB, 'a'); + + std::vector keys; + for (const auto& partition : describe.GetTopicDescription().GetPartitions()) { + keys.push_back(partition.GetFromBound().value_or("")); + } + + auto getQueue = [&](const std::shared_ptr& s) -> std::queue& { + if (s == session1) { + return readyTokens1; + } + if (s == session2) { + return readyTokens2; + } + Y_ABORT("Unknown session pointer in AutoPartitioning_KeyedWriteSession"); + }; + + auto eventLoop = [&](std::shared_ptr s) { + while (true) { + auto event = s->GetEvent(false); + if (!event) { + break; + } + if (auto* ready = std::get_if(&*event)) { + getQueue(s).push(std::move(ready->ContinuationToken)); + continue; + } + if (auto* closedEv = std::get_if(&*event)) { + sessionClosedEvent = std::move(*closedEv); + closed = true; + break; + } + if (auto* acks = std::get_if(&*event)) { + for (const auto& ack : acks->Acks) { + UNIT_ASSERT_C( + ackedSeqNos.insert(ack.SeqNo).second, + "Duplicate ack for seqNo " << ack.SeqNo); + } + } + } + }; + + auto getReadyToken = [&](std::shared_ptr s) -> std::optional { + auto& q = getQueue(s); + while (q.empty() && !closed) { + s->WaitEvent().Wait(TDuration::Seconds(5)); + eventLoop(s); + } + if (q.empty()) { + return std::nullopt; + } + auto t = std::move(q.front()); + q.pop(); + return t; + }; + + auto writeMessage = [&](std::shared_ptr s, std::string_view payload, ui64 seqNo) { + auto token = getReadyToken(s); + UNIT_ASSERT(token); + auto key = keys[seqNo % keys.size()]; + if (key.empty()) { + key = "lalala"; + } + s->Write(std::move(*token), key, createMessage(payload, seqNo)); + }; + + { + writeMessage(session1, msgData, 1); + writeMessage(session1, msgData, 2); + Sleep(TDuration::Seconds(5)); + auto d = client.DescribeTopic(TEST_TOPIC).GetValueSync(); + UNIT_ASSERT_EQUAL(d.GetTopicDescription().GetPartitions().size(), 2); + } + + { + writeMessage(session1, msgData, 3); + writeMessage(session1, msgData, 4); + writeMessage(session1, msgData, 5); + writeMessage(session1, msgData, 6); + writeMessage(session1, msgData, 7); + writeMessage(session2, msgData, 8); + writeMessage(session1, msgData, 9); + writeMessage(session1, msgData, 10); + writeMessage(session2, msgData, 11); + writeMessage(session1, msgData, 12); + Sleep(TDuration::Seconds(30)); + for (int i = 0; i < 50 && ackedSeqNos.size() < 12 && !closed; ++i) { + eventLoop(session1); + eventLoop(session2); + if (ackedSeqNos.size() < 12) { + Sleep(TDuration::MilliSeconds(200)); + } + } + UNIT_ASSERT_EQUAL_C(ackedSeqNos.size(), 12, + "Expected exactly 12 distinct acks, each seqNo exactly once; got " << ackedSeqNos.size()); + } + + auto describeResult = client.DescribeTopic(TEST_TOPIC).GetValueSync(); + auto partitionsCount = describeResult.GetTopicDescription().GetPartitions().size(); + UNIT_ASSERT_C(partitionsCount >= 4, + TStringBuilder() << "Partitions count: " << partitionsCount << ", expected at least 4"); + + writeMessage(session1, msgData, 13); + writeMessage(session1, msgData, 14); + Sleep(TDuration::Seconds(20)); + for (int i = 0; i < 50 && ackedSeqNos.size() < 14 && !closed; ++i) { + eventLoop(session1); + eventLoop(session2); + if (ackedSeqNos.size() < 14) { + Sleep(TDuration::MilliSeconds(200)); + } + } + + UNIT_ASSERT_EQUAL_C(ackedSeqNos.size(), 14, + "Expected exactly 14 distinct acks, each seqNo exactly once; got " << ackedSeqNos.size()); + auto sessionPartitions = dynamic_cast(session1.get())->GetPartitions(); + UNIT_ASSERT_EQUAL_C(sessionPartitions.size(), partitionsCount, + "Expected exactly" << partitionsCount << " partitions, actual: " << sessionPartitions.size()); + + UNIT_ASSERT(session1->Close(TDuration::Seconds(30))); + UNIT_ASSERT(session2->Close(TDuration::Seconds(30))); + } + + Y_UNIT_TEST(AutoPartitioning_KeyedWriteSession_SmallMessages) { + auto settings = TTopicSdkTestSetup::MakeServerSettings(); + settings.PQConfig.SetUseSrcIdMetaMappingInFirstClass(true); + TTopicSdkTestSetup setup{TEST_CASE_NAME, settings, false}; + TTopicClient client = setup.MakeClient(); + + std::queue readyTokens1; + std::queue readyTokens2; + std::optional sessionClosedEvent; + std::unordered_set ackedSeqNos; + bool closed = false; + + auto createMessage = [](std::string_view payload, ui64 seqNo) -> TWriteMessage { + TWriteMessage msg(payload); + msg.SeqNo(seqNo); + return msg; + }; + + TCreateTopicSettings createSettings; + createSettings + .BeginConfigurePartitioningSettings() + .MinActivePartitions(2) + .MaxActivePartitions(100) + .BeginConfigureAutoPartitioningSettings() + .UpUtilizationPercent(2) + .DownUtilizationPercent(1) + .StabilizationWindow(TDuration::Seconds(2)) + .Strategy(EAutoPartitioningStrategy::ScaleUp) + .EndConfigureAutoPartitioningSettings() + .EndConfigurePartitioningSettings(); + client.CreateTopic(TEST_TOPIC, createSettings).Wait(); + + auto describe = client.DescribeTopic(TEST_TOPIC).GetValueSync(); + UNIT_ASSERT_EQUAL(describe.GetTopicDescription().GetPartitions().size(), 2); + + TKeyedWriteSessionSettings writeSettings1; + writeSettings1 + .Path(setup.GetTopicPath(TEST_TOPIC)) + .Codec(ECodec::RAW); + writeSettings1.ProducerIdPrefix("autopartitioning_keyed_small_1"); + writeSettings1.PartitionChooserStrategy(TKeyedWriteSessionSettings::EPartitionChooserStrategy::Bound); + writeSettings1.SubSessionIdleTimeout(TDuration::Seconds(30)); + + TKeyedWriteSessionSettings writeSettings2 = writeSettings1; + writeSettings2.ProducerIdPrefix("autopartitioning_keyed_small_2"); + + auto session1 = client.CreateKeyedWriteSession(writeSettings1); + auto session2 = client.CreateKeyedWriteSession(writeSettings2); + const size_t msgSize = 256_KB; + auto msgData = TString(msgSize, 'a'); + const ui64 totalMessages = 44; + + std::vector keys; + for (const auto& partition : describe.GetTopicDescription().GetPartitions()) { + keys.push_back(partition.GetFromBound().value_or("")); + } + + auto getQueue = [&](const std::shared_ptr& s) -> std::queue& { + if (s == session1) return readyTokens1; + if (s == session2) return readyTokens2; + Y_ABORT("Unknown session pointer in AutoPartitioning_KeyedWriteSession_SmallMessages"); + }; + + auto eventLoop = [&](std::shared_ptr s) { + while (true) { + auto event = s->GetEvent(false); + if (!event) break; + if (auto* ready = std::get_if(&*event)) { + getQueue(s).push(std::move(ready->ContinuationToken)); + continue; + } + if (auto* closedEv = std::get_if(&*event)) { + sessionClosedEvent = std::move(*closedEv); + closed = true; + break; + } + if (auto* acks = std::get_if(&*event)) { + for (const auto& ack : acks->Acks) { + UNIT_ASSERT_C(ackedSeqNos.insert(ack.SeqNo).second, + "Duplicate ack for seqNo " << ack.SeqNo); + } + } + } + }; + + auto getReadyToken = [&](std::shared_ptr s) -> std::optional { + auto& q = getQueue(s); + while (q.empty() && !closed) { + s->WaitEvent().Wait(TDuration::Seconds(5)); + eventLoop(s); + } + if (q.empty()) return std::nullopt; + auto t = std::move(q.front()); + q.pop(); + return t; + }; + + auto writeMessage = [&](std::shared_ptr s, std::string_view payload, ui64 seqNo) { + auto token = getReadyToken(s); + UNIT_ASSERT(token); + auto key = keys[seqNo % keys.size()]; + if (key.empty()) key = "a"; + s->Write(std::move(*token), key, createMessage(payload, seqNo)); + }; + + { + writeMessage(session1, msgData, 1); + writeMessage(session1, msgData, 2); + Sleep(TDuration::Seconds(5)); + auto d = client.DescribeTopic(TEST_TOPIC).GetValueSync(); + UNIT_ASSERT_EQUAL(d.GetTopicDescription().GetPartitions().size(), 2); + } + + { + for (ui64 seq = 3; seq <= totalMessages - 2; ++seq) { + auto s = (seq % 4 == 0) ? session2 : session1; + writeMessage(s, msgData, seq); + } + Sleep(TDuration::Seconds(30)); + for (int i = 0; i < 80 && ackedSeqNos.size() < totalMessages - 2 && !closed; ++i) { + eventLoop(session1); + eventLoop(session2); + if (ackedSeqNos.size() < totalMessages - 2) Sleep(TDuration::MilliSeconds(200)); + } + UNIT_ASSERT_EQUAL_C(ackedSeqNos.size(), totalMessages - 2, + "Expected " << totalMessages - 2 << " acks; got " << ackedSeqNos.size()); + } + + auto describeResult = client.DescribeTopic(TEST_TOPIC).GetValueSync(); + auto partitionsCount = describeResult.GetTopicDescription().GetPartitions().size(); + UNIT_ASSERT_C(partitionsCount >= 3, + TStringBuilder() << "Partitions count: " << partitionsCount << ", expected at least 3 (auto-partitioning)"); + + writeMessage(session1, msgData, totalMessages - 1); + writeMessage(session1, msgData, totalMessages); + Sleep(TDuration::Seconds(20)); + for (int i = 0; i < 80 && ackedSeqNos.size() < totalMessages && !closed; ++i) { + eventLoop(session1); + eventLoop(session2); + if (ackedSeqNos.size() < totalMessages) Sleep(TDuration::MilliSeconds(200)); + } + + UNIT_ASSERT_EQUAL_C(ackedSeqNos.size(), totalMessages, + "Expected " << totalMessages << " acks; got " << ackedSeqNos.size()); + auto sessionPartitions = dynamic_cast(session1.get())->GetPartitions(); + UNIT_ASSERT_EQUAL_C(sessionPartitions.size(), partitionsCount, + "Session partitions " << sessionPartitions.size() << " != topic partitions " << partitionsCount); + + UNIT_ASSERT(session1->Close(TDuration::Seconds(30))); + UNIT_ASSERT(session2->Close(TDuration::Seconds(30))); + } + } // Y_UNIT_TEST_SUITE(BasicUsage) } // namespace diff --git a/src/client/topic/ut/direct_read_ut.cpp b/src/client/topic/ut/direct_read_ut.cpp index ea7d3227263..29e0e6e6955 100644 --- a/src/client/topic/ut/direct_read_ut.cpp +++ b/src/client/topic/ut/direct_read_ut.cpp @@ -47,7 +47,7 @@ Y_UNIT_TEST_SUITE(DirectReadWithServer) { auto readerSettings = TReadSessionSettings() .ConsumerName(setup.GetConsumerName()) .AppendTopics(setup.GetTopicPath()) - // .DirectRead(true) + .DirectRead(true) ; TIntrusivePtr partitionSession; @@ -129,7 +129,7 @@ Y_UNIT_TEST_SUITE(DirectReadWithServer) { auto readerSettings = TReadSessionSettings() .ConsumerName(setup.GetConsumerName()) .AppendTopics(setup.GetTopicPath()) - // .DirectRead(true) + .DirectRead(true) ; TIntrusivePtr partitionSession; @@ -175,6 +175,7 @@ Y_UNIT_TEST_SUITE(DirectReadWithServer) { reader->Close(); } + } // Y_UNIT_TEST_SUITE_F(DirectReadWithServer) } // namespace NYdb::NTopic::NTests diff --git a/src/client/topic/ut/ut_utils/event_loop.cpp b/src/client/topic/ut/ut_utils/event_loop.cpp new file mode 100644 index 00000000000..233b2863e71 --- /dev/null +++ b/src/client/topic/ut/ut_utils/event_loop.cpp @@ -0,0 +1,93 @@ +#include "event_loop.h" +#include + +namespace NYdb::inline V3::NTopic::NTests { + +TKeyedWriteSessionEventLoop::TKeyedWriteSessionEventLoop(std::shared_ptr session) + : Session_(std::move(session)) +{} + +void TKeyedWriteSessionEventLoop::Run() { + while (true) { + auto event = Session_->GetEvent(false); + if (!event) { + break; + } + if (auto* ready = std::get_if(&*event)) { + std::lock_guard lk(Lock_); + ReadyTokens_.push(std::move(ready->ContinuationToken)); + continue; + } + if (std::get_if(&*event)) { + break; + } + if (auto* acks = std::get_if(&*event)) { + std::lock_guard lk(Lock_); + for (const auto& ack : acks->Acks) { + auto [it, inserted] = AckedSeqNos_.insert(ack.SeqNo); + UNIT_ASSERT_C(inserted, TStringBuilder() << "Ack already received: " << ack.SeqNo); + AckOrder_.push_back(ack.SeqNo); + } + } + } +} + +std::optional TKeyedWriteSessionEventLoop::GetContinuationToken(TDuration timeout) { + const TInstant deadline = TInstant::Now() + timeout; + while (TInstant::Now() < deadline) { + { + std::lock_guard lock(Lock_); + if (!ReadyTokens_.empty()) { + auto token = std::move(ReadyTokens_.front()); + ReadyTokens_.pop(); + return token; + } + } + Session_->WaitEvent().Wait(deadline); + Run(); + } + + std::lock_guard lock(Lock_); + if (!ReadyTokens_.empty()) { + auto token = std::move(ReadyTokens_.front()); + ReadyTokens_.pop(); + return token; + } + return std::nullopt; +} + +bool TKeyedWriteSessionEventLoop::WaitForAcks(size_t count, TDuration timeout) { + const TInstant deadline = TInstant::Now() + timeout; + while (TInstant::Now() < deadline) { + { + std::lock_guard lock(Lock_); + if (AckedSeqNos_.size() >= count) { + return true; + } + } + Session_->WaitEvent().Wait(deadline); + Run(); + } + return false; +} + +void TKeyedWriteSessionEventLoop::CheckAcksOrder() { + std::lock_guard lock(Lock_); + size_t expectedAck = 1; + UNIT_ASSERT_C(AckedSeqNos_.size() == AckOrder_.size(), TStringBuilder() << "Unexpected number of acks: got " << AckOrder_.size() << ", expected " << AckedSeqNos_.size()); + size_t index = 0; + for (const auto& ack : AckOrder_) { + TStringBuilder sb; + if (ack != expectedAck) { + for (size_t i = std::min(size_t(0), index - 10); i < std::min(index + 10, AckOrder_.size()); i++) { + sb << "Ack " << i << ": " << AckOrder_[i] << " "; + } + + sb << "Unexpected ack order: got " << ack << ", expected " << expectedAck; + } + UNIT_ASSERT_VALUES_EQUAL_C(ack, expectedAck, sb); + expectedAck++; + } +} + +} // namespace NYdb::inline V3::NTopic::NTests diff --git a/src/client/topic/ut/ut_utils/event_loop.h b/src/client/topic/ut/ut_utils/event_loop.h new file mode 100644 index 00000000000..a1fa6ca60f6 --- /dev/null +++ b/src/client/topic/ut/ut_utils/event_loop.h @@ -0,0 +1,37 @@ +#pragma once + +#include + +#include + +#include +#include +#include +#include + +namespace NYdb::inline V3::NTopic::NTests { + +//! Helper for keyed write session tests: runs event loop and provides continuation tokens. +class TKeyedWriteSessionEventLoop { +public: + explicit TKeyedWriteSessionEventLoop(std::shared_ptr session); + + //! Block until a continuation token is available or timeout. Calls Run() while waiting. + //! Returns nullopt on timeout or if session was closed before a token appeared. + std::optional GetContinuationToken(TDuration timeout); + + bool WaitForAcks(size_t count, TDuration timeout); + void CheckAcksOrder(); + +private: + //! Process all currently available events. Returns true if SessionClosed was seen. + void Run(); + + std::shared_ptr Session_; + std::queue ReadyTokens_; + std::unordered_set AckedSeqNos_; + std::vector AckOrder_; + std::mutex Lock_; +}; + +} // namespace NYdb::inline V3::NTopic::NTests diff --git a/src/library/grpc/client/grpc_client_low.cpp b/src/library/grpc/client/grpc_client_low.cpp index a8fa3c820cc..a54553e2110 100644 --- a/src/library/grpc/client/grpc_client_low.cpp +++ b/src/library/grpc/client/grpc_client_low.cpp @@ -37,18 +37,20 @@ void EnableGRpcTracing() { } #if !defined(YDB_DISABLE_GRPC_SOCKET_MUTATOR) -class TGRpcKeepAliveSocketMutator : public grpc_socket_mutator { +class TGRpcSocketMutator : public grpc_socket_mutator { public: - TGRpcKeepAliveSocketMutator(int idle, int count, int interval) - : Idle_(idle) + TGRpcSocketMutator(bool isKeepAliveEnabled, int idle, int count, int interval, bool tcpNoDelay) + : IsKeepAliveEnabled_(isKeepAliveEnabled) + , Idle_(idle) , Count_(count) , Interval_(interval) + , TcpNoDelay_(tcpNoDelay) { grpc_socket_mutator_init(this, &VTable); } private: - static TGRpcKeepAliveSocketMutator* Cast(grpc_socket_mutator* mutator) { - return static_cast(mutator); + static TGRpcSocketMutator* Cast(grpc_socket_mutator* mutator) { + return static_cast(mutator); } template @@ -56,24 +58,32 @@ class TGRpcKeepAliveSocketMutator : public grpc_socket_mutator { return setsockopt(fd, level, optname, reinterpret_cast(&value), sizeof(value)) == 0; } bool SetOption(int fd) { - if (!SetOption(fd, SOL_SOCKET, SO_KEEPALIVE, 1)) { - std::cerr << std::format("Failed to set SO_KEEPALIVE option: {}", strerror(errno)) << std::endl; - return false; - } + if (IsKeepAliveEnabled_) { + if (!SetOption(fd, SOL_SOCKET, SO_KEEPALIVE, 1)) { + std::cerr << std::format("Failed to set SO_KEEPALIVE option: {}", strerror(errno)) << std::endl; + return false; + } #ifdef _linux_ - if (Idle_ && !SetOption(fd, IPPROTO_TCP, TCP_KEEPIDLE, Idle_)) { - std::cerr << std::format("Failed to set TCP_KEEPIDLE option: {}", strerror(errno)) << std::endl; - return false; - } - if (Count_ && !SetOption(fd, IPPROTO_TCP, TCP_KEEPCNT, Count_)) { - std::cerr << std::format("Failed to set TCP_KEEPCNT option: {}", strerror(errno)) << std::endl; - return false; + if (Idle_ && !SetOption(fd, IPPROTO_TCP, TCP_KEEPIDLE, Idle_)) { + std::cerr << std::format("Failed to set TCP_KEEPIDLE option: {}", strerror(errno)) << std::endl; + return false; + } + if (Count_ && !SetOption(fd, IPPROTO_TCP, TCP_KEEPCNT, Count_)) { + std::cerr << std::format("Failed to set TCP_KEEPCNT option: {}", strerror(errno)) << std::endl; + return false; + } + if (Interval_ && !SetOption(fd, IPPROTO_TCP, TCP_KEEPINTVL, Interval_)) { + std::cerr << std::format("Failed to set TCP_KEEPINTVL option: {}", strerror(errno)) << std::endl; + return false; + } +#endif } - if (Interval_ && !SetOption(fd, IPPROTO_TCP, TCP_KEEPINTVL, Interval_)) { - std::cerr << std::format("Failed to set TCP_KEEPINTVL option: {}", strerror(errno)) << std::endl; + + if (!SetOption(fd, IPPROTO_TCP, TCP_NODELAY, static_cast(TcpNoDelay_))) { + std::cerr << std::format("Failed to set TCP_NODELAY option: {}", strerror(errno)) << std::endl; return false; } -#endif + return true; } static bool Mutate(int fd, grpc_socket_mutator* mutator) { @@ -83,8 +93,8 @@ class TGRpcKeepAliveSocketMutator : public grpc_socket_mutator { static int Compare(grpc_socket_mutator* a, grpc_socket_mutator* b) { const auto* selfA = Cast(a); const auto* selfB = Cast(b); - auto tupleA = std::make_tuple(selfA->Idle_, selfA->Count_, selfA->Interval_); - auto tupleB = std::make_tuple(selfB->Idle_, selfB->Count_, selfB->Interval_); + auto tupleA = std::make_tuple(selfA->IsKeepAliveEnabled_, selfA->Idle_, selfA->Count_, selfA->Interval_, selfA->TcpNoDelay_); + auto tupleB = std::make_tuple(selfB->IsKeepAliveEnabled_, selfB->Idle_, selfB->Count_, selfB->Interval_, selfB->TcpNoDelay_); return tupleA < tupleB ? -1 : tupleA > tupleB ? 1 : 0; } static void Destroy(grpc_socket_mutator* mutator) { @@ -96,17 +106,19 @@ class TGRpcKeepAliveSocketMutator : public grpc_socket_mutator { } static grpc_socket_mutator_vtable VTable; + const bool IsKeepAliveEnabled_; const int Idle_; const int Count_; const int Interval_; + const bool TcpNoDelay_; }; -grpc_socket_mutator_vtable TGRpcKeepAliveSocketMutator::VTable = +grpc_socket_mutator_vtable TGRpcSocketMutator::VTable = { - &TGRpcKeepAliveSocketMutator::Mutate, - &TGRpcKeepAliveSocketMutator::Compare, - &TGRpcKeepAliveSocketMutator::Destroy, - &TGRpcKeepAliveSocketMutator::Mutate2 + &TGRpcSocketMutator::Mutate, + &TGRpcSocketMutator::Compare, + &TGRpcSocketMutator::Destroy, + &TGRpcSocketMutator::Mutate2 }; #endif @@ -133,8 +145,9 @@ void TGRpcRequestProcessorCommon::GetInitialMetadata(std::unordered_multimapsecond); LastUsedQueue_.emplace(Pool_.at(channelId).GetLastUseTime(), channelId); @@ -611,17 +624,15 @@ void TGRpcClientLow::ForgetContext(TContextImpl* context) { } } -grpc_socket_mutator* NImpl::CreateGRpcKeepAliveSocketMutator(const TTcpKeepAliveSettings& TcpKeepAliveSettings_) { +grpc_socket_mutator* NImpl::CreateGRpcSocketMutator(const TTcpKeepAliveSettings& TcpKeepAliveSettings_, bool tcpNoDelay) { #if !defined(YDB_DISABLE_GRPC_SOCKET_MUTATOR) - TGRpcKeepAliveSocketMutator* mutator = nullptr; - if (TcpKeepAliveSettings_.Enabled) { - mutator = new TGRpcKeepAliveSocketMutator( - TcpKeepAliveSettings_.Idle, - TcpKeepAliveSettings_.Count, - TcpKeepAliveSettings_.Interval - ); - } - return mutator; + return new TGRpcSocketMutator( + TcpKeepAliveSettings_.Enabled, + TcpKeepAliveSettings_.Idle, + TcpKeepAliveSettings_.Count, + TcpKeepAliveSettings_.Interval, + tcpNoDelay + ); #endif return nullptr; } diff --git a/src/library/grpc/client/grpc_client_low.h b/src/library/grpc/client/grpc_client_low.h index ac865d29b00..fb0ae5b55af 100644 --- a/src/library/grpc/client/grpc_client_low.h +++ b/src/library/grpc/client/grpc_client_low.h @@ -425,10 +425,10 @@ class IStreamRequestReadWriteProcessor : public IStreamRequestReadProcessor cb); @@ -511,6 +511,7 @@ class TChannelPool { std::unordered_map Pool_; std::multimap LastUsedQueue_; [[maybe_unused]] TTcpKeepAliveSettings TcpKeepAliveSettings_; + [[maybe_unused]] bool TcpNoDelay_; TDuration ExpireTime_; TDuration UpdateReUseTime_; void EraseFromQueueByTime(const TInstant& lastUseTime, const std::string& channelId); @@ -1367,8 +1368,8 @@ class TGRpcClientLow } template - std::unique_ptr> CreateGRpcServiceConnection(const TGRpcClientConfig& config, const TTcpKeepAliveSettings& keepAlive) { - auto mutator = NImpl::CreateGRpcKeepAliveSocketMutator(keepAlive); + std::unique_ptr> CreateGRpcServiceConnection(const TGRpcClientConfig& config, const TTcpKeepAliveSettings& keepAlive, bool tcpNoDelay = true) { + auto mutator = NImpl::CreateGRpcSocketMutator(keepAlive, tcpNoDelay); // will be destroyed inside grpc return std::unique_ptr>(new TServiceConnection(CreateChannelInterface(config, mutator), this)); } diff --git a/tests/integration/topic/basic_usage_it.cpp b/tests/integration/topic/basic_usage_it.cpp index 585c59bcb8b..d4e05d0ce47 100644 --- a/tests/integration/topic/basic_usage_it.cpp +++ b/tests/integration/topic/basic_usage_it.cpp @@ -15,6 +15,7 @@ #include #include +#include namespace NYdb::inline V3::NPersQueue::NTests { @@ -39,6 +40,83 @@ std::uint64_t TSimpleWriteSessionTestAdapter::GetAcquiredMessagesCount() const { return 0; } +class TKeyedWriteSessionTestAdapter { +public: + TKeyedWriteSessionTestAdapter(NTopic::IKeyedWriteSession* session); + + void WaitForAcks(size_t count, TDuration timeout); + std::optional GetContinuationToken(TDuration timeout); + size_t GetAckedSeqNosCount() const; + bool ValidateAcksOrder() const; + +private: + void RunEventLoop(TDuration timeout, size_t stopOnAcksCount, bool stopOnContinuationToken = false); + + NTopic::IKeyedWriteSession* Session; + std::queue tokens; + std::vector ackedSeqNos; +}; + +TKeyedWriteSessionTestAdapter::TKeyedWriteSessionTestAdapter(NTopic::IKeyedWriteSession* session) + : Session(session) +{} + +void TKeyedWriteSessionTestAdapter::WaitForAcks(size_t count, TDuration timeout) { + RunEventLoop(timeout, count, false); +} + +bool TKeyedWriteSessionTestAdapter::ValidateAcksOrder() const { + size_t expectedSeqNo = 1; + for (const auto& seqNo : ackedSeqNos) { + if (seqNo != expectedSeqNo) { + return false; + } + expectedSeqNo++; + } + return true; +} + +size_t TKeyedWriteSessionTestAdapter::GetAckedSeqNosCount() const { + return ackedSeqNos.size(); +} + +std::optional TKeyedWriteSessionTestAdapter::GetContinuationToken(TDuration timeout) { + RunEventLoop(timeout, 0, true); + if (tokens.empty()) { + return std::nullopt; + } + auto token = std::move(tokens.front()); + tokens.pop(); + return token; +} + +void TKeyedWriteSessionTestAdapter::RunEventLoop(TDuration timeout, size_t stopOnAcksCount, bool stopOnContinuationToken) { + auto deadline = TInstant::Now() + timeout; + while (TInstant::Now() < deadline) { + Session->WaitEvent().Wait(deadline); + auto event = Session->GetEvent(false); + if (!event) { + continue; + } + if (auto ev = std::get_if(&*event)) { + tokens.push(std::move(ev->ContinuationToken)); + if (stopOnContinuationToken) { + return; + } + continue; + } + if (auto ev = std::get_if(&*event)) { + for (const auto& ack : ev->Acks) { + ackedSeqNos.push_back(ack.SeqNo); + } + + if (ackedSeqNos.size() >= stopOnAcksCount) { + return; + } + } + } +} + } namespace NYdb::inline V3::NTopic::NTests { @@ -854,6 +932,90 @@ TEST_F(BasicUsage, TEST_NAME(TWriteSession_WriteEncoded_Broken)) { } } +TEST_F(BasicUsage, TEST_NAME(TKeyedWriteSessionBasicWrite_NoAutoPartitioning)) { + // Basic write test for keyed write session. + // Write 10 messages with different keys and check that they are written to different partitions. + // Check that the order of messages is preserved. + constexpr auto TOPIC_NAME = "test-topic-2"; + constexpr auto CONSUMER_NAME = "test-consumer-2"; + + CreateTopic(TOPIC_NAME, CONSUMER_NAME, 5, 5); + + auto driver = MakeDriver(); + TTopicClient client(driver); + + auto describeTopicSettings = TDescribeTopicSettings().IncludeStats(true); + + TKeyedWriteSessionSettings writeSettings; + writeSettings + .Path(GetTopicPath(TOPIC_NAME)) + .Codec(ECodec::RAW); + writeSettings.ProducerIdPrefix(CreateGuidAsString()); + writeSettings.PartitionChooserStrategy(TKeyedWriteSessionSettings::EPartitionChooserStrategy::Hash); + writeSettings.SubSessionIdleTimeout(TDuration::Seconds(30)); + writeSettings.PartitioningKeyHasher([](const std::string_view key) -> std::string { + return std::string{key}; + }); + + auto session = client.CreateKeyedWriteSession(writeSettings); + auto keyedSession = std::dynamic_pointer_cast(session); + + NPersQueue::NTests::TKeyedWriteSessionTestAdapter testAdapter(session.get()); + for (size_t i = 0; i < 100; ++i) { + auto key = CreateGuidAsString(); + auto token = testAdapter.GetContinuationToken(TDuration::Seconds(30)); + ASSERT_TRUE(token.has_value()) << "Timed out waiting for ReadyToAcceptEvent"; + TWriteMessage msg("msg"); + msg.SeqNo(i + 1); + session->Write(std::move(*token), key, std::move(msg)); + } + + testAdapter.WaitForAcks(100, TDuration::Seconds(30)); + ASSERT_TRUE(session->Close(TDuration::Seconds(10))); + ASSERT_EQ(testAdapter.GetAckedSeqNosCount(), 100ull); + ASSERT_TRUE(testAdapter.ValidateAcksOrder()); + + auto after = client.DescribeTopic(GetTopicPath(TOPIC_NAME), describeTopicSettings).GetValueSync(); + ASSERT_TRUE(after.IsSuccess()) << after.GetIssues().ToOneLineString(); + + auto readSettings = TReadSessionSettings() + .ConsumerName(GetConsumerName(CONSUMER_NAME)) + .AppendTopics(GetTopicPath(TOPIC_NAME)) + ; + std::shared_ptr readSession = client.CreateReadSession(readSettings); + std::uint32_t readMessageCount = 0; + while (readMessageCount < 100) { + std::cerr << "Get event on client\n"; + auto event = *readSession->GetEvent(true); + std::visit(TOverloaded { + [&](TReadSessionEvent::TDataReceivedEvent& event) { + readMessageCount += event.GetMessages().size(); + }, + [&](TReadSessionEvent::TCommitOffsetAcknowledgementEvent&) { + FAIL(); + }, + [&](TReadSessionEvent::TStartPartitionSessionEvent& event) { + event.Confirm(); + }, + [&](TReadSessionEvent::TStopPartitionSessionEvent& event) { + event.Confirm(); + }, + [&](TReadSessionEvent::TEndPartitionSessionEvent& event) { + event.Confirm(); + }, + [&](TReadSessionEvent::TPartitionSessionStatusEvent&) { + FAIL() << "Test does not support lock sessions yet"; + }, + [&](TReadSessionEvent::TPartitionSessionClosedEvent&) { + FAIL() << "Test does not support lock sessions yet"; + }, + [&](TSessionClosedEvent&) { + FAIL() << "Session closed"; + } + }, event); + } +} + namespace { enum class EExpectedTestResult { SUCCESS, diff --git a/tests/unit/library/grpc_client/grpc_client_low_ut.cpp b/tests/unit/library/grpc_client/grpc_client_low_ut.cpp index 35a89328cb7..a53ee718334 100644 --- a/tests/unit/library/grpc_client/grpc_client_low_ut.cpp +++ b/tests/unit/library/grpc_client/grpc_client_low_ut.cpp @@ -22,7 +22,7 @@ Y_UNIT_TEST_SUITE(ChannelPoolTests) { 5, // NYdb::TCP_KEEPALIVE_COUNT, unused in UT, but is necessary in constructor 10 // NYdb::TCP_KEEPALIVE_INTERVAL, unused in UT, but is necessary in constructor }; - auto channelPool = TChannelPool(tcpKeepAliveSettings, TDuration::MilliSeconds(250)); + auto channelPool = TChannelPool(tcpKeepAliveSettings, TDuration::MilliSeconds(250), true); std::vector> ChannelInterfacesWeak; {