diff --git a/rabbitmq/include/userver/urabbitmq/client_settings.hpp b/rabbitmq/include/userver/urabbitmq/client_settings.hpp index 5a0122bb10c1..93683c5c4904 100644 --- a/rabbitmq/include/userver/urabbitmq/client_settings.hpp +++ b/rabbitmq/include/userver/urabbitmq/client_settings.hpp @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -74,6 +75,10 @@ struct PoolSettings final { /// (tcp error/protocol error/write timeout) leads to a errors burst: /// all outstanding request will fails at once size_t max_in_flight_requests = 5; + + /// Requested AMQP heartbeat interval in seconds. + /// Set to 0 to disable heartbeats. + size_t heartbeat_interval_seconds = 30; }; class TestsHelper; diff --git a/rabbitmq/include/userver/urabbitmq/typedefs.hpp b/rabbitmq/include/userver/urabbitmq/typedefs.hpp index 4714339b802d..b8d1efc26c0e 100644 --- a/rabbitmq/include/userver/urabbitmq/typedefs.hpp +++ b/rabbitmq/include/userver/urabbitmq/typedefs.hpp @@ -4,6 +4,7 @@ /// @brief Convenient typedefs for RabbitMQ entities. #include +#include #include @@ -67,6 +68,8 @@ enum class MessageType { /// metadata fields. This struct is used to pass messages to the end user, /// hiding the actual AMQP message object implementation. struct ConsumedMessage { + using Headers = std::unordered_map; + struct Metadata { std::string exchange; std::string routingKey; @@ -75,17 +78,21 @@ struct ConsumedMessage { Metadata metadata; std::optional reply_to{}; std::optional correlation_id{}; + Headers headers{}; }; /// @brief Structure holding an AMQP message body along with some of its /// metadata fields. This struct is used to pass messages from the end user, /// hiding the actual AMQP message object implementation. struct Envelope { + using Headers = std::unordered_map; + std::string message; MessageType type; std::optional reply_to{}; std::optional correlation_id{}; std::optional expiration{}; + std::optional headers{}; }; } // namespace urabbitmq diff --git a/rabbitmq/src/tests/publish_consume_rmqtest.cpp b/rabbitmq/src/tests/publish_consume_rmqtest.cpp index 976274cf255c..cac2772257f6 100644 --- a/rabbitmq/src/tests/publish_consume_rmqtest.cpp +++ b/rabbitmq/src/tests/publish_consume_rmqtest.cpp @@ -1,6 +1,9 @@ #include "utils_rmqtest.hpp" #include +#include + +#include #include #include @@ -78,6 +81,41 @@ class ThrowingConsumer final : public urabbitmq::ConsumerBase { engine::ConditionVariable cond_; }; +class MetadataConsumer final : public urabbitmq::ConsumerBase { +public: + using urabbitmq::ConsumerBase::ConsumerBase; + ~MetadataConsumer() override { Stop(); } + + void Process(urabbitmq::ConsumedMessage message) override { + { + auto locked = messages_.Lock(); + locked->emplace_back(std::move(message)); + } + + if (++consumed_ == expected_consumed_) { + event_.Send(); + } + } + + void ExpectConsume(size_t count) { expected_consumed_ = count; } + + std::vector Wait() { + if (expected_consumed_ != 0) { + [[maybe_unused]] auto res = + event_.WaitForEventFor(utest::kMaxTestWaitTime); + } + + auto locked = messages_.Lock(); + return *locked; + } + +private: + concurrent::Variable> messages_; + std::atomic expected_consumed_{0}; + std::atomic consumed_{0}; + engine::SingleConsumerEvent event_; +}; + } // namespace UTEST(Consumer, CreateOnInvalidQueueWorks) { @@ -227,4 +265,145 @@ UTEST(Consumer, ForDifferentQueuesWork) { client->GetAdminChannel(client.GetDeadline()).RemoveQueue(second_queue, client.GetDeadline()); } +UTEST(Consumer, ConsumeMetadataAndHeadersWork) { + ClientWrapper client{}; + client.SetupRmqEntities(); + const urabbitmq::ConsumerSettings settings{client.GetQueue(), 10}; + + struct Case { + std::string name; + std::optional reply_to; + std::optional correlation_id; + urabbitmq::Envelope::Headers headers; + }; + + const std::vector cases{ + {"no-user-headers", std::nullopt, std::nullopt, {}}, + { + "simple-user-headers", + "reply-queue", + "corr-id", + { + {"x-custom-header", "custom-value"}, + {"x-custom-int", "42"}, + }, + }, + { + "many-user-headers", + "reply-many", + "corr-many", + { + {"x-empty", ""}, + {"x-spaces", "a b c"}, + {"x-symbols", R"(!@#$%^&*()[]{}<>/?\\|;:'\",.~-_=+)"}, + {"x-long", std::string(128, 'x')}, + }, + }, + { + "trace-headers-override", + "reply-override", + "corr-override", + { + {"u-trace-id", "trace-from-user"}, + {"u-parent-span-id", "parent-from-user"}, + {"x-another", "value"}, + }, + }, + }; + + for (const auto &case_data : cases) { + urabbitmq::Envelope envelope{ + "payload-" + case_data.name, + urabbitmq::MessageType::kTransient, + }; + envelope.reply_to = case_data.reply_to; + envelope.correlation_id = case_data.correlation_id; + envelope.headers = case_data.headers; + client->PublishReliable(client.GetExchange(), client.GetRoutingKey(), + envelope, client.GetDeadline()); + } + + MetadataConsumer consumer{client.Get(), settings}; + consumer.ExpectConsume(cases.size()); + consumer.Start(); + auto consumed = consumer.Wait(); + + ASSERT_EQ(consumed.size(), cases.size()); + std::unordered_map + consumed_by_payload; + consumed_by_payload.reserve(consumed.size()); + for (const auto &msg : consumed) { + consumed_by_payload.emplace(msg.message, &msg); + } + + for (const auto &case_data : cases) { + const auto payload = "payload-" + case_data.name; + const auto it = consumed_by_payload.find(payload); + ASSERT_NE(it, consumed_by_payload.end()) + << "Missing consumed payload: " << payload; + + const auto &msg = *it->second; + EXPECT_EQ(msg.message, payload); + EXPECT_EQ(msg.metadata.exchange, client.GetExchange().GetUnderlying()); + EXPECT_EQ(msg.metadata.routingKey, client.GetRoutingKey()); + EXPECT_EQ(msg.reply_to, case_data.reply_to); + EXPECT_EQ(msg.correlation_id, case_data.correlation_id); + + for (const auto &[header_key, header_value] : case_data.headers) { + ASSERT_EQ(msg.headers.count(header_key), 1) + << "Missing header '" << header_key << "' in " << payload; + const auto &actual = msg.headers.at(header_key); + EXPECT_NE(actual.find(header_value), std::string::npos) + << "Unexpected value for header '" << header_key << "' in " << payload + << ": " << actual; + } + + ASSERT_EQ(msg.headers.count("u-trace-id"), 1) + << "Missing u-trace-id in " << payload; + ASSERT_EQ(msg.headers.count("u-parent-span-id"), 1) + << "Missing u-parent-span-id in " << payload; + EXPECT_FALSE(msg.headers.at("u-trace-id").empty()); + EXPECT_FALSE(msg.headers.at("u-parent-span-id").empty()); + } +} + +UTEST(Consumer, HeaderFieldStringConversionInvariants) { + // rabbitmq/src/urabbitmq/consumer_base_impl.cpp:112 + AMQP::Array array_field; + array_field.push_back(AMQP::LongString{"arr-string"}); + array_field.push_back(AMQP::Long{123}); + array_field.push_back(AMQP::BooleanSet{true}); + + AMQP::Table nested_table; + nested_table.set("nested-string", "nested-value"); + nested_table.set("nested-int", 7); + + AMQP::Table headers; + headers.set("string", "value"); + headers.set("empty-string", ""); + headers.set("bool-true", true); + headers.set("bool-false", false); + headers.set("uint8", static_cast(255)); + headers.set("int8", static_cast(-100)); + headers.set("uint16", static_cast(65000)); + headers.set("int16", static_cast(-30000)); + headers.set("uint32", static_cast(4000000000U)); + headers.set("int32", static_cast(-2000000000)); + headers.set("uint64", static_cast(9000000000000000000ULL)); + headers.set("int64", static_cast(-9000000000000000000LL)); + headers.set("float", AMQP::Float{3.14f}); + headers.set("double", AMQP::Double{2.718281828}); + headers.set("decimal", AMQP::DecimalField{2, 12345}); + headers.set("void", nullptr); + headers.set("array", array_field); + headers.set("table", nested_table); + + for (const auto &key : headers.keys()) { + EXPECT_NO_THROW({ + [[maybe_unused]] const auto value = std::string(headers.get(key)); + }) << "Failed for key: " + << key; + } +} + USERVER_NAMESPACE_END diff --git a/rabbitmq/src/urabbitmq/client_settings.cpp b/rabbitmq/src/urabbitmq/client_settings.cpp index 85d25ba5adf0..ba7682449243 100644 --- a/rabbitmq/src/urabbitmq/client_settings.cpp +++ b/rabbitmq/src/urabbitmq/client_settings.cpp @@ -1,6 +1,7 @@ #include #include +#include #include #include #include @@ -100,9 +101,13 @@ PoolSettings Parse(const yaml_config::YamlConfig& config, formats::parse::To(result.min_pool_size); result.max_pool_size = config["max_pool_size"].As(result.max_pool_size); result.max_in_flight_requests = config["max_in_flight_requests"].As(result.max_in_flight_requests); + result.heartbeat_interval_seconds = + config["heartbeat_interval_seconds"].As(result.heartbeat_interval_seconds); UINVARIANT(result.min_pool_size <= result.max_pool_size, "max_pool_size is less than min_pool_size"); UINVARIANT(result.max_pool_size > 0, "max_pool_size is set to zero"); + UINVARIANT(result.heartbeat_interval_seconds <= std::numeric_limits::max(), + "heartbeat_interval_seconds is too large"); return result; } diff --git a/rabbitmq/src/urabbitmq/component.yaml b/rabbitmq/src/urabbitmq/component.yaml index cdc35b89df77..757c5500a6c2 100644 --- a/rabbitmq/src/urabbitmq/component.yaml +++ b/rabbitmq/src/urabbitmq/component.yaml @@ -23,6 +23,11 @@ properties: description: | per-connection limit for requests awaiting response from the broker default: 5 + heartbeat_interval_seconds: + type: integer + description: | + requested AMQP heartbeat interval in seconds; 0 disables heartbeats + default: 30 use_secure_connection: type: boolean description: whether to use TLS for connections diff --git a/rabbitmq/src/urabbitmq/connection.cpp b/rabbitmq/src/urabbitmq/connection.cpp index f54e752b5bf4..4d8fcac026ed 100644 --- a/rabbitmq/src/urabbitmq/connection.cpp +++ b/rabbitmq/src/urabbitmq/connection.cpp @@ -9,11 +9,20 @@ Connection::Connection( const EndpointInfo& endpoint, const AuthSettings& auth_settings, size_t max_in_flight_requests, + size_t heartbeat_interval_seconds, bool secure, statistics::ConnectionStatistics& stats, engine::Deadline deadline ) - : handler_{resolver, endpoint, auth_settings, secure, stats, deadline}, + : handler_{ + resolver, + endpoint, + auth_settings, + heartbeat_interval_seconds, + secure, + stats, + deadline, + }, connection_{handler_, max_in_flight_requests, deadline}, channel_{connection_}, reliable_channel_{connection_} diff --git a/rabbitmq/src/urabbitmq/connection.hpp b/rabbitmq/src/urabbitmq/connection.hpp index 372596ffadb9..a53796b948a6 100644 --- a/rabbitmq/src/urabbitmq/connection.hpp +++ b/rabbitmq/src/urabbitmq/connection.hpp @@ -29,6 +29,7 @@ class Connection final { const EndpointInfo& endpoint, const AuthSettings& auth_settings, size_t max_in_flight_requests, + size_t heartbeat_interval_seconds, bool secure, statistics::ConnectionStatistics& stats, engine::Deadline deadline diff --git a/rabbitmq/src/urabbitmq/connection_pool.cpp b/rabbitmq/src/urabbitmq/connection_pool.cpp index e1a9baf3a245..febb66583cc4 100644 --- a/rabbitmq/src/urabbitmq/connection_pool.cpp +++ b/rabbitmq/src/urabbitmq/connection_pool.cpp @@ -92,6 +92,7 @@ ConnectionPool::ConnectionUniquePtr ConnectionPool::DoCreateConnection(engine::D endpoint_info_, auth_settings_, pool_settings_.max_in_flight_requests, + pool_settings_.heartbeat_interval_seconds, use_secure_connection_, stats_, deadline diff --git a/rabbitmq/src/urabbitmq/consumer_base_impl.cpp b/rabbitmq/src/urabbitmq/consumer_base_impl.cpp index 3bb22d1ea41b..1d5723b9b724 100644 --- a/rabbitmq/src/urabbitmq/consumer_base_impl.cpp +++ b/rabbitmq/src/urabbitmq/consumer_base_impl.cpp @@ -93,18 +93,26 @@ void ConsumerBaseImpl::Stop() { bool ConsumerBaseImpl::IsBroken() const { return broken_ || !connection_ptr_.IsUsable(); } void ConsumerBaseImpl::OnMessage(const AMQP::Message& message, uint64_t delivery_tag) { - std::string span_name{fmt::format("consume_{}_{}", queue_name_, consumer_tag_.value_or("ctag:unknown"))}; - std::string trace_id = message.headers().get("u-trace-id"); - std::string parent_span_id = message.headers().get("u-parent-span-id"); - ConsumedMessage consumed; - consumed.message = std::string(message.body(), message.bodySize()); - consumed.metadata.exchange = message.exchange(); - consumed.metadata.routingKey = message.routingkey(); - if (message.hasReplyTo()) { - consumed.reply_to = message.replyTo(); - } + const auto &headers = message.headers(); + std::string span_name{fmt::format("consume_{}_{}", queue_name_, + consumer_tag_.value_or("ctag:unknown"))}; + std::string trace_id = headers.get("u-trace-id"); + std::string parent_span_id = headers.get("u-parent-span-id"); + ConsumedMessage consumed; + consumed.message = std::string(message.body(), message.bodySize()); + consumed.metadata.exchange = message.exchange(); + consumed.metadata.routingKey = message.routingkey(); + if (message.hasReplyTo()) { + consumed.reply_to = message.replyTo(); + } if (message.hasCorrelationID()) { - consumed.correlation_id = message.correlationID(); + consumed.correlation_id = message.correlationID(); + } + + const auto keys = headers.keys(); + consumed.headers.reserve(keys.size()); + for (const auto &key : keys) { + consumed.headers.emplace(key, std::string(headers.get(key))); } bts_.Detach(engine::AsyncNoSpan( diff --git a/rabbitmq/src/urabbitmq/impl/amqp_channel.cpp b/rabbitmq/src/urabbitmq/impl/amqp_channel.cpp index c374dc41bd5d..d6aa45d5b2ef 100644 --- a/rabbitmq/src/urabbitmq/impl/amqp_channel.cpp +++ b/rabbitmq/src/urabbitmq/impl/amqp_channel.cpp @@ -94,6 +94,17 @@ AMQP::Table CreateHeaders() { return headers; } +AMQP::Table CreateHeadersForPublish(const Envelope& envelope) { + auto headers = CreateHeaders(); + if (envelope.headers.has_value()) { + for (const auto& [key, value] : envelope.headers.value()) { + headers[key] = value; + } + } + + return headers; +} + } // namespace AmqpChannel::AmqpChannel(AmqpConnection& conn) @@ -196,7 +207,7 @@ void AmqpChannel::Publish( ) { AMQP::Envelope native_envelope{envelope.message.data(), envelope.message.size()}; native_envelope.setPersistent(envelope.type == MessageType::kPersistent); - native_envelope.setHeaders(CreateHeaders()); + native_envelope.setHeaders(CreateHeadersForPublish(envelope)); if (envelope.reply_to.has_value()) { native_envelope.setReplyTo(envelope.reply_to.value().c_str()); } @@ -285,7 +296,7 @@ ResponseAwaiter AmqpReliableChannel::Publish( if (envelope.expiration.has_value()) { native_envelope.setExpiration(std::to_string(envelope.expiration.value().count())); } - native_envelope.setHeaders(CreateHeaders()); + native_envelope.setHeaders(CreateHeadersForPublish(envelope)); auto awaiter = conn_.GetAwaiter(deadline); diff --git a/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.cpp b/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.cpp index e0857e5510da..68ddded55c44 100644 --- a/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.cpp +++ b/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.cpp @@ -1,5 +1,7 @@ #include "amqp_connection_handler.hpp" +#include +#include #include #include @@ -11,6 +13,7 @@ #include #include +#include #include USERVER_NAMESPACE_BEGIN @@ -18,6 +21,7 @@ USERVER_NAMESPACE_BEGIN namespace urabbitmq::impl { namespace { +constexpr std::chrono::milliseconds kHeartbeatSendTimeout{200}; engine::io::Socket CreateSocket(engine::io::Sockaddr& addr, engine::Deadline deadline) { engine::io::Socket socket{addr.Domain(), engine::io::SocketType::kTcp}; @@ -91,6 +95,7 @@ AmqpConnectionHandler::AmqpConnectionHandler( clients::dns::Resolver& resolver, const EndpointInfo& endpoint, const AuthSettings& auth_settings, + size_t heartbeat_interval_seconds, bool secure, statistics::ConnectionStatistics& stats, engine::Deadline deadline @@ -98,10 +103,15 @@ AmqpConnectionHandler::AmqpConnectionHandler( : address_{ToAmqpAddress(endpoint, auth_settings, secure)}, socket_{CreateSocketPtr(resolver, address_, auth_settings, deadline)}, reader_{*this, *socket_}, + configured_heartbeat_seconds_{ + static_cast(std::min(heartbeat_interval_seconds, std::numeric_limits::max()))}, stats_{stats} {} -AmqpConnectionHandler::~AmqpConnectionHandler() { reader_.Stop(); } +AmqpConnectionHandler::~AmqpConnectionHandler() { + heartbeat_task_.Stop(); + reader_.Stop(); +} void AmqpConnectionHandler::onProperties(AMQP::Connection*, const AMQP::Table&, AMQP::Table& client) { client["product"] = "uServer AMQP library"; @@ -109,6 +119,18 @@ void AmqpConnectionHandler::onProperties(AMQP::Connection*, const AMQP::Table&, client["information"] = "https://userver.tech/dd/de2/rabbitmq_driver.html"; } +uint16_t AmqpConnectionHandler::onNegotiate(AMQP::Connection*, uint16_t interval) { + if (interval == 0 || configured_heartbeat_seconds_ == 0) { + negotiated_heartbeat_seconds_.store(0, std::memory_order_relaxed); + return 0; + } + + const auto negotiated = static_cast(std::min(interval, configured_heartbeat_seconds_)); + negotiated_heartbeat_seconds_.store(negotiated, std::memory_order_relaxed); + LOG_INFO() << "RabbitMQ heartbeat negotiated at " << negotiated << "s"; + return negotiated; +} + void AmqpConnectionHandler::onData(AMQP::Connection* connection, const char* buffer, size_t size) { if (IsBroken()) { // No further actions can be done @@ -160,6 +182,7 @@ void AmqpConnectionHandler::onReady(AMQP::Connection*) { } void AmqpConnectionHandler::OnConnectionCreated(AmqpConnection* connection, engine::Deadline deadline) { + connection_ = connection; reader_.Start(connection); if (!connection_ready_event_.WaitForEventUntil(deadline)) { @@ -169,11 +192,22 @@ void AmqpConnectionHandler::OnConnectionCreated(AmqpConnection* connection, engi if (error_.has_value()) { reader_.Stop(); + connection_ = nullptr; throw ConnectionSetupError{"Failed to setup a connection: " + *error_}; } + + const auto heartbeat_seconds = negotiated_heartbeat_seconds_.load(std::memory_order_relaxed); + if (heartbeat_seconds > 0) { + const auto heartbeat_period = std::chrono::seconds{std::max(1, heartbeat_seconds / 2)}; + heartbeat_task_.Start("amqp_heartbeat", {heartbeat_period}, [this] { SendHeartbeat(); }); + } } -void AmqpConnectionHandler::OnConnectionDestruction() { reader_.Stop(); } +void AmqpConnectionHandler::OnConnectionDestruction() { + heartbeat_task_.Stop(); + connection_ = nullptr; + reader_.Stop(); +} void AmqpConnectionHandler::Invalidate() { broken_ = true; } @@ -189,6 +223,26 @@ statistics::ConnectionStatistics& AmqpConnectionHandler::GetStatistics() { retur const AMQP::Address& AmqpConnectionHandler::GetAddress() const { return address_; } +void AmqpConnectionHandler::SendHeartbeat() { + if (IsBroken() || connection_ == nullptr) { + return; + } + + try { + const auto deadline = engine::Deadline::FromDuration(kHeartbeatSendTimeout); + auto lock = AmqpConnectionLocker{*connection_}.Lock(deadline); + connection_->SetOperationDeadline(deadline); + connection_->GetNative().heartbeat(); + } catch (const std::exception& ex) { + LOG_WARNING() << "Failed to send AMQP heartbeat: " << ex.what(); + Invalidate(); + if (connection_ != nullptr) { + auto lock = AmqpConnectionLocker{*connection_}.Lock({}); + connection_->GetNative().fail("Underlying connection broke."); + } + } +} + } // namespace urabbitmq::impl USERVER_NAMESPACE_END diff --git a/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.hpp b/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.hpp index 31b7a8a88806..5f8af1c347a1 100644 --- a/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.hpp +++ b/rabbitmq/src/urabbitmq/impl/amqp_connection_handler.hpp @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -7,6 +8,7 @@ #include #include +#include #include @@ -45,6 +47,7 @@ class AmqpConnectionHandler final : public AMQP::ConnectionHandler { clients::dns::Resolver& resolver, const EndpointInfo& endpoint, const AuthSettings& auth_settings, + size_t heartbeat_interval_seconds, bool secure, statistics::ConnectionStatistics& stats, engine::Deadline deadline @@ -52,6 +55,7 @@ class AmqpConnectionHandler final : public AMQP::ConnectionHandler { ~AmqpConnectionHandler() override; void onProperties(AMQP::Connection* connection, const AMQP::Table& server, AMQP::Table& client) override; + uint16_t onNegotiate(AMQP::Connection* connection, uint16_t interval) override; void onData(AMQP::Connection* connection, const char* buffer, size_t size) override; @@ -77,9 +81,15 @@ class AmqpConnectionHandler final : public AMQP::ConnectionHandler { const AMQP::Address& GetAddress() const; private: + void SendHeartbeat(); + AMQP::Address address_; std::unique_ptr socket_; io::SocketReader reader_; + utils::PeriodicTask heartbeat_task_; + AmqpConnection* connection_{nullptr}; + std::atomic negotiated_heartbeat_seconds_{0}; + uint16_t configured_heartbeat_seconds_{0}; engine::SingleConsumerEvent connection_ready_event_; std::atomic broken_{false};