Skip to content
Open
41 changes: 41 additions & 0 deletions core/include/userver/engine/io/sockaddr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ class AddrException : public std::runtime_error {
using std::runtime_error::runtime_error;
};

/// Multicast request related exceptions
class IpMulticastRequestException : public std::runtime_error {
public:
using std::runtime_error::runtime_error;
};

/// Communication domain
enum class AddrDomain {
kUnspecified = AF_UNSPEC, ///< Unspecified
Expand All @@ -39,6 +45,41 @@ static_assert(
"Your socket subsystem looks broken, please contact support chat."
);

/// Native ip multicast request wrapper
class IpMreq final {
public:
/// @brief Creates multicast request. IP version is chosen automatically from ip_multiaddr value.
/// @param ip_multiaddr IP multicast group address (e.g. 239.255.0.1" or "ff02::1")
/// @param interface_index Interface index (0 for default);
IpMreq(const char* ip_multiaddr, unsigned int interface_index);

/// @brief Native multicast request structure pointer.
void* Data() { return &data_; }

/// @brief Native multicast request structure pointer.
const void* Data() const { return &data_; }

/// @brief Returns socket option level.
int GetSocketOptionLevel() const noexcept { return (family_ == AF_INET ? IPPROTO_IP : IPPROTO_IPV6); }

/// @brief Returns socket option name for joining multicast group.
int GetJoinSocketOptionName() const noexcept { return (family_ == AF_INET ? IP_ADD_MEMBERSHIP : IPV6_JOIN_GROUP); }

/// @brief Returns socket option name for leaving multicast group.
int GetLeaveSocketOption() const noexcept { return (family_ == AF_INET ? IP_DROP_MEMBERSHIP : IPV6_LEAVE_GROUP); }

/// Returns appropriate size for setsockopt based on address family.
/// @param domain Socket domain (AF_INET or AF_INET6)
size_t Size() const noexcept { return (family_ == AF_INET ? sizeof(struct ip_mreqn) : sizeof(struct ipv6_mreq)); }

private:
union Storage {
struct ip_mreqn ip_req;
struct ipv6_mreq ipv6_req;
} data_;
int family_;
};

/// Native socket address wrapper
class Sockaddr final {
public:
Expand Down
10 changes: 10 additions & 0 deletions core/include/userver/engine/io/socket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ class [[nodiscard]] Socket final : public RwBase {
/// Starts listening for connections on a specified socket (must be bound).
void Listen(int backlog = SOMAXCONN);

/// @brief Joins multicast group to receive multicast datagrams.
/// @snippet src/engine/io/socket_test.cpp multicast socket creation sample
void AddMembership(const IpMreq& mreq);

/// @brief Leaves multicast group previously joined with AddMembership.
void DropMembership(const IpMreq& mreq);

/// Suspends current task until the socket has data available.
/// @returns false on timeout or on task cancellations; true otherwise.
[[nodiscard]] bool WaitReadable(Deadline) override;
Expand Down Expand Up @@ -154,6 +161,9 @@ class [[nodiscard]] Socket final : public RwBase {
/// Sets a socket option.
void SetOption(int layer, int optname, int optval);

/// Sets a socket option with non-trivial optval.
void SetOption(int layer, int optname, const void* optval, socklen_t optlen);

/// @brief Receives at least one byte from the socket.
/// @returns 0 if connection is closed on one side and no data could be
/// received any more, received bytes count otherwise.
Expand Down
20 changes: 20 additions & 0 deletions core/src/engine/io/sockaddr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,26 @@ USERVER_NAMESPACE_BEGIN

namespace engine::io {

IpMreq::IpMreq(const char* ip_multiaddr, unsigned int interface_index) {
data_.ip_req = {};
if (inet_pton(AF_INET, ip_multiaddr, &data_.ip_req.imr_multiaddr) == 1) {
// imr_address field is not set since it's not used if imr_ifindex presents
family_ = AF_INET;
data_.ip_req.imr_ifindex = interface_index;
}
else {
data_.ipv6_req = {};
if (inet_pton(AF_INET6, ip_multiaddr, &data_.ipv6_req.ipv6mr_multiaddr) == 1) {
family_ = AF_INET6;
data_.ipv6_req.ipv6mr_interface = interface_index;
} else {
throw IpMulticastRequestException(
fmt::format("Invalid IP address: {}", ip_multiaddr)
);
}
}
}

Sockaddr Sockaddr::MakeUnixSocketAddress(std::string_view path) {
Sockaddr addr;
auto* sa = addr.As<struct sockaddr_un>();
Expand Down
19 changes: 19 additions & 0 deletions core/src/engine/io/socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,14 @@ void Socket::Listen(int backlog) {
IoSystemError>(::listen(Fd(), backlog), "listening on a socket, fd={}, backlog={}", Fd(), backlog);
}

void Socket::AddMembership(const IpMreq& mreq) {
SetOption(mreq.GetSocketOptionLevel(), mreq.GetJoinSocketOptionName(), mreq.Data(), mreq.Size());
}

void Socket::DropMembership(const IpMreq& mreq) {
SetOption(mreq.GetSocketOptionLevel(), mreq.GetLeaveSocketOption(), mreq.Data(), mreq.Size());
}

bool Socket::WaitReadable(Deadline deadline) {
UASSERT(IsValid());
return fd_control_->Read().Wait(deadline);
Expand Down Expand Up @@ -501,6 +509,17 @@ void Socket::SetOption(int layer, int optname, int optval) {
);
}

void Socket::SetOption(int layer, int optname, const void* optval, socklen_t optlen) {
UASSERT(IsValid());
utils::CheckSyscallCustomException<IoSystemError>(
::setsockopt(Fd(), layer, optname, optval, optlen),
"setting socket option {},{} on fd {}",
layer,
optname,
Fd()
);
}

} // namespace engine::io

USERVER_NAMESPACE_END
71 changes: 71 additions & 0 deletions core/src/engine/io/socket_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
#include <array>
#include <cerrno>
#include <cstdlib>
#include <memory>
#include <string_view>
#include <thread>

#include <userver/engine/async.hpp>
#include <userver/engine/condition_variable.hpp>
Expand Down Expand Up @@ -464,6 +466,75 @@ UTEST_MT(Socket, ConcurrentReadWriteUdp, 2) {
/// [send self concurrent]
}

UTEST_MT(Socket, UdpIpMreqMultipleReceiversIPv4, 3) {
const auto deadline = Deadline::FromDuration(utest::kMaxTestWaitTime);

static constexpr uint16_t kPort = 12345;
static constexpr const char* kGroup = "239.255.0.1";
static constexpr int packets_count = 3;

sockaddr_in raw_multiaddr{AF_INET, htons(kPort), {}, {}};
inet_pton(AF_INET, kGroup, &raw_multiaddr.sin_addr);
io::Sockaddr multiaddr(&raw_multiaddr);
io::IpMreq mreq(kGroup, INADDR_ANY);

std::vector<engine::TaskWithResult<void>> tasks;
std::vector<std::shared_ptr<io::Socket>> receivers;
for (int i = 0; i < 2; ++i) {
/// [[multicast socket creation sample]]
const auto& receiver = std::make_shared<io::Socket>(io::AddrDomain::kInet, io::SocketType::kDgram);
int reuse = 1;
receiver->SetOption(SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse));

sockaddr_in any{AF_INET, htons(kPort), {}, {}};
any.sin_addr.s_addr = htonl(INADDR_ANY);
receiver->Bind(io::Sockaddr(&any));
receiver->AddMembership(mreq);
/// [[multicast socket creation sample]]
receivers.emplace_back(receiver);

tasks.push_back(engine::AsyncNoSpan([receiver, deadline] {
char c{};
for (int packet_idx = 0; packet_idx < packets_count; ++packet_idx) {
const auto result = receiver->RecvSomeFrom(&c, 1, deadline);
EXPECT_EQ(result.bytes_received, 1);
EXPECT_EQ(c, 'a' + packet_idx);
}
}));
}

io::Socket sender{io::AddrDomain::kInet, io::SocketType::kDgram};
for (int packet_idx = 0; packet_idx < packets_count; ++packet_idx) {
const char data = 'a' + packet_idx;
EXPECT_EQ(sender.SendAllTo(multiaddr, &data, 1, deadline), 1);
}

for (auto& t : tasks) {
t.Get();
}
tasks.clear();

for (int i = 0; i < 2; ++i) {
const auto& receiver = receivers[i];
receiver->DropMembership(mreq);

tasks.push_back(engine::AsyncNoSpan([receiver] {
auto short_deadline = Deadline::FromDuration(std::chrono::milliseconds(300));
char c{};
const auto result = receiver->RecvSomeFrom(&c, 1, short_deadline);
EXPECT_EQ(result.bytes_received, 0);
}));
}

std::this_thread::sleep_for(std::chrono::milliseconds(100));

char data = 'x';
EXPECT_EQ(sender.SendAllTo(multiaddr, &data, 1, deadline), 1);
for (auto& t : tasks) {
UEXPECT_THROW(t.Get(), io::IoTimeout);
}
}

UTEST(Socket, WriteALot) {
const auto deadline = Deadline::FromDuration(utest::kMaxTestWaitTime);

Expand Down
Loading