Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions rabbitmq/include/userver/urabbitmq/client_settings.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include <chrono>
#include <cstddef>
#include <optional>
#include <string>
Expand Down Expand Up @@ -74,6 +75,10 @@ struct PoolSettings final {
/// (tcp error/protocol error/write timeout) leads to a errors burst:
/// all outstanding request will fails at once
size_t max_in_flight_requests = 5;

/// Requested AMQP heartbeat interval in seconds.
/// Set to 0 to disable heartbeats.
size_t heartbeat_interval_seconds = 30;
};

class TestsHelper;
Expand Down
7 changes: 7 additions & 0 deletions rabbitmq/include/userver/urabbitmq/typedefs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
/// @brief Convenient typedefs for RabbitMQ entities.

#include <chrono>
#include <unordered_map>

#include <userver/utils/strong_typedef.hpp>

Expand Down Expand Up @@ -67,6 +68,8 @@ enum class MessageType {
/// metadata fields. This struct is used to pass messages to the end user,
/// hiding the actual AMQP message object implementation.
struct ConsumedMessage {
using Headers = std::unordered_map<std::string, std::string>;

struct Metadata {
std::string exchange;
std::string routingKey;
Expand All @@ -75,17 +78,21 @@ struct ConsumedMessage {
Metadata metadata;
std::optional<std::string> reply_to{};
std::optional<std::string> correlation_id{};
Headers headers{};
};

/// @brief Structure holding an AMQP message body along with some of its
/// metadata fields. This struct is used to pass messages from the end user,
/// hiding the actual AMQP message object implementation.
struct Envelope {
using Headers = std::unordered_map<std::string, std::string>;

std::string message;
MessageType type;
std::optional<std::string> reply_to{};
std::optional<std::string> correlation_id{};
std::optional<std::chrono::milliseconds> expiration{};
std::optional<Headers> headers{};
};

} // namespace urabbitmq
Expand Down
5 changes: 5 additions & 0 deletions rabbitmq/src/urabbitmq/client_settings.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <userver/urabbitmq/client_settings.hpp>

#include <stdexcept>
#include <limits>
#include <string>
#include <unordered_map>
#include <unordered_set>
Expand Down Expand Up @@ -100,9 +101,13 @@ PoolSettings Parse(const yaml_config::YamlConfig& config, formats::parse::To<Poo
result.min_pool_size = config["min_pool_size"].As<size_t>(result.min_pool_size);
result.max_pool_size = config["max_pool_size"].As<size_t>(result.max_pool_size);
result.max_in_flight_requests = config["max_in_flight_requests"].As<size_t>(result.max_in_flight_requests);
result.heartbeat_interval_seconds =
config["heartbeat_interval_seconds"].As<size_t>(result.heartbeat_interval_seconds);

UINVARIANT(result.min_pool_size <= result.max_pool_size, "max_pool_size is less than min_pool_size");
UINVARIANT(result.max_pool_size > 0, "max_pool_size is set to zero");
UINVARIANT(result.heartbeat_interval_seconds <= std::numeric_limits<uint16_t>::max(),
"heartbeat_interval_seconds is too large");

return result;
}
Expand Down
5 changes: 5 additions & 0 deletions rabbitmq/src/urabbitmq/component.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ properties:
description: |
per-connection limit for requests awaiting response from the broker
default: 5
heartbeat_interval_seconds:
type: integer
description: |
requested AMQP heartbeat interval in seconds; 0 disables heartbeats
default: 30
use_secure_connection:
type: boolean
description: whether to use TLS for connections
Expand Down
11 changes: 10 additions & 1 deletion rabbitmq/src/urabbitmq/connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,20 @@ Connection::Connection(
const EndpointInfo& endpoint,
const AuthSettings& auth_settings,
size_t max_in_flight_requests,
size_t heartbeat_interval_seconds,
bool secure,
statistics::ConnectionStatistics& stats,
engine::Deadline deadline
)
: handler_{resolver, endpoint, auth_settings, secure, stats, deadline},
: handler_{
resolver,
endpoint,
auth_settings,
heartbeat_interval_seconds,
secure,
stats,
deadline,
},
connection_{handler_, max_in_flight_requests, deadline},
channel_{connection_},
reliable_channel_{connection_}
Expand Down
1 change: 1 addition & 0 deletions rabbitmq/src/urabbitmq/connection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class Connection final {
const EndpointInfo& endpoint,
const AuthSettings& auth_settings,
size_t max_in_flight_requests,
size_t heartbeat_interval_seconds,
bool secure,
statistics::ConnectionStatistics& stats,
engine::Deadline deadline
Expand Down
1 change: 1 addition & 0 deletions rabbitmq/src/urabbitmq/connection_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ ConnectionPool::ConnectionUniquePtr ConnectionPool::DoCreateConnection(engine::D
endpoint_info_,
auth_settings_,
pool_settings_.max_in_flight_requests,
pool_settings_.heartbeat_interval_seconds,
use_secure_connection_,
stats_,
deadline
Expand Down
7 changes: 7 additions & 0 deletions rabbitmq/src/urabbitmq/consumer_base_impl.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "consumer_base_impl.hpp"

#include <sstream>
#include <string>

#include <fmt/format.h>
Expand Down Expand Up @@ -106,6 +107,12 @@ void ConsumerBaseImpl::OnMessage(const AMQP::Message& message, uint64_t delivery
if (message.hasCorrelationID()) {
consumed.correlation_id = message.correlationID();
}
const auto& headers = message.headers();
for (const auto& key : headers.keys()) {
std::ostringstream stream;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Image

stream << headers.get(key);
consumed.headers.emplace(key, stream.str());
}

bts_.Detach(engine::AsyncNoSpan(
dispatcher_,
Expand Down
15 changes: 13 additions & 2 deletions rabbitmq/src/urabbitmq/impl/amqp_channel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,17 @@ AMQP::Table CreateHeaders() {
return headers;
}

AMQP::Table CreateHeadersForPublish(const Envelope& envelope) {
auto headers = CreateHeaders();
if (envelope.headers.has_value()) {
for (const auto& [key, value] : envelope.headers.value()) {
headers[key] = value;
}
}

return headers;
}

} // namespace

AmqpChannel::AmqpChannel(AmqpConnection& conn)
Expand Down Expand Up @@ -196,7 +207,7 @@ void AmqpChannel::Publish(
) {
AMQP::Envelope native_envelope{envelope.message.data(), envelope.message.size()};
native_envelope.setPersistent(envelope.type == MessageType::kPersistent);
native_envelope.setHeaders(CreateHeaders());
native_envelope.setHeaders(CreateHeadersForPublish(envelope));
if (envelope.reply_to.has_value()) {
native_envelope.setReplyTo(envelope.reply_to.value().c_str());
}
Expand Down Expand Up @@ -285,7 +296,7 @@ ResponseAwaiter AmqpReliableChannel::Publish(
if (envelope.expiration.has_value()) {
native_envelope.setExpiration(std::to_string(envelope.expiration.value().count()));
}
native_envelope.setHeaders(CreateHeaders());
native_envelope.setHeaders(CreateHeadersForPublish(envelope));

auto awaiter = conn_.GetAwaiter(deadline);

Expand Down
58 changes: 56 additions & 2 deletions rabbitmq/src/urabbitmq/impl/amqp_connection_handler.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "amqp_connection_handler.hpp"

#include <algorithm>
#include <limits>
#include <netinet/in.h>
#include <netinet/tcp.h>

Expand All @@ -11,13 +13,15 @@
#include <userver/urabbitmq/client_settings.hpp>
#include <userver/utils/assert.hpp>

#include <urabbitmq/impl/amqp_connection.hpp>
#include <urabbitmq/statistics/connection_statistics.hpp>

USERVER_NAMESPACE_BEGIN

namespace urabbitmq::impl {

namespace {
constexpr std::chrono::milliseconds kHeartbeatSendTimeout{200};

engine::io::Socket CreateSocket(engine::io::Sockaddr& addr, engine::Deadline deadline) {
engine::io::Socket socket{addr.Domain(), engine::io::SocketType::kTcp};
Expand Down Expand Up @@ -91,24 +95,42 @@ AmqpConnectionHandler::AmqpConnectionHandler(
clients::dns::Resolver& resolver,
const EndpointInfo& endpoint,
const AuthSettings& auth_settings,
size_t heartbeat_interval_seconds,
bool secure,
statistics::ConnectionStatistics& stats,
engine::Deadline deadline
)
: address_{ToAmqpAddress(endpoint, auth_settings, secure)},
socket_{CreateSocketPtr(resolver, address_, auth_settings, deadline)},
reader_{*this, *socket_},
configured_heartbeat_seconds_{
static_cast<uint16_t>(std::min<size_t>(heartbeat_interval_seconds, std::numeric_limits<uint16_t>::max()))},
stats_{stats}
{}

AmqpConnectionHandler::~AmqpConnectionHandler() { reader_.Stop(); }
AmqpConnectionHandler::~AmqpConnectionHandler() {
heartbeat_task_.Stop();
reader_.Stop();
}

void AmqpConnectionHandler::onProperties(AMQP::Connection*, const AMQP::Table&, AMQP::Table& client) {
client["product"] = "uServer AMQP library";
client["copyright"] = "Copyright 2022-2022 Yandex NV";
client["information"] = "https://userver.tech/dd/de2/rabbitmq_driver.html";
}

uint16_t AmqpConnectionHandler::onNegotiate(AMQP::Connection*, uint16_t interval) {
if (interval == 0 || configured_heartbeat_seconds_ == 0) {
negotiated_heartbeat_seconds_.store(0, std::memory_order_relaxed);
return 0;
}

const auto negotiated = static_cast<uint16_t>(std::min<uint16_t>(interval, configured_heartbeat_seconds_));
negotiated_heartbeat_seconds_.store(negotiated, std::memory_order_relaxed);
LOG_INFO() << "RabbitMQ heartbeat negotiated at " << negotiated << "s";
return negotiated;
}

void AmqpConnectionHandler::onData(AMQP::Connection* connection, const char* buffer, size_t size) {
if (IsBroken()) {
// No further actions can be done
Expand Down Expand Up @@ -160,6 +182,7 @@ void AmqpConnectionHandler::onReady(AMQP::Connection*) {
}

void AmqpConnectionHandler::OnConnectionCreated(AmqpConnection* connection, engine::Deadline deadline) {
connection_ = connection;
reader_.Start(connection);

if (!connection_ready_event_.WaitForEventUntil(deadline)) {
Expand All @@ -169,11 +192,22 @@ void AmqpConnectionHandler::OnConnectionCreated(AmqpConnection* connection, engi

if (error_.has_value()) {
reader_.Stop();
connection_ = nullptr;
throw ConnectionSetupError{"Failed to setup a connection: " + *error_};
}

const auto heartbeat_seconds = negotiated_heartbeat_seconds_.load(std::memory_order_relaxed);
if (heartbeat_seconds > 0) {
const auto heartbeat_period = std::chrono::seconds{std::max<uint16_t>(1, heartbeat_seconds / 2)};
heartbeat_task_.Start("amqp_heartbeat", {heartbeat_period}, [this] { SendHeartbeat(); });
}
}

void AmqpConnectionHandler::OnConnectionDestruction() { reader_.Stop(); }
void AmqpConnectionHandler::OnConnectionDestruction() {
heartbeat_task_.Stop();
connection_ = nullptr;
reader_.Stop();
}

void AmqpConnectionHandler::Invalidate() { broken_ = true; }

Expand All @@ -189,6 +223,26 @@ statistics::ConnectionStatistics& AmqpConnectionHandler::GetStatistics() { retur

const AMQP::Address& AmqpConnectionHandler::GetAddress() const { return address_; }

void AmqpConnectionHandler::SendHeartbeat() {
if (IsBroken() || connection_ == nullptr) {
return;
}

try {
const auto deadline = engine::Deadline::FromDuration(kHeartbeatSendTimeout);
auto lock = AmqpConnectionLocker{*connection_}.Lock(deadline);
connection_->SetOperationDeadline(deadline);
connection_->GetNative().heartbeat();
} catch (const std::exception& ex) {
LOG_WARNING() << "Failed to send AMQP heartbeat: " << ex.what();
Invalidate();
if (connection_ != nullptr) {
auto lock = AmqpConnectionLocker{*connection_}.Lock({});
connection_->GetNative().fail("Underlying connection broke.");
}
}
}

} // namespace urabbitmq::impl

USERVER_NAMESPACE_END
10 changes: 10 additions & 0 deletions rabbitmq/src/urabbitmq/impl/amqp_connection_handler.hpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
#pragma once

#include <cstddef>
#include <memory>
#include <optional>
#include <stdexcept>
#include <string>

#include <userver/clients/dns/resolver_fwd.hpp>
#include <userver/engine/single_consumer_event.hpp>
#include <userver/utils/periodic_task.hpp>

#include <urabbitmq/impl/io/socket_reader.hpp>

Expand Down Expand Up @@ -45,13 +47,15 @@ class AmqpConnectionHandler final : public AMQP::ConnectionHandler {
clients::dns::Resolver& resolver,
const EndpointInfo& endpoint,
const AuthSettings& auth_settings,
size_t heartbeat_interval_seconds,
bool secure,
statistics::ConnectionStatistics& stats,
engine::Deadline deadline
);
~AmqpConnectionHandler() override;

void onProperties(AMQP::Connection* connection, const AMQP::Table& server, AMQP::Table& client) override;
uint16_t onNegotiate(AMQP::Connection* connection, uint16_t interval) override;

void onData(AMQP::Connection* connection, const char* buffer, size_t size) override;

Expand All @@ -77,9 +81,15 @@ class AmqpConnectionHandler final : public AMQP::ConnectionHandler {
const AMQP::Address& GetAddress() const;

private:
void SendHeartbeat();

AMQP::Address address_;
std::unique_ptr<engine::io::RwBase> socket_;
io::SocketReader reader_;
utils::PeriodicTask heartbeat_task_;
AmqpConnection* connection_{nullptr};
std::atomic<uint16_t> negotiated_heartbeat_seconds_{0};
uint16_t configured_heartbeat_seconds_{0};

engine::SingleConsumerEvent connection_ready_event_;
std::atomic<bool> broken_{false};
Expand Down