diff --git a/core/include/userver/engine/io/sockaddr.hpp b/core/include/userver/engine/io/sockaddr.hpp index 01b14a50454e..c38c1824e47f 100644 --- a/core/include/userver/engine/io/sockaddr.hpp +++ b/core/include/userver/engine/io/sockaddr.hpp @@ -26,6 +26,12 @@ class AddrException : public std::runtime_error { using std::runtime_error::runtime_error; }; +/// Multicast request related exceptions +class IpMulticastRequestException : public std::runtime_error { +public: + using std::runtime_error::runtime_error; +}; + /// Communication domain enum class AddrDomain { kUnspecified = AF_UNSPEC, ///< Unspecified @@ -39,6 +45,41 @@ static_assert( "Your socket subsystem looks broken, please contact support chat." ); +/// Native ip multicast request wrapper +class IpMreq final { +public: + /// @brief Creates multicast request. IP version is chosen automatically from ip_multiaddr value. + /// @param ip_multiaddr IP multicast group address (e.g. 239.255.0.1" or "ff02::1") + /// @param interface_index Interface index (0 for default); + IpMreq(const char* ip_multiaddr, unsigned int interface_index); + + /// @brief Native multicast request structure pointer. + void* Data() { return &data_; } + + /// @brief Native multicast request structure pointer. + const void* Data() const { return &data_; } + + /// @brief Returns socket option level. + int GetSocketOptionLevel() const noexcept { return (family_ == AF_INET ? IPPROTO_IP : IPPROTO_IPV6); } + + /// @brief Returns socket option name for joining multicast group. + int GetJoinSocketOptionName() const noexcept { return (family_ == AF_INET ? IP_ADD_MEMBERSHIP : IPV6_JOIN_GROUP); } + + /// @brief Returns socket option name for leaving multicast group. + int GetLeaveSocketOption() const noexcept { return (family_ == AF_INET ? IP_DROP_MEMBERSHIP : IPV6_LEAVE_GROUP); } + + /// Returns appropriate size for setsockopt based on address family. + /// @param domain Socket domain (AF_INET or AF_INET6) + size_t Size() const noexcept { return (family_ == AF_INET ? sizeof(struct ip_mreqn) : sizeof(struct ipv6_mreq)); } + +private: + union Storage { + struct ip_mreqn ip_req; + struct ipv6_mreq ipv6_req; + } data_; + int family_; +}; + /// Native socket address wrapper class Sockaddr final { public: diff --git a/core/include/userver/engine/io/socket.hpp b/core/include/userver/engine/io/socket.hpp index 3bbfc9ddacb6..68a179d56163 100644 --- a/core/include/userver/engine/io/socket.hpp +++ b/core/include/userver/engine/io/socket.hpp @@ -68,6 +68,13 @@ class [[nodiscard]] Socket final : public RwBase { /// Starts listening for connections on a specified socket (must be bound). void Listen(int backlog = SOMAXCONN); + /// @brief Joins multicast group to receive multicast datagrams. + /// @snippet src/engine/io/socket_test.cpp multicast socket creation sample + void AddMembership(const IpMreq& mreq); + + /// @brief Leaves multicast group previously joined with AddMembership. + void DropMembership(const IpMreq& mreq); + /// Suspends current task until the socket has data available. /// @returns false on timeout or on task cancellations; true otherwise. [[nodiscard]] bool WaitReadable(Deadline) override; @@ -154,6 +161,9 @@ class [[nodiscard]] Socket final : public RwBase { /// Sets a socket option. void SetOption(int layer, int optname, int optval); + /// Sets a socket option with non-trivial optval. + void SetOption(int layer, int optname, const void* optval, socklen_t optlen); + /// @brief Receives at least one byte from the socket. /// @returns 0 if connection is closed on one side and no data could be /// received any more, received bytes count otherwise. diff --git a/core/src/engine/io/sockaddr.cpp b/core/src/engine/io/sockaddr.cpp index d49696e63e44..cee4050d41d4 100644 --- a/core/src/engine/io/sockaddr.cpp +++ b/core/src/engine/io/sockaddr.cpp @@ -18,6 +18,26 @@ USERVER_NAMESPACE_BEGIN namespace engine::io { +IpMreq::IpMreq(const char* ip_multiaddr, unsigned int interface_index) { + data_.ip_req = {}; + if (inet_pton(AF_INET, ip_multiaddr, &data_.ip_req.imr_multiaddr) == 1) { + // imr_address field is not set since it's not used if imr_ifindex presents + family_ = AF_INET; + data_.ip_req.imr_ifindex = interface_index; + } + else { + data_.ipv6_req = {}; + if (inet_pton(AF_INET6, ip_multiaddr, &data_.ipv6_req.ipv6mr_multiaddr) == 1) { + family_ = AF_INET6; + data_.ipv6_req.ipv6mr_interface = interface_index; + } else { + throw IpMulticastRequestException( + fmt::format("Invalid IP address: {}", ip_multiaddr) + ); + } + } +} + Sockaddr Sockaddr::MakeUnixSocketAddress(std::string_view path) { Sockaddr addr; auto* sa = addr.As(); diff --git a/core/src/engine/io/socket.cpp b/core/src/engine/io/socket.cpp index e3915eafb1f1..110f9c641652 100644 --- a/core/src/engine/io/socket.cpp +++ b/core/src/engine/io/socket.cpp @@ -233,6 +233,14 @@ void Socket::Listen(int backlog) { IoSystemError>(::listen(Fd(), backlog), "listening on a socket, fd={}, backlog={}", Fd(), backlog); } +void Socket::AddMembership(const IpMreq& mreq) { + SetOption(mreq.GetSocketOptionLevel(), mreq.GetJoinSocketOptionName(), mreq.Data(), mreq.Size()); +} + +void Socket::DropMembership(const IpMreq& mreq) { + SetOption(mreq.GetSocketOptionLevel(), mreq.GetLeaveSocketOption(), mreq.Data(), mreq.Size()); +} + bool Socket::WaitReadable(Deadline deadline) { UASSERT(IsValid()); return fd_control_->Read().Wait(deadline); @@ -501,6 +509,17 @@ void Socket::SetOption(int layer, int optname, int optval) { ); } +void Socket::SetOption(int layer, int optname, const void* optval, socklen_t optlen) { + UASSERT(IsValid()); + utils::CheckSyscallCustomException( + ::setsockopt(Fd(), layer, optname, optval, optlen), + "setting socket option {},{} on fd {}", + layer, + optname, + Fd() + ); +} + } // namespace engine::io USERVER_NAMESPACE_END diff --git a/core/src/engine/io/socket_test.cpp b/core/src/engine/io/socket_test.cpp index 53174e3cdaab..e6ef6eef261f 100644 --- a/core/src/engine/io/socket_test.cpp +++ b/core/src/engine/io/socket_test.cpp @@ -8,7 +8,9 @@ #include #include #include +#include #include +#include #include #include @@ -464,6 +466,75 @@ UTEST_MT(Socket, ConcurrentReadWriteUdp, 2) { /// [send self concurrent] } +UTEST_MT(Socket, UdpIpMreqMultipleReceiversIPv4, 3) { + const auto deadline = Deadline::FromDuration(utest::kMaxTestWaitTime); + + static constexpr uint16_t kPort = 12345; + static constexpr const char* kGroup = "239.255.0.1"; + static constexpr int packets_count = 3; + + sockaddr_in raw_multiaddr{AF_INET, htons(kPort), {}, {}}; + inet_pton(AF_INET, kGroup, &raw_multiaddr.sin_addr); + io::Sockaddr multiaddr(&raw_multiaddr); + io::IpMreq mreq(kGroup, INADDR_ANY); + + std::vector> tasks; + std::vector> receivers; + for (int i = 0; i < 2; ++i) { + /// [[multicast socket creation sample]] + const auto& receiver = std::make_shared(io::AddrDomain::kInet, io::SocketType::kDgram); + int reuse = 1; + receiver->SetOption(SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(reuse)); + + sockaddr_in any{AF_INET, htons(kPort), {}, {}}; + any.sin_addr.s_addr = htonl(INADDR_ANY); + receiver->Bind(io::Sockaddr(&any)); + receiver->AddMembership(mreq); + /// [[multicast socket creation sample]] + receivers.emplace_back(receiver); + + tasks.push_back(engine::AsyncNoSpan([receiver, deadline] { + char c{}; + for (int packet_idx = 0; packet_idx < packets_count; ++packet_idx) { + const auto result = receiver->RecvSomeFrom(&c, 1, deadline); + EXPECT_EQ(result.bytes_received, 1); + EXPECT_EQ(c, 'a' + packet_idx); + } + })); + } + + io::Socket sender{io::AddrDomain::kInet, io::SocketType::kDgram}; + for (int packet_idx = 0; packet_idx < packets_count; ++packet_idx) { + const char data = 'a' + packet_idx; + EXPECT_EQ(sender.SendAllTo(multiaddr, &data, 1, deadline), 1); + } + + for (auto& t : tasks) { + t.Get(); + } + tasks.clear(); + + for (int i = 0; i < 2; ++i) { + const auto& receiver = receivers[i]; + receiver->DropMembership(mreq); + + tasks.push_back(engine::AsyncNoSpan([receiver] { + auto short_deadline = Deadline::FromDuration(std::chrono::milliseconds(300)); + char c{}; + const auto result = receiver->RecvSomeFrom(&c, 1, short_deadline); + EXPECT_EQ(result.bytes_received, 0); + })); + } + + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + char data = 'x'; + EXPECT_EQ(sender.SendAllTo(multiaddr, &data, 1, deadline), 1); + for (auto& t : tasks) { + UEXPECT_THROW(t.Get(), io::IoTimeout); + } +} + UTEST(Socket, WriteALot) { const auto deadline = Deadline::FromDuration(utest::kMaxTestWaitTime);