diff --git a/CMakeLists.txt b/CMakeLists.txt index beda6f2..69537c6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,7 +3,8 @@ project(FMI) set(CMAKE_CXX_STANDARD 17) set(CMAKE_POSITION_INDEPENDENT_CODE ON) -set(Boost_USE_STATIC_LIBS ON) +option(FMI_BOOST_STATIC "Use static Boost libraries" ON) +set(Boost_USE_STATIC_LIBS ${FMI_BOOST_STATIC}) list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") @@ -17,6 +18,7 @@ find_package(AWSSDK COMPONENTS s3 REQUIRED) find_package(hiredis REQUIRED) add_subdirectory(extern/TCPunch/client/) +find_package(ZLIB REQUIRED) add_library(FMI STATIC src/Communicator.cpp src/utils/Configuration.cpp src/comm/Channel.cpp src/comm/ClientServer.cpp src/comm/S3.cpp src/comm/Redis.cpp src/utils/ChannelPolicy.cpp src/comm/PeerToPeer.cpp src/comm/Direct.cpp) @@ -31,4 +33,4 @@ target_link_libraries(FMI ${Boost_Libraries} Boost::log ${AWSSDK_LINK_LIBRARIES} #target_link_libraries(client FMI OpenMP::OpenMP_CXX) -#add_subdirectory(tests) +add_subdirectory(tests) diff --git a/extern/TCPunch b/extern/TCPunch deleted file mode 160000 index befd086..0000000 --- a/extern/TCPunch +++ /dev/null @@ -1 +0,0 @@ -Subproject commit befd086c88fd974684f3e8c6ec05826696dc7ef7 diff --git a/extern/TCPunch/client/CMakeLists.txt b/extern/TCPunch/client/CMakeLists.txt new file mode 100644 index 0000000..5222f85 --- /dev/null +++ b/extern/TCPunch/client/CMakeLists.txt @@ -0,0 +1,50 @@ +cmake_minimum_required(VERSION 3.10) + +set(CMAKE_THREAD_PREFER_PTHREAD TRUE) +set(THREADS_PREFER_PTHREAD_FLAG TRUE) +set(CMAKE_POSITION_INDEPENDENT_CODE ON) + + +project(tcpunch VERSION 1.0) + +set(CMAKE_CXX_STANDARD 11) +set(CMAKE_CXX_STANDARD_REQUIRED True) + +find_package(Threads REQUIRED) +add_library(tcpunch STATIC tcpunch.cpp) +target_include_directories(tcpunch PUBLIC + $ + $ +) +target_link_libraries(tcpunch PRIVATE Threads::Threads) + + +include(GNUInstallDirs) +include(CMakePackageConfigHelpers) + +install(FILES tcpunch.h DESTINATION include) +install(TARGETS + tcpunch + EXPORT tcpunchTargets + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} +) + +install(EXPORT + tcpunchTargets + FILE tcpunchTargets.cmake + DESTINATION "${CMAKE_INSTALL_DATADIR}/tcpunch/cmake" + NAMESPACE tcpunch:: +) +configure_package_config_file( + cmake/tcpunchConfig.cmake.in + "${CMAKE_CURRENT_BINARY_DIR}/tcpunchConfig.cmake" + INSTALL_DESTINATION "${CMAKE_INSTALL_DATADIR}/tcpunch/cmake" +) + +install( + FILES "${CMAKE_CURRENT_BINARY_DIR}/tcpunchConfig.cmake" + DESTINATION "${CMAKE_INSTALL_DATADIR}/tcpunch/cmake" +) + diff --git a/extern/TCPunch/client/cmake/tcpunchConfig.cmake.in b/extern/TCPunch/client/cmake/tcpunchConfig.cmake.in new file mode 100644 index 0000000..f9b2a16 --- /dev/null +++ b/extern/TCPunch/client/cmake/tcpunchConfig.cmake.in @@ -0,0 +1,8 @@ +@PACKAGE_INIT@ + +include(CMakeFindDependencyMacro) +find_dependency(Threads) + +include("${CMAKE_CURRENT_LIST_DIR}/tcpunchTargets.cmake") + +check_required_components(tcpunch) diff --git a/extern/TCPunch/client/tcpunch.cpp b/extern/TCPunch/client/tcpunch.cpp new file mode 100644 index 0000000..0233fe6 --- /dev/null +++ b/extern/TCPunch/client/tcpunch.cpp @@ -0,0 +1,194 @@ +#include "tcpunch.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "../common/utils.h" + + +std::atomic connection_established(false); +std::atomic accepting_socket(-1); + +void* peer_listen(void* p) { + auto* info = (PeerConnectionData*)p; + + // Create socket on the port that was previously used to contact the rendezvous server + int listen_socket = socket(AF_INET, SOCK_STREAM, 0); + if (listen_socket == -1) { + error_exit_errno("Socket creation failed: "); + } + int enable_flag = 1; + if (setsockopt(listen_socket, SOL_SOCKET, SO_REUSEADDR, &enable_flag, sizeof(int)) < 0 || + setsockopt(listen_socket, SOL_SOCKET, SO_REUSEPORT, &enable_flag, sizeof(int)) < 0) { + error_exit_errno("Setting REUSE options failed: "); + } + + struct sockaddr_in local_port_data{}; + local_port_data.sin_family = AF_INET; + local_port_data.sin_addr.s_addr = INADDR_ANY; + local_port_data.sin_port = info->port; + + if (bind(listen_socket, (const struct sockaddr *)&local_port_data, sizeof(local_port_data)) < 0) { + error_exit_errno("Could not bind to local port: "); + } + + if (listen(listen_socket, 1) == -1) { + error_exit_errno("Listening on local port failed: "); + } + + struct sockaddr_in peer_info{}; + unsigned int len = sizeof(peer_info); + + while(true) { + int peer = accept(listen_socket, (struct sockaddr*)&peer_info, &len); + if (peer == -1) { +#if DEBUG + std::cout << "Error when connecting to peer" << strerror(errno) << std::endl; +#endif + } else { +#if DEBUG + std::cout << "Succesfully connected to peer, accepting" << std::endl; +#endif + accepting_socket = peer; + connection_established = true; + return 0; + } + } +} + +int pair(const std::string& pairing_name, const std::string& server_address, int port, int timeout_ms) { + connection_established = false; + accepting_socket = -1; + struct timeval timeout; + timeout.tv_sec = timeout_ms / 1000; + timeout.tv_usec = (timeout_ms % 1000) * 1000; + + int socket_rendezvous; + struct sockaddr_in server_data{}; + + socket_rendezvous = socket(AF_INET, SOCK_STREAM, 0); + if (socket_rendezvous == -1) { + error_exit_errno("Could not create socket for rendezvous server: "); + } + + // Enable binding multiple sockets to the same local endpoint, see https://bford.info/pub/net/p2pnat/ for details + int enable_flag = 1; + if (setsockopt(socket_rendezvous, SOL_SOCKET, SO_REUSEADDR, &enable_flag, sizeof(int)) < 0 || + setsockopt(socket_rendezvous, SOL_SOCKET, SO_REUSEPORT, &enable_flag, sizeof(int)) < 0) { + error_exit_errno("Setting REUSE options failed: "); + } + if (setsockopt(socket_rendezvous, SOL_SOCKET, SO_RCVTIMEO, (const char*)&timeout, sizeof timeout) < 0 || + setsockopt(socket_rendezvous, SOL_SOCKET, SO_REUSEPORT, &enable_flag, sizeof(int)) < 0) { + error_exit_errno("Setting timeout failed: "); + } + + server_data.sin_family = AF_INET; + server_data.sin_addr.s_addr = inet_addr(server_address.c_str()); + server_data.sin_port = htons(port); + + if (connect(socket_rendezvous, (struct sockaddr *)&server_data, sizeof(server_data)) != 0) { + error_exit_errno("Connection with the rendezvous server failed: "); + } + + if(send(socket_rendezvous, pairing_name.c_str(), pairing_name.length(), MSG_DONTWAIT) == -1) { + error_exit_errno("Failed to send data to rendezvous server: "); + } + + PeerConnectionData public_info; + ssize_t bytes = recv(socket_rendezvous, &public_info, sizeof(public_info), MSG_WAITALL); + if (bytes == -1) { + error_exit_errno("Failed to get data from rendezvous server: "); + } else if(bytes == 0) { + error_exit("Server has disconnected"); + } + + /*pthread_t peer_listen_thread; + int thread_return = pthread_create(&peer_listen_thread, nullptr, peer_listen, (void*) &public_info); + if(thread_return) { + error_exit_errno("Error when creating thread for listening: "); + }*/ + + PeerConnectionData peer_data; + + // Wait until rendezvous server sends info about peer + ssize_t bytes_received = recv(socket_rendezvous, &peer_data, sizeof(peer_data), MSG_WAITALL); + if(bytes_received == -1) { + error_exit_errno("Failed to get peer data from rendezvous server: "); + } else if(bytes_received == 0) { + error_exit("Server has disconnected when waiting for peer data"); + } +#if DEBUG + std::cout << "Peer: " << ip_to_string(&peer_data.ip.s_addr) << ":" << ntohs(peer_data.port) << std::endl; +#endif + + //We do NOT close the socket_rendezvous socket here, otherwise the next binds sometimes fail (although SO_REUSEADDR|SO_REUSEPORT is set)! + + int peer_socket = socket(AF_INET, SOCK_STREAM, 0); + if (setsockopt(peer_socket, SOL_SOCKET, SO_REUSEADDR, &enable_flag, sizeof(int)) < 0 || + setsockopt(peer_socket, SOL_SOCKET, SO_REUSEPORT, &enable_flag, sizeof(int)) < 0) { + error_exit("Setting REUSE options failed"); + } + + //Set socket to non blocking for the following polling operations + if(fcntl(peer_socket, F_SETFL, O_NONBLOCK) != 0) { + error_exit_errno("Setting O_NONBLOCK failed: "); + } + + struct sockaddr_in local_port_addr = {0}; + local_port_addr.sin_family = AF_INET; + local_port_addr.sin_addr.s_addr = INADDR_ANY; + local_port_addr.sin_port = public_info.port; + + if (bind(peer_socket, (const struct sockaddr *)&local_port_addr, sizeof(local_port_addr))) { + error_exit_errno("Binding to same port failed: "); + } + + struct sockaddr_in peer_addr = {0}; + peer_addr.sin_family = AF_INET; + peer_addr.sin_addr.s_addr = peer_data.ip.s_addr; + peer_addr.sin_port = peer_data.port; + + while(!connection_established.load()) { + int peer_status = connect(peer_socket, (struct sockaddr *)&peer_addr, sizeof(struct sockaddr)); + if (peer_status != 0) { + if (errno == EALREADY || errno == EAGAIN || errno == EINPROGRESS) { + continue; + } else if(errno == EISCONN) { + #if DEBUG + std::cout << "Succesfully connected to peer, EISCONN" << std::endl; + #endif + break; + } else { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + continue; + } + } else { + #if DEBUG + std::cout << "Succesfully connected to peer" << std::endl; + #endif + break; + } + } + + /*if(connection_established.load()) { + pthread_join(peer_listen_thread, nullptr); + peer_socket = accepting_socket.load(); + }*/ + + int flags = fcntl(peer_socket, F_GETFL, 0); + flags &= ~(O_NONBLOCK); + fcntl(peer_socket, F_SETFL, flags); + + return peer_socket; +} \ No newline at end of file diff --git a/extern/TCPunch/client/tcpunch.h b/extern/TCPunch/client/tcpunch.h new file mode 100644 index 0000000..a219968 --- /dev/null +++ b/extern/TCPunch/client/tcpunch.h @@ -0,0 +1,16 @@ +#ifndef HOLE_PUNCHING_CLIENT_H +#define HOLE_PUNCHING_CLIENT_H +#include +#include +#include +#include +#include +#include + +#define DEBUG 1 + +struct Timeout : public std::exception {}; + +int pair(const std::string& pairing_name, const std::string& server_address, int port = 10000, int timeout_ms = 0); + +#endif \ No newline at end of file diff --git a/extern/TCPunch/common/utils.h b/extern/TCPunch/common/utils.h new file mode 100644 index 0000000..2e8e652 --- /dev/null +++ b/extern/TCPunch/common/utils.h @@ -0,0 +1,32 @@ +#ifndef HOLEPUNCHINGSERVERCLIENT_UTILS_H +#define HOLEPUNCHINGSERVERCLIENT_UTILS_H + +#include +#include +#include "../client/tcpunch.h" + +typedef struct { + struct in_addr ip; + in_port_t port; +} PeerConnectionData; + +void error_exit(const std::string& error_string) { + throw std::runtime_error{error_string}; +} + +void error_exit_errno(const std::string& error_string) { + if (errno == EAGAIN) { + throw Timeout(); + } else { + std::string err = error_string + strerror(errno); + throw std::runtime_error{err}; + } +} + +std::string ip_to_string(in_addr_t *ip) { + char str_buffer[20]; + inet_ntop(AF_INET, ip, str_buffer, sizeof(str_buffer)); + return {str_buffer}; +} + +#endif //HOLEPUNCHINGSERVERCLIENT_UTILS_H diff --git a/include/Communicator.h b/include/Communicator.h index 8a13e4c..ed83029 100644 --- a/include/Communicator.h +++ b/include/Communicator.h @@ -26,23 +26,45 @@ namespace FMI { template void send(Comm::Data &buf, FMI::Utils::peer_num dest) { std::string channel = policy->get_channel({Utils::send, buf.size_in_bytes()}); - channel_data data {buf.data(), buf.size_in_bytes()}; + auto data = std::make_shared(buf.data(), buf.size_in_bytes(), noop_deleter); channels[channel]->send(data, dest); } + //! Non-blocking send buf to peer dest + template + void send(Comm::Data &buf, FMI::Utils::peer_num dest, + FMI::Utils::fmiContext* context, FMI::Utils::Mode mode, + std::function callback) { + std::string channel = policy->get_channel({Utils::send, buf.size_in_bytes()}); + auto data = std::make_shared(buf.data(), buf.size_in_bytes(), noop_deleter); + channels[channel]->send(data, dest, context, mode, callback); + } + //! Receive data from src and store data into the provided buf template void recv(Comm::Data &buf, FMI::Utils::peer_num src) { std::string channel = policy->get_channel({Utils::send, buf.size_in_bytes()}); - channel_data data {buf.data(), buf.size_in_bytes()}; + auto data = std::make_shared(buf.data(), buf.size_in_bytes(), noop_deleter); channels[channel]->recv(data, src); } + //! Non-blocking receive data from src + template + void recv(Comm::Data &buf, FMI::Utils::peer_num src, + FMI::Utils::fmiContext* context, FMI::Utils::Mode mode, + std::function callback) { + std::string channel = policy->get_channel({Utils::send, buf.size_in_bytes()}); + auto data = std::make_shared(buf.data(), buf.size_in_bytes(), noop_deleter); + channels[channel]->recv(data, src, context, mode, callback); + } + //! Broadcast the data that is in the provided buf of the root peer. Result is stored in buf for all peers. template void bcast(Comm::Data &buf, FMI::Utils::peer_num root) { std::string channel = policy->get_channel({Utils::bcast, buf.size_in_bytes()}); - channel_data data {buf.data(), buf.size_in_bytes()}; + auto data = std::make_shared(buf.data(), buf.size_in_bytes(), noop_deleter); channels[channel]->bcast(data, root); } @@ -60,11 +82,40 @@ namespace FMI { template void gather(Comm::Data &sendbuf, Comm::Data &recvbuf, FMI::Utils::peer_num root) { std::string channel = policy->get_channel({Utils::gather, sendbuf.size_in_bytes()}); - channel_data senddata {sendbuf.data(), sendbuf.size_in_bytes()}; - channel_data recvdata {recvbuf.data(), recvbuf.size_in_bytes()}; + auto senddata = std::make_shared(sendbuf.data(), sendbuf.size_in_bytes(), noop_deleter); + auto recvdata = std::make_shared(recvbuf.data(), recvbuf.size_in_bytes(), noop_deleter); channels[channel]->gather(senddata, recvdata, root); } + //! Variable-length gather + template + void gatherv(Comm::Data &sendbuf, Comm::Data &recvbuf, FMI::Utils::peer_num root, + const std::vector& recvcounts, const std::vector& displs) { + std::string channel = policy->get_channel({Utils::gatherv, sendbuf.size_in_bytes()}); + auto senddata = std::make_shared(sendbuf.data(), sendbuf.size_in_bytes(), noop_deleter); + auto recvdata = std::make_shared(recvbuf.data(), recvbuf.size_in_bytes(), noop_deleter); + channels[channel]->gatherv(senddata, recvdata, root, recvcounts, displs); + } + + //! All-gather - gather data and distribute to all peers + template + void allgather(Comm::Data &sendbuf, Comm::Data &recvbuf, FMI::Utils::peer_num root) { + std::string channel = policy->get_channel({Utils::allgather, sendbuf.size_in_bytes()}); + auto senddata = std::make_shared(sendbuf.data(), sendbuf.size_in_bytes(), noop_deleter); + auto recvdata = std::make_shared(recvbuf.data(), recvbuf.size_in_bytes(), noop_deleter); + channels[channel]->allgather(senddata, recvdata, root); + } + + //! Variable-length all-gather + template + void allgatherv(Comm::Data &sendbuf, Comm::Data &recvbuf, FMI::Utils::peer_num root, + const std::vector& recvcounts, const std::vector& displs) { + std::string channel = policy->get_channel({Utils::allgatherv, sendbuf.size_in_bytes()}); + auto senddata = std::make_shared(sendbuf.data(), sendbuf.size_in_bytes(), noop_deleter); + auto recvdata = std::make_shared(recvbuf.data(), recvbuf.size_in_bytes(), noop_deleter); + channels[channel]->allgatherv(senddata, recvdata, root, recvcounts, displs); + } + //! Scatter the data from root's sendbuf to the recvbuf of all peers. /*! * @param sendbuf The data to scatter, size needs to be recvbuf.size * num_peers (i.e., divisible by the number of peers). Only relevant for the root peer. @@ -73,8 +124,8 @@ namespace FMI { template void scatter(Comm::Data &sendbuf, Comm::Data &recvbuf, FMI::Utils::peer_num root) { std::string channel = policy->get_channel({Utils::scatter, recvbuf.size_in_bytes()}); - channel_data senddata {sendbuf.data(), sendbuf.size_in_bytes()}; - channel_data recvdata {recvbuf.data(), recvbuf.size_in_bytes()}; + auto senddata = std::make_shared(sendbuf.data(), sendbuf.size_in_bytes(), noop_deleter); + auto recvdata = std::make_shared(recvbuf.data(), recvbuf.size_in_bytes(), noop_deleter); channels[channel]->scatter(senddata, recvdata, root); } @@ -91,8 +142,8 @@ namespace FMI { } bool left_to_right = !(f.commutative && f.associative); std::string channel = policy->get_channel({Utils::reduce, sendbuf.size_in_bytes(), left_to_right}); - channel_data senddata {sendbuf.data(), sendbuf.size_in_bytes()}; - channel_data recvdata {recvbuf.data(), recvbuf.size_in_bytes()}; + auto senddata = std::make_shared(sendbuf.data(), sendbuf.size_in_bytes(), noop_deleter); + auto recvdata = std::make_shared(recvbuf.data(), recvbuf.size_in_bytes(), noop_deleter); auto func = convert_to_raw_function(f, sendbuf.size_in_bytes()); raw_function raw_f { func, @@ -115,8 +166,8 @@ namespace FMI { } bool left_to_right = !(f.commutative && f.associative); std::string channel = policy->get_channel({Utils::allreduce, sendbuf.size_in_bytes(), left_to_right}); - channel_data senddata {sendbuf.data(), sendbuf.size_in_bytes()}; - channel_data recvdata {recvbuf.data(), recvbuf.size_in_bytes()}; + auto senddata = std::make_shared(sendbuf.data(), sendbuf.size_in_bytes(), noop_deleter); + auto recvdata = std::make_shared(recvbuf.data(), recvbuf.size_in_bytes(), noop_deleter); auto func = convert_to_raw_function(f, sendbuf.size_in_bytes()); raw_function raw_f { func, @@ -138,8 +189,8 @@ namespace FMI { throw std::runtime_error("Dimensions of send and receive data must match"); } std::string channel = policy->get_channel({Utils::scan, sendbuf.size_in_bytes()}); - channel_data senddata {sendbuf.data(), sendbuf.size_in_bytes()}; - channel_data recvdata {recvbuf.data(), recvbuf.size_in_bytes()}; + auto senddata = std::make_shared(sendbuf.data(), sendbuf.size_in_bytes(), noop_deleter); + auto recvdata = std::make_shared(recvbuf.data(), recvbuf.size_in_bytes(), noop_deleter); auto func = convert_to_raw_function(f, sendbuf.size_in_bytes()); raw_function raw_f { func, @@ -149,6 +200,20 @@ namespace FMI { channels[channel]->scan(senddata, recvdata, raw_f); } + //! Progress function for polling non-blocking completion + FMI::Utils::EventProcessStatus progress(FMI::Utils::Operation op = FMI::Utils::send) { + FMI::Utils::EventProcessStatus status = FMI::Utils::EMPTY; + for (auto& [name, channel] : channels) { + auto channel_status = channel->channel_event_progress(op); + if (channel_status == FMI::Utils::PROCESSING) { + status = FMI::Utils::PROCESSING; + } else if (channel_status == FMI::Utils::NOOP && status == FMI::Utils::EMPTY) { + status = FMI::Utils::NOOP; + } + } + return status; + } + //! Add a new channel to the communicator with the given name by providing a pointer to it. void register_channel(std::string name, std::shared_ptr); @@ -192,4 +257,4 @@ namespace FMI { -#endif //FMI_COMMUNICATOR_H +#endif //FMI_COMMUNICATOR_H \ No newline at end of file diff --git a/include/comm/Channel.h b/include/comm/Channel.h index 2bba961..e0ca77d 100644 --- a/include/comm/Channel.h +++ b/include/comm/Channel.h @@ -5,6 +5,9 @@ #include #include #include +#include +#include +#include #include "../utils/Function.h" #include "../utils/Common.h" @@ -22,14 +25,45 @@ struct raw_function { bool commutative; }; +//! Noop deleter for external buffer management +static inline std::function noop_deleter = [](void*) {}; + //! Data that is passed to and from channels /*! * We intentionally use type erasure such that channels do not need to deal about types. * However, the communicator interface ensures that len corresponds to the type in buf and users never directly interact with channel_data. + * Uses shared_ptr for automatic memory management. */ struct channel_data { - char* buf; - std::size_t len; + std::shared_ptr buf; + std::size_t len = 0; + std::shared_ptr orig; // For external buffer ownership + + // Default constructor + channel_data() = default; + + // Allocating constructor - allocates new buffer + explicit channel_data(std::size_t length) + : buf(std::shared_ptr(new char[length])), len(length) {} + + // From raw pointer with custom deleter (for external buffers) + channel_data(char* external_buf, std::size_t length, + std::function deleter) + : buf(std::shared_ptr(external_buf, + [deleter](char* p) { deleter(p); })), + len(length) {} + + // From raw pointer with deleter and original reference + channel_data(char* external_buf, std::size_t length, + std::function deleter, + std::shared_ptr original) + : buf(std::shared_ptr(external_buf, + [deleter](char* p) { deleter(p); })), + len(length), orig(std::move(original)) {} + + // Raw pointer accessor + char* get() { return buf.get(); } + const char* get() const { return buf.get(); } }; @@ -38,14 +72,44 @@ namespace FMI::Comm { //! Interface that defines channel operations. Only provides a few default implementations, the rest is implemented in the specific ClientServer or PeerToPeer channel types. class Channel { public: + //! Initialize channel (for non-blocking setup) + virtual void init() {} + //! Send data to peer with id dest, must match a recv call - virtual void send(channel_data buf, FMI::Utils::peer_num dest) = 0; + virtual void send(std::shared_ptr buf, FMI::Utils::peer_num dest) = 0; //! Receive data from peer with id src, must match a send call - virtual void recv(channel_data buf, FMI::Utils::peer_num src) = 0; + virtual void recv(std::shared_ptr buf, FMI::Utils::peer_num src) = 0; + + //! Non-blocking send with callback + virtual void send(std::shared_ptr buf, FMI::Utils::peer_num dest, + FMI::Utils::fmiContext* context, FMI::Utils::Mode mode, + std::function callback) = 0; + + //! Non-blocking receive with callback + virtual void recv(std::shared_ptr buf, FMI::Utils::peer_num src, + FMI::Utils::fmiContext* context, FMI::Utils::Mode mode, + std::function callback) = 0; + + //! Check if ready to receive from peer + virtual bool checkReceive(FMI::Utils::peer_num src, FMI::Utils::Mode mode) { return true; } + + //! Check if ready to send to peer + virtual bool checkSend(FMI::Utils::peer_num dest, FMI::Utils::Mode mode) { return true; } + + //! Event progress for non-blocking operations + virtual FMI::Utils::EventProcessStatus channel_event_progress(FMI::Utils::Operation op) = 0; //! Broadcast data. Buf only needs to contain useful data for root, the buffer is overwritten for all other peers - virtual void bcast(channel_data buf, FMI::Utils::peer_num root) = 0; + virtual void bcast(std::shared_ptr buf, FMI::Utils::peer_num root) = 0; + + //! Non-blocking broadcast with callback + virtual void bcast(std::shared_ptr buf, FMI::Utils::peer_num root, + FMI::Utils::Mode mode, + std::function callback); //! Barrier synchronization collective. virtual void barrier() = 0; @@ -56,7 +120,53 @@ namespace FMI::Comm { * @param sendbuf Data that is sent to the root * @param recvbuf Buffer to receive data in, only relevant for root. Needs to have a size of (at least) num_peers * sendbuf.size */ - virtual void gather(channel_data sendbuf, channel_data recvbuf, FMI::Utils::peer_num root); + virtual void gather(std::shared_ptr sendbuf, std::shared_ptr recvbuf, FMI::Utils::peer_num root); + + //! Variable-length gather - each peer sends different amount of data + virtual void gatherv(std::shared_ptr sendbuf, + std::shared_ptr recvbuf, + FMI::Utils::peer_num root, + const std::vector& recvcounts, + const std::vector& displs); + + //! Non-blocking variable-length gather + virtual void gatherv(std::shared_ptr sendbuf, + std::shared_ptr recvbuf, + FMI::Utils::peer_num root, + const std::vector& recvcounts, + const std::vector& displs, + FMI::Utils::Mode mode, + std::function callback); + + //! All-gather - gather data from all peers and distribute to all + virtual void allgather(std::shared_ptr sendbuf, + std::shared_ptr recvbuf, + FMI::Utils::peer_num root); + + //! Non-blocking all-gather + virtual void allgather(std::shared_ptr sendbuf, + std::shared_ptr recvbuf, + FMI::Utils::peer_num root, FMI::Utils::Mode mode, + std::function callback); + + //! Variable-length all-gather + virtual void allgatherv(std::shared_ptr sendbuf, + std::shared_ptr recvbuf, + FMI::Utils::peer_num root, + const std::vector& recvcounts, + const std::vector& displs); + + //! Non-blocking variable-length all-gather + virtual void allgatherv(std::shared_ptr sendbuf, + std::shared_ptr recvbuf, + FMI::Utils::peer_num root, + const std::vector& recvcounts, + const std::vector& displs, + FMI::Utils::Mode mode, + std::function callback); //! Scatter data from root to all peers /*! @@ -64,7 +174,7 @@ namespace FMI::Comm { * @param sendbuf Only relevant for root, contains the data that is scattered and needs to have a (divisible) size of num_peers * recvbuf.size * @param recvbuf Buffer to receive the data (of size sendbuf.size / num_peers), needs to be set by all peers */ - virtual void scatter(channel_data sendbuf, channel_data recvbuf, FMI::Utils::peer_num root); + virtual void scatter(std::shared_ptr sendbuf, std::shared_ptr recvbuf, FMI::Utils::peer_num root); //! Apply function f to sendbuf of all peers. /*! @@ -74,7 +184,7 @@ namespace FMI::Comm { * @param recvbuf Only relevant for root. Needs to have the same size as sendbuf * @param f Associativity / Commutativity of f controls choice of algorithm, depending on the channel / channel type */ - virtual void reduce(channel_data sendbuf, channel_data recvbuf, FMI::Utils::peer_num root, raw_function f) = 0; + virtual void reduce(std::shared_ptr sendbuf, std::shared_ptr recvbuf, FMI::Utils::peer_num root, raw_function f) = 0; //! Apply function f to sendbuf of all peers, make result available to everyone. /*! @@ -84,10 +194,13 @@ namespace FMI::Comm { * @param recvbuf Relevant for all peers in contrast to reduce * @param f */ - virtual void allreduce(channel_data sendbuf, channel_data recvbuf, raw_function f); + virtual void allreduce(std::shared_ptr sendbuf, std::shared_ptr recvbuf, raw_function f); //! Inclusive prefix scan, recvbuf / sendbuf needs to be set for all peers - virtual void scan(channel_data sendbuf, channel_data recvbuf, raw_function f) = 0; + virtual void scan(std::shared_ptr sendbuf, std::shared_ptr recvbuf, raw_function f) = 0; + + //! Get max timeout configuration + virtual int getMaxTimeout() { return 0; } //! Helper utility to set peer id, ID needs to be set before first collective operation void set_peer_id(FMI::Utils::peer_num num) { peer_id = num; } @@ -103,7 +216,7 @@ namespace FMI::Comm { * Note that we provide an explicit finalize function on purpose (and do not use a virtual destructor), * because derived classes may require that some values of parent classes still exist when cleaning up. */ - virtual void finalize() {}; + virtual void finalize() {} //! Create a new channel with the given config and model params /*! diff --git a/include/comm/ClientServer.h b/include/comm/ClientServer.h index 04bd9eb..6430492 100644 --- a/include/comm/ClientServer.h +++ b/include/comm/ClientServer.h @@ -15,34 +15,49 @@ namespace FMI::Comm { explicit ClientServer(std::map params); //! Constructs file / key name based on sender and recipient and then uploads the data. - void send(channel_data buf, FMI::Utils::peer_num dest) override; + void send(std::shared_ptr buf, FMI::Utils::peer_num dest) override; + + //! Non-blocking send (not fully supported in client-server mode) + void send(std::shared_ptr buf, FMI::Utils::peer_num dest, + FMI::Utils::fmiContext* context, FMI::Utils::Mode mode, + std::function callback) override; //! Waits until the object with the expected file / key name appears (or a timeout occurs), then downloads it. - void recv(channel_data buf, FMI::Utils::peer_num dest) override; + void recv(std::shared_ptr buf, FMI::Utils::peer_num dest) override; + + //! Non-blocking recv (not fully supported in client-server mode) + void recv(std::shared_ptr buf, FMI::Utils::peer_num src, + FMI::Utils::fmiContext* context, FMI::Utils::Mode mode, + std::function callback) override; //! Root uploads its data, all other peers download the object - void bcast(channel_data buf, FMI::Utils::peer_num root) override; + void bcast(std::shared_ptr buf, FMI::Utils::peer_num root) override; //! All peers upload a 1 byte file and wait until num_peers files (associated to this operation based on the file name) exist void barrier() override; + //! Event progress (client-server returns NOOP as operations are synchronous) + Utils::EventProcessStatus channel_event_progress(Utils::Operation op) override; + //! All peers upload their data. The root peer downloads these objects and applies the function (as soon as objects become available for associative / commutative functions, left-to-right otherwise) - void reduce(channel_data sendbuf, channel_data recvbuf, FMI::Utils::peer_num root, raw_function f) override; + void reduce(std::shared_ptr sendbuf, std::shared_ptr recvbuf, FMI::Utils::peer_num root, raw_function f) override; //! All peers upload their data and download the needed files to apply the function. Left-to-right evaluation order is enforced for non-commutative / non-associative functions. - void scan(channel_data sendbuf, channel_data recvbuf, raw_function f) override; + void scan(std::shared_ptr sendbuf, std::shared_ptr recvbuf, raw_function f) override; //! Function to upload data with a given name / key to the server, needs to be implemented by the channels and should never be invoked directly (use upload instead). - virtual void upload_object(channel_data buf, std::string name) = 0; + virtual void upload_object(std::shared_ptr buf, std::string name) = 0; //! Function to download data with a given name / key from the server, needs to be implemented by the channels. Returns true when download was successful, false when file does not exist. - virtual bool download_object(channel_data buf, std::string name) = 0; + virtual bool download_object(std::shared_ptr buf, std::string name) = 0; //! Try the download (using download_object) until the object appears or the timeout was reached. - virtual void download(channel_data buf, std::string name); + virtual void download(std::shared_ptr buf, std::string name); //! Uploads objects and keeps track of them. - virtual void upload(channel_data buf, std::string name); + virtual void upload(std::shared_ptr buf, std::string name); //! List all the currently existing objects, needs to be implemented by channels. Needed by some collectives that check for the existence of files, but do not care about their content. virtual std::vector get_object_names() = 0; diff --git a/include/comm/Direct.h b/include/comm/Direct.h index 9dc9f86..85d539c 100644 --- a/include/comm/Direct.h +++ b/include/comm/Direct.h @@ -2,26 +2,56 @@ #define FMI_DIRECT_H #include "PeerToPeer.h" +#include namespace FMI::Comm { //! Channel that uses the TCPunch TCP NAT Hole Punching Library for connection establishment. class Direct : public PeerToPeer { public: explicit Direct(std::map params, std::map model_params); + virtual ~Direct(); - void send_object(channel_data buf, Utils::peer_num rcpt_id) override; + //! Initialize non-blocking infrastructure + void init() override; - void recv_object(channel_data buf, Utils::peer_num sender_id) override; + //! Get max timeout configuration + int getMaxTimeout() override; + + //! Blocking send object + void send_object(std::shared_ptr buf, Utils::peer_num rcpt_id) override; + + //! Blocking receive object + void recv_object(std::shared_ptr buf, Utils::peer_num sender_id) override; + + //! Non-blocking send object + void send_object(std::shared_ptr state, + Utils::peer_num rcpt_id, Utils::Mode mode) override; + + //! Non-blocking receive object + void recv_object(std::shared_ptr state, + Utils::peer_num sender_id, Utils::Mode mode) override; + + //! Event progress for non-blocking operations + Utils::EventProcessStatus channel_event_progress(Utils::Operation op) override; double get_latency(Utils::peer_num producer, Utils::peer_num consumer, std::size_t size_in_bytes) override; double get_price(Utils::peer_num producer, Utils::peer_num consumer, std::size_t size_in_bytes) override; private: - //! Contains the socket file descriptor for the communication with the peers. - std::vector sockets; + //! Socket storage by mode (supports both blocking and non-blocking sockets) + std::unordered_map> sockets; + + //! I/O state tracking for non-blocking operations + std::unordered_map>> io_states; + + //! Current operating mode + Utils::Mode mode = Utils::BLOCKING; + std::string hostname; int port; + bool resolve_host_dns; unsigned int max_timeout; // Model params double bandwidth; @@ -33,6 +63,18 @@ namespace FMI::Comm { //! Checks if connection with a peer partner_id is already established, otherwise establishes it using TCPunch. void check_socket(Utils::peer_num partner_id, std::string pair_name); + + //! Check and setup non-blocking socket + void check_socket_nbx(Utils::peer_num partner_id, std::string pair_name); + + //! Generate pairing name with mode distinction + std::string get_pairing_name(Utils::peer_num a, Utils::peer_num b, Utils::Mode mode); + + //! Event handling for progress. Returns iterator to next element. + std::unordered_map>::iterator + handle_event(std::unordered_map>::iterator it, + std::unordered_map>& states, + Utils::Operation op); }; } diff --git a/include/comm/PeerToPeer.h b/include/comm/PeerToPeer.h index 5ebab3d..aa8e1a8 100644 --- a/include/comm/PeerToPeer.h +++ b/include/comm/PeerToPeer.h @@ -2,20 +2,73 @@ #define FMI_PEERTOPEER_H #include "Channel.h" +#include +#include namespace FMI::Comm { + + //! Helper struct for variable-length gather operations + struct GatherVData { + std::size_t buf_len; + std::shared_ptr recvbuf; + std::vector displs; + std::shared_ptr buffer; + Utils::peer_num real_src; + }; + + //! I/O state for tracking non-blocking operations + struct IOState { + std::shared_ptr request; + size_t processed = 0; + Utils::Operation operation = Utils::send; + Utils::fmiContext* context = nullptr; + char dummy = 0; + + std::function callbackResult; + std::function callback = nullptr; + std::chrono::steady_clock::time_point deadline; + + //! Set the request data + void setRequest(const std::shared_ptr& cdata) { + request = cdata; + } + + //! Set the callback with bound arguments + template + void setCallback(Func&& func, Args&&... args) { + callback = std::bind(std::forward(func), std::forward(args)...); + } + }; + //! Peer-To-Peer channel type /*! * This class provides optimized collectives for channels where clients can address each other directly and defines the interface that these channels need to implement. */ class PeerToPeer : public Channel { public: - void send(channel_data buf, FMI::Utils::peer_num dest) override; - - void recv(channel_data buf, FMI::Utils::peer_num src) override; + // Blocking send/recv + void send(std::shared_ptr buf, FMI::Utils::peer_num dest) override; + void recv(std::shared_ptr buf, FMI::Utils::peer_num src) override; + + // Non-blocking send/recv + void send(std::shared_ptr buf, FMI::Utils::peer_num dest, + FMI::Utils::fmiContext* context, FMI::Utils::Mode mode, + std::function callback) override; + void recv(std::shared_ptr buf, FMI::Utils::peer_num src, + FMI::Utils::fmiContext* context, FMI::Utils::Mode mode, + std::function callback) override; //! Binomial tree broadcast implementation - void bcast(channel_data buf, FMI::Utils::peer_num root) override; + void bcast(std::shared_ptr buf, FMI::Utils::peer_num root) override; + + //! Non-blocking broadcast + void bcast(std::shared_ptr buf, FMI::Utils::peer_num root, + FMI::Utils::Mode mode, + std::function callback) override; //! Calls allreduce with a (associative and commutative) NOP operation void barrier() override; @@ -26,28 +79,85 @@ namespace FMI::Comm { * If the ID of the root is not 0, we cannot necessarily receive all values directly in recvbuf because we need to wrap around (e.g., when we get from peer N - 1 the values for N - 1, 0, and 1). * This is solved by allocating a temporary buffer and copying the values. */ - void gather(channel_data sendbuf, channel_data recvbuf, FMI::Utils::peer_num root) override; + void gather(std::shared_ptr sendbuf, std::shared_ptr recvbuf, FMI::Utils::peer_num root) override; + + //! Variable-length gather + void gatherv(std::shared_ptr sendbuf, + std::shared_ptr recvbuf, + FMI::Utils::peer_num root, + const std::vector& recvcounts, + const std::vector& displs) override; + + //! Non-blocking variable-length gather + void gatherv(std::shared_ptr sendbuf, + std::shared_ptr recvbuf, + FMI::Utils::peer_num root, + const std::vector& recvcounts, + const std::vector& displs, + FMI::Utils::Mode mode, + std::function callback) override; + + //! All-gather + void allgather(std::shared_ptr sendbuf, + std::shared_ptr recvbuf, + FMI::Utils::peer_num root) override; + + //! Non-blocking all-gather + void allgather(std::shared_ptr sendbuf, + std::shared_ptr recvbuf, + FMI::Utils::peer_num root, FMI::Utils::Mode mode, + std::function callback) override; + + //! Variable-length all-gather + void allgatherv(std::shared_ptr sendbuf, + std::shared_ptr recvbuf, + FMI::Utils::peer_num root, + const std::vector& recvcounts, + const std::vector& displs) override; + + //! Non-blocking variable-length all-gather + void allgatherv(std::shared_ptr sendbuf, + std::shared_ptr recvbuf, + FMI::Utils::peer_num root, + const std::vector& recvcounts, + const std::vector& displs, + FMI::Utils::Mode mode, + std::function callback) override; //! Binomial tree scatter /*! * Similarly to gather, the root may need to send values from its sendbuf that is not consecutive when its ID is not 0, which is solved with a temporary buffer. */ - void scatter(channel_data sendbuf, channel_data recvbuf, FMI::Utils::peer_num root) override; + void scatter(std::shared_ptr sendbuf, std::shared_ptr recvbuf, FMI::Utils::peer_num root) override; //! Calls reduce_no_order for associative and commutative functions, reduce_ltr otherwise - void reduce(channel_data sendbuf, channel_data recvbuf, FMI::Utils::peer_num root, raw_function f) override; + void reduce(std::shared_ptr sendbuf, std::shared_ptr recvbuf, FMI::Utils::peer_num root, raw_function f) override; //! For associative and commutative functions, allreduce_no_order is called. Otherwise, reduce followed by bcast is used. - void allreduce(channel_data sendbuf, channel_data recvbuf, raw_function f) override; + void allreduce(std::shared_ptr sendbuf, std::shared_ptr recvbuf, raw_function f) override; //! For associative and commutative functions, scan_no_order is called. Otherwise, scan_ltr is called - void scan(channel_data sendbuf, channel_data recvbuf, raw_function f) override; + void scan(std::shared_ptr sendbuf, std::shared_ptr recvbuf, raw_function f) override; + + //! Event progress for non-blocking operations + Utils::EventProcessStatus channel_event_progress(Utils::Operation op) override; //! Send an object to peer with ID peer_id. Needs to be implemented by the channels. - virtual void send_object(channel_data buf, Utils::peer_num peer_id) = 0; + virtual void send_object(std::shared_ptr buf, Utils::peer_num peer_id) = 0; //! Receive an object from peer with ID peer_id. Needs to be implemented by the channels. - virtual void recv_object(channel_data buf, Utils::peer_num peer_id) = 0; + virtual void recv_object(std::shared_ptr buf, Utils::peer_num peer_id) = 0; + + //! Non-blocking send object + virtual void send_object(std::shared_ptr state, + Utils::peer_num peer_id, Utils::Mode mode) = 0; + + //! Non-blocking receive object + virtual void recv_object(std::shared_ptr state, + Utils::peer_num peer_id, Utils::Mode mode) = 0; double get_operation_latency(Utils::OperationInfo op_info) override; @@ -55,19 +165,19 @@ namespace FMI::Comm { protected: //! Reduction with left-to-right evaluation, gather followed by a function evaluation on the root peer. - void reduce_ltr(channel_data sendbuf, channel_data recvbuf, FMI::Utils::peer_num root, const raw_function& f); + void reduce_ltr(std::shared_ptr sendbuf, std::shared_ptr recvbuf, FMI::Utils::peer_num root, const raw_function& f); //! Binomial tree reduction where all peers apply the function in every step. - void reduce_no_order(channel_data sendbuf, channel_data recvbuf, FMI::Utils::peer_num root, const raw_function& f); + void reduce_no_order(std::shared_ptr sendbuf, std::shared_ptr recvbuf, FMI::Utils::peer_num root, const raw_function& f); //! Recursive doubling allreduce implementation. When num_peers is not a power of two, there is an additional message in the beginning and end for every peer where they send their value / receive the reduced value. - void allreduce_no_order(channel_data sendbuf, channel_data recvbuf, const raw_function& f); + void allreduce_no_order(std::shared_ptr sendbuf, std::shared_ptr recvbuf, const raw_function& f); //! Linear function application / sending - void scan_ltr(channel_data sendbuf, channel_data recvbuf, const raw_function& f); + void scan_ltr(std::shared_ptr sendbuf, std::shared_ptr recvbuf, const raw_function& f); //! Binomial tree with up- and down-phase - void scan_no_order(channel_data sendbuf, channel_data recvbuf, const raw_function& f); + void scan_no_order(std::shared_ptr sendbuf, std::shared_ptr recvbuf, const raw_function& f); private: //! Allows to implement all collectives as if root were 0 diff --git a/include/comm/Redis.h b/include/comm/Redis.h index 16edc56..42d894c 100644 --- a/include/comm/Redis.h +++ b/include/comm/Redis.h @@ -14,9 +14,9 @@ namespace FMI::Comm { ~Redis(); - void upload_object(channel_data buf, std::string name) override; + void upload_object(std::shared_ptr buf, std::string name) override; - bool download_object(channel_data buf, std::string name) override; + bool download_object(std::shared_ptr buf, std::string name) override; void delete_object(std::string name) override; diff --git a/include/comm/S3.h b/include/comm/S3.h index cb694ab..88dae70 100644 --- a/include/comm/S3.h +++ b/include/comm/S3.h @@ -16,9 +16,9 @@ namespace FMI::Comm { ~S3(); - void upload_object(channel_data buf, std::string name) override; + void upload_object(std::shared_ptr buf, std::string name) override; - bool download_object(channel_data buf, std::string name) override; + bool download_object(std::shared_ptr buf, std::string name) override; void delete_object(std::string name) override; diff --git a/include/utils/Common.h b/include/utils/Common.h index 7964955..aff3b92 100644 --- a/include/utils/Common.h +++ b/include/utils/Common.h @@ -19,9 +19,49 @@ namespace FMI::Utils { fast, cheap }; + //! Mode for blocking/non-blocking selection + enum Mode { + BLOCKING, + NONBLOCKING + }; + + //! Detailed error codes for non-blocking operations + enum NbxStatus { + SUCCESS, + SEND_FAILED, + RECEIVE_FAILED, + DUMMY_SEND_FAILED, + CONNECTION_CLOSED_BY_PEER, + SOCKET_CREATE_FAILED, + TCP_NODELAY_FAILED, + FCNTL_GET_FAILED, + FCNTL_SET_FAILED, + ADD_EVENT_FAILED, + EPOLL_WAIT_FAILED, + SOCKET_PAIR_FAILED, + SOCKET_SET_SO_RCVTIMEO_FAILED, + SOCKET_SET_SO_SNDTIMEO_FAILED, + SOCKET_SET_TCP_NODELAY_FAILED, + SOCKET_SET_NONBLOCKING_FAILED, + NBX_TIMEOUT + }; + + //! Event processing status for progress tracking + enum EventProcessStatus { + PROCESSING, + EMPTY, + NOOP + }; + + //! Completion context structure for non-blocking operations + struct fmiContext { + int completed; + }; + //! List of currently supported collectives enum Operation { - send, bcast, barrier, gather, scatter, reduce, allreduce, scan + send, recv, bcast, barrier, gather, gatherv, allgather, allgatherv, + scatter, reduce, allreduce, scan }; //! All the information about an operation, passed to the Channel Policy for its decision on which channel to use. diff --git a/src/comm/Channel.cpp b/src/comm/Channel.cpp index 8dc2ea3..92db28d 100644 --- a/src/comm/Channel.cpp +++ b/src/comm/Channel.cpp @@ -16,30 +16,111 @@ std::shared_ptr FMI::Comm::Channel::get_channel(std::string } } -void FMI::Comm::Channel::gather(channel_data sendbuf, channel_data recvbuf, FMI::Utils::peer_num root) { +void FMI::Comm::Channel::gather(std::shared_ptr sendbuf, std::shared_ptr recvbuf, FMI::Utils::peer_num root) { if (peer_id != root) { send(sendbuf, root); } else { - auto buffer_length = sendbuf.len; - for (int i = 0; i < num_peers; i++) { + auto buffer_length = sendbuf->len; + for (unsigned int i = 0; i < num_peers; i++) { if (i == root) { - std::memcpy(recvbuf.buf + root * buffer_length, sendbuf.buf, buffer_length); + std::memcpy(recvbuf->get() + root * buffer_length, sendbuf->get(), buffer_length); } else { - channel_data peer_data {recvbuf.buf + i * buffer_length, buffer_length}; + auto peer_data = std::make_shared(recvbuf->get() + i * buffer_length, buffer_length, noop_deleter); recv(peer_data, i); } } } } -void FMI::Comm::Channel::scatter(channel_data sendbuf, channel_data recvbuf, FMI::Utils::peer_num root) { +void FMI::Comm::Channel::gatherv(std::shared_ptr sendbuf, + std::shared_ptr recvbuf, + FMI::Utils::peer_num root, + const std::vector& recvcounts, + const std::vector& displs) { + // Default implementation - subclasses can provide optimized versions + if (peer_id != root) { + send(sendbuf, root); + } else { + // Copy own data + std::memcpy(recvbuf->get() + displs[root], sendbuf->get(), recvcounts[root]); + // Receive from others + for (unsigned int i = 0; i < num_peers; i++) { + if (i != root) { + auto peer_data = std::make_shared(recvbuf->get() + displs[i], recvcounts[i], noop_deleter); + recv(peer_data, i); + } + } + } +} + +void FMI::Comm::Channel::gatherv(std::shared_ptr sendbuf, + std::shared_ptr recvbuf, + FMI::Utils::peer_num root, + const std::vector& recvcounts, + const std::vector& displs, + FMI::Utils::Mode mode, + std::function callback) { + // Default: just call blocking version + gatherv(sendbuf, recvbuf, root, recvcounts, displs); + if (callback) { + callback(Utils::SUCCESS, "", nullptr); + } +} + +void FMI::Comm::Channel::allgather(std::shared_ptr sendbuf, + std::shared_ptr recvbuf, + FMI::Utils::peer_num root) { + // Default: gather + bcast + gather(sendbuf, recvbuf, root); + bcast(recvbuf, root); +} + +void FMI::Comm::Channel::allgather(std::shared_ptr sendbuf, + std::shared_ptr recvbuf, + FMI::Utils::peer_num root, FMI::Utils::Mode mode, + std::function callback) { + // Default: just call blocking version + allgather(sendbuf, recvbuf, root); + if (callback) { + callback(Utils::SUCCESS, "", nullptr); + } +} + +void FMI::Comm::Channel::allgatherv(std::shared_ptr sendbuf, + std::shared_ptr recvbuf, + FMI::Utils::peer_num root, + const std::vector& recvcounts, + const std::vector& displs) { + // Default: gatherv + bcast + gatherv(sendbuf, recvbuf, root, recvcounts, displs); + bcast(recvbuf, root); +} + +void FMI::Comm::Channel::allgatherv(std::shared_ptr sendbuf, + std::shared_ptr recvbuf, + FMI::Utils::peer_num root, + const std::vector& recvcounts, + const std::vector& displs, + FMI::Utils::Mode mode, + std::function callback) { + // Default: just call blocking version + allgatherv(sendbuf, recvbuf, root, recvcounts, displs); + if (callback) { + callback(Utils::SUCCESS, "", nullptr); + } +} + +void FMI::Comm::Channel::scatter(std::shared_ptr sendbuf, std::shared_ptr recvbuf, FMI::Utils::peer_num root) { if (peer_id == root) { - auto buffer_length = recvbuf.len; - for (int i = 0; i < num_peers; i++) { + auto buffer_length = recvbuf->len; + for (unsigned int i = 0; i < num_peers; i++) { if (i == root) { - std::memcpy(recvbuf.buf, sendbuf.buf + root * buffer_length, buffer_length); + std::memcpy(recvbuf->get(), sendbuf->get() + root * buffer_length, buffer_length); } else { - channel_data peer_data {sendbuf.buf + i * buffer_length, buffer_length}; + auto peer_data = std::make_shared(sendbuf->get() + i * buffer_length, buffer_length, noop_deleter); send(peer_data, i); } } @@ -48,7 +129,18 @@ void FMI::Comm::Channel::scatter(channel_data sendbuf, channel_data recvbuf, FMI } } -void FMI::Comm::Channel::allreduce(channel_data sendbuf, channel_data recvbuf, raw_function f) { +void FMI::Comm::Channel::allreduce(std::shared_ptr sendbuf, std::shared_ptr recvbuf, raw_function f) { reduce(sendbuf, recvbuf, 0, f); bcast(recvbuf, 0); } + +void FMI::Comm::Channel::bcast(std::shared_ptr buf, FMI::Utils::peer_num root, + FMI::Utils::Mode mode, + std::function callback) { + // Default: just call blocking version + bcast(buf, root); + if (callback) { + callback(Utils::SUCCESS, "", nullptr); + } +} diff --git a/src/comm/ClientServer.cpp b/src/comm/ClientServer.cpp index ef9f5c8..8806bb6 100644 --- a/src/comm/ClientServer.cpp +++ b/src/comm/ClientServer.cpp @@ -3,7 +3,7 @@ #include #include -void FMI::Comm::ClientServer::send(channel_data buf, FMI::Utils::peer_num dest) { +void FMI::Comm::ClientServer::send(std::shared_ptr buf, FMI::Utils::peer_num dest) { auto num_operation_entry = num_operations.find("send" + std::to_string(dest)); unsigned int operation_num; if (num_operation_entry == num_operations.end()) { @@ -17,7 +17,18 @@ void FMI::Comm::ClientServer::send(channel_data buf, FMI::Utils::peer_num dest) upload(buf, file_name); } -void FMI::Comm::ClientServer::recv(channel_data buf, FMI::Utils::peer_num dest) { +void FMI::Comm::ClientServer::send(std::shared_ptr buf, FMI::Utils::peer_num dest, + FMI::Utils::fmiContext* context, FMI::Utils::Mode mode, + std::function callback) { + // ClientServer doesn't support true non-blocking - just call blocking version + send(buf, dest); + if (callback) { + callback(Utils::SUCCESS, "", context); + } +} + +void FMI::Comm::ClientServer::recv(std::shared_ptr buf, FMI::Utils::peer_num dest) { auto num_operation_entry = num_operations.find("recv" + std::to_string(dest)); unsigned int operation_num; if (num_operation_entry == num_operations.end()) { @@ -31,7 +42,18 @@ void FMI::Comm::ClientServer::recv(channel_data buf, FMI::Utils::peer_num dest) download(buf, file_name); } -void FMI::Comm::ClientServer::bcast(channel_data buf, FMI::Utils::peer_num root) { +void FMI::Comm::ClientServer::recv(std::shared_ptr buf, FMI::Utils::peer_num src, + FMI::Utils::fmiContext* context, FMI::Utils::Mode mode, + std::function callback) { + // ClientServer doesn't support true non-blocking - just call blocking version + recv(buf, src); + if (callback) { + callback(Utils::SUCCESS, "", context); + } +} + +void FMI::Comm::ClientServer::bcast(std::shared_ptr buf, FMI::Utils::peer_num root) { std::string file_name = comm_name + std::to_string(root) + "_bcast_" + std::to_string(num_operations["bcast"]); num_operations["bcast"]++; if (peer_id == root) { @@ -47,14 +69,15 @@ void FMI::Comm::ClientServer::barrier() { std::string file_name = comm_name + std::to_string(peer_id) + barrier_suffix; num_operations["barrier"]++; char b = '1'; - upload({&b, sizeof(b)}, file_name); + auto buf = std::make_shared(&b, sizeof(b), noop_deleter); + upload(buf, file_name); unsigned int elapsed_time = 0; while (elapsed_time < max_timeout) { auto objects = get_object_names(); auto has_barrier_suffix = [barrier_suffix] (const std::string& s){return s.size() > barrier_suffix.size() && s.compare(s.size() - barrier_suffix.size(), barrier_suffix.size(), barrier_suffix) == 0 ;}; auto num_arrived = std::count_if(objects.begin(), objects.end(), has_barrier_suffix); - if (num_arrived >= num_peers) { + if (num_arrived >= (long)num_peers) { return; } else { elapsed_time += timeout; @@ -64,13 +87,18 @@ void FMI::Comm::ClientServer::barrier() { throw Utils::Timeout(); } +FMI::Utils::EventProcessStatus FMI::Comm::ClientServer::channel_event_progress(Utils::Operation op) { + // ClientServer operations are synchronous + return Utils::NOOP; +} + void FMI::Comm::ClientServer::finalize() { for (const auto& object_name : created_objects) { delete_object(object_name); } } -void FMI::Comm::ClientServer::download(channel_data buf, std::string name) { +void FMI::Comm::ClientServer::download(std::shared_ptr buf, std::string name) { unsigned int elapsed_time = 0; while (elapsed_time < max_timeout) { bool success = download_object(buf, name); @@ -84,38 +112,39 @@ void FMI::Comm::ClientServer::download(channel_data buf, std::string name) { throw Utils::Timeout(); } -void FMI::Comm::ClientServer::upload(channel_data buf, std::string name) { +void FMI::Comm::ClientServer::upload(std::shared_ptr buf, std::string name) { created_objects.push_back(name); upload_object(buf, name); } -void FMI::Comm::ClientServer::reduce(channel_data sendbuf, channel_data recvbuf, FMI::Utils::peer_num root, raw_function f) { +void FMI::Comm::ClientServer::reduce(std::shared_ptr sendbuf, std::shared_ptr recvbuf, FMI::Utils::peer_num root, raw_function f) { if (peer_id == root) { bool left_to_right = !(f.commutative && f.associative); std::vector received(num_peers, false); std::vector applied(num_peers, false); - auto buffer_length = sendbuf.len; + auto buffer_length = sendbuf->len; std::vector data(buffer_length * num_peers); - std::memcpy(reinterpret_cast(recvbuf.buf), sendbuf.buf, buffer_length); + std::memcpy(reinterpret_cast(recvbuf->get()), sendbuf->get(), buffer_length); received[root] = true; applied[root] = true; unsigned int elapsed_time = 0; while (elapsed_time < max_timeout && std::any_of(applied.begin(), applied.end(), [] (bool v) { return !v; }) ) { // Receive all values - for (int i = 0; i < num_peers; i++) { + for (unsigned int i = 0; i < num_peers; i++) { if (received[i]) { continue; } std::string file_name = comm_name + std::to_string(i) + "_reduce_" + std::to_string(num_operations["reduce"]); - if (download_object({data.data() + i * buffer_length, buffer_length}, file_name)) { + auto peer_buf = std::make_shared(data.data() + i * buffer_length, buffer_length, noop_deleter); + if (download_object(peer_buf, file_name)) { received[i] = true; } } // Apply function where possible bool all_left_applied = true; - for (int i = 0; i < num_peers; i++) { + for (unsigned int i = 0; i < num_peers; i++) { if (received[i] && !applied[i] && (!left_to_right || all_left_applied)) { - f.f(recvbuf.buf, data.data() + i * buffer_length); + f.f(recvbuf->get(), data.data() + i * buffer_length); applied[i] = true; } else if (!received[i]) { all_left_applied = false; @@ -136,7 +165,7 @@ void FMI::Comm::ClientServer::reduce(channel_data sendbuf, channel_data recvbuf, } } -void FMI::Comm::ClientServer::scan(channel_data sendbuf, channel_data recvbuf, raw_function f) { +void FMI::Comm::ClientServer::scan(std::shared_ptr sendbuf, std::shared_ptr recvbuf, raw_function f) { if (peer_id != num_peers - 1) { std::string file_name = comm_name + std::to_string(peer_id) + "_scan_" + std::to_string(num_operations["scan"]); upload(sendbuf, file_name); @@ -145,28 +174,29 @@ void FMI::Comm::ClientServer::scan(channel_data sendbuf, channel_data recvbuf, r auto num_data = peer_id + 1; std::vector received(num_data, false); std::vector applied(num_data, false); - auto buffer_length = sendbuf.len; + auto buffer_length = sendbuf->len; std::vector data(buffer_length * num_data); - std::memcpy(reinterpret_cast(recvbuf.buf), sendbuf.buf, buffer_length); + std::memcpy(reinterpret_cast(recvbuf->get()), sendbuf->get(), buffer_length); received[peer_id] = true; applied[peer_id] = true; unsigned int elapsed_time = 0; while (elapsed_time < max_timeout && std::any_of(applied.begin(), applied.end(), [] (bool v) { return !v; }) ) { // Receive all values - for (int i = 0; i < num_data; i++) { + for (unsigned int i = 0; i < num_data; i++) { if (received[i]) { continue; } std::string file_name = comm_name + std::to_string(i) + "_scan_" + std::to_string(num_operations["scan"]); - if (download_object({data.data() + i * buffer_length, buffer_length}, file_name)) { + auto peer_buf = std::make_shared(data.data() + i * buffer_length, buffer_length, noop_deleter); + if (download_object(peer_buf, file_name)) { received[i] = true; } } // Apply function where possible bool all_left_applied = true; - for (int i = 0; i < num_peers; i++) { + for (unsigned int i = 0; i < num_peers; i++) { if (received[i] && !applied[i] && (!left_to_right || all_left_applied)) { - f.f(recvbuf.buf, data.data() + i * buffer_length); + f.f(recvbuf->get(), data.data() + i * buffer_length); applied[i] = true; } else if (!received[i]) { all_left_applied = false; @@ -201,6 +231,7 @@ double FMI::Comm::ClientServer::get_operation_latency(FMI::Utils::OperationInfo return upload + download; } case Utils::gather: + case Utils::gatherv: return get_latency(num_peers - 1, 1, size_in_bytes); case Utils::scatter: return get_latency(1, num_peers - 1, size_in_bytes); @@ -212,6 +243,13 @@ double FMI::Comm::ClientServer::get_operation_latency(FMI::Utils::OperationInfo double bcast = get_latency(1, num_peers - 1, size_in_bytes); return reduction + bcast; } + case Utils::allgather: + case Utils::allgatherv: + { + double gather = get_latency(num_peers - 1, 1, size_in_bytes); + double bcast = get_latency(1, num_peers - 1, num_peers * size_in_bytes); + return gather + bcast; + } case Utils::scan: // Pattern is parallel (num_peers - 1, 1), (num_peers - 2, 1), ... -> Slowest one is (num_peers - 1, 1) return get_latency(num_peers - 1, 1, size_in_bytes); @@ -233,6 +271,7 @@ double FMI::Comm::ClientServer::get_operation_price(FMI::Utils::OperationInfo op return upload + download; } case Utils::gather: + case Utils::gatherv: return get_price(num_peers - 1, 1, size_in_bytes); case Utils::scatter: return get_price(1, num_peers - 1, size_in_bytes); @@ -244,13 +283,20 @@ double FMI::Comm::ClientServer::get_operation_price(FMI::Utils::OperationInfo op double bcast = get_price(1, num_peers - 1, size_in_bytes); return reduction + bcast; } + case Utils::allgather: + case Utils::allgatherv: + { + double gather = get_price(num_peers - 1, 1, size_in_bytes); + double bcast = get_price(1, num_peers - 1, num_peers * size_in_bytes); + return gather + bcast; + } case Utils::scan: double costs = 0.; // N - 1 uploads with varying number of consumers - for (int i = 1; i < num_peers; i++) { + for (unsigned int i = 1; i < num_peers; i++) { costs += get_latency(1, num_peers - i, size_in_bytes); } return costs; } throw std::runtime_error("Operation not implemented"); -} +} \ No newline at end of file diff --git a/src/comm/Direct.cpp b/src/comm/Direct.cpp index 1dc8d64..cb73973 100644 --- a/src/comm/Direct.cpp +++ b/src/comm/Direct.cpp @@ -5,16 +5,69 @@ #include #include #include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include FMI::Comm::Direct::Direct(std::map params, std::map model_params) { + struct addrinfo hints, *res, *p; + int status; + char ipstr[INET6_ADDRSTRLEN]; + hostname = params["host"]; port = std::stoi(params["port"]); + if (model_params["resolve_host_dns"] == "true") { + + memset(&hints, 0, sizeof hints); + hints.ai_family = AF_INET; + hints.ai_socktype = SOCK_STREAM; + + if ((status = getaddrinfo(hostname.c_str(), nullptr, &hints, &res)) != 0) { + std::cerr << "getaddrinfo error: " << gai_strerror(status) << std::endl; + } else { + // Iterate through the result list and convert each address to a string + for(p = res; p != nullptr; p = p->ai_next) { + void *addr; + + // Get the pointer to the address itself, + struct sockaddr_in *ipv4 = (struct sockaddr_in *)p->ai_addr; + addr = &(ipv4->sin_addr); + + // Convert the IP to a string and print it: + inet_ntop(p->ai_family, addr, ipstr, sizeof ipstr); + std::cout << " resolved dns: " << ipstr << std::endl; + } + + freeaddrinfo(res); // Free the linked list + hostname = ipstr; + + } + + resolve_host_dns = true; + } else { + resolve_host_dns = false; + } + max_timeout = std::stoi(params["max_timeout"]); + std::cout << "max_timeout set to: " << max_timeout << std::endl; bandwidth = std::stod(model_params["bandwidth"]); + overhead = std::stod(model_params["overhead"]); + transfer_price = std::stod(model_params["transfer_price"]); + vm_price = std::stod(model_params["vm_price"]); + requests_per_hour = std::stoi(model_params["requests_per_hour"]); + + if (model_params["include_infrastructure_costs"] == "true") { include_infrastructure_costs = true; } else { @@ -22,9 +75,31 @@ FMI::Comm::Direct::Direct(std::map params, std::map(num_peers, -1); + } +} + +int FMI::Comm::Direct::getMaxTimeout() { + return max_timeout; +} + +void FMI::Comm::Direct::send_object(std::shared_ptr buf, Utils::peer_num rcpt_id) { check_socket(rcpt_id, comm_name + std::to_string(peer_id) + "_" + std::to_string(rcpt_id)); - long sent = ::send(sockets[rcpt_id], buf.buf, buf.len, 0); + long sent = ::send(sockets[Utils::BLOCKING][rcpt_id], buf->get(), buf->len, 0); if (sent == -1) { if (errno == EAGAIN) { throw Utils::Timeout(); @@ -33,10 +108,10 @@ void FMI::Comm::Direct::send_object(channel_data buf, Utils::peer_num rcpt_id) { } } -void FMI::Comm::Direct::recv_object(channel_data buf, Utils::peer_num sender_id) { +void FMI::Comm::Direct::recv_object(std::shared_ptr buf, Utils::peer_num sender_id) { check_socket(sender_id, comm_name + std::to_string(sender_id) + "_" + std::to_string(peer_id)); - long received = ::recv(sockets[sender_id], buf.buf, buf.len, MSG_WAITALL); - if (received == -1 || received < buf.len) { + long received = ::recv(sockets[Utils::BLOCKING][sender_id], buf->get(), buf->len, MSG_WAITALL); + if (received == -1 || received < (long)buf->len) { if (errno == EAGAIN) { throw Utils::Timeout(); } @@ -44,13 +119,170 @@ void FMI::Comm::Direct::recv_object(channel_data buf, Utils::peer_num sender_id) } } +void FMI::Comm::Direct::send_object(std::shared_ptr state, Utils::peer_num rcpt_id, Utils::Mode mode) { + if (mode == Utils::BLOCKING) { + send_object(state->request, rcpt_id); + if (state->callbackResult) { + state->callbackResult(Utils::SUCCESS, "", state->context); + } + return; + } + + // Non-blocking mode + std::string pair_name = get_pairing_name(peer_id, rcpt_id, mode); + check_socket_nbx(rcpt_id, pair_name); + + int sock = sockets[Utils::NONBLOCKING][rcpt_id]; + state->deadline = std::chrono::steady_clock::now() + std::chrono::milliseconds(max_timeout); + + // Try to send what we can + ssize_t sent = ::send(sock, state->request->get() + state->processed, + state->request->len - state->processed, MSG_DONTWAIT); + + if (sent > 0) { + state->processed += sent; + } + + if (state->processed >= state->request->len) { + // Complete + if (state->callbackResult) { + state->callbackResult(Utils::SUCCESS, "", state->context); + } + } else { + // Register for progress + io_states[Utils::send][sock] = state; + } +} + +void FMI::Comm::Direct::recv_object(std::shared_ptr state, Utils::peer_num sender_id, Utils::Mode mode) { + if (mode == Utils::BLOCKING) { + recv_object(state->request, sender_id); + if (state->callbackResult) { + state->callbackResult(Utils::SUCCESS, "", state->context); + } + return; + } + + // Non-blocking mode + std::string pair_name = get_pairing_name(sender_id, peer_id, mode); + check_socket_nbx(sender_id, pair_name); + + int sock = sockets[Utils::NONBLOCKING][sender_id]; + state->deadline = std::chrono::steady_clock::now() + std::chrono::milliseconds(max_timeout); + + // Try to receive what we can + ssize_t recvd = ::recv(sock, state->request->get() + state->processed, + state->request->len - state->processed, MSG_DONTWAIT); + + if (recvd > 0) { + state->processed += recvd; + } + + if (state->processed >= state->request->len) { + // Complete + if (state->callbackResult) { + state->callbackResult(Utils::SUCCESS, "", state->context); + } + } else { + // Register for progress + io_states[Utils::recv][sock] = state; + } +} + +FMI::Utils::EventProcessStatus FMI::Comm::Direct::channel_event_progress(Utils::Operation op) { + if (io_states.find(op) == io_states.end() || io_states[op].empty()) { + return Utils::EMPTY; + } + + auto& states = io_states[op]; + bool any_processed = false; + + for (auto it = states.begin(); it != states.end(); ) { + int sock = it->first; + auto& state = it->second; + + // Check timeout + if (std::chrono::steady_clock::now() > state->deadline) { + if (state->callbackResult) { + state->callbackResult(Utils::NBX_TIMEOUT, "Operation timed out", state->context); + } + it = states.erase(it); + continue; + } + + it = handle_event(it, states, op); + any_processed = true; + } + + return any_processed ? Utils::PROCESSING : Utils::NOOP; +} + +std::unordered_map>::iterator +FMI::Comm::Direct::handle_event(std::unordered_map>::iterator it, + std::unordered_map>& states, + Utils::Operation op) { + if (it == states.end()) { + return it; + } + + int socketfd = it->first; + auto& state = it->second; + + // Use poll to check readiness + struct pollfd pfd; + pfd.fd = socketfd; + pfd.events = (op == Utils::send) ? POLLOUT : POLLIN; + pfd.revents = 0; + + int ret = poll(&pfd, 1, 0); // Non-blocking poll + if (ret <= 0) { + return ++it; + } + + if (pfd.revents & POLLOUT) { + // Ready for writing + ssize_t sent = ::send(socketfd, state->request->get() + state->processed, + state->request->len - state->processed, MSG_DONTWAIT); + if (sent > 0) { + state->processed += sent; + } + } + + if (pfd.revents & POLLIN) { + // Ready for reading + ssize_t recvd = ::recv(socketfd, state->request->get() + state->processed, + state->request->len - state->processed, MSG_DONTWAIT); + if (recvd > 0) { + state->processed += recvd; + } else if (recvd == 0) { + // Connection closed + if (state->callbackResult) { + state->callbackResult(Utils::CONNECTION_CLOSED_BY_PEER, "Connection closed", state->context); + } + return states.erase(it); + } + } + + // Check if complete + if (state->processed >= state->request->len) { + if (state->callbackResult) { + state->callbackResult(Utils::SUCCESS, "", state->context); + } + if (state->callback) { + state->callback(); + } + return states.erase(it); + } + return ++it; +} + void FMI::Comm::Direct::check_socket(FMI::Utils::peer_num partner_id, std::string pair_name) { - if (sockets.empty()) { - sockets = std::vector(num_peers, -1); + if (sockets.find(Utils::BLOCKING) == sockets.end() || sockets[Utils::BLOCKING].empty()) { + sockets[Utils::BLOCKING] = std::vector(num_peers, -1); } - if (sockets[partner_id] == -1) { + if (sockets[Utils::BLOCKING][partner_id] == -1) { try { - sockets[partner_id] = pair(pair_name, hostname, port, max_timeout); + sockets[Utils::BLOCKING][partner_id] = pair(pair_name, hostname, port, max_timeout); } catch (Timeout) { throw Utils::Timeout(); } @@ -58,18 +290,54 @@ void FMI::Comm::Direct::check_socket(FMI::Utils::peer_num partner_id, std::strin struct timeval timeout; timeout.tv_sec = max_timeout / 1000; timeout.tv_usec = (max_timeout % 1000) * 1000; - setsockopt(sockets[partner_id], SOL_SOCKET, SO_RCVTIMEO, (const char*)&timeout, sizeof timeout); - setsockopt(sockets[partner_id], SOL_SOCKET, SO_SNDTIMEO, (const char*)&timeout, sizeof timeout); + setsockopt(sockets[Utils::BLOCKING][partner_id], SOL_SOCKET, SO_RCVTIMEO, (const char*)&timeout, sizeof timeout); + setsockopt(sockets[Utils::BLOCKING][partner_id], SOL_SOCKET, SO_SNDTIMEO, (const char*)&timeout, sizeof timeout); // Disable Nagle algorithm to avoid 40ms TCP ack delays int one = 1; // SOL_TCP not defined on macOS #if !defined(SOL_TCP) && defined(IPPROTO_TCP) #define SOL_TCP IPPROTO_TCP #endif - setsockopt(sockets[partner_id], SOL_TCP, TCP_NODELAY, &one, sizeof(one)); + setsockopt(sockets[Utils::BLOCKING][partner_id], SOL_TCP, TCP_NODELAY, &one, sizeof(one)); } } +void FMI::Comm::Direct::check_socket_nbx(FMI::Utils::peer_num partner_id, std::string pair_name) { + if (sockets.find(Utils::NONBLOCKING) == sockets.end() || sockets[Utils::NONBLOCKING].empty()) { + sockets[Utils::NONBLOCKING] = std::vector(num_peers, -1); + } + if (sockets[Utils::NONBLOCKING][partner_id] == -1) { + try { + sockets[Utils::NONBLOCKING][partner_id] = pair(pair_name, hostname, port, max_timeout); + } catch (Timeout) { + throw Utils::Timeout(); + } + + int sock = sockets[Utils::NONBLOCKING][partner_id]; + + // Set non-blocking mode + int flags = fcntl(sock, F_GETFL, 0); + if (flags == -1) { + BOOST_LOG_TRIVIAL(error) << "Failed to get socket flags"; + } + if (fcntl(sock, F_SETFL, flags | O_NONBLOCK) == -1) { + BOOST_LOG_TRIVIAL(error) << "Failed to set non-blocking mode"; + } + + // Disable Nagle algorithm + int one = 1; + #if !defined(SOL_TCP) && defined(IPPROTO_TCP) + #define SOL_TCP IPPROTO_TCP + #endif + setsockopt(sock, SOL_TCP, TCP_NODELAY, &one, sizeof(one)); + } +} + +std::string FMI::Comm::Direct::get_pairing_name(Utils::peer_num a, Utils::peer_num b, Utils::Mode mode) { + std::string mode_suffix = (mode == Utils::NONBLOCKING) ? "_nbx" : ""; + return comm_name + std::to_string(a) + "_" + std::to_string(b) + mode_suffix; +} + double FMI::Comm::Direct::get_latency(Utils::peer_num producer, Utils::peer_num consumer, std::size_t size_in_bytes) { double agg_bandwidth = bandwidth; double trans_time = producer * consumer * ((double) size_in_bytes / 1000000.) / agg_bandwidth; @@ -83,4 +351,4 @@ double FMI::Comm::Direct::get_price(Utils::peer_num producer, Utils::peer_num co total_costs += 1. / requests_per_hour * vm_price; } return total_costs; -} +} \ No newline at end of file diff --git a/src/comm/PeerToPeer.cpp b/src/comm/PeerToPeer.cpp index 26ef25d..a5be698 100644 --- a/src/comm/PeerToPeer.cpp +++ b/src/comm/PeerToPeer.cpp @@ -3,15 +3,39 @@ #include #include -void FMI::Comm::PeerToPeer::send(channel_data buf, FMI::Utils::peer_num dest) { +void FMI::Comm::PeerToPeer::send(std::shared_ptr buf, FMI::Utils::peer_num dest) { send_object(buf, dest); } -void FMI::Comm::PeerToPeer::recv(channel_data buf, FMI::Utils::peer_num src) { +void FMI::Comm::PeerToPeer::recv(std::shared_ptr buf, FMI::Utils::peer_num src) { recv_object(buf, src); } -void FMI::Comm::PeerToPeer::bcast(channel_data buf, FMI::Utils::peer_num root) { +void FMI::Comm::PeerToPeer::send(std::shared_ptr buf, FMI::Utils::peer_num dest, + FMI::Utils::fmiContext* context, FMI::Utils::Mode mode, + std::function callback) { + auto state = std::make_shared(); + state->setRequest(buf); + state->context = context; + state->callbackResult = callback; + state->operation = Utils::send; + send_object(state, dest, mode); +} + +void FMI::Comm::PeerToPeer::recv(std::shared_ptr buf, FMI::Utils::peer_num src, + FMI::Utils::fmiContext* context, FMI::Utils::Mode mode, + std::function callback) { + auto state = std::make_shared(); + state->setRequest(buf); + state->context = context; + state->callbackResult = callback; + state->operation = Utils::recv; + recv_object(state, src, mode); +} + +void FMI::Comm::PeerToPeer::bcast(std::shared_ptr buf, FMI::Utils::peer_num root) { int rounds = ceil(log2(num_peers)); Utils::peer_num trans_peer_id = transform_peer_id(peer_id, root, true); for (int i = rounds - 1; i >= 0; i--) { @@ -26,13 +50,26 @@ void FMI::Comm::PeerToPeer::bcast(channel_data buf, FMI::Utils::peer_num root) { } } +void FMI::Comm::PeerToPeer::bcast(std::shared_ptr buf, FMI::Utils::peer_num root, + FMI::Utils::Mode mode, + std::function callback) { + // For now, just call blocking version + bcast(buf, root); + if (callback) { + callback(Utils::SUCCESS, "", nullptr); + } +} + void FMI::Comm::PeerToPeer::barrier() { auto nop = [] (char* a, char* b) {}; - char send = 1; - allreduce({&send, sizeof(char)}, {&send, sizeof(char)}, {nop, true, true}); + char send_val = 1; + auto sendbuf = std::make_shared(&send_val, sizeof(char), noop_deleter); + auto recvbuf = std::make_shared(&send_val, sizeof(char), noop_deleter); + allreduce(sendbuf, recvbuf, {nop, true, true}); } -void FMI::Comm::PeerToPeer::reduce(channel_data sendbuf, channel_data recvbuf, FMI::Utils::peer_num root, raw_function f) { +void FMI::Comm::PeerToPeer::reduce(std::shared_ptr sendbuf, std::shared_ptr recvbuf, FMI::Utils::peer_num root, raw_function f) { bool left_to_right = !(f.commutative && f.associative); if (left_to_right) { reduce_ltr(sendbuf, recvbuf, root, f); @@ -41,49 +78,50 @@ void FMI::Comm::PeerToPeer::reduce(channel_data sendbuf, channel_data recvbuf, F } } -void FMI::Comm::PeerToPeer::reduce_ltr(channel_data sendbuf, channel_data recvbuf, FMI::Utils::peer_num root, const raw_function& f) { +void FMI::Comm::PeerToPeer::reduce_ltr(std::shared_ptr sendbuf, std::shared_ptr recvbuf, FMI::Utils::peer_num root, const raw_function& f) { if (peer_id == root) { - std::size_t tmpbuf_len = sendbuf.len * num_peers; - char* tmpbuf = new char[tmpbuf_len]; - gather(sendbuf, {tmpbuf, tmpbuf_len}, root); - std::memcpy(reinterpret_cast(recvbuf.buf), tmpbuf, sendbuf.len); - for (std::size_t i = sendbuf.len; i < tmpbuf_len; i += sendbuf.len) { - f.f(recvbuf.buf, tmpbuf + i); + std::size_t tmpbuf_len = sendbuf->len * num_peers; + auto tmpbuf = std::make_shared(tmpbuf_len); + gather(sendbuf, tmpbuf, root); + std::memcpy(reinterpret_cast(recvbuf->get()), tmpbuf->get(), sendbuf->len); + for (std::size_t i = sendbuf->len; i < tmpbuf_len; i += sendbuf->len) { + f.f(recvbuf->get(), tmpbuf->get() + i); } - delete[] tmpbuf; } else { - gather(sendbuf, {}, root); + gather(sendbuf, std::make_shared(), root); } } -void FMI::Comm::PeerToPeer::reduce_no_order(channel_data sendbuf, channel_data recvbuf, FMI::Utils::peer_num root, const raw_function& f) { +void FMI::Comm::PeerToPeer::reduce_no_order(std::shared_ptr sendbuf, std::shared_ptr recvbuf, FMI::Utils::peer_num root, const raw_function& f) { int rounds = ceil(log2(num_peers)); Utils::peer_num trans_peer_id = transform_peer_id(peer_id, root, true); + + std::shared_ptr local_recvbuf; if (peer_id != root) { - recvbuf.buf = new char[sendbuf.len]; - recvbuf.len = sendbuf.len; + local_recvbuf = std::make_shared(sendbuf->len); + } else { + local_recvbuf = recvbuf; } + for (int i = 0; i < rounds; i++) { Utils::peer_num src = trans_peer_id + (Utils::peer_num) std::pow(2, i); if (trans_peer_id % (int) std::pow(2, i + 1) == 0 && src < num_peers) { Utils::peer_num real_src = transform_peer_id(src, root, false); - recv({recvbuf.buf, recvbuf.len}, real_src); - f.f(sendbuf.buf, recvbuf.buf); + recv(local_recvbuf, real_src); + f.f(sendbuf->get(), local_recvbuf->get()); } else if (trans_peer_id % (int) std::pow(2, i) == 0 && trans_peer_id % (int) std::pow(2, i + 1) != 0){ Utils::peer_num real_dst = transform_peer_id(trans_peer_id - (int) std::pow(2, i), root, false); - send({sendbuf.buf, sendbuf.len}, real_dst); + send(sendbuf, real_dst); } } if (peer_id == root) { - std::memcpy(recvbuf.buf, sendbuf.buf, sendbuf.len); - } else { - delete[] recvbuf.buf; + std::memcpy(recvbuf->get(), sendbuf->get(), sendbuf->len); } } -void FMI::Comm::PeerToPeer::allreduce(channel_data sendbuf, channel_data recvbuf, raw_function f) { +void FMI::Comm::PeerToPeer::allreduce(std::shared_ptr sendbuf, std::shared_ptr recvbuf, raw_function f) { bool left_to_right = !(f.commutative && f.associative); if (left_to_right) { reduce(sendbuf, recvbuf, 0, f); @@ -93,43 +131,43 @@ void FMI::Comm::PeerToPeer::allreduce(channel_data sendbuf, channel_data recvbuf } } -void FMI::Comm::PeerToPeer::allreduce_no_order(channel_data sendbuf, channel_data recvbuf, const raw_function &f) { +void FMI::Comm::PeerToPeer::allreduce_no_order(std::shared_ptr sendbuf, std::shared_ptr recvbuf, const raw_function &f) { // Non power of two N: First receive from processes with ID >= 2^ceil(log2(N)), send result after reduction int rounds = floor(log2(num_peers)); int nearest_power_two = (int) std::pow(2, rounds); - if (num_peers > nearest_power_two) { - if (peer_id < nearest_power_two && peer_id + nearest_power_two < num_peers) { + if (num_peers > (unsigned int)nearest_power_two) { + if (peer_id < (unsigned int)nearest_power_two && peer_id + nearest_power_two < num_peers) { recv(recvbuf, peer_id + nearest_power_two); - f.f(sendbuf.buf, recvbuf.buf); - } else if (peer_id >= nearest_power_two) { + f.f(sendbuf->get(), recvbuf->get()); + } else if (peer_id >= (unsigned int)nearest_power_two) { send(sendbuf, peer_id - nearest_power_two); } } - if (peer_id < nearest_power_two) { + if (peer_id < (unsigned int)nearest_power_two) { // Actual recursive doubling for (int i = 0; i < rounds; i++) { int peer = peer_id ^ (int) std::pow(2, i); - if (peer < peer_id) { + if (peer < (int)peer_id) { send(sendbuf, peer); recv(recvbuf, peer); } else { recv(recvbuf, peer); send(sendbuf, peer); } - f.f(sendbuf.buf, recvbuf.buf); + f.f(sendbuf->get(), recvbuf->get()); } } - if (num_peers > nearest_power_two) { - if (peer_id < nearest_power_two && peer_id + nearest_power_two < num_peers) { + if (num_peers > (unsigned int)nearest_power_two) { + if (peer_id < (unsigned int)nearest_power_two && peer_id + nearest_power_two < num_peers) { send(sendbuf, peer_id + nearest_power_two); - } else if (peer_id >= nearest_power_two) { + } else if (peer_id >= (unsigned int)nearest_power_two) { recv(sendbuf, peer_id - nearest_power_two); } } - std::memcpy(recvbuf.buf, sendbuf.buf, sendbuf.len); + std::memcpy(recvbuf->get(), sendbuf->get(), sendbuf->len); } -void FMI::Comm::PeerToPeer::scan(channel_data sendbuf, channel_data recvbuf, raw_function f) { +void FMI::Comm::PeerToPeer::scan(std::shared_ptr sendbuf, std::shared_ptr recvbuf, raw_function f) { bool left_to_right = !(f.commutative && f.associative); if (left_to_right) { scan_ltr(sendbuf, recvbuf, f); @@ -138,27 +176,27 @@ void FMI::Comm::PeerToPeer::scan(channel_data sendbuf, channel_data recvbuf, raw } } -void FMI::Comm::PeerToPeer::scan_ltr(channel_data sendbuf, channel_data recvbuf, const raw_function& f) { +void FMI::Comm::PeerToPeer::scan_ltr(std::shared_ptr sendbuf, std::shared_ptr recvbuf, const raw_function& f) { if (peer_id == 0) { send(sendbuf, 1); - std::memcpy(recvbuf.buf, sendbuf.buf, sendbuf.len); + std::memcpy(recvbuf->get(), sendbuf->get(), sendbuf->len); } else { recv(recvbuf, peer_id - 1); - f.f(recvbuf.buf, sendbuf.buf); + f.f(recvbuf->get(), sendbuf->get()); if (peer_id < num_peers - 1) { send(recvbuf, peer_id + 1); } } } -void FMI::Comm::PeerToPeer::scan_no_order(channel_data sendbuf, channel_data recvbuf, const raw_function& f) { +void FMI::Comm::PeerToPeer::scan_no_order(std::shared_ptr sendbuf, std::shared_ptr recvbuf, const raw_function& f) { int rounds = floor(log2(num_peers)); for (int i = 0; i < rounds; i ++) { - if ((peer_id & ((int) std::pow(2, i + 1) - 1)) == (int) std::pow(2, i + 1) - 1) { + if ((peer_id & ((int) std::pow(2, i + 1) - 1)) == (unsigned int)((int) std::pow(2, i + 1) - 1)) { Utils::peer_num src = peer_id - (int) std::pow(2, i); recv(recvbuf, src); - f.f(sendbuf.buf, recvbuf.buf); - } else if ((peer_id & ((int) std::pow(2, i) - 1)) == (int) std::pow(2, i) - 1) { + f.f(sendbuf->get(), recvbuf->get()); + } else if ((peer_id & ((int) std::pow(2, i) - 1)) == (unsigned int)((int) std::pow(2, i) - 1)) { Utils::peer_num dst = peer_id + (int) std::pow(2, i); if (dst < num_peers) { send(sendbuf, dst); @@ -167,27 +205,29 @@ void FMI::Comm::PeerToPeer::scan_no_order(channel_data sendbuf, channel_data rec } } for (int i = rounds; i > 0; i--) { - if ((peer_id & ((int) std::pow(2, i) - 1)) == (int) std::pow(2, i) - 1) { + if ((peer_id & ((int) std::pow(2, i) - 1)) == (unsigned int)((int) std::pow(2, i) - 1)) { Utils::peer_num dst = peer_id + (int) std::pow(2, i - 1); if (dst < num_peers) { send(sendbuf, dst); } - } else if ((peer_id & ((int) std::pow(2, i - 1) - 1)) == (int) std::pow(2, i - 1) - 1) { + } else if ((peer_id & ((int) std::pow(2, i - 1) - 1)) == (unsigned int)((int) std::pow(2, i - 1) - 1)) { int src = peer_id - (int) std::pow(2, i - 1); if (src > 0) { recv(recvbuf, src); - f.f(sendbuf.buf, recvbuf.buf); + f.f(sendbuf->get(), recvbuf->get()); } } } - std::memcpy(recvbuf.buf, sendbuf.buf, sendbuf.len); + std::memcpy(recvbuf->get(), sendbuf->get(), sendbuf->len); } -void FMI::Comm::PeerToPeer::gather(channel_data sendbuf, channel_data recvbuf, FMI::Utils::peer_num root) { +void FMI::Comm::PeerToPeer::gather(std::shared_ptr sendbuf, std::shared_ptr recvbuf, FMI::Utils::peer_num root) { int rounds = ceil(log2(num_peers)); Utils::peer_num trans_peer_id = transform_peer_id(peer_id, root, true); - std::size_t single_buffer_size = sendbuf.len; + std::size_t single_buffer_size = sendbuf->len; + // Find needed buffer size and allocate it + std::shared_ptr local_recvbuf; if (peer_id != root) { unsigned int peers_in_buffer = 1; for (int i = rounds - 1; i >= 0; i--) { @@ -196,11 +236,11 @@ void FMI::Comm::PeerToPeer::gather(channel_data sendbuf, channel_data recvbuf, F peers_in_buffer += std::min((Utils::peer_num) std::pow(2, i), num_peers - src); } } - recvbuf.buf = new char[peers_in_buffer * single_buffer_size]; - recvbuf.len = peers_in_buffer * single_buffer_size; - std::memcpy(recvbuf.buf, sendbuf.buf, single_buffer_size); + local_recvbuf = std::make_shared(peers_in_buffer * single_buffer_size); + std::memcpy(local_recvbuf->get(), sendbuf->get(), single_buffer_size); } else { - std::memcpy(recvbuf.buf + single_buffer_size * root, sendbuf.buf, single_buffer_size); + local_recvbuf = recvbuf; + std::memcpy(recvbuf->get() + single_buffer_size * root, sendbuf->get(), single_buffer_size); } for (int i = 0; i < rounds; i++) { @@ -212,78 +252,164 @@ void FMI::Comm::PeerToPeer::gather(channel_data sendbuf, channel_data recvbuf, F Utils::peer_num real_src = transform_peer_id(src, root, false); if (peer_id == root) { - if (real_src * single_buffer_size + buf_len > recvbuf.len) { + if (real_src * single_buffer_size + buf_len > recvbuf->len) { // Need to wraparound with temporary buffer - char *tmp = new char[buf_len]; - recv({tmp, buf_len}, real_src); - unsigned int length_end = recvbuf.len - real_src * single_buffer_size; // How many bytes to copy at end of buffer - std::memcpy(recvbuf.buf + real_src * single_buffer_size, tmp, length_end); - std::memcpy(recvbuf.buf, tmp + length_end, buf_len - length_end); - delete[] tmp; + auto tmp = std::make_shared(buf_len); + recv(tmp, real_src); + unsigned int length_end = recvbuf->len - real_src * single_buffer_size; // How many bytes to copy at end of buffer + std::memcpy(recvbuf->get() + real_src * single_buffer_size, tmp->get(), length_end); + std::memcpy(recvbuf->get(), tmp->get() + length_end, buf_len - length_end); } else { - recv({recvbuf.buf + real_src * single_buffer_size, buf_len}, real_src); + auto peer_buf = std::make_shared(recvbuf->get() + real_src * single_buffer_size, buf_len, noop_deleter); + recv(peer_buf, real_src); } } else { - recv({recvbuf.buf + (src - trans_peer_id) * single_buffer_size, buf_len}, real_src); + auto peer_buf = std::make_shared(local_recvbuf->get() + (src - trans_peer_id) * single_buffer_size, buf_len, noop_deleter); + recv(peer_buf, real_src); } } else if (trans_peer_id % (int) std::pow(2, i) == 0 && trans_peer_id % (int) std::pow(2, i + 1) != 0){ unsigned int responsible_peers = std::min((Utils::peer_num) std::pow(2, i), num_peers - trans_peer_id); std::size_t buf_len = responsible_peers * single_buffer_size; Utils::peer_num real_dst = transform_peer_id(trans_peer_id - (int) std::pow(2, i), root, false); - send({recvbuf.buf, buf_len}, real_dst); + auto send_buf = std::make_shared(local_recvbuf->get(), buf_len, noop_deleter); + send(send_buf, real_dst); } } +} + +void FMI::Comm::PeerToPeer::gatherv(std::shared_ptr sendbuf, + std::shared_ptr recvbuf, + FMI::Utils::peer_num root, + const std::vector& recvcounts, + const std::vector& displs) { + // Simple implementation: each non-root peer sends to root if (peer_id != root) { - delete[] recvbuf.buf; + send(sendbuf, root); + } else { + // Copy own data + std::memcpy(recvbuf->get() + displs[root], sendbuf->get(), recvcounts[root]); + // Receive from others + for (unsigned int i = 0; i < num_peers; i++) { + if (i != root) { + auto peer_buf = std::make_shared(recvbuf->get() + displs[i], recvcounts[i], noop_deleter); + recv(peer_buf, i); + } + } } } -void FMI::Comm::PeerToPeer::scatter(channel_data sendbuf, channel_data recvbuf, FMI::Utils::peer_num root) { +void FMI::Comm::PeerToPeer::gatherv(std::shared_ptr sendbuf, + std::shared_ptr recvbuf, + FMI::Utils::peer_num root, + const std::vector& recvcounts, + const std::vector& displs, + FMI::Utils::Mode mode, + std::function callback) { + // For now, just call blocking version + gatherv(sendbuf, recvbuf, root, recvcounts, displs); + if (callback) { + callback(Utils::SUCCESS, "", nullptr); + } +} + +void FMI::Comm::PeerToPeer::allgather(std::shared_ptr sendbuf, + std::shared_ptr recvbuf, + FMI::Utils::peer_num root) { + // Two-phase: gather + bcast + gather(sendbuf, recvbuf, root); + bcast(recvbuf, root); +} + +void FMI::Comm::PeerToPeer::allgather(std::shared_ptr sendbuf, + std::shared_ptr recvbuf, + FMI::Utils::peer_num root, FMI::Utils::Mode mode, + std::function callback) { + // For now, just call blocking version + allgather(sendbuf, recvbuf, root); + if (callback) { + callback(Utils::SUCCESS, "", nullptr); + } +} + +void FMI::Comm::PeerToPeer::allgatherv(std::shared_ptr sendbuf, + std::shared_ptr recvbuf, + FMI::Utils::peer_num root, + const std::vector& recvcounts, + const std::vector& displs) { + // Two-phase: gatherv + bcast + gatherv(sendbuf, recvbuf, root, recvcounts, displs); + bcast(recvbuf, root); +} + +void FMI::Comm::PeerToPeer::allgatherv(std::shared_ptr sendbuf, + std::shared_ptr recvbuf, + FMI::Utils::peer_num root, + const std::vector& recvcounts, + const std::vector& displs, + FMI::Utils::Mode mode, + std::function callback) { + // For now, just call blocking version + allgatherv(sendbuf, recvbuf, root, recvcounts, displs); + if (callback) { + callback(Utils::SUCCESS, "", nullptr); + } +} + +void FMI::Comm::PeerToPeer::scatter(std::shared_ptr sendbuf, std::shared_ptr recvbuf, FMI::Utils::peer_num root) { int rounds = ceil(log2(num_peers)); Utils::peer_num trans_peer_id = transform_peer_id(peer_id, root, true); - std::size_t single_buffer_size = recvbuf.len; + std::size_t single_buffer_size = recvbuf->len; + + std::shared_ptr local_sendbuf = sendbuf; + for (int i = rounds - 1; i >= 0; i--) { Utils::peer_num rcpt = trans_peer_id + (Utils::peer_num) std::pow(2, i); - if (trans_peer_id % (int) std::pow(2, i + 1) == 0 && rcpt < num_peers) { unsigned int responsible_peers = std::min((Utils::peer_num) std::pow(2, i), num_peers - rcpt); std::size_t buf_len = responsible_peers * single_buffer_size; Utils::peer_num real_rcpt = transform_peer_id(rcpt, root, false); if (peer_id == root) { - if (real_rcpt * single_buffer_size + buf_len > sendbuf.len) { + if (real_rcpt * single_buffer_size + buf_len > sendbuf->len) { // Wrapping around, need to allocate a temporary buffer - char* tmp = new char[buf_len]; - unsigned int length_end = sendbuf.len - real_rcpt * single_buffer_size; // How many bytes we need to send at end of buffer - std::memcpy(tmp, sendbuf.buf + real_rcpt * single_buffer_size, length_end); + auto tmp = std::make_shared(buf_len); + unsigned int length_end = sendbuf->len - real_rcpt * single_buffer_size; // How many bytes we need to send at end of buffer + std::memcpy(tmp->get(), sendbuf->get() + real_rcpt * single_buffer_size, length_end); // Copy rest from beginning - std::memcpy(tmp + length_end, sendbuf.buf, buf_len - length_end); - send({tmp, buf_len}, real_rcpt); - delete[] tmp; + std::memcpy(tmp->get() + length_end, sendbuf->get(), buf_len - length_end); + send(tmp, real_rcpt); } else { - send({sendbuf.buf + real_rcpt * single_buffer_size, buf_len}, real_rcpt); + auto send_buf = std::make_shared(sendbuf->get() + real_rcpt * single_buffer_size, buf_len, noop_deleter); + send(send_buf, real_rcpt); } } else { - send({sendbuf.buf + (rcpt - trans_peer_id) * single_buffer_size, buf_len}, real_rcpt); + auto send_buf = std::make_shared(local_sendbuf->get() + (rcpt - trans_peer_id) * single_buffer_size, buf_len, noop_deleter); + send(send_buf, real_rcpt); } } else if (trans_peer_id % (int) std::pow(2, i) == 0 && trans_peer_id % (int) std::pow(2, i + 1) != 0){ unsigned int responsible_peers = std::min((Utils::peer_num) std::pow(2, i), num_peers - trans_peer_id); std::size_t buf_len = responsible_peers * single_buffer_size; Utils::peer_num real_src = transform_peer_id(trans_peer_id - (int) std::pow(2, i), root, false); - sendbuf.buf = new char[buf_len]; - sendbuf.len = buf_len; - recv(sendbuf, real_src); + local_sendbuf = std::make_shared(buf_len); + recv(local_sendbuf, real_src); } } if (peer_id == root) { - std::memcpy(recvbuf.buf, sendbuf.buf + peer_id * single_buffer_size, single_buffer_size); + std::memcpy(recvbuf->get(), sendbuf->get() + peer_id * single_buffer_size, single_buffer_size); } else { - std::memcpy(recvbuf.buf, sendbuf.buf, single_buffer_size); - delete[] sendbuf.buf; + std::memcpy(recvbuf->get(), local_sendbuf->get(), single_buffer_size); } } +FMI::Utils::EventProcessStatus FMI::Comm::PeerToPeer::channel_event_progress(FMI::Utils::Operation op) { + // Default implementation - subclasses override for actual non-blocking support + return Utils::NOOP; +} + FMI::Utils::peer_num FMI::Comm::PeerToPeer::transform_peer_id(FMI::Utils::peer_num id, FMI::Utils::peer_num root, bool forward) { if (forward) { return (id + num_peers - root) % num_peers; // Transform s.t. root has id 0 @@ -304,6 +430,7 @@ double FMI::Comm::PeerToPeer::get_operation_latency(FMI::Utils::OperationInfo op return ceil(log2(num_peers)) * get_latency(1, 1, size_in_bytes); } // else, gather used case Utils::gather: + case Utils::gatherv: case Utils::scatter: { // ceil(log2(num_peers)) rounds, doubling buffer size in each round @@ -316,6 +443,19 @@ double FMI::Comm::PeerToPeer::get_operation_latency(FMI::Utils::OperationInfo op latency += get_latency(1, 1, rem_nodes * size_in_bytes); return latency; } + case Utils::allgather: + case Utils::allgatherv: + { + // gather + bcast + double latency = 0.; + for (int i = 1; i <= floor(log2(num_peers)); i++) { + latency += get_latency(1, 1, i * size_in_bytes); + } + int rem_nodes = num_peers - (int) std::pow(2, floor(log2(num_peers))); + latency += get_latency(1, 1, rem_nodes * size_in_bytes); + latency += ceil(log2(num_peers)) * get_latency(1, 1, num_peers * size_in_bytes); + return latency; + } case Utils::barrier: size_in_bytes = 1; case Utils::allreduce: @@ -366,6 +506,7 @@ double FMI::Comm::PeerToPeer::get_operation_price(FMI::Utils::OperationInfo op_i return comm_rounds * get_price(1, 1, size_in_bytes); } // else, gather used case Utils::gather: + case Utils::gatherv: case Utils::scatter: { double costs = 0.; @@ -375,6 +516,18 @@ double FMI::Comm::PeerToPeer::get_operation_price(FMI::Utils::OperationInfo op_i } return costs; } + case Utils::allgather: + case Utils::allgatherv: + { + // gather + bcast + double costs = 0.; + for (int i = 1; i <= ceil(log2(num_peers)); i++) { + costs += std::pow(2, floor(log2(num_peers)) - i) * get_price(1, 1, i * size_in_bytes); + } + double comm_rounds = num_peers - 1; + costs += comm_rounds * get_price(1, 1, num_peers * size_in_bytes); + return costs; + } case Utils::barrier: size_in_bytes = 1; case Utils::allreduce: @@ -403,4 +556,4 @@ double FMI::Comm::PeerToPeer::get_operation_price(FMI::Utils::OperationInfo op_i } throw std::runtime_error("Operation not implemented"); -} +} \ No newline at end of file diff --git a/src/comm/Redis.cpp b/src/comm/Redis.cpp index 74e36db..ca72dc8 100644 --- a/src/comm/Redis.cpp +++ b/src/comm/Redis.cpp @@ -31,23 +31,23 @@ FMI::Comm::Redis::~Redis() { redisFree(context); } -void FMI::Comm::Redis::upload_object(channel_data buf, std::string name) { +void FMI::Comm::Redis::upload_object(std::shared_ptr buf, std::string name) { std::string command = "SET " + name + " %b"; - auto* reply = (redisReply*) redisCommand(context, command.c_str(), buf.buf, buf.len); + auto* reply = (redisReply*) redisCommand(context, command.c_str(), buf->get(), buf->len); if (reply->type == REDIS_REPLY_ERROR) { BOOST_LOG_TRIVIAL(error) << "Error when uploading to Redis: " << reply->str; } freeReplyObject(reply); } -bool FMI::Comm::Redis::download_object(channel_data buf, std::string name) { +bool FMI::Comm::Redis::download_object(std::shared_ptr buf, std::string name) { std::string command = "GET " + name; auto* reply = (redisReply*) redisCommand(context, command.c_str()); if (reply->type == REDIS_REPLY_NIL || reply->type == REDIS_REPLY_ERROR) { freeReplyObject(reply); return false; } else { - std::memcpy(buf.buf, reply->str, std::min(buf.len, reply->len)); + std::memcpy(buf->get(), reply->str, std::min(buf->len, reply->len)); freeReplyObject(reply); return true; } @@ -63,7 +63,7 @@ std::vector FMI::Comm::Redis::get_object_names() { std::vector keys; std::string command = "KEYS *"; auto* reply = (redisReply*) redisCommand(context, command.c_str()); - for (int i = 0; i < reply->elements; i++) { + for (size_t i = 0; i < reply->elements; i++) { keys.emplace_back(reply->element[i]->str); } return keys; @@ -82,5 +82,4 @@ double FMI::Comm::Redis::get_price(Utils::peer_num producer, Utils::peer_num con total_costs += 1. / requests_per_hour * instance_price; } return total_costs; -} - +} \ No newline at end of file diff --git a/src/comm/S3.cpp b/src/comm/S3.cpp index 7dde3a5..b670fb5 100644 --- a/src/comm/S3.cpp +++ b/src/comm/S3.cpp @@ -24,8 +24,9 @@ FMI::Comm::S3::S3(std::map params, std::map(TAG); - client = Aws::MakeUnique(TAG, credentialsProvider, config); + //auto credentialsProvider = Aws::MakeShared(TAG); + //client = Aws::MakeUnique(TAG, credentialsProvider, config); + client = Aws::MakeUnique(TAG, config); } FMI::Comm::S3::~S3() { @@ -35,24 +36,24 @@ FMI::Comm::S3::~S3() { } } -bool FMI::Comm::S3::download_object(channel_data buf, std::string name) { +bool FMI::Comm::S3::download_object(std::shared_ptr buf, std::string name) { Aws::S3::Model::GetObjectRequest request; request.WithBucket(bucket_name).WithKey(name); auto outcome = client->GetObject(request); if (outcome.IsSuccess()) { auto& s = outcome.GetResult().GetBody(); - s.read(buf.buf, buf.len); + s.read(buf->get(), buf->len); return true; } else { return false; } } -void FMI::Comm::S3::upload_object(channel_data buf, std::string name) { +void FMI::Comm::S3::upload_object(std::shared_ptr buf, std::string name) { Aws::S3::Model::PutObjectRequest request; request.WithBucket(bucket_name).WithKey(name); - const std::shared_ptr data = Aws::MakeShared(TAG, buf.buf, buf.len); + const std::shared_ptr data = Aws::MakeShared(TAG, buf->get(), buf->len); request.SetBody(data); auto outcome = client->PutObject(request); @@ -100,5 +101,4 @@ double FMI::Comm::S3::get_price(Utils::peer_num producer, Utils::peer_num consum double expected_polls = (max_timeout / timeout) / 2; double download_costs = producer * consumer * expected_polls * download_price + producer * consumer * ((double) size_in_bytes / 1000000000.) * transfer_price; return upload_costs + download_costs; -} - +} \ No newline at end of file diff --git a/tests/channels.cpp b/tests/channels.cpp index 432251a..e4ffdca 100644 --- a/tests/channels.cpp +++ b/tests/channels.cpp @@ -7,6 +7,8 @@ #include #include #include +#include +#include BOOST_AUTO_TEST_SUITE(Channels); @@ -45,7 +47,7 @@ std::map redis_test_model_params = { std::map direct_test_params = { {"host", "127.0.0.1"}, {"port", "10000"}, - {"max_timeout", "1000"} + {"max_timeout", "5000"} }; std::map direct_test_model_params = { @@ -54,7 +56,8 @@ std::map direct_test_model_params = { {"transfer_price", "0.0"}, {"vm_price", "0.0134"}, {"requests_per_hour", "1000"}, - {"include_infrastructure_costs", "true"} + {"include_infrastructure_costs", "true"}, + {"resolve_host_dns", "false"} }; std::map< std::string, std::pair< std::map, std::map > > backends = { @@ -83,10 +86,10 @@ BOOST_AUTO_TEST_CASE(sending_receiving) { ch->set_num_peers(2); ch->set_comm_name(comm_name); if (tid == 0) { - channel_data buf {reinterpret_cast(&val), sizeof(val)}; + auto buf = std::make_shared(reinterpret_cast(&val), sizeof(val), noop_deleter); ch->send(buf, 1); } else if (tid == 1) { - channel_data recv_buf {reinterpret_cast(&recv), sizeof(recv)}; + auto recv_buf = std::make_shared(reinterpret_cast(&recv), sizeof(recv), noop_deleter); ch->recv(recv_buf, 0); } ch->finalize(); @@ -112,11 +115,15 @@ BOOST_AUTO_TEST_CASE(sending_receiving_mult_times) { ch->set_num_peers(2); ch->set_comm_name(comm_name); if (tid == 0) { - ch->send({reinterpret_cast(&val1), sizeof(val1)}, 1); - ch->send({reinterpret_cast(&val2), sizeof(val2)}, 1); + auto buf1 = std::make_shared(reinterpret_cast(&val1), sizeof(val1), noop_deleter); + auto buf2 = std::make_shared(reinterpret_cast(&val2), sizeof(val2), noop_deleter); + ch->send(buf1, 1); + ch->send(buf2, 1); } else if (tid == 1) { - ch->recv({reinterpret_cast(&recv1), sizeof(recv1)}, 0); - ch->recv({reinterpret_cast(&recv2), sizeof(recv2)}, 0); + auto recv_buf1 = std::make_shared(reinterpret_cast(&recv1), sizeof(recv1), noop_deleter); + auto recv_buf2 = std::make_shared(reinterpret_cast(&recv2), sizeof(recv2), noop_deleter); + ch->recv(recv_buf1, 0); + ch->recv(recv_buf2, 0); } ch->finalize(); } @@ -131,7 +138,7 @@ BOOST_AUTO_TEST_CASE(bcast) { auto channel_name = backend_data.first; auto test_params = backend_data.second.first; auto model_params = backend_data.second.second; - + FMI::Utils::peer_num root = 14; constexpr int num_peers = 32; int* vals = static_cast(mmap(nullptr, num_peers * sizeof(int), PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0)); @@ -148,7 +155,8 @@ BOOST_AUTO_TEST_CASE(bcast) { ch->set_peer_id(peer_id); ch->set_num_peers(num_peers); ch->set_comm_name(comm_name); - ch->bcast({reinterpret_cast(&vals[peer_id]), sizeof(vals[peer_id])}, root); + auto buf = std::make_shared(reinterpret_cast(&vals[peer_id]), sizeof(vals[peer_id]), noop_deleter); + ch->bcast(buf, root); ch->finalize(); if (peer_id == 0) { int status = 0; @@ -168,7 +176,7 @@ BOOST_AUTO_TEST_CASE(barrier_unsucc) { auto channel_name = backend_data.first; auto test_params = backend_data.second.first; auto model_params = backend_data.second.second; - + constexpr int num_peers = 4; bool* caught = static_cast(mmap(nullptr, num_peers * sizeof(bool), PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0)); int peer_id = 0; @@ -212,7 +220,7 @@ BOOST_AUTO_TEST_CASE(barrier_succ) { auto channel_name = backend_data.first; auto test_params = backend_data.second.first; auto model_params = backend_data.second.second; - + constexpr int num_peers = 2; int peer_id = 0; for (int i = 1; i < num_peers; i ++) { @@ -247,7 +255,7 @@ BOOST_AUTO_TEST_CASE(gather_one) { auto channel_name = backend_data.first; auto test_params = backend_data.second.first; auto model_params = backend_data.second.second; - + constexpr int num_peers = 2; std::vector vals {1,2,3,4}; FMI::Utils::peer_num root = 1; @@ -266,10 +274,13 @@ BOOST_AUTO_TEST_CASE(gather_one) { ch->set_num_peers(num_peers); ch->set_comm_name(comm_name); if (peer_id == root) { - ch->gather({reinterpret_cast(vals.data() + 2 * peer_id), sizeof(vals[0]) * 2}, - {reinterpret_cast(rcv_vals), sizeof(int) * num_peers * 2}, root); + auto sendbuf = std::make_shared(reinterpret_cast(vals.data() + 2 * peer_id), sizeof(vals[0]) * 2, noop_deleter); + auto recvbuf = std::make_shared(reinterpret_cast(rcv_vals), sizeof(int) * num_peers * 2, noop_deleter); + ch->gather(sendbuf, recvbuf, root); } else { - ch->gather({reinterpret_cast(vals.data() + 2 * peer_id), sizeof(vals[0]) * 2}, {}, root); + auto sendbuf = std::make_shared(reinterpret_cast(vals.data() + 2 * peer_id), sizeof(vals[0]) * 2, noop_deleter); + auto recvbuf = std::make_shared(); + ch->gather(sendbuf, recvbuf, root); } ch->finalize(); if (peer_id == 0) { @@ -289,10 +300,10 @@ BOOST_AUTO_TEST_CASE(gather_multiple) { auto channel_name = backend_data.first; auto test_params = backend_data.second.first; auto model_params = backend_data.second.second; - + constexpr int num_peers = 14; std::vector vals(2 * num_peers); - for (int i = 0; i < vals.size(); i++) { + for (size_t i = 0; i < vals.size(); i++) { vals[i] = i + 1; } FMI::Utils::peer_num root = 0; @@ -311,10 +322,13 @@ BOOST_AUTO_TEST_CASE(gather_multiple) { ch->set_num_peers(num_peers); ch->set_comm_name(comm_name); if (peer_id == root) { - ch->gather({reinterpret_cast(vals.data() + 2 * peer_id), sizeof(vals[0]) * 2}, - {reinterpret_cast(rcv_vals), sizeof(int) * num_peers * 2}, root); + auto sendbuf = std::make_shared(reinterpret_cast(vals.data() + 2 * peer_id), sizeof(vals[0]) * 2, noop_deleter); + auto recvbuf = std::make_shared(reinterpret_cast(rcv_vals), sizeof(int) * num_peers * 2, noop_deleter); + ch->gather(sendbuf, recvbuf, root); } else { - ch->gather({reinterpret_cast(vals.data() + 2 * peer_id), sizeof(vals[0]) * 2}, {}, root); + auto sendbuf = std::make_shared(reinterpret_cast(vals.data() + 2 * peer_id), sizeof(vals[0]) * 2, noop_deleter); + auto recvbuf = std::make_shared(); + ch->gather(sendbuf, recvbuf, root); } ch->finalize(); if (peer_id == 0) { @@ -334,7 +348,7 @@ BOOST_AUTO_TEST_CASE(scatter_one) { auto channel_name = backend_data.first; auto test_params = backend_data.second.first; auto model_params = backend_data.second.second; - + constexpr int num_peers = 2; std::vector root_vals {1,2,3,4}; FMI::Utils::peer_num root = 0; @@ -353,10 +367,13 @@ BOOST_AUTO_TEST_CASE(scatter_one) { ch->set_num_peers(num_peers); ch->set_comm_name(comm_name); if (peer_id == root) { - ch->scatter({reinterpret_cast(root_vals.data()), sizeof(root_vals[0]) * root_vals.size()}, - {reinterpret_cast(rcv_vals + peer_id * 2), sizeof(int) * 2}, root); + auto sendbuf = std::make_shared(reinterpret_cast(root_vals.data()), sizeof(root_vals[0]) * root_vals.size(), noop_deleter); + auto recvbuf = std::make_shared(reinterpret_cast(rcv_vals + peer_id * 2), sizeof(int) * 2, noop_deleter); + ch->scatter(sendbuf, recvbuf, root); } else { - ch->scatter({}, {reinterpret_cast(rcv_vals + peer_id * 2), sizeof(int) * 2}, root); + auto sendbuf = std::make_shared(); + auto recvbuf = std::make_shared(reinterpret_cast(rcv_vals + peer_id * 2), sizeof(int) * 2, noop_deleter); + ch->scatter(sendbuf, recvbuf, root); } ch->finalize(); if (peer_id == 0) { @@ -376,10 +393,10 @@ BOOST_AUTO_TEST_CASE(scatter_multiple) { auto channel_name = backend_data.first; auto test_params = backend_data.second.first; auto model_params = backend_data.second.second; - + constexpr int num_peers = 14; std::vector root_vals(2 * num_peers); - for (int i = 0; i < root_vals.size(); i++) { + for (size_t i = 0; i < root_vals.size(); i++) { root_vals[i] = i + 1; } FMI::Utils::peer_num root = 3; @@ -398,10 +415,13 @@ BOOST_AUTO_TEST_CASE(scatter_multiple) { ch->set_num_peers(num_peers); ch->set_comm_name(comm_name); if (peer_id == root) { - ch->scatter({reinterpret_cast(root_vals.data()), sizeof(root_vals[0]) * root_vals.size()}, - {reinterpret_cast(rcv_vals + peer_id * 2), sizeof(int) * 2}, root); + auto sendbuf = std::make_shared(reinterpret_cast(root_vals.data()), sizeof(root_vals[0]) * root_vals.size(), noop_deleter); + auto recvbuf = std::make_shared(reinterpret_cast(rcv_vals + peer_id * 2), sizeof(int) * 2, noop_deleter); + ch->scatter(sendbuf, recvbuf, root); } else { - ch->scatter({}, {reinterpret_cast(rcv_vals + peer_id * 2), sizeof(int) * 2}, root); + auto sendbuf = std::make_shared(); + auto recvbuf = std::make_shared(reinterpret_cast(rcv_vals + peer_id * 2), sizeof(int) * 2, noop_deleter); + ch->scatter(sendbuf, recvbuf, root); } ch->finalize(); if (peer_id == 0) { @@ -421,7 +441,7 @@ BOOST_AUTO_TEST_CASE(reduce_multiple) { auto channel_name = backend_data.first; auto test_params = backend_data.second.first; auto model_params = backend_data.second.second; - + FMI::Utils::peer_num root = 5; constexpr int num_peers = 13; int* res = static_cast(mmap(nullptr, sizeof(int), PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0)); @@ -443,9 +463,13 @@ BOOST_AUTO_TEST_CASE(reduce_multiple) { ch->set_comm_name(comm_name); int val = peer_id + 1; if (peer_id == root) { - ch->reduce({reinterpret_cast(&val), sizeof(int)}, {reinterpret_cast(res), sizeof(int)}, root, {f, true, true}); + auto sendbuf = std::make_shared(reinterpret_cast(&val), sizeof(int), noop_deleter); + auto recvbuf = std::make_shared(reinterpret_cast(res), sizeof(int), noop_deleter); + ch->reduce(sendbuf, recvbuf, root, {f, true, true}); } else { - ch->reduce({reinterpret_cast(&val), sizeof(int)}, {}, root, {f, true, true}); + auto sendbuf = std::make_shared(reinterpret_cast(&val), sizeof(int), noop_deleter); + auto recvbuf = std::make_shared(); + ch->reduce(sendbuf, recvbuf, root, {f, true, true}); } ch->finalize(); @@ -469,7 +493,7 @@ BOOST_AUTO_TEST_CASE(reduce_multiple_ltr) { auto channel_name = backend_data.first; auto test_params = backend_data.second.first; auto model_params = backend_data.second.second; - + FMI::Utils::peer_num root = 0; constexpr int num_peers = 8; int* res = static_cast(mmap(nullptr, sizeof(int), PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0)); @@ -491,9 +515,13 @@ BOOST_AUTO_TEST_CASE(reduce_multiple_ltr) { ch->set_comm_name(comm_name); int val = peer_id + 1; if (peer_id == root) { - ch->reduce({reinterpret_cast(&val), sizeof(int)}, {reinterpret_cast(res), sizeof(int)}, root, {f, false, false}); + auto sendbuf = std::make_shared(reinterpret_cast(&val), sizeof(int), noop_deleter); + auto recvbuf = std::make_shared(reinterpret_cast(res), sizeof(int), noop_deleter); + ch->reduce(sendbuf, recvbuf, root, {f, false, false}); } else { - ch->reduce({reinterpret_cast(&val), sizeof(int)}, {}, root, {f, false, false}); + auto sendbuf = std::make_shared(reinterpret_cast(&val), sizeof(int), noop_deleter); + auto recvbuf = std::make_shared(); + ch->reduce(sendbuf, recvbuf, root, {f, false, false}); } ch->finalize(); @@ -517,7 +545,7 @@ BOOST_AUTO_TEST_CASE(allreduce_multiple) { auto channel_name = backend_data.first; auto test_params = backend_data.second.first; auto model_params = backend_data.second.second; - + constexpr int num_peers = 8; int* res = static_cast(mmap(nullptr, num_peers * sizeof(int), PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0)); int peer_id = 0; @@ -537,7 +565,9 @@ BOOST_AUTO_TEST_CASE(allreduce_multiple) { ch->set_num_peers(num_peers); ch->set_comm_name(comm_name); int val = peer_id + 1; - ch->allreduce({reinterpret_cast(&val), sizeof(int)}, {reinterpret_cast(res + peer_id), sizeof(int)}, {f, true, true}); + auto sendbuf = std::make_shared(reinterpret_cast(&val), sizeof(int), noop_deleter); + auto recvbuf = std::make_shared(reinterpret_cast(res + peer_id), sizeof(int), noop_deleter); + ch->allreduce(sendbuf, recvbuf, {f, true, true}); ch->finalize(); if (peer_id == 0) { @@ -562,7 +592,7 @@ BOOST_AUTO_TEST_CASE(allreduce_multiple_ltr) { auto channel_name = backend_data.first; auto test_params = backend_data.second.first; auto model_params = backend_data.second.second; - + FMI::Utils::peer_num root = 0; constexpr int num_peers = 8; int* res = static_cast(mmap(nullptr, num_peers * sizeof(int), PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0)); @@ -583,7 +613,9 @@ BOOST_AUTO_TEST_CASE(allreduce_multiple_ltr) { ch->set_num_peers(num_peers); ch->set_comm_name(comm_name); int val = peer_id + 1; - ch->allreduce({reinterpret_cast(&val), sizeof(int)}, {reinterpret_cast(res + peer_id), sizeof(int)}, {f, false, false}); + auto sendbuf = std::make_shared(reinterpret_cast(&val), sizeof(int), noop_deleter); + auto recvbuf = std::make_shared(reinterpret_cast(res + peer_id), sizeof(int), noop_deleter); + ch->allreduce(sendbuf, recvbuf, {f, false, false}); ch->finalize(); if (peer_id == 0) { @@ -608,7 +640,7 @@ BOOST_AUTO_TEST_CASE(scan) { auto channel_name = backend_data.first; auto test_params = backend_data.second.first; auto model_params = backend_data.second.second; - + constexpr int num_peers = 32; int* res = static_cast(mmap(nullptr, sizeof(int) * num_peers, PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0)); int peer_id = 0; @@ -628,7 +660,9 @@ BOOST_AUTO_TEST_CASE(scan) { ch->set_num_peers(num_peers); ch->set_comm_name(comm_name); int val = peer_id + 1; - ch->scan({reinterpret_cast(&val), sizeof(int)}, {reinterpret_cast(res + peer_id), sizeof(int)}, {f, true, true}); + auto sendbuf = std::make_shared(reinterpret_cast(&val), sizeof(int), noop_deleter); + auto recvbuf = std::make_shared(reinterpret_cast(res + peer_id), sizeof(int), noop_deleter); + ch->scan(sendbuf, recvbuf, {f, true, true}); ch->finalize(); if (peer_id == 0) { int status = 0; @@ -671,7 +705,9 @@ BOOST_AUTO_TEST_CASE(scan_ltr) { ch->set_num_peers(num_peers); ch->set_comm_name(comm_name); int val = peer_id; - ch->scan({reinterpret_cast(&val), sizeof(int)}, {reinterpret_cast(res + peer_id), sizeof(int)}, {f, false, false}); + auto sendbuf = std::make_shared(reinterpret_cast(&val), sizeof(int), noop_deleter); + auto recvbuf = std::make_shared(reinterpret_cast(res + peer_id), sizeof(int), noop_deleter); + ch->scan(sendbuf, recvbuf, {f, false, false}); ch->finalize(); if (peer_id == 0) { int status = 0; @@ -689,4 +725,321 @@ BOOST_AUTO_TEST_CASE(scan_ltr) { } } +// Variable-length collective tests +BOOST_AUTO_TEST_CASE(gatherv_basic) { + for (auto const & backend_data : backends) { + auto channel_name = backend_data.first; + auto test_params = backend_data.second.first; + auto model_params = backend_data.second.second; + + constexpr int num_peers = 4; + FMI::Utils::peer_num root = 0; + + // Each peer sends different amounts: peer 0 sends 1 int, peer 1 sends 2 ints, etc. + // Note: recvcounts and displs are in BYTES (not elements) for channel_data + constexpr int int_size = sizeof(int); + std::vector recvcounts = {1 * int_size, 2 * int_size, 3 * int_size, 4 * int_size}; + std::vector displs = {0, 1 * int_size, 3 * int_size, 6 * int_size}; // Cumulative byte displacements + int total_size = 10; // 1 + 2 + 3 + 4 elements + + int* rcv_vals = static_cast(mmap(nullptr, total_size * sizeof(int), PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0)); + int peer_id = 0; + for (int i = 1; i < num_peers; i++) { + int pid = fork(); + if (pid == 0) { + peer_id = i; + break; + } + } + + auto ch = FMI::Comm::Channel::get_channel(channel_name, test_params, model_params); + ch->set_peer_id(peer_id); + ch->set_num_peers(num_peers); + ch->set_comm_name(comm_name + "_gatherv"); + + // Each peer fills its send buffer with its peer_id + int send_count = peer_id + 1; + std::vector send_vals(send_count, peer_id + 1); + + auto sendbuf = std::make_shared(reinterpret_cast(send_vals.data()), send_count * sizeof(int), noop_deleter); + if (peer_id == root) { + auto recvbuf = std::make_shared(reinterpret_cast(rcv_vals), total_size * sizeof(int), noop_deleter); + ch->gatherv(sendbuf, recvbuf, root, recvcounts, displs); + } else { + auto recvbuf = std::make_shared(); + ch->gatherv(sendbuf, recvbuf, root, recvcounts, displs); + } + + ch->finalize(); + if (peer_id == 0) { + int status = 0; + while (wait(&status) > 0); + // Verify: peer i contributed (i+1) values of value (i+1) + BOOST_CHECK_EQUAL(rcv_vals[0], 1); // peer 0: 1 value of 1 + BOOST_CHECK_EQUAL(rcv_vals[1], 2); // peer 1: 2 values of 2 + BOOST_CHECK_EQUAL(rcv_vals[2], 2); + BOOST_CHECK_EQUAL(rcv_vals[3], 3); // peer 2: 3 values of 3 + BOOST_CHECK_EQUAL(rcv_vals[4], 3); + BOOST_CHECK_EQUAL(rcv_vals[5], 3); + } else { + exit(0); + } + } +} + +BOOST_AUTO_TEST_CASE(allgather_basic) { + for (auto const & backend_data : backends) { + auto channel_name = backend_data.first; + auto test_params = backend_data.second.first; + auto model_params = backend_data.second.second; + + constexpr int num_peers = 4; + FMI::Utils::peer_num root = 0; + + int* rcv_vals = static_cast(mmap(nullptr, num_peers * sizeof(int), PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0)); + bool* success = static_cast(mmap(nullptr, num_peers * sizeof(bool), PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0)); + + int peer_id = 0; + for (int i = 1; i < num_peers; i++) { + int pid = fork(); + if (pid == 0) { + peer_id = i; + break; + } + } + + auto ch = FMI::Comm::Channel::get_channel(channel_name, test_params, model_params); + ch->set_peer_id(peer_id); + ch->set_num_peers(num_peers); + ch->set_comm_name(comm_name + "_allgather"); + + int send_val = peer_id + 1; + std::vector recv_vals(num_peers); + + auto sendbuf = std::make_shared(reinterpret_cast(&send_val), sizeof(int), noop_deleter); + auto recvbuf = std::make_shared(reinterpret_cast(recv_vals.data()), num_peers * sizeof(int), noop_deleter); + + ch->allgather(sendbuf, recvbuf, root); + + // Verify all peers got the complete data + success[peer_id] = true; + for (int i = 0; i < num_peers; i++) { + if (recv_vals[i] != i + 1) { + success[peer_id] = false; + } + } + + ch->finalize(); + if (peer_id == 0) { + int status = 0; + while (wait(&status) > 0); + for (int i = 0; i < num_peers; i++) { + BOOST_CHECK(success[i]); + } + } else { + exit(0); + } + } +} + +BOOST_AUTO_TEST_CASE(allgatherv_basic) { + for (auto const & backend_data : backends) { + auto channel_name = backend_data.first; + auto test_params = backend_data.second.first; + auto model_params = backend_data.second.second; + + constexpr int num_peers = 3; + FMI::Utils::peer_num root = 0; + + // Each peer sends different amounts + // Note: recvcounts and displs are in BYTES for channel_data + constexpr int int_size = sizeof(int); + std::vector recvcounts = {1 * int_size, 2 * int_size, 3 * int_size}; + std::vector displs = {0, 1 * int_size, 3 * int_size}; + int total_size = 6; // 1 + 2 + 3 elements + + bool* success = static_cast(mmap(nullptr, num_peers * sizeof(bool), PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0)); + + int peer_id = 0; + for (int i = 1; i < num_peers; i++) { + int pid = fork(); + if (pid == 0) { + peer_id = i; + break; + } + } + + auto ch = FMI::Comm::Channel::get_channel(channel_name, test_params, model_params); + ch->set_peer_id(peer_id); + ch->set_num_peers(num_peers); + ch->set_comm_name(comm_name + "_allgatherv"); + + int send_count = peer_id + 1; + std::vector send_vals(send_count, peer_id + 1); + std::vector recv_vals(total_size); + + auto sendbuf = std::make_shared(reinterpret_cast(send_vals.data()), send_count * sizeof(int), noop_deleter); + auto recvbuf = std::make_shared(reinterpret_cast(recv_vals.data()), total_size * sizeof(int), noop_deleter); + + ch->allgatherv(sendbuf, recvbuf, root, recvcounts, displs); + + // Verify all peers got the complete data + success[peer_id] = (recv_vals[0] == 1 && + recv_vals[1] == 2 && recv_vals[2] == 2 && + recv_vals[3] == 3 && recv_vals[4] == 3 && recv_vals[5] == 3); + + ch->finalize(); + if (peer_id == 0) { + int status = 0; + while (wait(&status) > 0); + for (int i = 0; i < num_peers; i++) { + BOOST_CHECK(success[i]); + } + } else { + exit(0); + } + } +} + +// Non-blocking tests +BOOST_AUTO_TEST_CASE(nonblocking_send_recv) { + // Test true non-blocking send/recv with Direct channel + auto test_params = direct_test_params; + auto model_params = direct_test_model_params; + + int val = 42; + bool* send_complete = static_cast(mmap(nullptr, sizeof(bool), PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0)); + bool* recv_complete = static_cast(mmap(nullptr, sizeof(bool), PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0)); + int* recv_result = static_cast(mmap(nullptr, sizeof(int), PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0)); + *send_complete = false; + *recv_complete = false; + *recv_result = 0; + + int peer_id = 0; + int pid = fork(); + if (pid == 0) { + peer_id = 1; + } + + auto ch = FMI::Comm::Channel::get_channel("Direct", test_params, model_params); + ch->set_peer_id(peer_id); + ch->set_num_peers(2); + ch->set_comm_name(comm_name + "_nbx"); + ch->init(); + + FMI::Utils::fmiContext ctx{0}; + + if (peer_id == 0) { + auto buf = std::make_shared(reinterpret_cast(&val), sizeof(val), noop_deleter); + ch->send(buf, 1, &ctx, FMI::Utils::NONBLOCKING, + [send_complete](FMI::Utils::NbxStatus status, const std::string& msg, FMI::Utils::fmiContext* ctx) { + if (status == FMI::Utils::SUCCESS) { + *send_complete = true; + } + }); + + // Poll for completion + int timeout_counter = 0; + while (!*send_complete && timeout_counter < 5000) { + ch->channel_event_progress(FMI::Utils::send); + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + timeout_counter++; + } + } else { + auto buf = std::make_shared(reinterpret_cast(recv_result), sizeof(int), noop_deleter); + ch->recv(buf, 0, &ctx, FMI::Utils::NONBLOCKING, + [recv_complete](FMI::Utils::NbxStatus status, const std::string& msg, FMI::Utils::fmiContext* ctx) { + if (status == FMI::Utils::SUCCESS) { + *recv_complete = true; + } + }); + + // Poll for completion + int timeout_counter = 0; + while (!*recv_complete && timeout_counter < 5000) { + ch->channel_event_progress(FMI::Utils::recv); + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + timeout_counter++; + } + } + + ch->finalize(); + + if (peer_id == 0) { + int status = 0; + while (wait(&status) > 0); + BOOST_CHECK(*send_complete); + BOOST_CHECK(*recv_complete); + BOOST_CHECK_EQUAL(val, *recv_result); + } else { + exit(0); + } +} + +BOOST_AUTO_TEST_CASE(nonblocking_progress_empty) { + // Test that progress returns correct status when no operations pending + auto ch = FMI::Comm::Channel::get_channel("Direct", direct_test_params, direct_test_model_params); + ch->set_peer_id(0); + ch->set_num_peers(1); + ch->set_comm_name(comm_name + "_progress"); + ch->init(); + + // With no pending operations, should return EMPTY or NOOP + auto status = ch->channel_event_progress(FMI::Utils::send); + BOOST_CHECK(status == FMI::Utils::EMPTY || status == FMI::Utils::NOOP); + + ch->finalize(); +} + +BOOST_AUTO_TEST_CASE(blocking_mode_with_callback) { + // Test that blocking mode still works with callback API + auto test_params = direct_test_params; + auto model_params = direct_test_model_params; + + int val = 123; + int recv_val = 0; + bool* completed = static_cast(mmap(nullptr, 2 * sizeof(bool), PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0)); + completed[0] = false; + completed[1] = false; + + int peer_id = 0; + int pid = fork(); + if (pid == 0) { + peer_id = 1; + } + + auto ch = FMI::Comm::Channel::get_channel("Direct", test_params, model_params); + ch->set_peer_id(peer_id); + ch->set_num_peers(2); + ch->set_comm_name(comm_name + "_blocking_cb"); + + FMI::Utils::fmiContext ctx{0}; + + if (peer_id == 0) { + auto buf = std::make_shared(reinterpret_cast(&val), sizeof(val), noop_deleter); + // Use BLOCKING mode with callback - should complete immediately + ch->send(buf, 1, &ctx, FMI::Utils::BLOCKING, + [&completed](FMI::Utils::NbxStatus status, const std::string& msg, FMI::Utils::fmiContext* ctx) { + completed[0] = true; + }); + } else { + auto buf = std::make_shared(reinterpret_cast(&recv_val), sizeof(recv_val), noop_deleter); + ch->recv(buf, 0, &ctx, FMI::Utils::BLOCKING, + [&completed](FMI::Utils::NbxStatus status, const std::string& msg, FMI::Utils::fmiContext* ctx) { + completed[1] = true; + }); + } + + ch->finalize(); + + if (peer_id == 0) { + int status = 0; + while (wait(&status) > 0); + BOOST_CHECK(completed[0]); + BOOST_CHECK(completed[1]); + } else { + exit(0); + } +} + BOOST_AUTO_TEST_SUITE_END(); \ No newline at end of file