diff --git a/.github/workflows/quick-test.yml b/.github/workflows/quick-test.yml index 9269a7ca74..415ad157d3 100644 --- a/.github/workflows/quick-test.yml +++ b/.github/workflows/quick-test.yml @@ -42,7 +42,6 @@ jobs: libunwind: libunwind-12-dev steps: - uses: actions/checkout@v2 - - uses: bazelbuild/setup-bazelisk@v2 - name: install dependencies # I observed 404s for some packages and added an `apt-get update`. Then, I observed package # conflicts between LLVM 14 and 15, and added the line which removes LLVM 14, the default on diff --git a/c++/.bazelrc b/c++/.bazelrc index 39ec5624d4..ba80b81903 100644 --- a/c++/.bazelrc +++ b/c++/.bazelrc @@ -1,3 +1,4 @@ +common --noenable_bzlmod common --enable_platform_specific_config build:unix --cxxopt='-std=c++20' --host_cxxopt='-std=c++20' --force_pic --verbose_failures @@ -6,14 +7,15 @@ build:unix --cxxopt='-Wextra' --host_cxxopt='-Wextra' build:unix --cxxopt='-Wno-strict-aliasing' --host_cxxopt='-Wno-strict-aliasing' build:unix --cxxopt='-Wno-sign-compare' --host_cxxopt='-Wno-sign-compare' build:unix --cxxopt='-Wno-unused-parameter' --host_cxxopt='-Wno-unused-parameter' +build:unix --cxxopt='-Wno-deprecated-this-capture' --host_cxxopt='-Wno-deprecated-this-capture' # I needed these magic spells to build locally with clang-11 and clang-12 on Ubuntu. clang-13 and up # work out-of-the-box. # TODO(2.0): Remove this when we support g++ again. -build:unix --action_env=CXXFLAGS=-stdlib=libc++ -build:unix --action_env=LDFLAGS=-stdlib=libc++ -build:unix --action_env=BAZEL_CXXOPTS=-stdlib=libc++ -build:unix --action_env=BAZEL_LINKOPTS=-lc++:-lm +build:linux --action_env=CXXFLAGS=-stdlib=libc++ +build:linux --action_env=LDFLAGS=-stdlib=libc++ +build:linux --action_env=BAZEL_CXXOPTS=-stdlib=libc++ +build:linux --action_env=BAZEL_LINKOPTS=-lc++:-lm build:linux --config=unix build:macos --config=unix diff --git a/c++/.bazelversion b/c++/.bazelversion index 5e3254243a..66ce77b7ea 100644 --- a/c++/.bazelversion +++ b/c++/.bazelversion @@ -1 +1 @@ -6.1.2 +7.0.0 diff --git a/c++/Makefile.am b/c++/Makefile.am index 1567491d4d..101365bd46 100644 --- a/c++/Makefile.am +++ b/c++/Makefile.am @@ -151,6 +151,7 @@ includekj_HEADERS = \ src/kj/vector.h \ src/kj/string.h \ src/kj/string-tree.h \ + src/kj/glob-filter.h \ src/kj/hash.h \ src/kj/table.h \ src/kj/map.h \ @@ -165,7 +166,6 @@ includekj_HEADERS = \ src/kj/mutex.h \ src/kj/source-location.h \ src/kj/thread.h \ - src/kj/threadlocal.h \ src/kj/filesystem.h \ src/kj/async-prelude.h \ src/kj/async.h \ @@ -229,8 +229,7 @@ includecapnp_HEADERS = \ src/capnp/rpc-twoparty.h \ src/capnp/rpc.capnp.h \ src/capnp/rpc-twoparty.capnp.h \ - src/capnp/persistent.capnp.h \ - src/capnp/ez-rpc.h + src/capnp/persistent.capnp.h includecapnpcompat_HEADERS = \ src/capnp/compat/json.h \ @@ -276,6 +275,7 @@ libkj_la_SOURCES= \ src/kj/string.c++ \ src/kj/string-tree.c++ \ src/kj/source-location.c++ \ + src/kj/glob-filter.c++ \ src/kj/hash.c++ \ src/kj/table.c++ \ src/kj/encoding.c++ \ @@ -374,8 +374,7 @@ libcapnp_rpc_la_SOURCES= \ src/capnp/rpc.capnp.c++ \ src/capnp/rpc-twoparty.c++ \ src/capnp/rpc-twoparty.capnp.c++ \ - src/capnp/persistent.capnp.c++ \ - src/capnp/ez-rpc.c++ + src/capnp/persistent.capnp.c++ libcapnp_json_la_LIBADD = libcapnp.la libkj.la $(PTHREAD_LIBS) libcapnp_json_la_LDFLAGS = -release $(SO_VERSION) -no-undefined @@ -542,7 +541,6 @@ heavy_tests = \ src/capnp/serialize-text-test.c++ \ src/capnp/rpc-test.c++ \ src/capnp/rpc-twoparty-test.c++ \ - src/capnp/ez-rpc-test.c++ \ src/capnp/compat/json-test.c++ \ src/capnp/compat/websocket-rpc-test.c++ \ src/capnp/compiler/lexer-test.c++ \ @@ -587,10 +585,10 @@ capnp_test_SOURCES = \ src/kj/io-test.c++ \ src/kj/mutex-test.c++ \ src/kj/time-test.c++ \ - src/kj/threadlocal-test.c++ \ src/kj/filesystem-test.c++ \ src/kj/filesystem-disk-test.c++ \ src/kj/test-test.c++ \ + src/kj/glob-filter-test.c++ \ src/capnp/common-test.c++ \ src/capnp/blob-test.c++ \ src/capnp/endian-test.c++ \ diff --git a/c++/WORKSPACE b/c++/WORKSPACE index 3e5bfe595e..d58ffdbd14 100644 --- a/c++/WORKSPACE +++ b/c++/WORKSPACE @@ -1,7 +1,6 @@ workspace(name = "capnp-cpp") load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") -load("//:build/load_br.bzl", "load_brotli") http_archive( name = "bazel_skylib", @@ -16,6 +15,14 @@ load("@bazel_skylib//:workspace.bzl", "bazel_skylib_workspace") bazel_skylib_workspace() +http_archive( + name = "brotli", + sha256 = "e720a6ca29428b803f4ad165371771f5398faba397edf6778837a18599ea13ff", + strip_prefix = "brotli-1.1.0", + type = "tgz", + urls = ["https://github.com/google/brotli/archive/refs/tags/v1.1.0.tar.gz"], +) + http_archive( name = "ssl", sha256 = "873ec711658f65192e9c58554ce058d1cfa4e57e13ab5366ee16f76d1c757efc", @@ -38,6 +45,7 @@ cc_library( "-Dverbose=-1", ] + select({ "@platforms//os:macos": [ "-Wno-implicit-function-declaration" ], + "@platforms//os:linux": [ "-Wno-implicit-function-declaration" ], "//conditions:default": [], }), visibility = ["//visibility:public"], @@ -47,9 +55,7 @@ cc_library( http_archive( name = "zlib", build_file_content = _zlib_build, - sha256 = "8a9ba2898e1d0d774eca6ba5b4627a11e5588ba85c8851336eb38de4683050a7", - strip_prefix = "zlib-1.3", - urls = ["https://zlib.net/zlib-1.3.tar.xz"], + sha256 = "38ef96b8dfe510d42707d9c781877914792541133e1870841463bfa73f883e32", + strip_prefix = "zlib-1.3.1", + urls = ["https://zlib.net/zlib-1.3.1.tar.xz"], ) - -load_brotli() diff --git a/c++/build/configure.bzl b/c++/build/configure.bzl index a392fa45b2..4891d61cd5 100644 --- a/c++/build/configure.bzl +++ b/c++/build/configure.bzl @@ -1,4 +1,4 @@ -load("@bazel_skylib//rules:common_settings.bzl", "bool_flag", "int_flag") +load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") def kj_configure(): """Generates set of flag, settings for kj configuration. diff --git a/c++/build/load_br.bzl b/c++/build/load_br.bzl deleted file mode 100644 index fe6fdfedd6..0000000000 --- a/c++/build/load_br.bzl +++ /dev/null @@ -1,12 +0,0 @@ -load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") - -# Defined in a bzl file to allow dependents to pull in brotli via capnproto. Using latest brotli -# commit due to macOS compile issues with v1.0.9, switch to a release version later -def load_brotli(): - http_archive( - name = "brotli", - sha256 = "e33f397d86aaa7f3e786bdf01a7b5cff4101cfb20041c04b313b149d34332f64", - strip_prefix = "google-brotli-ed1995b", - type = "tgz", - urls = ["https://github.com/google/brotli/tarball/ed1995b6bda19244070ab5d331111f16f67c8054"], - ) diff --git a/c++/samples/calculator-client.c++ b/c++/samples/calculator-client.c++ index 5d8452921c..bdaa5260ec 100644 --- a/c++/samples/calculator-client.c++ +++ b/c++/samples/calculator-client.c++ @@ -20,7 +20,8 @@ // THE SOFTWARE. #include "calculator.capnp.h" -#include +#include +#include #include #include #include @@ -47,13 +48,26 @@ int main(int argc, const char* argv[]) { return 1; } - capnp::EzRpcClient client(argv[1]); - Calculator::Client calculator = client.getMain(); + // First we need to set up the KJ async event loop. This should happen one + // per thread that needs to perform RPC. + auto io = kj::setupAsyncIo(); // Keep an eye on `waitScope`. Whenever you see it used is a place where we // stop and wait for the server to respond. If a line of code does not use // `waitScope`, then it does not block! - auto& waitScope = client.getWaitScope(); + auto& waitScope = io.waitScope; + + // Using KJ APIs, let's parse our network address and connect to it. + kj::Network& network = io.provider->getNetwork(); + kj::Own addr = network.parseAddress(argv[1]).wait(waitScope); + kj::Own conn = addr->connect().wait(waitScope); + + // Now we can start the Cap'n Proto RPC system on this connection. + capnp::TwoPartyClient client(*conn); + + // The server exports a "bootstrap" capability implementing the + // `Calculator` interface. + Calculator::Client calculator = client.bootstrap().castAs(); { // Make a request that just evaluates the literal value 123. diff --git a/c++/samples/calculator-server.c++ b/c++/samples/calculator-server.c++ index c2593be3a9..9253aced53 100644 --- a/c++/samples/calculator-server.c++ +++ b/c++/samples/calculator-server.c++ @@ -20,8 +20,9 @@ // THE SOFTWARE. #include "calculator.capnp.h" +#include +#include #include -#include #include #include @@ -196,12 +197,17 @@ int main(int argc, const char* argv[]) { return 1; } - // Set up a server. - capnp::EzRpcServer server(kj::heap(), argv[1]); + // First we need to set up the KJ async event loop. This should happen one + // per thread that needs to perform RPC. + auto io = kj::setupAsyncIo(); + + // Using KJ APIs, let's parse our network address and listen on it. + kj::Network& network = io.provider->getNetwork(); + kj::Own addr = network.parseAddress(argv[1]).wait(io.waitScope); + kj::Own listener = addr->listen(); // Write the port number to stdout, in case it was chosen automatically. - auto& waitScope = server.getWaitScope(); - uint port = server.getPort().wait(waitScope); + uint port = listener->getPort(); if (port == 0) { // The address format "unix:/path/to/socket" opens a unix domain socket, // in which case the port will be zero. @@ -210,6 +216,9 @@ int main(int argc, const char* argv[]) { std::cout << "Listening on port " << port << "..." << std::endl; } + // Start the RPC server. + capnp::TwoPartyServer server(kj::heap()); + // Run forever, accepting connections and handling requests. - kj::NEVER_DONE.wait(waitScope); + server.listen(*listener).wait(io.waitScope); } diff --git a/c++/src/capnp/BUILD.bazel b/c++/src/capnp/BUILD.bazel index eea8349985..ec0833cef2 100644 --- a/c++/src/capnp/BUILD.bazel +++ b/c++/src/capnp/BUILD.bazel @@ -60,7 +60,6 @@ cc_library( srcs = [ "capability.c++", "dynamic-capability.c++", - "ez-rpc.c++", "membrane.c++", "persistent.capnp.c++", "reconnect.c++", @@ -71,7 +70,6 @@ cc_library( "serialize-async.c++", ], hdrs = [ - "ez-rpc.h", "persistent.capnp.h", "reconnect.h", "rpc.capnp.h", @@ -231,7 +229,6 @@ cc_library( "dynamic-test.c++", "encoding-test.c++", "endian-test.c++", - "ez-rpc-test.c++", "layout-test.c++", "membrane-test.c++", "message-test.c++", diff --git a/c++/src/capnp/CMakeLists.txt b/c++/src/capnp/CMakeLists.txt index 9980fde617..045cd98f8c 100644 --- a/c++/src/capnp/CMakeLists.txt +++ b/c++/src/capnp/CMakeLists.txt @@ -87,7 +87,6 @@ set(capnp-rpc_sources rpc-twoparty.c++ rpc-twoparty.capnp.c++ persistent.capnp.c++ - ez-rpc.c++ ) set(capnp-rpc_headers rpc-prelude.h @@ -96,7 +95,6 @@ set(capnp-rpc_headers rpc.capnp.h rpc-twoparty.capnp.h persistent.capnp.h - ez-rpc.h ) set(capnp-rpc_schemas rpc.capnp @@ -292,7 +290,6 @@ if(BUILD_TESTING) serialize-text-test.c++ rpc-test.c++ rpc-twoparty-test.c++ - ez-rpc-test.c++ compiler/lexer-test.c++ compiler/type-id-test.c++ test-util.c++ diff --git a/c++/src/capnp/arena.h b/c++/src/capnp/arena.h index aeaff8448d..7308912b80 100644 --- a/c++/src/capnp/arena.h +++ b/c++/src/capnp/arena.h @@ -330,7 +330,7 @@ class BuilderArena final: public Arena { SegmentBuilder* addExternalSegment(kj::ArrayPtr content); // Add a new segment to the arena which points to some existing memory region. The segment is - // assumed to be completley full; the arena will never allocate from it. In fact, the segment + // assumed to be completely full; the arena will never allocate from it. In fact, the segment // is considered read-only. Any attempt to get a Builder pointing into this segment will throw // an exception. Readers are allowed, however. // diff --git a/c++/src/capnp/capability.h b/c++/src/capnp/capability.h index b74c40c02b..7219845604 100644 --- a/c++/src/capnp/capability.h +++ b/c++/src/capnp/capability.h @@ -133,7 +133,7 @@ class Request: public Params::Builder { // to complete (and possibly other things, if that RPC itself returned a promise capability), // but when using `sendPipelineOnly()`, `whenResolved()` may complete immediately, or never, or // at an arbitrary time. Do not rely on it. - // - Normal path shortening may not work with these capabilities. For exmaple, if the caller + // - Normal path shortening may not work with these capabilities. For example, if the caller // forwards a pipelined capability back to the callee's vat, calls made by the callee to that // capability may continue to proxy through the caller. Conversely, if the callee ends up // returning a capability that points back to the caller's vat, calls on the pipelined diff --git a/c++/src/capnp/cc_capnp_library.bzl b/c++/src/capnp/cc_capnp_library.bzl index 9e4acd35b9..e6d9f5dada 100644 --- a/c++/src/capnp/cc_capnp_library.bzl +++ b/c++/src/capnp/cc_capnp_library.bzl @@ -87,6 +87,7 @@ def cc_capnp_library( data = [], deps = [], src_prefix = "", + tags = ["off-by-default"], visibility = None, target_compatible_with = None, **kwargs): @@ -122,6 +123,8 @@ def cc_capnp_library( srcs = srcs_cpp, hdrs = hdrs, deps = deps + ["@capnp-cpp//src/capnp:capnp_runtime"], + # Allows us to avoid building the library archive when using start_end_lib + tags = tags, visibility = visibility, target_compatible_with = target_compatible_with, **kwargs diff --git a/c++/src/capnp/compat/BUILD.bazel b/c++/src/capnp/compat/BUILD.bazel index 951c9ddc2c..528c919689 100644 --- a/c++/src/capnp/compat/BUILD.bazel +++ b/c++/src/capnp/compat/BUILD.bazel @@ -23,6 +23,32 @@ cc_library( ], ) +cc_library( + name = "json-rpc", + srcs = [ + "json-rpc.c++", + ], + hdrs = [ + "json-rpc.h", + ], + include_prefix = "capnp/compat", + visibility = ["//visibility:public"], + deps = [ + ":json-rpc_capnp", + "//src/kj/compat:kj-http", + ], +) + +cc_capnp_library( + name = "json-rpc_capnp", + srcs = [ + "json-rpc.capnp", + ], + include_prefix = "capnp/compat", + src_prefix = "src", + visibility = ["//visibility:public"], +) + cc_capnp_library( name = "json-test_capnp", srcs = [ @@ -93,13 +119,22 @@ cc_library( "websocket-rpc-test.c++", ]] +cc_test( + name = "json-rpc-test", + srcs = ["json-rpc-test.c++"], + deps = [ + ":json-rpc", + "//src/capnp:capnp-test", + ], +) + cc_test( name = "json-test", srcs = ["json-test.c++"], deps = [ - "//src/capnp:capnp-test", ":json", - ":json-test_capnp" + ":json-test_capnp", + "//src/capnp:capnp-test", ], ) @@ -107,13 +142,3 @@ cc_library( name = "http-over-capnp-test-as-header", hdrs = ["http-over-capnp-test.c++"], ) - -cc_test( - name = "http-over-capnp-old-test", - srcs = ["http-over-capnp-old-test.c++"], - deps = [ - ":http-over-capnp-test-as-header", - ":http-over-capnp", - "//src/capnp:capnp-test" - ], -) diff --git a/c++/src/capnp/compat/http-over-capnp-old-test.c++ b/c++/src/capnp/compat/http-over-capnp-old-test.c++ deleted file mode 100644 index 9a5aea9b13..0000000000 --- a/c++/src/capnp/compat/http-over-capnp-old-test.c++ +++ /dev/null @@ -1,2 +0,0 @@ -#define TEST_PEER_OPTIMIZATION_LEVEL HttpOverCapnpFactory::LEVEL_1 -#include "http-over-capnp-test.c++" diff --git a/c++/src/capnp/compat/http-over-capnp-test.c++ b/c++/src/capnp/compat/http-over-capnp-test.c++ index 5d1e000069..8886e7e5df 100644 --- a/c++/src/capnp/compat/http-over-capnp-test.c++ +++ b/c++/src/capnp/compat/http-over-capnp-test.c++ @@ -657,14 +657,13 @@ KJ_TEST("HttpService isn't destroyed while call outstanding") { KJ_EXPECT(!called); KJ_EXPECT(!destroyed); - auto req = service.startRequestRequest(); + auto req = service.requestRequest(); auto httpReq = req.initRequest(); httpReq.setMethod(capnp::HttpMethod::GET); httpReq.setUrl("/"); - auto serverContext = req.send().wait(waitScope).getContext(); + auto promise = req.send(); service = nullptr; - auto promise = serverContext.whenResolved(); KJ_EXPECT(!promise.poll(waitScope)); KJ_EXPECT(called); @@ -769,7 +768,7 @@ KJ_TEST("HTTP-over-Cap'n-Proto Connect with close") { ByteStreamFactory streamFactory; kj::HttpHeaderTable::Builder tableBuilder; - HttpOverCapnpFactory factory(streamFactory, tableBuilder); + HttpOverCapnpFactory factory(streamFactory, tableBuilder, TEST_PEER_OPTIMIZATION_LEVEL); kj::Own table = tableBuilder.build(); ConnectWriteCloseService service(*table); kj::HttpServer server(timer, *table, service); @@ -844,7 +843,7 @@ KJ_TEST("HTTP-over-Cap'n-Proto Connect Reject") { ByteStreamFactory streamFactory; kj::HttpHeaderTable::Builder tableBuilder; - HttpOverCapnpFactory factory(streamFactory, tableBuilder); + HttpOverCapnpFactory factory(streamFactory, tableBuilder, TEST_PEER_OPTIMIZATION_LEVEL); kj::Own table = tableBuilder.build(); ConnectRejectService service(*table); kj::HttpServer server(timer, *table, service); @@ -917,7 +916,7 @@ KJ_TEST("HTTP-over-Cap'n-Proto Connect with startTls") { ByteStreamFactory streamFactory; kj::HttpHeaderTable::Builder tableBuilder; - HttpOverCapnpFactory factory(streamFactory, tableBuilder); + HttpOverCapnpFactory factory(streamFactory, tableBuilder, TEST_PEER_OPTIMIZATION_LEVEL); kj::Own table = tableBuilder.build(); ConnectWriteRespService service(*table); kj::HttpServer server(timer, *table, service); diff --git a/c++/src/capnp/compat/http-over-capnp.c++ b/c++/src/capnp/compat/http-over-capnp.c++ index ce5d702f27..139e6eedd6 100644 --- a/c++/src/capnp/compat/http-over-capnp.c++ +++ b/c++/src/capnp/compat/http-over-capnp.c++ @@ -29,98 +29,91 @@ namespace capnp { using kj::uint; using kj::byte; -class HttpOverCapnpFactory::RequestState final - : public kj::Refcounted, public kj::TaskSet::ErrorHandler { -public: - RequestState() { - tasks.emplace(*this); - } +// ======================================================================================= - template - auto wrap(Func&& func) -> decltype(func()) { - if (tasks == kj::none) { - return KJ_EXCEPTION(DISCONNECTED, "client canceled HTTP request"); - } else { - return canceler.wrap(func()); +class HttpOverCapnpFactory::CapnpToKjWebSocketAdapter final: public capnp::WebSocket::Server { +public: + CapnpToKjWebSocketAdapter(kj::WebSocket& webSocket, + kj::Promise shorteningPromise, + kj::Own>> onEnd, + kj::Maybe&> selfRef) + : webSocket(webSocket), + shorteningPromise(kj::mv(shorteningPromise)), + onEnd(kj::mv(onEnd)), selfRef(selfRef) { + KJ_IF_SOME(s, selfRef) { + s = *this; } } + CapnpToKjWebSocketAdapter(kj::Own webSocket, + kj::Promise shorteningPromise, + kj::Own>> onEnd) + : webSocket(*webSocket), ownWebSocket(kj::mv(webSocket)), + shorteningPromise(kj::mv(shorteningPromise)), + onEnd(kj::mv(onEnd)) {} + // `onEnd` is resolved if and when the stream (in this direction) ends cleanly. + // + // `selfRef`, if given, will be initialized to point back to this object, and will be nulled + // out in the destructor. This is intended to allow the caller to arrange to call cancel() if + // the capability still exists when the underlying `webSocket` is about to go away. + // + // The second version of the constructor takes ownership of the underlying `webSocket`. In + // this case, a `selfRef` isn't needed since there's no need to call `cancel()`. - void cancel() { - if (tasks != kj::none) { - if (!canceler.isEmpty()) { - canceler.cancel(KJ_EXCEPTION(DISCONNECTED, "request canceled")); - } - tasks = kj::none; - webSocket = kj::none; - } - } + ~CapnpToKjWebSocketAdapter() noexcept(false) { + // The peer dropped the capability, which means the WebSocket stream has ended. We want to + // tanslate this to a `disconnect()` call on the `kj::WebSocket`, if it is still around. - void assertNotCanceled() { - if (tasks == kj::none) { - kj::throwFatalException(KJ_EXCEPTION(DISCONNECTED, "client canceled HTTP request")); + // Null out our self-ref, if any. + KJ_IF_SOME(s, selfRef) { + s = nullptr; } - } - void addTask(kj::Promise task) { - KJ_IF_SOME(t, tasks) { - t.add(kj::mv(task)); + // Arrange to call disconnect() and then notify the observer that the stream has finished. + // If an error has occurred, this will also propagate that. + // + // Note that we can't just use `wrap()` here because the canceler is about to be destroyed, + // which would immediately cancel the operation. Luckily, we don't actually need to wrap this + // promise in the canceler because we can assume that the observer listening on this fulfiller + // will cancel the promise themselves at the same time as calling cancel() on this object. + KJ_IF_SOME(e, error) { + onEnd->reject(kj::mv(*e)); + } else KJ_IF_SOME(ws, webSocket) { + onEnd->fulfill(kj::evalNow([&]() { + return ws.disconnect().attach(kj::mv(ownWebSocket)); + })); } else { - // Just drop the task. - } - } - - kj::Promise finishTasks() { - // This is merged into the final promise, so we don't need to worry about wrapping it for - // cancellation. - return KJ_REQUIRE_NONNULL(tasks).onEmpty() - .then([this]() { - KJ_IF_SOME(e, error) { - kj::throwRecoverableException(kj::mv(e)); - } - }); - } - - void taskFailed(kj::Exception&& exception) override { - if (error == kj::none) { - error = kj::mv(exception); + // cancel() was called -- we assume no one is waiting on the fulfiller } } - void holdWebSocket(kj::Own webSocket) { - // Hold on to this WebSocket until cancellation. - KJ_REQUIRE(this->webSocket == nullptr); - KJ_REQUIRE(tasks != nullptr); - this->webSocket = kj::mv(webSocket); - } - - void disconnectWebSocket() { - KJ_IF_SOME(t, tasks) { - t.add(kj::evalNow([&]() { return KJ_ASSERT_NONNULL(webSocket)->disconnect(); })); + void cancel() { + // Called when the overall HTTP request completes or is canceled while this capability still + // exists. Since we can't force the peer to drop the capability, we have to disable it. + // Further access to `webSocket` must be blocked since it is no longer valid. + // + // Arguably we could instead use capnp::RevocableServer to accomplish something similar. The + // problem is, we also do actually want to know when the peer drops this capability. With + // RevocableServer, we no longer get notification of that -- the destructor runs when we tell + // it to, rather than when the peer drops the cap. + // + // TODO(cleanup): Could RevocableServer be improved to allow us to notice the drop? + // Alternatively, maybe it's not really that important for us to call disconnect() + // proactively, considering: + // - The application can send a close message for explicit end. + // - A client->server disconnect will presumably cancel the whole request anyway. + // - A server->client disconnect will presumably be followed by the server returning from + // the request() RPC. + + selfRef = kj::none; + webSocket = kj::none; + { auto drop = kj::mv(ownWebSocket); } + if (!canceler.isEmpty()) { + canceler.cancel(KJ_EXCEPTION(DISCONNECTED, "request canceled")); } } -private: - kj::Maybe error; - kj::Maybe> webSocket; - kj::Canceler canceler; - kj::Maybe tasks; -}; - -// ======================================================================================= - -class HttpOverCapnpFactory::CapnpToKjWebSocketAdapter final: public capnp::WebSocket::Server { -public: - CapnpToKjWebSocketAdapter(kj::Own state, kj::WebSocket& webSocket, - kj::Promise shorteningPromise) - : state(kj::mv(state)), webSocket(webSocket), - shorteningPromise(kj::mv(shorteningPromise)) {} - - ~CapnpToKjWebSocketAdapter() noexcept(false) { - state->disconnectWebSocket(); - } - kj::Maybe> shortenPath() override { - auto onAbort = webSocket.whenAborted() + auto onAbort = canceler.wrap(KJ_ASSERT_NONNULL(webSocket).whenAborted()) .then([]() -> kj::Promise { return KJ_EXCEPTION(DISCONNECTED, "WebSocket was aborted"); }); @@ -128,20 +121,60 @@ public: } kj::Promise sendText(SendTextContext context) override { - return state->wrap([&]() { return webSocket.send(context.getParams().getText()); }); + return wrap([&](kj::WebSocket& ws) { return ws.send(context.getParams().getText()); }); } kj::Promise sendData(SendDataContext context) override { - return state->wrap([&]() { return webSocket.send(context.getParams().getData()); }); + return wrap([&](kj::WebSocket& ws) { return ws.send(context.getParams().getData()); }); } kj::Promise close(CloseContext context) override { auto params = context.getParams(); - return state->wrap([&]() { return webSocket.close(params.getCode(), params.getReason()); }); + return wrap([&](kj::WebSocket& ws) { return ws.close(params.getCode(), params.getReason()); }); } private: - kj::Own state; - kj::WebSocket& webSocket; + kj::Maybe webSocket; // becomes none when canceled + kj::Own ownWebSocket; kj::Promise shorteningPromise; + kj::Own>> onEnd; + kj::Maybe&> selfRef; + + kj::Canceler canceler; + + kj::Maybe> error; + + kj::WebSocket& getWebSocket() { + return KJ_REQUIRE_NONNULL(webSocket, "request canceled"); + } + + template + kj::Promise wrap(Func&& func) { + KJ_IF_SOME(e, error) { + kj::throwFatalException(kj::cp(*e)); + } + + // Detect cancellation (of the operation) and mark the object broken in this case. + bool done = false; + KJ_DEFER({ + if (!done && error == kj::none) { + error = kj::heap(KJ_EXCEPTION(FAILED, + "a write was canceled before completing, breaking the WebSocket")); + } + }); + + try { + KJ_IF_SOME(ws, webSocket) { + co_await canceler.wrap(func(ws)); + } else { + kj::throwFatalException(KJ_EXCEPTION(DISCONNECTED, "request canceled")); + } + } catch (...) { + auto e = kj::getCaughtExceptionAsKj(); + error = kj::heap(kj::cp(e)); + kj::throwFatalException(kj::mv(e)); + } + + done = true; + } }; class HttpOverCapnpFactory::KjToCapnpWebSocketAdapter final: public kj::WebSocket { @@ -234,6 +267,12 @@ public: uint64_t sentByteCount() override { return sentBytes; } uint64_t receivedByteCount() override { return KJ_ASSERT_NONNULL(in)->receivedByteCount(); } + kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { + // TODO(someday): Optimzed pump is tricky with HttpOverCapnp, we may want to revist + // this but for now we always return none (indicating no preference). + return kj::none; + }; + private: kj::Maybe> in; // One end of a WebSocketPipe, used only for receiving. kj::Maybe out; // Used only for sending. @@ -247,18 +286,17 @@ class HttpOverCapnpFactory::ClientRequestContextImpl final : public capnp::HttpService::ClientRequestContext::Server { public: ClientRequestContextImpl(HttpOverCapnpFactory& factory, - kj::Own state, kj::HttpService::Response& kjResponse) - : factory(factory), state(kj::mv(state)), kjResponse(kjResponse) {} + : factory(factory), kjResponse(kjResponse) {} ~ClientRequestContextImpl() noexcept(false) { - // Note this implicitly cancels the upstream pump task. + KJ_IF_SOME(ws, maybeWebSocket) { + ws.cancel(); + } } kj::Promise startResponse(StartResponseContext context) override { - KJ_REQUIRE(!sent, "already called startResponse() or startWebSocket()"); - sent = true; - state->assertNotCanceled(); + KJ_REQUIRE(responsePumpTask == kj::none, "already called startResponse() or startWebSocket()"); auto params = context.getParams(); auto rpcResponse = params.getResponse(); @@ -279,52 +317,74 @@ public: if (hasBody) { auto pipe = kj::newOneWayPipe(); results.setBody(factory.streamFactory.kjToCapnp(kj::mv(pipe.out))); - state->addTask(pipe.in->pumpTo(*bodyStream) + responsePumpTask = pipe.in->pumpTo(*bodyStream) .ignoreResult() - .attach(kj::mv(bodyStream), kj::mv(pipe.in))); + .attach(kj::mv(bodyStream), kj::mv(pipe.in)); } return kj::READY_NOW; } kj::Promise startWebSocket(StartWebSocketContext context) override { - KJ_REQUIRE(!sent, "already called startResponse() or startWebSocket()"); - sent = true; - state->assertNotCanceled(); + KJ_REQUIRE(responsePumpTask == kj::none, "already called startResponse() or startWebSocket()"); auto params = context.getParams(); auto shorteningPaf = kj::newPromiseAndFulfiller>(); - auto ownWebSocket = kjResponse.acceptWebSocket(factory.headersToKj(params.getHeaders())); - auto& webSocket = *ownWebSocket; - state->holdWebSocket(kj::mv(ownWebSocket)); + auto webSocket = kjResponse.acceptWebSocket(factory.headersToKj(params.getHeaders())); auto upWrapper = kj::heap( nullptr, params.getUpSocket(), kj::mv(shorteningPaf.fulfiller)); - state->addTask(webSocket.pumpTo(*upWrapper).attach(kj::mv(upWrapper)) - .catch_([&webSocket=webSocket](kj::Exception&& e) -> kj::Promise { + auto upPumpTask = webSocket->pumpTo(*upWrapper).attach(kj::mv(upWrapper)) + .catch_([&webSocket=*webSocket](kj::Exception&& e) -> kj::Promise { // The pump in the client -> server direction failed. The error may have originated from // either the client or the server. In case it came from the server, we want to call .abort() // to propagate the problem back to the client. If the error came from the client, then // .abort() probably is a noop. webSocket.abort(); return kj::mv(e); - })); + }); auto results = context.getResults(MessageSize { 16, 1 }); - results.setDownSocket(kj::heap( - kj::addRef(*state), webSocket, kj::mv(shorteningPaf.promise))); + auto downPaf = kj::newPromiseAndFulfiller>(); + auto downSocket = kj::heap( + *webSocket, kj::mv(shorteningPaf.promise), kj::mv(downPaf.fulfiller), maybeWebSocket); + results.setDownSocket(kj::mv(downSocket)); + + // Note: This intentionally uses joinPromises and not joinPromisesFailFast, because + // finishPump() isn't called to wait on the final promise until we already expect both + // promises to be done anyway, so if we cancel one as a result of the other failing, it + // provides no benefit and possibly creates confusion. + responsePumpTask = kj::joinPromises(kj::arr(kj::mv(upPumpTask), kj::mv(downPaf.promise))); + + // We need to hold onto this WebSocket until `CapnpToKjWebSocketAdapter` is canceled or + // destroyed. If `responsePumpTask`completes successfully, then `CapnpToKjWebSocketAdapter` + // has to have been destroyed, since `downPaf.promise` doesn't resolve until that point. But + // in the case of request cancellation, it is our own destructor that will call `cancel()` + // on the `CapnpToKjWebSocketAdapter`, so we should make sure the `webSocket` outlives that. + // + // (Additionally, the WebSocket must outlive `responsePumpTask` itself, even when it is + // canceled.) + ownWebSocket = kj::mv(webSocket); return kj::READY_NOW; } + kj::Promise finishPump() { + KJ_IF_SOME(r, responsePumpTask) { + return kj::mv(r); + } else { + return kj::READY_NOW; + } + } + private: HttpOverCapnpFactory& factory; - kj::Own state; - bool sent = false; + kj::Maybe> ownWebSocket; + kj::Maybe> responsePumpTask; + kj::Maybe maybeWebSocket; kj::HttpService::Response& kjResponse; - // Must check state->assertNotCanceled() before using this. }; class HttpOverCapnpFactory::ConnectClientRequestContextImpl final @@ -382,17 +442,10 @@ public: KjToCapnpHttpServiceAdapter(HttpOverCapnpFactory& factory, capnp::HttpService::Client inner) : factory(factory), inner(kj::mv(inner)) {} - template - kj::Promise requestImpl( - Request rpcRequest, + kj::Promise request( kj::HttpMethod method, kj::StringPtr url, const kj::HttpHeaders& headers, - kj::AsyncInputStream& requestBody, kj::HttpService::Response& kjResponse, - AwaitCompletionFunc&& awaitCompletion) { - // Common implementation calling request() or startRequest(). awaitCompletion() waits for - // final completion in a method-specific way. - // - // TODO(cleanup): When we move to C++17 or newer we can use `if constexpr` instead of a - // callback. + kj::AsyncInputStream& requestBody, kj::HttpService::Response& kjResponse) override { + auto rpcRequest = inner.requestRequest(); auto metadata = rpcRequest.initRequest(); metadata.setMethod(static_cast(method)); @@ -418,24 +471,25 @@ public: maybeRequestBody = requestBody; } - auto state = kj::refcounted(); - auto deferredCancel = kj::defer([state = kj::addRef(*state)]() mutable { - state->cancel(); - }); + ClientRequestContextImpl context(factory, kjResponse); + RevocableServer revocableContext(context); - rpcRequest.setContext( - kj::heap(factory, kj::addRef(*state), kjResponse)); + rpcRequest.setContext(revocableContext.getClient()); auto pipeline = rpcRequest.send(); + // Make sure the request message isn't pinned into memory through the co_await below. + { auto drop = kj::mv(rpcRequest); } + // Pump upstream -- unless we don't expect a request body. + kj::Maybe> pumpRequestFailedReason; kj::Maybe> pumpRequestTask; KJ_IF_SOME(rb, maybeRequestBody) { auto bodyOut = factory.streamFactory.capnpToKjExplicitEnd(pipeline.getRequestBody()); pumpRequestTask = rb.pumpTo(*bodyOut) .then([&bodyOut = *bodyOut](uint64_t) mutable { return bodyOut.end(); - }).eagerlyEvaluate([state = kj::addRef(*state), bodyOut = kj::mv(bodyOut)] + }).eagerlyEvaluate([&pumpRequestFailedReason, bodyOut = kj::mv(bodyOut)] (kj::Exception&& e) mutable { // A DISCONNECTED exception probably means the server decided not to read the whole request // before responding. In that case we simply want the pump to end, so that on this end it @@ -443,34 +497,28 @@ public: // exception in that case. For any other exception, we want to merge the exception with // the final result. if (e.getType() != kj::Exception::Type::DISCONNECTED) { - state->taskFailed(kj::mv(e)); + pumpRequestFailedReason = kj::heap(kj::mv(e)); } }); } // Wait for the server to indicate completion. Meanwhile, if the - // promise is canceled from the client side, we propagate cancellation naturally, and we - // also call state->cancel(). - return awaitCompletion(pipeline) - // Once the server indicates it is done, then we can cancel pumping the request, because - // obviously the server won't use it. We should not cancel pumping the response since there - // could be data in-flight still. - .attach(kj::mv(pumpRequestTask)) - // finishTasks() will wait for the respones to complete. - .then([state = kj::mv(state)]() mutable { return state->finishTasks(); }) - .attach(kj::mv(deferredCancel)); - } - - kj::Promise request( - kj::HttpMethod method, kj::StringPtr url, const kj::HttpHeaders& headers, - kj::AsyncInputStream& requestBody, kj::HttpService::Response& kjResponse) override { - if (factory.peerOptimizationLevel < LEVEL_2) { - return requestImpl(inner.startRequestRequest(), method, url, headers, requestBody, kjResponse, - [](auto& pipeline) { return pipeline.getContext().whenResolved(); }); - } else { - return requestImpl(inner.requestRequest(), method, url, headers, requestBody, kjResponse, - [](auto& pipeline) { return pipeline.ignoreResult(); }); + // promise is canceled from the client side, we propagate cancellation naturally. + co_await pipeline.ignoreResult(); + + // Once the server indicates it is done, then we can cancel pumping the request, because + // obviously the server won't use it. We should not cancel pumping the response since there + // could be data in-flight still. + { auto drop = kj::mv(pumpRequestTask); } + + // If the request pump failed (for a non-disconnect reason) we'd better propagate that + // exception. + KJ_IF_SOME(e, pumpRequestFailedReason) { + kj::throwFatalException(kj::mv(*e)); } + + // Finish pumping the response or WebSocket. (Probably it's already finished.) + co_await context.finishPump(); } kj::Promise connect( @@ -482,8 +530,8 @@ public: rpcRequest.setDown(factory.streamFactory.kjToCapnp(kj::mv(downPipe.out))); rpcRequest.initSettings().setUseTls(settings.useTls); - auto context = kj::heap(factory, tunnel); - RevocableServer revocableContext(*context); + ConnectClientRequestContextImpl context(factory, tunnel); + RevocableServer revocableContext(context); auto builder = capnp::Request< capnp::HttpService::ConnectParams, @@ -493,6 +541,9 @@ public: rpcRequest.setContext(revocableContext.getClient()); RemotePromise pipeline = rpcRequest.send(); + // Make sure the request message isn't pinned into memory through the co_await below. + { auto drop = kj::mv(rpcRequest); } + // We read from `downPipe` (the other side writes into it.) auto downPumpTask = downPipe.in->pumpTo(connection) .then([&connection, down = kj::mv(downPipe.in)](uint64_t) -> kj::Promise { @@ -524,10 +575,7 @@ public: return kj::NEVER_DONE; }); - return pipeline.ignoreResult() - .attach(kj::mv(downPumpTask), kj::mv(upPumpTask), kj::mv(revocableContext)) - // Separate attach to make sure `revocableContext` is destroyed before `context`. - .attach(kj::mv(context)); + co_await pipeline.ignoreResult(); } @@ -542,67 +590,26 @@ kj::Own HttpOverCapnpFactory::capnpToKj(capnp::HttpService::Cli // ======================================================================================= -namespace { - -class NullInputStream final: public kj::AsyncInputStream { - // TODO(cleanup): This class has been replicated in a bunch of places now, make it public - // somewhere. - -public: - kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { - return kj::constPromise(); - } - - kj::Maybe tryGetLength() override { - return uint64_t(0); - } - - kj::Promise pumpTo(kj::AsyncOutputStream& output, uint64_t amount) override { - return kj::constPromise(); - } -}; - -class NullOutputStream final: public kj::AsyncOutputStream { - // TODO(cleanup): This class has been replicated in a bunch of places now, make it public - // somewhere. - -public: - kj::Promise write(const void* buffer, size_t size) override { - return kj::READY_NOW; - } - kj::Promise write(kj::ArrayPtr> pieces) override { - return kj::READY_NOW; - } - kj::Promise whenWriteDisconnected() override { - return kj::NEVER_DONE; - } - - // We can't really optimize tryPumpFrom() unless AsyncInputStream grows a skip() method. -}; - -class ResolvedServerRequestContext final: public capnp::HttpService::ServerRequestContext::Server { -public: - // Nothing! It's done. -}; - -} // namespace - -class HttpOverCapnpFactory::HttpServiceResponseImpl +class HttpOverCapnpFactory::HttpServiceResponseImpl final : public kj::HttpService::Response { public: HttpServiceResponseImpl(HttpOverCapnpFactory& factory, capnp::HttpRequest::Reader request, - capnp::HttpService::ClientRequestContext::Client clientContext) + capnp::HttpService::ClientRequestContext::Client clientContext, + kj::Own>> replyFulfiller) : factory(factory), method(validateMethod(request.getMethod())), url(request.getUrl()), headers(factory.headersToKj(request.getHeaders())), - clientContext(kj::mv(clientContext)) {} + clientContext(kj::mv(clientContext)), + replyFulfiller(kj::mv(replyFulfiller)) {} + // `replyFulfiller` is eventulaly fulfilled with the task that that sends the reply back to + // the client. kj::Own send( uint statusCode, kj::StringPtr statusText, const kj::HttpHeaders& headers, kj::Maybe expectedBodySize = kj::none) override { - KJ_REQUIRE(replyTask == nullptr, "already called send() or acceptWebSocket()"); + KJ_REQUIRE(replyFulfiller->isWaiting(), "already called send() or acceptWebSocket()"); auto req = clientContext.startResponseRequest(); @@ -622,26 +629,19 @@ public: hasBody = s > 0; } - auto logError = [hasBody](kj::Exception&& e) { - KJ_LOG(INFO, "HTTP-over-RPC startResponse() failed", hasBody, e); - }; if (hasBody) { auto pipeline = req.send(); auto result = factory.streamFactory.capnpToKj(pipeline.getBody()); - replyTask = pipeline.ignoreResult().eagerlyEvaluate(kj::mv(logError)); + replyFulfiller->fulfill(pipeline.ignoreResult()); return result; } else { - replyTask = req.send().ignoreResult().eagerlyEvaluate(kj::mv(logError)); - return kj::heap(); + replyFulfiller->fulfill(req.send().ignoreResult()); + return kj::heap(); } - - // We don't actually wait for replyTask anywhere, because we may be all done with this HTTP - // message before the client gets a chance to respond, and we don't want to force an extra - // network round trip. If the client fails this call that's the client's problem, really. } kj::Own acceptWebSocket(const kj::HttpHeaders& headers) override { - KJ_REQUIRE(replyTask == nullptr, "already called send() or acceptWebSocket()"); + KJ_REQUIRE(replyFulfiller->isWaiting(), "already called send() or acceptWebSocket()"); auto req = clientContext.startWebSocketRequest(); @@ -652,27 +652,22 @@ public: auto pipe = kj::newWebSocketPipe(); auto shorteningPaf = kj::newPromiseAndFulfiller>(); - // We don't need the RequestState mechanism on the server side because - // CapnpToKjWebSocketAdapter wraps a pipe end, and that pipe end can continue to exist beyond - // the lifetime of the request, because the other end will have been dropped. We only create - // a RequestState here so that we can reuse the implementation of CapnpToKjWebSocketAdapter - // that needs this for the client side. - auto dummyState = kj::refcounted(); - auto& pipeEnd0Ref = *pipe.ends[0]; - dummyState->holdWebSocket(kj::mv(pipe.ends[0])); + // Note that since CapnpToKjWebSocketAdapter takes ownership of the pipe end, we don't need + // to cancel it later. Dropping the other end of the pipe will have the same effect. + auto upPumpPaf = kj::newPromiseAndFulfiller>(); req.setUpSocket(kj::heap( - kj::mv(dummyState), pipeEnd0Ref, kj::mv(shorteningPaf.promise))); + kj::mv(pipe.ends[0]), kj::mv(shorteningPaf.promise), kj::mv(upPumpPaf.fulfiller))); auto pipeline = req.send(); auto result = kj::heap( kj::mv(pipe.ends[1]), pipeline.getDownSocket(), kj::mv(shorteningPaf.fulfiller)); - // Note we need eagerlyEvaluate() here to force proactively discarding the response object, - // since it holds a reference to `downSocket`. - replyTask = pipeline.ignoreResult() - .eagerlyEvaluate([](kj::Exception&& e) { - KJ_LOG(INFO, "HTTP-over-RPC startWebSocketRequest() failed", e); - }); + replyFulfiller->fulfill(pipeline.ignoreResult() + .then([upPumpTask = kj::mv(upPumpPaf.promise)]() mutable { + // We need to continue pumping the WebSocket in the client->server direction, so let's do + // that as part of the reply task. + return kj::mv(upPumpTask); + })); return result; } @@ -682,7 +677,7 @@ public: kj::StringPtr url; kj::HttpHeaders headers; capnp::HttpService::ClientRequestContext::Client clientContext; - kj::Maybe> replyTask; + kj::Own>> replyFulfiller; static kj::HttpMethod validateMethod(capnp::HttpMethod method) { KJ_REQUIRE(method <= capnp::HttpMethod::UNSUBSCRIBE, "unknown method", method); @@ -750,50 +745,12 @@ public: }; -class HttpOverCapnpFactory::ServerRequestContextImpl final - : public capnp::HttpService::ServerRequestContext::Server, - public HttpServiceResponseImpl { -public: - ServerRequestContextImpl(HttpOverCapnpFactory& factory, - HttpService::Client serviceCap, - kj::Own request, - capnp::HttpService::ClientRequestContext::Client clientContext, - kj::Own requestBodyIn, - kj::HttpService& kjService) - : HttpServiceResponseImpl(factory, *request, kj::mv(clientContext)), - request(kj::mv(request)), - serviceCap(kj::mv(serviceCap)), - // Note we attach `requestBodyIn` to `task` so that we will implicitly cancel reading - // the request body as soon as the service returns. This is important in particular when - // the request body is not fully consumed, in order to propagate cancellation. - task(kjService.request(method, url, headers, *requestBodyIn, *this) - .attach(kj::mv(requestBodyIn))) {} - - kj::Maybe> shortenPath() override { - return task.then([]() -> Capability::Client { - // If all went well, resolve to a settled capability. - // TODO(perf): Could save a message by resolving to a capability hosted by the client, or - // some special "null" capability that isn't an error but is still transmitted by value. - // Otherwise we need a Release message from client -> server just to drop this... - return kj::heap(); - }); - } - - KJ_DISALLOW_COPY_AND_MOVE(ServerRequestContextImpl); - -private: - kj::Own request; - HttpService::Client serviceCap; // ensures the inner kj::HttpService isn't destroyed - kj::Promise task; -}; - class HttpOverCapnpFactory::CapnpToKjHttpServiceAdapter final: public capnp::HttpService::Server { public: CapnpToKjHttpServiceAdapter(HttpOverCapnpFactory& factory, kj::Own inner) : factory(factory), inner(kj::mv(inner)) {} - template - kj::Promise requestImpl(CallContext context, Callback&& callback) { + kj::Promise request(RequestContext context) override { // Common implementation of request() and startRequest(). callback() performs the // method-specific stuff at the end. // @@ -818,46 +775,33 @@ public: auto pipe = kj::newOneWayPipe(expectedSize); auto requestBodyCap = factory.streamFactory.kjToCapnp(kj::mv(pipe.out)); - if (kj::isSameType()) { - // For request(), use context.setPipeline() to enable pipelined calls to the request body - // stream before this RPC completes. (We don't bother when using startRequest() because - // it returns immediately anyway, so this would just waste effort.) - PipelineBuilder pipeline; - pipeline.setRequestBody(kj::cp(requestBodyCap)); - context.setPipeline(pipeline.build()); - } + // For request(), use context.setPipeline() to enable pipelined calls to the request body + // stream before this RPC completes. + PipelineBuilder pipeline; + pipeline.setRequestBody(kj::cp(requestBodyCap)); + context.setPipeline(pipeline.build()); results.setRequestBody(kj::mv(requestBodyCap)); requestBody = kj::mv(pipe.in); } else { - requestBody = kj::heap(); + requestBody = kj::heap(); } - return callback(results, metadata, params, requestBody); - } - - kj::Promise request(RequestContext context) override { - return requestImpl(kj::mv(context), - [&](auto& results, auto& metadata, auto& params, auto& requestBody) { - class FinalHttpServiceResponseImpl final: public HttpServiceResponseImpl { - public: - using HttpServiceResponseImpl::HttpServiceResponseImpl; - }; - auto impl = kj::heap(factory, metadata, params.getContext()); - auto promise = inner->request(impl->method, impl->url, impl->headers, *requestBody, *impl); - return promise.attach(kj::mv(requestBody), kj::mv(impl)); + auto replyPaf = kj::newPromiseAndFulfiller>(); + auto replyPromise = replyPaf.promise.then([]() -> kj::Promise { + // The reply may complete before the request promise does. We don't want to cancel the + // request if the reply completed successfully, so return NEVER_DONE here so that the + // exclusiveJoin() below becomes a no-op. + // + // On the other hand, if the reply throws an exception, we want to cancel the request and + // propagate that exception immediately! + return kj::NEVER_DONE; }); - } - kj::Promise startRequest(StartRequestContext context) override { - return requestImpl(kj::mv(context), - [&](auto& results, auto& metadata, auto& params, auto& requestBody) { - results.setContext(kj::heap( - factory, thisCap(), capnp::clone(metadata), params.getContext(), kj::mv(requestBody), - *inner)); - - return kj::READY_NOW; - }); + HttpServiceResponseImpl impl( + factory, metadata, params.getContext(), kj::mv(replyPaf.fulfiller)); + co_await inner->request(impl.method, impl.url, impl.headers, *requestBody, impl) + .exclusiveJoin(kj::mv(replyPromise)); } kj::Promise connect(ConnectContext context) override { @@ -914,19 +858,20 @@ public: return kj::NEVER_DONE; }); - PipelineBuilder pb; - auto eofWrapper = kj::heap(kj::mv(ref2)); - auto up = factory.streamFactory.kjToCapnp(kj::mv(eofWrapper), kj::mv(tlsStarter)); - pb.setUp(kj::cp(up)); + { + PipelineBuilder pb; + auto eofWrapper = kj::heap(kj::mv(ref2)); + auto up = factory.streamFactory.kjToCapnp(kj::mv(eofWrapper), kj::mv(tlsStarter)); + pb.setUp(kj::cp(up)); - context.setPipeline(pb.build()); - context.initResults(capnp::MessageSize { 4, 1 }).setUp(kj::mv(up)); + context.setPipeline(pb.build()); + context.initResults(capnp::MessageSize { 4, 1 }).setUp(kj::mv(up)); + } - auto response = kj::heap( - factory, context.getParams().getContext()); + { auto drop = kj::mv(refcounted); } - return inner->connect(host, headers, *pipe.ends[0], *response, settings).attach( - kj::mv(host), kj::mv(headers), kj::mv(response), kj::mv(pipe)) + HttpOverCapnpConnectResponseImpl response(factory, context.getParams().getContext()); + co_await inner->connect(host, headers, *pipe.ends[0], response, settings) .exclusiveJoin(kj::mv(pumpTask)); } diff --git a/c++/src/capnp/compat/http-over-capnp.h b/c++/src/capnp/compat/http-over-capnp.h index 6b16749118..6a428c1c65 100644 --- a/c++/src/capnp/compat/http-over-capnp.h +++ b/c++/src/capnp/compat/http-over-capnp.h @@ -57,8 +57,11 @@ class HttpOverCapnpFactory { // will improve efficiency but breaks compatibility with older peers that don't implement newer // levels. - LEVEL_1, - // Use startRequest(), the original version of the protocol. + // There used to be a LEVEL_1, which used `startRequest()`, the original version of the + // protocol. Support for this level was removed in the v2 branch in order to simplify the code. + // If you have existing servers in the wild implementing this protocol that don't support + // LEVEL_2, then your clients will have to stick to Cap'n Proto 1.x until those servers are all + // updated. LEVEL_2 // Use request(). This is more efficient than startRequest() but won't work with old peers that @@ -66,7 +69,12 @@ class HttpOverCapnpFactory { }; HttpOverCapnpFactory(ByteStreamFactory& streamFactory, HeaderIdBundle headerIds, - OptimizationLevel peerOptimizationLevel = LEVEL_1); + OptimizationLevel peerOptimizationLevel); + // Note: `peerOptimizationLevel` use to be optional, but defaulted to LEVEL_1. However, any + // client still setting this to LEVEL_1 will be unable to talk to any server who is running new + // code where LEVEL_1 was removed. So if you hit a compile error because your code is not setting + // this option, you will need to roll back to an older version of Cap'n Proto for now, until you + // can update all code in production to pass LEVEL_2 here. kj::Own capnpToKj(capnp::HttpService::Client rpcService); capnp::HttpService::Client kjToCapnp(kj::Own service); @@ -80,8 +88,6 @@ class HttpOverCapnpFactory { kj::Array valueCapnpToKj; kj::HashMap valueKjToCapnp; - class RequestState; - class CapnpToKjWebSocketAdapter; class KjToCapnpWebSocketAdapter; diff --git a/c++/src/capnp/compat/json-rpc.c++ b/c++/src/capnp/compat/json-rpc.c++ index 2a8b88d449..92e1ad46f9 100644 --- a/c++/src/capnp/compat/json-rpc.c++ +++ b/c++/src/capnp/compat/json-rpc.c++ @@ -161,7 +161,7 @@ void JsonRpc::queueError(kj::Maybe id, int code, kj::String error.setMessage(message); // OK to discard result of queueWrite() since it's just one branch of a fork. - queueWrite(codec.encode(jsonResponse)); + (void)queueWrite(codec.encode(jsonResponse)); } kj::Promise JsonRpc::readLoop() { diff --git a/c++/src/capnp/compat/json-rpc.capnp b/c++/src/capnp/compat/json-rpc.capnp index 9380788cd7..33e47e0786 100644 --- a/c++/src/capnp/compat/json-rpc.capnp +++ b/c++/src/capnp/compat/json-rpc.capnp @@ -2,7 +2,7 @@ $import "/capnp/c++.capnp".namespace("capnp::json"); -using Json = import "json.capnp"; +using Json = import "/capnp/compat/json.capnp"; struct RpcMessage { jsonrpc @0 :Text; diff --git a/c++/src/capnp/compiler/capnp.c++ b/c++/src/capnp/compiler/capnp.c++ index 01e15df8c9..09553f64f2 100644 --- a/c++/src/capnp/compiler/capnp.c++ +++ b/c++/src/capnp/compiler/capnp.c++ @@ -39,7 +39,6 @@ #include #include #include "../message.h" -#include #include #include #include diff --git a/c++/src/capnp/compiler/compiler.c++ b/c++/src/capnp/compiler/compiler.c++ index 49c9b3a83f..c0744d1908 100644 --- a/c++/src/capnp/compiler/compiler.c++ +++ b/c++/src/capnp/compiler/compiler.c++ @@ -256,7 +256,7 @@ private: Node rootNode; }; -class Compiler::Impl: public SchemaLoader::LazyLoadCallback { +class Compiler::Impl final : public SchemaLoader::LazyLoadCallback { public: explicit Impl(AnnotationFlag annotationFlag); virtual ~Impl() noexcept(false); diff --git a/c++/src/capnp/ez-rpc-test.c++ b/c++/src/capnp/ez-rpc-test.c++ deleted file mode 100644 index 0cd2fdbb5e..0000000000 --- a/c++/src/capnp/ez-rpc-test.c++ +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors -// Licensed under the MIT License: -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -// THE SOFTWARE. - -#define CAPNP_TESTING_CAPNP 1 - -#include "ez-rpc.h" -#include "test-util.h" -#include - -namespace capnp { -namespace _ { -namespace { - -TEST(EzRpc, Basic) { - int callCount = 0; - EzRpcServer server(kj::heap(callCount), "localhost"); - - EzRpcClient client("localhost", server.getPort().wait(server.getWaitScope())); - - auto cap = client.getMain(); - auto request = cap.fooRequest(); - request.setI(123); - request.setJ(true); - - EXPECT_EQ(0, callCount); - auto response = request.send().wait(server.getWaitScope()); - EXPECT_EQ("foo", response.getX()); - EXPECT_EQ(1, callCount); -} - -TEST(EzRpc, DeprecatedNames) { - EzRpcServer server("localhost"); - int callCount = 0; - server.exportCap("cap1", kj::heap(callCount)); - server.exportCap("cap2", kj::heap()); - - EzRpcClient client("localhost", server.getPort().wait(server.getWaitScope())); - - auto cap = client.importCap("cap1"); - auto request = cap.fooRequest(); - request.setI(123); - request.setJ(true); - - EXPECT_EQ(0, callCount); - auto response = request.send().wait(server.getWaitScope()); - EXPECT_EQ("foo", response.getX()); - EXPECT_EQ(1, callCount); - - EXPECT_EQ(0, client.importCap("cap2").castAs() - .getCallSequenceRequest().send().wait(server.getWaitScope()).getN()); - EXPECT_EQ(1, client.importCap("cap2").castAs() - .getCallSequenceRequest().send().wait(server.getWaitScope()).getN()); -} - -} // namespace -} // namespace _ -} // namespace capnp diff --git a/c++/src/capnp/ez-rpc.c++ b/c++/src/capnp/ez-rpc.c++ deleted file mode 100644 index 6add953b6c..0000000000 --- a/c++/src/capnp/ez-rpc.c++ +++ /dev/null @@ -1,368 +0,0 @@ -// Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors -// Licensed under the MIT License: -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -// THE SOFTWARE. - -#include "ez-rpc.h" -#include "rpc-twoparty.h" -#include -#include -#include -#include -#include - -namespace capnp { - -KJ_THREADLOCAL_PTR(EzRpcContext) threadEzContext = nullptr; - -class EzRpcContext: public kj::Refcounted { -public: - EzRpcContext(): ioContext(kj::setupAsyncIo()) { - threadEzContext = this; - } - - ~EzRpcContext() noexcept(false) { - KJ_REQUIRE(threadEzContext == this, - "EzRpcContext destroyed from different thread than it was created.") { - return; - } - threadEzContext = nullptr; - } - - kj::WaitScope& getWaitScope() { - return ioContext.waitScope; - } - - kj::AsyncIoProvider& getIoProvider() { - return *ioContext.provider; - } - - kj::LowLevelAsyncIoProvider& getLowLevelIoProvider() { - return *ioContext.lowLevelProvider; - } - - static kj::Own getThreadLocal() { - EzRpcContext* existing = threadEzContext; - if (existing != nullptr) { - return kj::addRef(*existing); - } else { - return kj::refcounted(); - } - } - -private: - kj::AsyncIoContext ioContext; -}; - -// ======================================================================================= - -kj::Promise> connectAttach(kj::Own&& addr) { - return addr->connect().attach(kj::mv(addr)); -} - -struct EzRpcClient::Impl { - kj::Own context; - - struct ClientContext { - kj::Own stream; - TwoPartyVatNetwork network; - RpcSystem rpcSystem; - - ClientContext(kj::Own&& stream, ReaderOptions readerOpts) - : stream(kj::mv(stream)), - network(*this->stream, rpc::twoparty::Side::CLIENT, readerOpts), - rpcSystem(makeRpcClient(network)) {} - - Capability::Client getMain() { - word scratch[4]; - memset(scratch, 0, sizeof(scratch)); - MallocMessageBuilder message(scratch); - auto hostId = message.getRoot(); - hostId.setSide(rpc::twoparty::Side::SERVER); - return rpcSystem.bootstrap(hostId); - } - - Capability::Client restore(kj::StringPtr name) { - word scratch[64]; - memset(scratch, 0, sizeof(scratch)); - MallocMessageBuilder message(scratch); - - auto hostIdOrphan = message.getOrphanage().newOrphan(); - auto hostId = hostIdOrphan.get(); - hostId.setSide(rpc::twoparty::Side::SERVER); - - auto objectId = message.getRoot(); - objectId.setAs(name); -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wdeprecated-declarations" - return rpcSystem.restore(hostId, objectId); -#pragma GCC diagnostic pop - } - }; - - kj::ForkedPromise setupPromise; - - kj::Maybe> clientContext; - // Filled in before `setupPromise` resolves. - - Impl(kj::StringPtr serverAddress, uint defaultPort, - ReaderOptions readerOpts) - : context(EzRpcContext::getThreadLocal()), - setupPromise(context->getIoProvider().getNetwork() - .parseAddress(serverAddress, defaultPort) - .then([](kj::Own&& addr) { - return connectAttach(kj::mv(addr)); - }).then([this, readerOpts](kj::Own&& stream) { - clientContext = kj::heap(kj::mv(stream), - readerOpts); - }).fork()) {} - - Impl(const struct sockaddr* serverAddress, uint addrSize, - ReaderOptions readerOpts) - : context(EzRpcContext::getThreadLocal()), - setupPromise( - connectAttach(context->getIoProvider().getNetwork() - .getSockaddr(serverAddress, addrSize)) - .then([this, readerOpts](kj::Own&& stream) { - clientContext = kj::heap(kj::mv(stream), - readerOpts); - }).fork()) {} - - Impl(int socketFd, ReaderOptions readerOpts) - : context(EzRpcContext::getThreadLocal()), - setupPromise(kj::Promise(kj::READY_NOW).fork()), - clientContext(kj::heap( - context->getLowLevelIoProvider().wrapSocketFd(socketFd), - readerOpts)) {} -}; - -EzRpcClient::EzRpcClient(kj::StringPtr serverAddress, uint defaultPort, ReaderOptions readerOpts) - : impl(kj::heap(serverAddress, defaultPort, readerOpts)) {} - -EzRpcClient::EzRpcClient(const struct sockaddr* serverAddress, uint addrSize, ReaderOptions readerOpts) - : impl(kj::heap(serverAddress, addrSize, readerOpts)) {} - -EzRpcClient::EzRpcClient(int socketFd, ReaderOptions readerOpts) - : impl(kj::heap(socketFd, readerOpts)) {} - -EzRpcClient::~EzRpcClient() noexcept(false) {} - -Capability::Client EzRpcClient::getMain() { - KJ_IF_SOME(client, impl->clientContext) { - return client->getMain(); - } else { - return impl->setupPromise.addBranch().then([this]() { - return KJ_ASSERT_NONNULL(impl->clientContext)->getMain(); - }); - } -} - -Capability::Client EzRpcClient::importCap(kj::StringPtr name) { - KJ_IF_SOME(client, impl->clientContext) { - return client->restore(name); - } else { - return impl->setupPromise.addBranch().then( - [this,name=kj::heapString(name)]() { - return KJ_ASSERT_NONNULL(impl->clientContext)->restore(name); - }); - } -} - -kj::WaitScope& EzRpcClient::getWaitScope() { - return impl->context->getWaitScope(); -} - -kj::AsyncIoProvider& EzRpcClient::getIoProvider() { - return impl->context->getIoProvider(); -} - -kj::LowLevelAsyncIoProvider& EzRpcClient::getLowLevelIoProvider() { - return impl->context->getLowLevelIoProvider(); -} - -// ======================================================================================= - -namespace { - -class DummyFilter: public kj::LowLevelAsyncIoProvider::NetworkFilter { -public: - bool shouldAllow(const struct sockaddr* addr, uint addrlen) override { - return true; - } -}; - -static DummyFilter DUMMY_FILTER; - -} // namespace - -struct EzRpcServer::Impl final: public SturdyRefRestorer, - public kj::TaskSet::ErrorHandler { - Capability::Client mainInterface; - kj::Own context; - - struct ExportedCap { - kj::String name; - Capability::Client cap = nullptr; - - ExportedCap(kj::StringPtr name, Capability::Client cap) - : name(kj::heapString(name)), cap(cap) {} - - ExportedCap() = default; - ExportedCap(const ExportedCap&) = delete; - ExportedCap(ExportedCap&&) = default; - ExportedCap& operator=(const ExportedCap&) = delete; - ExportedCap& operator=(ExportedCap&&) = default; - // Make std::map happy... - }; - - std::map exportMap; - - kj::ForkedPromise portPromise; - - kj::TaskSet tasks; - - struct ServerContext { - kj::Own stream; - TwoPartyVatNetwork network; - RpcSystem rpcSystem; - -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wdeprecated-declarations" - ServerContext(kj::Own&& stream, SturdyRefRestorer& restorer, - ReaderOptions readerOpts) - : stream(kj::mv(stream)), - network(*this->stream, rpc::twoparty::Side::SERVER, readerOpts), - rpcSystem(makeRpcServer(network, restorer)) {} -#pragma GCC diagnostic pop - }; - - Impl(Capability::Client mainInterface, kj::StringPtr bindAddress, uint defaultPort, - ReaderOptions readerOpts) - : mainInterface(kj::mv(mainInterface)), - context(EzRpcContext::getThreadLocal()), portPromise(nullptr), tasks(*this) { - auto paf = kj::newPromiseAndFulfiller(); - portPromise = paf.promise.fork(); - - tasks.add(context->getIoProvider().getNetwork().parseAddress(bindAddress, defaultPort) - .then([this, portFulfiller=kj::mv(paf.fulfiller), readerOpts](kj::Own&& addr) mutable { - auto listener = addr->listen(); - portFulfiller->fulfill(listener->getPort()); - acceptLoop(kj::mv(listener), readerOpts); - })); - } - - Impl(Capability::Client mainInterface, struct sockaddr* bindAddress, uint addrSize, - ReaderOptions readerOpts) - : mainInterface(kj::mv(mainInterface)), - context(EzRpcContext::getThreadLocal()), portPromise(nullptr), tasks(*this) { - auto listener = context->getIoProvider().getNetwork() - .getSockaddr(bindAddress, addrSize)->listen(); - portPromise = kj::Promise(listener->getPort()).fork(); - acceptLoop(kj::mv(listener), readerOpts); - } - - Impl(Capability::Client mainInterface, int socketFd, uint port, ReaderOptions readerOpts) - : mainInterface(kj::mv(mainInterface)), - context(EzRpcContext::getThreadLocal()), - portPromise(kj::Promise(port).fork()), - tasks(*this) { - acceptLoop(context->getLowLevelIoProvider().wrapListenSocketFd(socketFd, DUMMY_FILTER), - readerOpts); - } - - void acceptLoop(kj::Own&& listener, ReaderOptions readerOpts) { - auto ptr = listener.get(); - tasks.add(ptr->accept().then([this, listener=kj::mv(listener), readerOpts](kj::Own&& connection) mutable { - acceptLoop(kj::mv(listener), readerOpts); - - auto server = kj::heap(kj::mv(connection), *this, readerOpts); - - // Arrange to destroy the server context when all references are gone, or when the - // EzRpcServer is destroyed (which will destroy the TaskSet). - tasks.add(server->network.onDisconnect().attach(kj::mv(server))); - })); - } - - Capability::Client restore(AnyPointer::Reader objectId) override { - if (objectId.isNull()) { - return mainInterface; - } else { - auto name = objectId.getAs(); - auto iter = exportMap.find(name); - if (iter == exportMap.end()) { - KJ_FAIL_REQUIRE("Server exports no such capability.", name) { break; } - return nullptr; - } else { - return iter->second.cap; - } - } - } - - void taskFailed(kj::Exception&& exception) override { - kj::throwFatalException(kj::mv(exception)); - } -}; - -EzRpcServer::EzRpcServer(Capability::Client mainInterface, kj::StringPtr bindAddress, - uint defaultPort, ReaderOptions readerOpts) - : impl(kj::heap(kj::mv(mainInterface), bindAddress, defaultPort, readerOpts)) {} - -EzRpcServer::EzRpcServer(Capability::Client mainInterface, struct sockaddr* bindAddress, - uint addrSize, ReaderOptions readerOpts) - : impl(kj::heap(kj::mv(mainInterface), bindAddress, addrSize, readerOpts)) {} - -EzRpcServer::EzRpcServer(Capability::Client mainInterface, int socketFd, uint port, - ReaderOptions readerOpts) - : impl(kj::heap(kj::mv(mainInterface), socketFd, port, readerOpts)) {} - -EzRpcServer::EzRpcServer(kj::StringPtr bindAddress, uint defaultPort, - ReaderOptions readerOpts) - : EzRpcServer(nullptr, bindAddress, defaultPort, readerOpts) {} - -EzRpcServer::EzRpcServer(struct sockaddr* bindAddress, uint addrSize, - ReaderOptions readerOpts) - : EzRpcServer(nullptr, bindAddress, addrSize, readerOpts) {} - -EzRpcServer::EzRpcServer(int socketFd, uint port, ReaderOptions readerOpts) - : EzRpcServer(nullptr, socketFd, port, readerOpts) {} - -EzRpcServer::~EzRpcServer() noexcept(false) {} - -void EzRpcServer::exportCap(kj::StringPtr name, Capability::Client cap) { - Impl::ExportedCap entry(kj::heapString(name), cap); - impl->exportMap[entry.name] = kj::mv(entry); -} - -kj::Promise EzRpcServer::getPort() { - return impl->portPromise.addBranch(); -} - -kj::WaitScope& EzRpcServer::getWaitScope() { - return impl->context->getWaitScope(); -} - -kj::AsyncIoProvider& EzRpcServer::getIoProvider() { - return impl->context->getIoProvider(); -} - -kj::LowLevelAsyncIoProvider& EzRpcServer::getLowLevelIoProvider() { - return impl->context->getLowLevelIoProvider(); -} - -} // namespace capnp diff --git a/c++/src/capnp/ez-rpc.h b/c++/src/capnp/ez-rpc.h deleted file mode 100644 index ef2649239a..0000000000 --- a/c++/src/capnp/ez-rpc.h +++ /dev/null @@ -1,251 +0,0 @@ -// Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors -// Licensed under the MIT License: -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -// THE SOFTWARE. - -#pragma once - -#include "rpc.h" -#include "message.h" - -CAPNP_BEGIN_HEADER - -struct sockaddr; - -namespace kj { class AsyncIoProvider; class LowLevelAsyncIoProvider; } - -namespace capnp { - -class EzRpcContext; - -class EzRpcClient { - // Super-simple interface for setting up a Cap'n Proto RPC client. Example: - // - // # Cap'n Proto schema - // interface Adder { - // add @0 (left :Int32, right :Int32) -> (value :Int32); - // } - // - // // C++ client - // int main() { - // capnp::EzRpcClient client("localhost:3456"); - // Adder::Client adder = client.getMain(); - // auto request = adder.addRequest(); - // request.setLeft(12); - // request.setRight(34); - // auto response = request.send().wait(client.getWaitScope()); - // assert(response.getValue() == 46); - // return 0; - // } - // - // // C++ server - // class AdderImpl final: public Adder::Server { - // public: - // kj::Promise add(AddContext context) override { - // auto params = context.getParams(); - // context.getResults().setValue(params.getLeft() + params.getRight()); - // return kj::READY_NOW; - // } - // }; - // - // int main() { - // capnp::EzRpcServer server(kj::heap(), "*:3456"); - // kj::NEVER_DONE.wait(server.getWaitScope()); - // } - // - // This interface is easy, but it hides a lot of useful features available from the lower-level - // classes: - // - The server can only export a small set of public, singleton capabilities under well-known - // string names. This is fine for transient services where no state needs to be kept between - // connections, but hides the power of Cap'n Proto when it comes to long-lived resources. - // - EzRpcClient/EzRpcServer automatically set up a `kj::EventLoop` and make it current for the - // thread. Only one `kj::EventLoop` can exist per thread, so you cannot use these interfaces - // if you wish to set up your own event loop. (However, you can safely create multiple - // EzRpcClient / EzRpcServer objects in a single thread; they will make sure to make no more - // than one EventLoop.) - // - These classes only support simple two-party connections, not multilateral VatNetworks. - // - These classes only support communication over a raw, unencrypted socket. If you want to - // build on an abstract stream (perhaps one which supports encryption), you must use the - // lower-level interfaces. - // - // Some of these restrictions will probably be lifted in future versions, but some things will - // always require using the low-level interfaces directly. If you are interested in working - // at a lower level, start by looking at these interfaces: - // - `kj::setupAsyncIo()` in `kj/async-io.h`. - // - `RpcSystem` in `capnp/rpc.h`. - // - `TwoPartyVatNetwork` in `capnp/rpc-twoparty.h`. - -public: - explicit EzRpcClient(kj::StringPtr serverAddress, uint defaultPort = 0, - ReaderOptions readerOpts = ReaderOptions()); - // Construct a new EzRpcClient and connect to the given address. The connection is formed in - // the background -- if it fails, calls to capabilities returned by importCap() will fail with an - // appropriate exception. - // - // `defaultPort` is the IP port number to use if `serverAddress` does not include it explicitly. - // If unspecified, the port is required in `serverAddress`. - // - // The address is parsed by `kj::Network` in `kj/async-io.h`. See that interface for more info - // on the address format, but basically it's what you'd expect. - // - // `readerOpts` is the ReaderOptions structure used to read each incoming message on the - // connection. Setting this may be necessary if you need to receive very large individual - // messages or messages. However, it is recommended that you instead think about how to change - // your protocol to send large data blobs in multiple small chunks -- this is much better for - // both security and performance. See `ReaderOptions` in `message.h` for more details. - - EzRpcClient(const struct sockaddr* serverAddress, uint addrSize, - ReaderOptions readerOpts = ReaderOptions()); - // Like the above constructor, but connects to an already-resolved socket address. Any address - // format supported by `kj::Network` in `kj/async-io.h` is accepted. - - explicit EzRpcClient(int socketFd, ReaderOptions readerOpts = ReaderOptions()); - // Create a client on top of an already-connected socket. - // `readerOpts` acts as in the first constructor. - - ~EzRpcClient() noexcept(false); - - template - typename Type::Client getMain(); - Capability::Client getMain(); - // Get the server's main (aka "bootstrap") interface. - - template - typename Type::Client importCap(kj::StringPtr name) CAPNP_DEPRECATED( - "Change your server to export a main interface, then use getMain() instead."); - Capability::Client importCap(kj::StringPtr name) CAPNP_DEPRECATED( - "Change your server to export a main interface, then use getMain() instead."); - // ** DEPRECATED ** - // - // Ask the sever for the capability with the given name. You may specify a type to automatically - // down-cast to that type. It is up to you to specify the correct expected type. - // - // Named interfaces are deprecated. The new preferred usage pattern is for the server to export - // a "main" interface which itself has methods for getting any other interfaces. - - kj::WaitScope& getWaitScope(); - // Get the `WaitScope` for the client's `EventLoop`, which allows you to synchronously wait on - // promises. - - kj::AsyncIoProvider& getIoProvider(); - // Get the underlying AsyncIoProvider set up by the RPC system. This is useful if you want - // to do some non-RPC I/O in asynchronous fashion. - - kj::LowLevelAsyncIoProvider& getLowLevelIoProvider(); - // Get the underlying LowLevelAsyncIoProvider set up by the RPC system. This is useful if you - // want to do some non-RPC I/O in asynchronous fashion. - -private: - struct Impl; - kj::Own impl; -}; - -class EzRpcServer { - // The server counterpart to `EzRpcClient`. See `EzRpcClient` for an example. - -public: - explicit EzRpcServer(Capability::Client mainInterface, kj::StringPtr bindAddress, - uint defaultPort = 0, ReaderOptions readerOpts = ReaderOptions()); - // Construct a new `EzRpcServer` that binds to the given address. An address of "*" means to - // bind to all local addresses. - // - // `defaultPort` is the IP port number to use if `serverAddress` does not include it explicitly. - // If unspecified, a port is chosen automatically, and you must call getPort() to find out what - // it is. - // - // The address is parsed by `kj::Network` in `kj/async-io.h`. See that interface for more info - // on the address format, but basically it's what you'd expect. - // - // The server might not begin listening immediately, especially if `bindAddress` needs to be - // resolved. If you need to wait until the server is definitely up, wait on the promise returned - // by `getPort()`. - // - // `readerOpts` is the ReaderOptions structure used to read each incoming message on the - // connection. Setting this may be necessary if you need to receive very large individual - // messages or messages. However, it is recommended that you instead think about how to change - // your protocol to send large data blobs in multiple small chunks -- this is much better for - // both security and performance. See `ReaderOptions` in `message.h` for more details. - - EzRpcServer(Capability::Client mainInterface, struct sockaddr* bindAddress, uint addrSize, - ReaderOptions readerOpts = ReaderOptions()); - // Like the above constructor, but binds to an already-resolved socket address. Any address - // format supported by `kj::Network` in `kj/async-io.h` is accepted. - - EzRpcServer(Capability::Client mainInterface, int socketFd, uint port, - ReaderOptions readerOpts = ReaderOptions()); - // Create a server on top of an already-listening socket (i.e. one on which accept() may be - // called). `port` is returned by `getPort()` -- it serves no other purpose. - // `readerOpts` acts as in the other two above constructors. - - explicit EzRpcServer(kj::StringPtr bindAddress, uint defaultPort = 0, - ReaderOptions readerOpts = ReaderOptions()) - CAPNP_DEPRECATED("Please specify a main interface for your server."); - EzRpcServer(struct sockaddr* bindAddress, uint addrSize, - ReaderOptions readerOpts = ReaderOptions()) - CAPNP_DEPRECATED("Please specify a main interface for your server."); - EzRpcServer(int socketFd, uint port, ReaderOptions readerOpts = ReaderOptions()) - CAPNP_DEPRECATED("Please specify a main interface for your server."); - - ~EzRpcServer() noexcept(false); - - void exportCap(kj::StringPtr name, Capability::Client cap); - // Export a capability publicly under the given name, so that clients can import it. - // - // Keep in mind that you can implicitly convert `kj::Own&&` to - // `Capability::Client`, so it's typical to pass something like - // `kj::heap()` as the second parameter. - - kj::Promise getPort(); - // Get the IP port number on which this server is listening. This promise won't resolve until - // the server is actually listening. If the address was not an IP address (e.g. it was a Unix - // domain socket) then getPort() resolves to zero. - - kj::WaitScope& getWaitScope(); - // Get the `WaitScope` for the client's `EventLoop`, which allows you to synchronously wait on - // promises. - - kj::AsyncIoProvider& getIoProvider(); - // Get the underlying AsyncIoProvider set up by the RPC system. This is useful if you want - // to do some non-RPC I/O in asynchronous fashion. - - kj::LowLevelAsyncIoProvider& getLowLevelIoProvider(); - // Get the underlying LowLevelAsyncIoProvider set up by the RPC system. This is useful if you - // want to do some non-RPC I/O in asynchronous fashion. - -private: - struct Impl; - kj::Own impl; -}; - -// ======================================================================================= -// inline implementation details - -template -inline typename Type::Client EzRpcClient::getMain() { - return getMain().castAs(); -} - -template -inline typename Type::Client EzRpcClient::importCap(kj::StringPtr name) { - return importCap(name).castAs(); -} - -} // namespace capnp - -CAPNP_END_HEADER diff --git a/c++/src/capnp/rpc-test.c++ b/c++/src/capnp/rpc-test.c++ index b8a6dccc11..06963a13f1 100644 --- a/c++/src/capnp/rpc-test.c++ +++ b/c++/src/capnp/rpc-test.c++ @@ -342,6 +342,9 @@ public: } } kj::Promise shutdown() override { + KJ_IF_SOME(e, network.shutdownExceptionToThrow) { + return kj::cp(e); + } KJ_IF_SOME(p, partner) { auto paf = kj::newPromiseAndFulfiller(); p.fulfillOnEnd = kj::mv(paf.fulfiller); @@ -410,11 +413,16 @@ public: } } + void setShutdownExceptionToThrow(kj::Exception&& e) { + shutdownExceptionToThrow = kj::mv(e); + } + private: TestNetwork& network; kj::StringPtr self; uint sent = 0; uint received = 0; + kj::Maybe shutdownExceptionToThrow = kj::none; std::map> connections; std::queue>>> fulfillerQueue; @@ -1303,6 +1311,46 @@ TEST(Rpc, Abort) { EXPECT_TRUE(conn->receiveIncomingMessage().wait(context.waitScope) == nullptr); } +KJ_TEST("handles exceptions thrown during disconnect") { + // This is similar to the earlier "abort" test, but throws an exception on + // connection shutdown, to exercise the RpcConnectionState error handler. + + TestContext context; + + MallocMessageBuilder refMessage(128); + auto hostId = refMessage.initRoot(); + hostId.setHost("server"); + + context.serverNetwork.setShutdownExceptionToThrow( + KJ_EXCEPTION(FAILED, "a_disconnect_exception")); + + auto conn = KJ_ASSERT_NONNULL(context.clientNetwork.connect(hostId)); + + { + // Send an invalid message (Return to non-existent question). + auto msg = conn->newOutgoingMessage(128); + auto body = msg->getBody().initAs().initReturn(); + body.setAnswerId(1234); + body.setCanceled(); + msg->send(); + } + + { + // The internal exception handler of RpcSystemBase logs exceptions thrown + // during disconnect, which the test framework will flag as a failure if we + // don't explicitly tell it to expect the logged output. + KJ_EXPECT_LOG(ERROR, "a_disconnect_exception"); + + // Force outstanding promises to completion. The server should detect the + // invalid message and disconnect, which should cause the connection's + // disconnect() to throw an exception that will then be handled by a + // RpcConnectionState handler. Since other state instances were freed prior + // to the handler invocation, this caused failures in earlier versions of + // the code when run under asan. + kj::Promise(kj::NEVER_DONE).poll(context.waitScope); + } +} + KJ_TEST("loopback bootstrap()") { int callCount = 0; test::TestInterface::Client bootstrap = kj::heap(callCount); diff --git a/c++/src/capnp/rpc-twoparty.h b/c++/src/capnp/rpc-twoparty.h index a2a1cec67c..53db974c91 100644 --- a/c++/src/capnp/rpc-twoparty.h +++ b/c++/src/capnp/rpc-twoparty.h @@ -47,9 +47,6 @@ class TwoPartyVatNetwork: public TwoPartyVatNetworkBase, private RpcFlowController::WindowGetter { // A `VatNetwork` that consists of exactly two parties communicating over an arbitrary byte // stream. This is used to implement the common case of a client/server network. - // - // See `ez-rpc.h` for a simple interface for setting up two-party clients and servers. - // Use `TwoPartyVatNetwork` only if you need the advanced features. public: TwoPartyVatNetwork(MessageStream& msgStream, diff --git a/c++/src/capnp/rpc.c++ b/c++/src/capnp/rpc.c++ index 84710fc7e8..02775a31e3 100644 --- a/c++/src/capnp/rpc.c++ +++ b/c++/src/capnp/rpc.c++ @@ -503,7 +503,8 @@ public: auto shutdownPromise = dyingConnection->shutdown() .attach(kj::mv(dyingConnection)) .then([]() -> kj::Promise { return kj::READY_NOW; }, - [this, origException = kj::mv(exception)](kj::Exception&& shutdownException) -> kj::Promise { + [self = kj::addRef(*this), origException = kj::mv(exception)]( + kj::Exception&& shutdownException) -> kj::Promise { // Don't report disconnects as an error. if (shutdownException.getType() == kj::Exception::Type::DISCONNECTED) { return kj::READY_NOW; @@ -516,7 +517,7 @@ public: } // We are shutting down after receive error, ignore shutdown exception since underlying // transport is probably broken. - if (receiveIncomingMessageError) { + if (self->receiveIncomingMessageError) { return kj::READY_NOW; } return kj::mv(shutdownException); @@ -713,7 +714,7 @@ private: bool gotReturnForHighQuestionId = false; // Becomes true if we ever get a `Return` message for a high question ID (with top bit set), // which we use in cases where we've hinted to the peer that we don't want a `Return`. If the - // peer sends us one anyway then it seemingly doesn't not implement our hints. We need to stop + // peer sends us one anyway then it seemingly does not implement our hints. We need to stop // using the hints in this case before the high question ID space wraps around since otherwise // we might reuse an ID that the peer thinks is still in use. @@ -1932,7 +1933,7 @@ private: question.paramExports = kj::mv(exports); question.isTailCall = isTailCall; - // Make the QuentionRef and result promise. + // Make the QuestionRef and result promise. SendInternalResult result; auto paf = kj::newPromiseAndFulfiller>>(); result.questionRef = kj::refcounted( @@ -2023,7 +2024,7 @@ private: question.paramExports = kj::mv(exports); question.isTailCall = false; - // Make the QuentionRef and result promise. + // Make the QuestionRef and result promise. auto questionRef = kj::refcounted(*connectionState, questionId, kj::none); question.selfRef = *questionRef; @@ -2248,7 +2249,7 @@ private: struct Resolution { kj::Own returnedCap; - // The capabiilty that appeared in the response message in this slot. + // The capability that appeared in the response message in this slot. kj::Own unwrapped; // Exactly what `getInnermostClient(returnedCap)` produced at the time that the return @@ -2462,7 +2463,7 @@ private: if (!responseImpl.hasCapabilities()) { returnMessage.setNoFinishNeeded(true); - // Tell ourselves that a finsih was already received, so that `cleanupAnswerTable()` + // Tell ourselves that a finish was already received, so that `cleanupAnswerTable()` // removes the answer table entry. receivedFinish = true; @@ -2695,7 +2696,7 @@ private: // Cancellation state ---------------------------------- bool receivedFinish = false; - // True if a `Finish` message has been recevied OR we sent a `Return` with `noFinishNedeed`. + // True if a `Finish` message has been received OR we sent a `Return` with `noFinishNeeded`. // In either case, it is our responsibility to clean up the answer table. kj::UnwindDetector unwindDetector; @@ -3450,7 +3451,7 @@ private: // Carol returns a capability from this call that points all the way back though Bob to // Alice. When this return capability passes through Bob, Bob will resolve the previous // promise-pipeline capability to it. However, Bob has to send a Disembargo to Carol before - // completing this resolution. In the meantime, though, Bob returns the final repsonse to + // completing this resolution. In the meantime, though, Bob returns the final response to // Alice. Alice then *also* sends a Disembargo to Bob. The Alice -> Bob Disembargo might // arrive at Bob before the Bob -> Carol Disembargo has resolved, in which case the // Disembargo is delivered to a promise capability. diff --git a/c++/src/capnp/rpc.h b/c++/src/capnp/rpc.h index b153307f99..1777c75123 100644 --- a/c++/src/capnp/rpc.h +++ b/c++/src/capnp/rpc.h @@ -68,9 +68,6 @@ class RpcSystem: public _::RpcSystemBase { // // See `makeRpcServer()` and `makeRpcClient()` below for convenient syntax for setting up an // `RpcSystem` given a `VatNetwork`. - // - // See `ez-rpc.h` for an even simpler interface for setting up RPC in a typical two-party - // client/server scenario. public: template makeRpcServer( // MyMainInterface::Client bootstrap = makeMain(); // auto server = makeRpcServer(network, bootstrap); // kj::NEVER_DONE.wait(waitScope); // run forever -// -// See also ez-rpc.h, which has simpler instructions for the common case of a two-party -// client-server RPC connection. template @@ -221,9 +215,6 @@ RpcSystem makeRpcClient( // MyCapability::Client cap = client.restore(hostId, objId).castAs(); // auto response = cap.fooRequest().send().wait(waitScope); // handleMyResponse(response); -// -// See also ez-rpc.h, which has simpler instructions for the common case of a two-party -// client-server RPC connection. template class SturdyRefRestorer: public _::SturdyRefRestorerBase { @@ -372,8 +363,7 @@ class VatNetwork: public _::VatNetworkBase { // to manage object references and make method calls. // // The most common implementation of VatNetwork is TwoPartyVatNetwork (rpc-twoparty.h). Most - // simple client-server apps will want to use it. (You may even want to use the EZ RPC - // interfaces in `ez-rpc.h` and avoid all of this.) + // simple client-server apps will want to use it. // // TODO(someday): Provide a standard implementation for the public internet. diff --git a/c++/src/capnp/serialize-async.c++ b/c++/src/capnp/serialize-async.c++ index 436b6c9afe..6fd88e393c 100644 --- a/c++/src/capnp/serialize-async.c++ +++ b/c++/src/capnp/serialize-async.c++ @@ -557,7 +557,7 @@ kj::Promise MessageStream::readMessage( // ======================================================================================= -class BufferedMessageStream::MessageReaderImpl: public FlatArrayMessageReader { +class BufferedMessageStream::MessageReaderImpl final : public FlatArrayMessageReader { public: MessageReaderImpl(BufferedMessageStream& parent, kj::ArrayPtr data, ReaderOptions options) diff --git a/c++/src/kj/BUILD.bazel b/c++/src/kj/BUILD.bazel index e492992c61..de89d41eb6 100644 --- a/c++/src/kj/BUILD.bazel +++ b/c++/src/kj/BUILD.bazel @@ -13,8 +13,7 @@ cc_library( "encoding.c++", "exception.c++", "filesystem.c++", - "filesystem-disk-unix.c++", - "filesystem-disk-win32.c++", + "glob-filter.c++", "hash.c++", "io.c++", "list.c++", @@ -27,11 +26,13 @@ cc_library( "string.c++", "string-tree.c++", "table.c++", - "test-helpers.c++", "thread.c++", "time.c++", "units.c++", - ], + ] + select({ + "@platforms//os:windows": ["filesystem-disk-win32.c++"], + "//conditions:default": ["filesystem-disk-unix.c++"], + }), hdrs = [ "arena.h", "array.h", @@ -42,6 +43,7 @@ cc_library( "exception.h", "filesystem.h", "function.h", + "glob-filter.h", "hash.h", "io.h", "list.h", @@ -61,7 +63,6 @@ cc_library( "table.h", "test.h", "thread.h", - "threadlocal.h", "time.h", "tuple.h", "units.h", @@ -87,12 +88,17 @@ cc_library( srcs = [ "async.c++", "async-io.c++", - "async-io-unix.c++", - "async-io-win32.c++", - "async-unix.c++", - "async-win32.c++", "timer.c++", - ], + ] + select({ + "@platforms//os:windows": [ + "async-io-win32.c++", + "async-win32.c++", + ], + "//conditions:default": [ + "async-io-unix.c++", + "async-unix.c++", + ], + }), hdrs = [ "async.h", "async-inl.h", @@ -100,10 +106,11 @@ cc_library( "async-io-internal.h", "async-prelude.h", "async-queue.h", - "async-unix.h", - "async-win32.h", "timer.h", - ], + ] + select({ + "@platforms//os:windows": ["async-win32.h"], + "//conditions:default": ["async-unix.h"], + }), include_prefix = "kj", linkopts = select({ "@platforms//os:windows": [ @@ -120,6 +127,10 @@ cc_library( name = "kj-test", srcs = [ "test.c++", + "test-helpers.c++", + ], + hdrs = [ + "test.h", ], include_prefix = "kj", visibility = ["//visibility:public"], @@ -164,7 +175,7 @@ cc_library( "string-tree-test.c++", "table-test.c++", "test-test.c++", - "threadlocal-test.c++", + "glob-filter-test.c++", "thread-test.c++", "time-test.c++", "tuple-test.c++", @@ -191,25 +202,25 @@ cc_library( cc_test( name = "filesystem-disk-generic-test", srcs = ["filesystem-disk-generic-test.c++"], + target_compatible_with = [ + "@platforms//os:linux", + ], deps = [ ":filesystem-disk-test-base", ":kj-test", ], - target_compatible_with = [ - "@platforms//os:linux", - ], ) cc_test( name = "filesystem-disk-old-kernel-test", srcs = ["filesystem-disk-old-kernel-test.c++"], + target_compatible_with = [ + "@platforms//os:linux", + ], deps = [ ":filesystem-disk-test-base", ":kj-test", ], - target_compatible_with = [ - "@platforms//os:linux", - ], ) cc_test( @@ -246,12 +257,14 @@ cc_test( cc_test( name = "exception-override-symbolizer-test", srcs = ["exception-override-symbolizer-test.c++"], + linkstatic = True, + target_compatible_with = select({ + "@platforms//os:linux": [], + "@platforms//os:macos": [], + "//conditions:default": ["@platforms//:incompatible"], + }), deps = [ ":kj", ":kj-test", ], - linkstatic = True, - target_compatible_with = [ - "@platforms//os:linux", - ], ) diff --git a/c++/src/kj/CMakeLists.txt b/c++/src/kj/CMakeLists.txt index c84f550d16..9cc7a9cac0 100644 --- a/c++/src/kj/CMakeLists.txt +++ b/c++/src/kj/CMakeLists.txt @@ -13,6 +13,7 @@ set(kj_sources_lite mutex.c++ string.c++ source-location.c++ + glob-filter.c++ hash.c++ table.c++ thread.c++ @@ -49,6 +50,7 @@ set(kj_headers string.h string-tree.h source-location.h + glob-filter.h hash.h table.h map.h @@ -62,7 +64,6 @@ set(kj_headers function.h mutex.h thread.h - threadlocal.h filesystem.h time.h main.h @@ -252,8 +253,8 @@ if(BUILD_TESTING) io-test.c++ mutex-test.c++ time-test.c++ - threadlocal-test.c++ test-test.c++ + glob-filter-test.c++ std/iostream-test.c++ ) # TODO: Link with librt on Solaris for sched_yield diff --git a/c++/src/kj/array-test.c++ b/c++/src/kj/array-test.c++ index 8693d2e4b0..50fd727f83 100644 --- a/c++/src/kj/array-test.c++ +++ b/c++/src/kj/array-test.c++ @@ -24,6 +24,7 @@ #include #include #include +#include namespace kj { namespace { @@ -509,5 +510,18 @@ TEST(Array, AttachFromArrayPtr) { KJ_EXPECT(destroyed1 == 3, destroyed1); } +struct Std { + template + static std::span from(Array* arr) { + return std::span(arr->begin(), arr->size()); + } +}; + +KJ_TEST("Array::as") { + kj::Array arr = kj::arr(1, 2, 4); + std::span stdArr = arr.as(); + KJ_EXPECT(stdArr.size() == 3); +} + } // namespace } // namespace kj diff --git a/c++/src/kj/array.h b/c++/src/kj/array.h index 52e83bd696..4db3934de8 100644 --- a/c++/src/kj/array.h +++ b/c++/src/kj/array.h @@ -237,6 +237,11 @@ class Array { Array attach(Attachments&&... attachments) KJ_WARN_UNUSED_RESULT; // Like Own::attach(), but attaches to an Array. + template + inline auto as() { return U::from(this); } + // Syntax sugar for invoking U::from. + // Used to chain conversion calls rather than wrap with function. + private: T* ptr; size_t size_; diff --git a/c++/src/kj/async-inl.h b/c++/src/kj/async-inl.h index 8032439059..b8efd04674 100644 --- a/c++/src/kj/async-inl.h +++ b/c++/src/kj/async-inl.h @@ -2062,6 +2062,7 @@ template concept NoWaitScope = !isSameType, WaitScope>(); // Define a Concept to use in our `coroutine_traits` specialization to validate allowable coroutine // parameter types. +// TODO(cleanup): This can be removed by adding KJ_DISALLOW_AS_COROUTINE_PARAM to WaitScope. } // namespace kj::_ @@ -2073,6 +2074,9 @@ template struct coroutine_traits, Args...> { // `Args...` are the coroutine's parameter types. + static_assert((!kj::_::isDisallowedInCoroutine() && ...), + "Disallowed type in coroutine"); + static_assert((::kj::_::NoWaitScope && ...), "Coroutines are not allowed to accept `WaitScope` parameters."); // Coroutines should never have access to a WaitScope. If they could, coroutines could be both a diff --git a/c++/src/kj/async-io-internal.h b/c++/src/kj/async-io-internal.h index d030ad9577..bc0df7c30c 100644 --- a/c++/src/kj/async-io-internal.h +++ b/c++/src/kj/async-io-internal.h @@ -58,6 +58,7 @@ class NetworkFilter: public LowLevelAsyncIoProvider::NetworkFilter { Vector denyCidrs; bool allowUnix; bool allowAbstractUnix; + bool allowVsock; bool allowPublic = false; bool allowNetwork = false; diff --git a/c++/src/kj/async-io-test.c++ b/c++/src/kj/async-io-test.c++ index 3b567fe1bc..0107705cfe 100644 --- a/c++/src/kj/async-io-test.c++ +++ b/c++/src/kj/async-io-test.c++ @@ -367,6 +367,49 @@ TEST(AsyncIo, AncillaryMessageHandler) { } #endif +#if __linux__ +TEST(AsyncIo, VmSocket) { + auto ioContext = setupAsyncIo(); + auto& network = ioContext.provider->getNetwork(); + + Own listener; + Own server; + Own client; + + char receiveBuffer[4]; + + int port = ((getpid() >> 10) % (0xffff - 1024 + 1)) + 1024; // wrap to dynamic port range + auto ready = newPromiseAndFulfiller(); + + ready.promise.then([&]() { + return network.parseAddress(kj::str("vsock:2:", port)); + }).then([&](Own&& addr) { + auto promise = addr->connect(); + return promise.then([&,addr=kj::mv(addr)](auto result) mutable { + client = kj::mv(result); + return client->write("foo", 3); + }); + }).detach([](kj::Exception&& exception) { + KJ_FAIL_EXPECT(exception); + }); + + kj::String result = network.parseAddress(kj::str("vsock:2:", port)) + .then([&](Own&& result) { + listener = result->listen(); + ready.fulfiller->fulfill(); + return listener->accept(); + }).then([&](auto result) { + server = kj::mv(result); + return server->tryRead(receiveBuffer, 3, 4); + }).then([&](size_t n) { + EXPECT_EQ(3u, n); + return heapString(receiveBuffer, n); + }).wait(ioContext.waitScope); + + EXPECT_EQ("foo", result); +} +#endif + String tryParse(WaitScope& waitScope, Network& network, StringPtr text, uint portHint = 0) { return network.parseAddress(text, portHint).wait(waitScope)->toString(); } @@ -409,6 +452,10 @@ TEST(AsyncIo, AddressParsing) { EXPECT_EQ("unix-abstract:foo/bar/baz", tryParse(w, network, "unix-abstract:foo/bar/baz")); #endif +#if __linux__ + EXPECT_EQ("vsock:4294967295:1234", tryParse(w, network, "vsock:-1:1234")); +#endif + // We can parse services by name... // // For some reason, Android and some various Linux distros do not support service names. @@ -1265,6 +1312,9 @@ KJ_TEST("Network::restrictPeers()") { #if !_WIN32 KJ_EXPECT_THROW_MESSAGE("restrictPeers", tryParse(w, *restrictedNetwork, "unix:/foo")); #endif +#if __linux__ + KJ_EXPECT_THROW_MESSAGE("restrictPeers", tryParse(w, *restrictedNetwork, "vsock:-1:80")); +#endif auto addr = restrictedNetwork->parseAddress("127.0.0.1").wait(w); @@ -3080,6 +3130,16 @@ KJ_TEST("AggregateConnectionReceiver") { acceptPromise3.wait(ws); } +KJ_TEST("AggregateConnectionReceiver empty") { + auto aggregate = newAggregateConnectionReceiver({}); + KJ_EXPECT(aggregate->getPort() == 0); + + int value; + uint length = sizeof(value); + + KJ_EXPECT_THROW_MESSAGE("receivers.size() > 0", aggregate->getsockopt(0, 0, &value, &length)); +} + // ======================================================================================= // Tests for optimized pumpTo() between OS handles. Note that this is only even optimized on // some OSes (only Linux as of this writing), but the behavior should still be the same on all diff --git a/c++/src/kj/async-io-unix.c++ b/c++/src/kj/async-io-unix.c++ index 91a85546ec..27e3dbfd9b 100644 --- a/c++/src/kj/async-io-unix.c++ +++ b/c++/src/kj/async-io-unix.c++ @@ -60,6 +60,7 @@ #if __linux__ #include +#include #endif #if !defined(SO_PEERCRED) && defined(LOCAL_PEERCRED) @@ -985,6 +986,11 @@ public: return str("unix:", path); } } +#if __linux__ + case AF_VSOCK: { + return str("vsock:", addr.vsock.svm_cid, ":", addr.vsock.svm_port); + } +#endif default: return str("(unknown address family ", addr.generic.sa_family, ")"); } @@ -1043,6 +1049,33 @@ public: return array.finish(); } +#if __linux__ + if (str.startsWith("vsock:")) { + StringPtr path = str.slice(strlen("vsock:")); + + char* endptr; + unsigned int cid = strtoul(path.cStr(), &endptr, 0); + KJ_REQUIRE(*endptr == ':', "missing vsock port"); + unsigned int port = strtoul(endptr + 1, &endptr, 0); + KJ_REQUIRE(*endptr == '\0', "invalid vsock addr"); + + memset(&result.addr.vsock, 0, sizeof(result.addr.vsock)); + result.addr.vsock.svm_family = AF_VSOCK; + result.addr.vsock.svm_cid = cid; + result.addr.vsock.svm_port = port; + result.addrlen = sizeof(struct sockaddr_vm); + + if (!result.parseAllowedBy(filter)) { + KJ_FAIL_REQUIRE("VM sockets blocked by restrictPeers()"); + return Array(); + } + + auto array = kj::heapArrayBuilder(1); + array.add(result); + return array.finish(); + } +#endif + // Try to separate the address and port. ArrayPtr addrPart; Maybe portPart; @@ -1195,6 +1228,9 @@ private: struct sockaddr_in inet4; struct sockaddr_in6 inet6; struct sockaddr_un unixDomain; +#if __linux__ + struct sockaddr_vm vsock; +#endif struct sockaddr_storage storage; } addr; @@ -1474,10 +1510,8 @@ public: class LowLevelAsyncIoProviderImpl final: public LowLevelAsyncIoProvider { public: - LowLevelAsyncIoProviderImpl() - : eventPort(), eventLoop(eventPort), waitScope(eventLoop) {} - - inline WaitScope& getWaitScope() { return waitScope; } + LowLevelAsyncIoProviderImpl(UnixEventPort& eventPort) + : eventPort(eventPort) {} Own wrapInputFd(int fd, uint flags = 0) override { return heap(eventPort, fd, flags, UnixEventPort::FdObserver::OBSERVE_READ); @@ -1539,12 +1573,8 @@ public: Timer& getTimer() override { return eventPort.getTimer(); } - UnixEventPort& getEventPort() { return eventPort; } - private: - UnixEventPort eventPort; - EventLoop eventLoop; - WaitScope waitScope; + UnixEventPort& eventPort; }; // ======================================================================================= @@ -2031,10 +2061,13 @@ public: auto pipe = lowLevel.wrapSocketFd(fds[0], NEW_FD_FLAGS); auto thread = heap([threadFd,startFunc=kj::mv(startFunc)]() mutable { - LowLevelAsyncIoProviderImpl lowLevel; + UnixEventPort eventPort; + EventLoop eventLoop(eventPort); + WaitScope waitScope(eventLoop); + LowLevelAsyncIoProviderImpl lowLevel(eventPort); auto stream = lowLevel.wrapSocketFd(threadFd, NEW_FD_FLAGS); AsyncIoProviderImpl ioProvider(lowLevel); - startFunc(ioProvider, *stream, lowLevel.getWaitScope()); + startFunc(ioProvider, *stream, waitScope); }); return { kj::mv(thread), kj::mv(pipe) }; @@ -2053,11 +2086,32 @@ Own newAsyncIoProvider(LowLevelAsyncIoProvider& lowLevel) { return kj::heap(lowLevel); } +Own newLowLevelAsyncIoProvider(UnixEventPort& eventPort) { + return kj::heap(eventPort); +} + AsyncIoContext setupAsyncIo() { - auto lowLevel = heap(); + struct BasicContext { + UnixEventPort eventPort; + EventLoop eventLoop; + WaitScope waitScope; + + BasicContext(): eventLoop(eventPort), waitScope(eventLoop) {} + }; + + auto basicContext = heap(); + auto lowLevel = heap(basicContext->eventPort); auto ioProvider = kj::heap(*lowLevel); - auto& waitScope = lowLevel->getWaitScope(); - auto& eventPort = lowLevel->getEventPort(); + auto& waitScope = basicContext->waitScope; + auto& eventPort = basicContext->eventPort; + + // Historically, `LowLevelAsyncIoProviderImpl` contained the stuff that `BasicContext` now + // contains. However, this made it impossible to create more elaborate EventLoop arrangements + // while still using the default LLAIOP implementation. For backwards-compatibility, + // `setupAsyncIo()` still attaches this context to the LLAIOP, but it's now possible to construct + // these objects directly and LLAIOP on top. + lowLevel = lowLevel.attach(kj::mv(basicContext)); + return { kj::mv(lowLevel), kj::mv(ioProvider), waitScope, eventPort }; } diff --git a/c++/src/kj/async-io-win32.c++ b/c++/src/kj/async-io-win32.c++ index 87a70aea5b..9333670885 100644 --- a/c++/src/kj/async-io-win32.c++ +++ b/c++/src/kj/async-io-win32.c++ @@ -919,10 +919,7 @@ public: class LowLevelAsyncIoProviderImpl final: public LowLevelAsyncIoProvider { public: - LowLevelAsyncIoProviderImpl() - : eventLoop(eventPort), waitScope(eventLoop) {} - - inline WaitScope& getWaitScope() { return waitScope; } + LowLevelAsyncIoProviderImpl(Win32EventPort& eventPort): eventPort(eventPort) {} Own wrapInputFd(SOCKET fd, uint flags = 0) override { return heap(eventPort, fd, flags); @@ -952,12 +949,8 @@ public: Timer& getTimer() override { return eventPort.getTimer(); } - Win32EventPort& getEventPort() { return eventPort; } - private: - Win32IocpEventPort eventPort; - EventLoop eventLoop; - WaitScope waitScope; + Win32EventPort& eventPort; }; // ======================================================================================= @@ -1153,10 +1146,13 @@ public: auto pipe = lowLevel.wrapSocketFd(fds[0], NEW_FD_FLAGS); auto thread = heap([threadFd,startFunc=kj::mv(startFunc)]() mutable { - LowLevelAsyncIoProviderImpl lowLevel; + Win32IocpEventPort eventPort; + EventLoop eventLoop(eventPort); + WaitScope waitScope(eventLoop); + LowLevelAsyncIoProviderImpl lowLevel(eventPort); auto stream = lowLevel.wrapSocketFd(threadFd, NEW_FD_FLAGS); AsyncIoProviderImpl ioProvider(lowLevel); - startFunc(ioProvider, *stream, lowLevel.getWaitScope()); + startFunc(ioProvider, *stream, waitScope); }); return { kj::mv(thread), kj::mv(pipe) }; @@ -1175,13 +1171,34 @@ Own newAsyncIoProvider(LowLevelAsyncIoProvider& lowLevel) { return kj::heap(lowLevel); } +Own newLowLevelAsyncIoProvider(Win32EventPort& eventPort) { + return kj::heap(eventPort); +} + AsyncIoContext setupAsyncIo() { _::initWinsockOnce(); - auto lowLevel = heap(); + struct BasicContext { + Win32IocpEventPort eventPort; + EventLoop eventLoop; + WaitScope waitScope; + + BasicContext(): eventLoop(eventPort), waitScope(eventLoop) {} + }; + + auto basicContext = heap(); + auto lowLevel = heap(basicContext->eventPort); auto ioProvider = kj::heap(*lowLevel); - auto& waitScope = lowLevel->getWaitScope(); - auto& eventPort = lowLevel->getEventPort(); + auto& waitScope = basicContext->waitScope; + auto& eventPort = basicContext->eventPort; + + // Historically, `LowLevelAsyncIoProviderImpl` contained the stuff that `BasicContext` now + // contains. However, this made it impossible to create more elaborate EventLoop arrangements + // while still using the default LLAIOP implementation. For backwards-compatibility, + // `setupAsyncIo()` still attaches this context to the LLAIOP, but it's now possible to construct + // these objects directly and LLAIOP on top. + lowLevel = lowLevel.attach(kj::mv(basicContext)); + return { kj::mv(lowLevel), kj::mv(ioProvider), waitScope, eventPort }; } diff --git a/c++/src/kj/async-io.c++ b/c++/src/kj/async-io.c++ index 45140abfa1..c8e36787e6 100644 --- a/c++/src/kj/async-io.c++ +++ b/c++/src/kj/async-io.c++ @@ -80,6 +80,28 @@ Maybe> AsyncInputStream::tryTee(uint64_t) { return kj::none; } +kj::Promise NullStream::tryRead(void* buffer, size_t minBytes, size_t maxBytes) { + return kj::constPromise(); +} +kj::Maybe NullStream::tryGetLength() { + return uint64_t(0); +} +kj::Promise NullStream::pumpTo(kj::AsyncOutputStream& output, uint64_t amount) { + return kj::constPromise(); +} + +kj::Promise NullStream::write(const void* buffer, size_t size) { + return kj::READY_NOW; +} +kj::Promise NullStream::write(kj::ArrayPtr> pieces) { + return kj::READY_NOW; +} +kj::Promise NullStream::whenWriteDisconnected() { + return kj::NEVER_DONE; +} + +void NullStream::shutdownWrite() {} + namespace { class AsyncPump { @@ -401,19 +423,35 @@ private: } template - static auto teeExceptionVoid(F& fulfiller) { + static auto teeExceptionVoid(F& fulfiller, Canceler& canceler) { // Returns a functor that can be passed as the second parameter to .then() to propagate the // exception to a given fulfiller. The functor's return type is void. - return [&fulfiller](kj::Exception&& e) { + // + // All use cases of this helper below are also wrapped in `canceler.wrap()`, and fulfilling + // `fulfiller` may cause the canceler to be canceled. It's possible the canceler will be + // canceled before the exception even gets a chance to propagate out of the wrapped promise, + // which would have the effet of replacing the original exception with a non-useful + // "operation canceled" exception. To avoid this, we must release the canceler before + // fulfilling the fulfiller. + return [&fulfiller, &canceler](kj::Exception&& e) { + canceler.release(); fulfiller.reject(kj::cp(e)); kj::throwRecoverableException(kj::mv(e)); }; } template - static auto teeExceptionSize(F& fulfiller) { + static auto teeExceptionSize(F& fulfiller, Canceler& canceler) { // Returns a functor that can be passed as the second parameter to .then() to propagate the // exception to a given fulfiller. The functor's return type is size_t. - return [&fulfiller](kj::Exception&& e) -> size_t { + // + // All use cases of this helper below are also wrapped in `canceler.wrap()`, and fulfilling + // `fulfiller` may cause the canceler to be canceled. It's possible the canceler will be + // canceled before the exception even gets a chance to propagate out of the wrapped promise, + // which would have the effet of replacing the original exception with a non-useful + // "operation canceled" exception. To avoid this, we must release the canceler before + // fulfilling the fulfiller. + return [&fulfiller, &canceler](kj::Exception&& e) -> size_t { + canceler.release(); fulfiller.reject(kj::cp(e)); kj::throwRecoverableException(kj::mv(e)); return 0; @@ -576,7 +614,7 @@ private: writeBuffer = writeBuffer.slice(amount, writeBuffer.size()); // We pumped the full amount, so we're done pumping. return amount; - }, teeExceptionSize(fulfiller))); + }, teeExceptionSize(fulfiller, canceler))); } // First piece doesn't cover the whole pump. Figure out how many more pieces to add. @@ -630,7 +668,7 @@ private: morePieces = newMorePieces; canceler.release(); return amount; - }, teeExceptionSize(fulfiller))); + }, teeExceptionSize(fulfiller, canceler))); } } @@ -807,7 +845,7 @@ private: // Completed entire pumpTo amount. KJ_ASSERT(actual == amount2); return amount2; - }, teeExceptionSize(fulfiller))); + }, teeExceptionSize(fulfiller, canceler))); } void abortRead() override { @@ -1263,7 +1301,7 @@ private: canceler.release(); fulfiller.fulfill(kj::cp(amount)); pipe.endState(*this); - }, teeExceptionVoid(fulfiller))); + }, teeExceptionVoid(fulfiller, canceler))); } auto remainder = pieces.slice(i, pieces.size()); @@ -1292,7 +1330,7 @@ private: fulfiller.fulfill(kj::cp(amount)); pipe.endState(*this); } - }, teeExceptionVoid(fulfiller))); + }, teeExceptionVoid(fulfiller, canceler))); } Promise writeWithFds(ArrayPtr data, @@ -2855,9 +2893,10 @@ public: } uint getPort() override { - return receivers[0]->getPort(); + return receivers.size() > 0 ? receivers[0]->getPort() : 0u; } void getsockopt(int level, int option, void* value, uint* length) override { + KJ_REQUIRE(receivers.size() > 0); return receivers[0]->getsockopt(level, option, value, length); } void setsockopt(int level, int option, const void* value, uint length) override { @@ -2867,6 +2906,7 @@ public: } } void getsockname(struct sockaddr* addr, uint* length) override { + KJ_REQUIRE(receivers.size() > 0); return receivers[0]->getsockname(addr, length); } @@ -3042,7 +3082,7 @@ bool matchesAny(ArrayPtr cidrs, const struct sockaddr* addr) { } NetworkFilter::NetworkFilter() - : allowUnix(true), allowAbstractUnix(true) { + : allowUnix(true), allowAbstractUnix(true), allowVsock(true) { allowCidrs.add(CidrRange::inet4({0,0,0,0}, 0)); allowCidrs.add(CidrRange::inet6({}, {}, 0)); denyCidrs.addAll(reservedCidrs()); @@ -3050,7 +3090,7 @@ NetworkFilter::NetworkFilter() NetworkFilter::NetworkFilter(ArrayPtr allow, ArrayPtr deny, NetworkFilter& next) - : allowUnix(false), allowAbstractUnix(false), next(next) { + : allowUnix(false), allowAbstractUnix(false), allowVsock(false), next(next) { for (auto rule: allow) { if (rule == "local") { allowCidrs.addAll(localCidrs()); @@ -3067,6 +3107,8 @@ NetworkFilter::NetworkFilter(ArrayPtr allow, ArrayPtr allow, ArrayPtrsa_family == AF_VSOCK) return allowVsock; +#endif + bool allowed = false; uint allowSpecificity = 0; @@ -3157,15 +3205,23 @@ bool NetworkFilter::shouldAllowParse(const struct sockaddr* addr, uint addrlen) } } else { #endif - if ((addr->sa_family == AF_INET || addr->sa_family == AF_INET6) && - (allowPublic || allowNetwork)) { - matched = true; - } - for (auto& cidr: allowCidrs) { - if (cidr.matchesFamily(addr->sa_family)) { +#if __linux__ + if (addr->sa_family == AF_VSOCK) { + if (allowVsock) matched = true; + } else { +#endif + if ((addr->sa_family == AF_INET || addr->sa_family == AF_INET6) && + (allowPublic || allowNetwork)) { matched = true; } + for (auto& cidr: allowCidrs) { + if (cidr.matchesFamily(addr->sa_family)) { + matched = true; + } + } +#if __linux__ } +#endif #if !_WIN32 } #endif diff --git a/c++/src/kj/async-io.h b/c++/src/kj/async-io.h index 411937668a..68fe673787 100644 --- a/c++/src/kj/async-io.h +++ b/c++/src/kj/async-io.h @@ -171,6 +171,23 @@ class AsyncIoStream: public AsyncInputStream, public AsyncOutputStream { // isn't wrapping a file descriptor. }; +class NullStream final: public AsyncIoStream { + // Convenience class that implements an I/O stream that ignores all writes and returns EOF for + // all reads. + // + // Hint: You can also use this class when you just need an input stream or an output stream. +public: + kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override; + kj::Maybe tryGetLength() override; + kj::Promise pumpTo(kj::AsyncOutputStream& output, uint64_t amount) override; + + kj::Promise write(const void* buffer, size_t size) override; + kj::Promise write(kj::ArrayPtr> pieces) override; + kj::Promise whenWriteDisconnected() override; + + void shutdownWrite() override; +}; + Promise unoptimizedPumpTo( AsyncInputStream& input, AsyncOutputStream& output, uint64_t amount, uint64_t completedSoFar = 0); @@ -903,6 +920,14 @@ class LowLevelAsyncIoProvider { Own newAsyncIoProvider(LowLevelAsyncIoProvider& lowLevel); // Make a new AsyncIoProvider wrapping a `LowLevelAsyncIoProvider`. +#if _WIN32 +Own newLowLevelAsyncIoProvider(Win32EventPort& eventPort); +// Make a new `LowLevelAsyncIoProvider` backed by a `Win32EventPort`. +#else +Own newLowLevelAsyncIoProvider(UnixEventPort& eventPort); +// Make a new `LowLevelAsyncIoProvider` backed by a `UnixEventPort`. +#endif + struct AsyncIoContext { Own lowLevelProvider; Own provider; diff --git a/c++/src/kj/async-unix-test.c++ b/c++/src/kj/async-unix-test.c++ index 448bc24fb4..bb0cc26dda 100644 --- a/c++/src/kj/async-unix-test.c++ +++ b/c++/src/kj/async-unix-test.c++ @@ -39,6 +39,7 @@ #include #include #include "mutex.h" +#include #if KJ_USE_EPOLL #include @@ -1119,6 +1120,188 @@ KJ_TEST("UnixEventPort thread-specific signals") { } #endif +#if KJ_USE_EPOLL +KJ_TEST("UnixEventPoll::getPollableFd() for external waiting") { + kj::UnixEventPort port; + kj::EventLoop loop(port); + kj::WaitScope ws(loop); + + auto portIsReady = [&port](int timout = 0) { + struct pollfd pfd; + memset(&pfd, 0, sizeof(pfd)); + pfd.events = POLLIN; + pfd.fd = port.getPollableFd(); + + int n; + KJ_SYSCALL(n = poll(&pfd, 1, timout)); + return n > 0; + }; + + // Test wakeup on observed FD. + { + int pair[2]; + KJ_SYSCALL(pipe(pair)); + kj::AutoCloseFd in(pair[0]); + kj::AutoCloseFd out(pair[1]); + + kj::UnixEventPort::FdObserver observer(port, in, kj::UnixEventPort::FdObserver::OBSERVE_READ); + auto promise = observer.whenBecomesReadable(); + + KJ_EXPECT(!promise.poll(ws)); + ws.poll(); + port.preparePollableFdForSleep(); + + KJ_EXPECT(!portIsReady()); + + KJ_SYSCALL(write(out, "a", 1)); + + KJ_EXPECT(portIsReady()); + + KJ_ASSERT(promise.poll(ws)); + promise.wait(ws); + } + + // Test wakeup due to queuing work to the event loop in-process. + { + ws.poll(); + port.preparePollableFdForSleep(); + + KJ_EXPECT(!portIsReady()); + + auto promise = kj::evalLater([]() {}).eagerlyEvaluate(nullptr); + + KJ_EXPECT(portIsReady()); + KJ_ASSERT(promise.poll(ws)); + promise.wait(ws); + } + + // Test wakeup on timeout. + { + auto promise = port.getTimer().afterDelay(50 * kj::MILLISECONDS); + + KJ_EXPECT(!promise.poll(ws)); + ws.poll(); + port.preparePollableFdForSleep(); + + KJ_EXPECT(!portIsReady()); + + usleep(50'000); + + KJ_EXPECT(portIsReady()); + + KJ_ASSERT(promise.poll(ws)); + promise.wait(ws); + } + + // Test wakeup on time in past. This verifies timerfd_settime() won't just silently fail if the + // time is already past. + { + ws.poll(); + + // Schedule time event in the past. + auto promise = port.getTimer().atTime(kj::origin() + 1 * SECONDS); + + // As of this writing, atTime() doesn't do any special handling of times in the past, e.g. to + // immediately resolve the promise. It goes ahead and schedules them like any other I/O. So + // scheduling such a promise will not immediately schedule work on the event loop, and + // preparePollableFdForSleep() will in fact go and timerfd_settime() to a time in the past. (If + // this changes, we'll need to structure this test differently I guess.) + KJ_EXPECT(!loop.isRunnable()); + + port.preparePollableFdForSleep(); + + // Uhhhh... Apparently when timerfd_settime() sets a time in the past, the timerfd does NOT + // immediately become readable. The kernel still needs to process the timer in the background + // before it raises the event. So we will need to give it some time... we give it 10ms here. + KJ_EXPECT(portIsReady(10)); + + KJ_ASSERT(promise.poll(ws)); + promise.wait(ws); + } + + // Test wakeup when a timer event is created during sleep. + { + ws.poll(); + auto startTime = port.getTimer().now(); + port.preparePollableFdForSleep(); + + KJ_EXPECT(!portIsReady()); + + // When sleeping, passage of real time updates `timer.now()`. + usleep(50'000); + KJ_EXPECT(port.getTimer().now() - startTime >= 50 * MILLISECONDS); + + // We can set a timer now, and the epoll FD will wake up when it expires, even though no timer + // was set when `preparePollableFdForSleep()` was called. + auto promise = port.getTimer().afterDelay(50 * MILLISECONDS); + + // It won't expire too early: the delay was added to the real time, not the last time the + // timer was advanced to. + KJ_EXPECT(!portIsReady(10)); + KJ_EXPECT(portIsReady(40)); + + KJ_ASSERT(promise.poll(ws)); + promise.wait(ws); + } +} + +KJ_TEST("m:n threads:EventLoops") { + // This test shows that it's possible for an EventLoop to switch threads, and for a thread to + // switch event loops. + + UnixEventPort port1; + EventLoop loop1(port1); + + UnixEventPort port2; + EventLoop loop2(port2); + + kj::TimePoint startTime = kj::origin(); + kj::Promise promise1 = nullptr; + PromiseCrossThreadFulfillerPair xpaf { nullptr, {} }; + const Executor* executor; + + { + WaitScope ws1(loop1); + ws1.poll(); + startTime = port1.getTimer().now(); + promise1 = port1.getTimer().afterDelay(10 * kj::MILLISECONDS); + xpaf = kj::newPromiseAndCrossThreadFulfiller(); + executor = &getCurrentThreadExecutor(); + } + + static thread_local uint threadId = 0; + + threadId = 1; + bool executorDone = false; + + kj::Thread thread([&]() noexcept { + threadId = 2; + + WaitScope ws1(loop1); + promise1.wait(ws1); + KJ_EXPECT(port1.getTimer().now() - startTime >= 10 * kj::MILLISECONDS); + KJ_EXPECT(executorDone); + + xpaf.promise.wait(ws1); + }); + + [&]() noexcept { + WaitScope ws2(loop2); + + // The `executor` we captured earlier is tied to loop1, which has changed threads, so code we + // schedule on it will run there. + uint remoteThreadId = executor->executeAsync([&]() { + return threadId; + }).wait(ws2); + executorDone = true; + KJ_EXPECT(remoteThreadId == 2); + KJ_EXPECT(threadId == 1); + + xpaf.fulfiller->fulfill(); + }(); +} +#endif + } // namespace } // namespace kj diff --git a/c++/src/kj/async-unix.c++ b/c++/src/kj/async-unix.c++ index 6de12479ec..b2e6ddc254 100644 --- a/c++/src/kj/async-unix.c++ +++ b/c++/src/kj/async-unix.c++ @@ -23,7 +23,6 @@ #include "async-unix.h" #include "debug.h" -#include "threadlocal.h" #include #include #include @@ -36,6 +35,7 @@ #if KJ_USE_EPOLL #include #include +#include #elif KJ_USE_KQUEUE #include #include @@ -70,7 +70,7 @@ bool threadClaimedChildExits = false; namespace { -KJ_THREADLOCAL_PTR(UnixEventPort) threadEventPort = nullptr; +thread_local UnixEventPort* threadEventPort = nullptr; // This is set to the current UnixEventPort just before epoll_pwait(), then back to null after it // returns. @@ -100,7 +100,7 @@ void UnixEventPort::signalHandler(int, siginfo_t* siginfo, void*) noexcept { #elif KJ_USE_KQUEUE #if !KJ_HAS_SIGTIMEDWAIT -KJ_THREADLOCAL_PTR(siginfo_t) threadCapture = nullptr; +static thread_local siginfo_t* threadCapture = nullptr; #endif void UnixEventPort::signalHandler(int, siginfo_t* siginfo, void*) noexcept { @@ -152,7 +152,7 @@ struct SignalCapture { #endif }; -KJ_THREADLOCAL_PTR(SignalCapture) threadCapture = nullptr; +thread_local SignalCapture* threadCapture = nullptr; } // namespace @@ -542,6 +542,8 @@ void UnixEventPort::wake() const { } bool UnixEventPort::wait() { + sleeping = false; + #ifdef KJ_DEBUG // In debug mode, verify the current signal mask matches the original. { @@ -647,6 +649,19 @@ bool UnixEventPort::processEpollEvents(struct epoll_event events[], int n) { // We were woken. Need to return true. woken = true; + } else if (events[i].data.u64 == 1) { + // timerfd fired. We need to clear it by reading it. + int tfd = KJ_ASSERT_NONNULL(timerFd).get(); + char buffer[16]; + ssize_t n; + KJ_NONBLOCKING_SYSCALL(n = read(tfd, buffer, sizeof(buffer))); + KJ_ASSERT(n == 8 || n < 0); + + timerfdIsArmed = false; + + // The purpose of this event is just to wake up the event loop when needed. We'll check the + // timer queue separately, so we don't need to do anything special in response to this event + // here. } else { FdObserver* observer = reinterpret_cast(events[i].data.ptr); observer->fire(events[i].events); @@ -663,6 +678,8 @@ bool UnixEventPort::poll() { // pending signals. Therefore, we need a completely different approach to poll for signals. We // might as well use regular epoll_wait() in this case, too, to save the kernel some effort. + sleeping = false; + if (signalHead != nullptr || childSet != kj::none) { // Use sigtimedwait() to poll for signals. @@ -712,6 +729,85 @@ bool UnixEventPort::poll() { return processEpollEvents(events, n); } +int UnixEventPort::getPollableFd() { + return epollFd.get(); +} + +void UnixEventPort::preparePollableFdForSleep() { + KJ_ASSERT(signalHead == nullptr, + "preparePollableFdForSleep() cannot be used when waiting for signals"); + + if (runnable) { + // There is still immediate work in the queue, so force the epoll to be ready immediately. (See + // comments in setRunnable() regarding using wake() for this.) + wake(); + } else { + updateNextTimerEvent(timerImpl.nextEvent()); + + // Flag that we're sleeping, so setRunnable() knows to use wake() if needed. + sleeping = true; + + // Tell the timer we're sleeping, so that it notifies us if anyone tries to create a new time + // event. + timerImpl.setSleeping(*this); + } +} + +void UnixEventPort::setRunnable(bool runnable) { + this->runnable = runnable; + if (runnable && sleeping) { + // A meta event loop is waiting for the epoll to become readable, and an event was queued + // directly to the event loop in the meantime (e.g. due to a promise fulfiller being + // fulfilled). So, we need to cause the epoll to signal readability, to wake the event loop. We + // can use the cross-thread wake mechanism for this. It'll cause the event loop to spuriously + // check for cross-thread events, incurring a mutex lock, but that's not a big deal, and that's + // much better than creating a whole second event FD for this purpose. + wake(); + sleeping = false; + } +} + +void UnixEventPort::updateNextTimerEvent(kj::Maybe time) { + if (time == kj::none && !timerfdIsArmed) { + // No change needed. + return; + } + + // Create the timerfd if needed. + int tfd; + KJ_IF_SOME(f, timerFd) { + tfd = f; + } else { + KJ_SYSCALL(tfd = timerfd_create(CLOCK_MONOTONIC, TFD_CLOEXEC | TFD_NONBLOCK)); + timerFd = kj::AutoCloseFd(tfd); + + struct epoll_event event; + memset(&event, 0, sizeof(event)); + event.events = EPOLLIN; + event.data.u64 = 1; + KJ_SYSCALL(epoll_ctl(epollFd, EPOLL_CTL_ADD, tfd, &event)); + } + + // Update timerfd's expiration time. + struct itimerspec ts; + memset(&ts, 0, sizeof(ts)); + KJ_IF_SOME(t, time) { + auto t2 = t - origin(); + ts.it_value.tv_sec = t2 / SECONDS; + ts.it_value.tv_nsec = (t2 % SECONDS) / NANOSECONDS; + timerfdIsArmed = true; + } else { + // setting the time to zero will disarm it + timerfdIsArmed = false; + } + + KJ_SYSCALL(timerfd_settime(tfd, TFD_TIMER_ABSTIME, &ts, nullptr)); +} + +kj::TimePoint UnixEventPort::getTimeWhileSleeping() { + return clock.now(); +} + #elif KJ_USE_KQUEUE // ======================================================================================= // kqueue FdObserver implementation @@ -1583,6 +1679,16 @@ void UnixEventPort::wake() const { #endif // KJ_USE_EPOLL, else KJ_USE_KQUEUE, else +#if !KJ_USE_EPOLL +void UnixEventPort::updateNextTimerEvent(kj::Maybe time) { + KJ_UNIMPLEMENTED("SleepHooks not used on this platform, this shouldn't be called"); +} + +kj::TimePoint UnixEventPort::getTimeWhileSleeping() { + KJ_UNIMPLEMENTED("SleepHooks not used on this platform, this shouldn't be called"); +} +#endif + } // namespace kj #endif // !_WIN32 diff --git a/c++/src/kj/async-unix.h b/c++/src/kj/async-unix.h index 665305ea70..f0fbe51202 100644 --- a/c++/src/kj/async-unix.h +++ b/c++/src/kj/async-unix.h @@ -64,7 +64,7 @@ struct timespec; namespace kj { -class UnixEventPort: public EventPort { +class UnixEventPort: public EventPort, private TimerImpl::SleepHooks { // An EventPort implementation which can wait for events on file descriptors as well as signals. // This API only makes sense on Unix. // @@ -178,11 +178,70 @@ class UnixEventPort: public EventPort { // This method may capture the `SIGCHLD` signal. You must not use `captureSignal(SIGCHLD)` nor // `onSignal(SIGCHLD)` in your own code if you use `captureChildExit()`. +#if KJ_USE_EPOLL + int getPollableFd(); + // Get a file descriptor which represents the EventPort's backing OS event queue, and becomes + // readable when there are events to process. This may be an epoll FD or a kqueue FD depending on + // the OS. + // + // You MUST use preparePollableFdForSleep() before waiting on this FD. + // + // The caller should not perform operations on this FD. It should only be used for polling in + // some other event loop. This can be useful for allowing UnixEventPort to operate embedded in + // an application that uses some other event loop as its main loop. Whenever this FD becomes + // readable, call waitScope.poll() to handle all available events, then + // `preparePollableFdForSleep()`. + // + // It's also possible to use this to drive multiple event loops from the same thread, or even + // multiple threads in an m:n mapping. When an event loop becomes ready, create a temporary + // WaitScope on the thread where you want to run it, pump the loop, and then destroy the + // WaitScope. A WaitScope binds a loop to a thread while it exists, but an event loop is allowed + // to change threads and a thread is allowed to change event loops as long as no WaitScope is + // currently binding them. + // + // TODO(someday): Currently this is only implemented for epoll, NOT for kqueue. But in principle + // it should be possible to implement for kqueue as well. + + void preparePollableFdForSleep(); + // If you plan to monitor the FD return by getPollableFd() for notifications that this queue is ready, + // you must call preparePollableFdForSleep() after each run of this port's event loop in order to + // ensure that all event types will in fact wake up the queue. + // + // This call is needed in particular to arrange for timer events. Normally, timer events are not + // implemented via the epoll/kqueue at all, but instead the call to epoll_wait() / kevent() is + // given a timeout that causes it to return early when the next timer event is scheduled. This + // doesn't work when the wait is being performed on an external event loop, so instead the + // implementation must arrange to use timerfd (for epoll) of EVFILT_TIMER (for kqueue) so that + // the queue FD actually becomes ready when the next timer event is ready to run. This requires + // some extra syscalls, unfortunately. + // + // A second reason this call is needed is to wake up the event queue if any work is queued + // directly to the event loop via function calls rather than OS events. For example, if anyone + // calls `fulfill()` on a `PromiseFulfiller`, thus queuing work to this event loop, it needs + // to wake up. The EventPort achieves this by implicitly calling `wake()` if any work is queued + // while sleeping -- similar to a cross-thread wakup. NOTE: Queuing work like this while the + // event loop is asleep is only safe if the EventLoop is still bound to the thread by the + // existence of a WaitScope. If you are destroying the WaitScope while sleeping, then you cannot + // manipulate promises attached to this loop at all while it is asleep. (Cross-thread events, + // e.g. queued via newPromiseAndCrossThreadFulfiller(), or via kj::Executor, are always safe. + // These will cause the queue fd to become ready so that the loop can wake up and respond to the + // events.) + // + // This method will throw an exception if the port is currently waiting on any signals via + // onSignal() or onChildExit(). Although it might theoretically be possible to support signals in + // this mode, deep flaws in the relevant APIs across multiple OSs make it likely not worth + // attempting. +#endif + // implements EventPort ------------------------------------------------------ bool wait() override; bool poll() override; void wake() const override; +#if KJ_USE_EPOLL + void setRunnable(bool runnable) override; +#endif + private: class SignalPromiseAdapter; class ChildExitPromiseAdapter; @@ -203,6 +262,11 @@ class UnixEventPort: public EventPort { sigset_t originalMask; AutoCloseFd epollFd; AutoCloseFd eventFd; // Used for cross-thread wakeups. + kj::Maybe timerFd; // Used if preparePollableFdForSleep() is ever called. + + bool sleeping = false; // Was preparePollableFdForSleep() called? + bool runnable = false; // Last value passed to setRunnable(). + bool timerfdIsArmed = false; bool processEpollEvents(struct epoll_event events[], int n); #elif KJ_USE_KQUEUE @@ -234,6 +298,10 @@ class UnixEventPort: public EventPort { static void registerReservedSignal(); #endif static void ignoreSigpipe(); + + // Implements TimerImpl::SleepHooks. + void updateNextTimerEvent(kj::Maybe time) override; + kj::TimePoint getTimeWhileSleeping() override; }; class UnixEventPort::FdObserver: private AsyncObject { diff --git a/c++/src/kj/async.c++ b/c++/src/kj/async.c++ index d4bfdd72ff..80cce9bc8d 100644 --- a/c++/src/kj/async.c++ +++ b/c++/src/kj/async.c++ @@ -42,7 +42,6 @@ #include "async.h" #include "debug.h" #include "vector.h" -#include "threadlocal.h" #include "mutex.h" #include "one-of.h" #include "function.h" @@ -116,7 +115,7 @@ namespace kj { namespace { -KJ_THREADLOCAL_PTR(DisallowAsyncDestructorsScope) disallowAsyncDestructorsScope = nullptr; +thread_local DisallowAsyncDestructorsScope* disallowAsyncDestructorsScope = nullptr; } // namespace @@ -161,7 +160,7 @@ AllowAsyncDestructorsScope::~AllowAsyncDestructorsScope() { namespace { -KJ_THREADLOCAL_PTR(EventLoop) threadLocalEventLoop = nullptr; +thread_local EventLoop* threadLocalEventLoop = nullptr; #define _kJ_ALREADY_READY reinterpret_cast< ::kj::_::Event*>(1) diff --git a/c++/src/kj/async.h b/c++/src/kj/async.h index f271ea5b21..f305869477 100644 --- a/c++/src/kj/async.h +++ b/c++/src/kj/async.h @@ -125,7 +125,7 @@ class [[nodiscard]] Promise: protected _::PromiseBase { // `kj::READY_NOW` to an already-fulfilled Promise. You may also implicitly convert a // `kj::Exception` to an already-broken promise of any type. // - // Promises are linear types -- they are moveable but not copyable. If a Promise is destroyed + // Promises are linear types -- they are movable but not copyable. If a Promise is destroyed // or goes out of scope (without being moved elsewhere), any ongoing asynchronous operations // meant to fulfill the promise will be canceled if possible. All methods of `Promise` (unless // otherwise noted) actually consume the promise in the sense of move semantics. (Arguably they @@ -476,7 +476,7 @@ PromiseForResult retryOnDisconnect(Func&& func) KJ_WARN_UNUSED_RESUL template PromiseForResult startFiber( size_t stackSize, Func&& func, SourceLocation location = {}) KJ_WARN_UNUSED_RESULT; -// Executes `func()` in a fiber, returning a promise for the eventual reseult. `func()` will be +// Executes `func()` in a fiber, returning a promise for the eventual result. `func()` will be // passed a `WaitScope&` as its parameter, allowing it to call `.wait()` on promises. Thus, `func()` // can be written in a synchronous, blocking style, instead of using `.then()`. This is often much // easier to write and read, and may even be significantly faster if it allows the use of stack diff --git a/c++/src/kj/common-test.c++ b/c++/src/kj/common-test.c++ index 248a169e1d..b138777fc2 100644 --- a/c++/src/kj/common-test.c++ +++ b/c++/src/kj/common-test.c++ @@ -23,6 +23,7 @@ #include "test.h" #include #include +#include namespace kj { namespace { @@ -925,5 +926,47 @@ KJ_TEST("kj::ArrayPtr startsWith / endsWith / findFirst / findLast") { KJ_EXPECT(arr.findLast(78).orDefault(100) == 100); } +struct Std { + template + static std::span from(ArrayPtr* arr) { + return std::span(arr->begin(), arr->size()); + } +}; + +KJ_TEST("ArrayPtr::as") { + int rawArray[] = {12, 34, 56, 34, 12}; + ArrayPtr arr(rawArray); + std::span stdPtr = arr.as(); + KJ_EXPECT(stdPtr.size() == 5); +} + +// Verifies the expected values of kj::isDisallowedInCoroutine + +struct DisallowedInCoroutineStruct { + KJ_DISALLOW_AS_COROUTINE_PARAM; +}; +class DisallowedInCoroutinePublic { +public: + KJ_DISALLOW_AS_COROUTINE_PARAM; +}; +class DisallowedInCoroutinePrivate { +private: + KJ_DISALLOW_AS_COROUTINE_PARAM; +}; +struct AllowedInCoroutine {}; + +static_assert(_::isDisallowedInCoroutine()); +static_assert(_::isDisallowedInCoroutine()); +static_assert(_::isDisallowedInCoroutine()); +static_assert(_::isDisallowedInCoroutine()); +static_assert(_::isDisallowedInCoroutine()); +static_assert(_::isDisallowedInCoroutine()); +static_assert(_::isDisallowedInCoroutine()); +static_assert(_::isDisallowedInCoroutine()); +static_assert(_::isDisallowedInCoroutine()); +static_assert(!_::isDisallowedInCoroutine()); +static_assert(!_::isDisallowedInCoroutine()); +static_assert(!_::isDisallowedInCoroutine()); + } // namespace } // namespace kj diff --git a/c++/src/kj/common.h b/c++/src/kj/common.h index 10555fe39f..67718d5951 100644 --- a/c++/src/kj/common.h +++ b/c++/src/kj/common.h @@ -290,7 +290,7 @@ typedef unsigned char byte; #define KJ_DEPRECATED(reason) \ __attribute__((deprecated)) #define KJ_UNAVAILABLE(reason) = delete -// If the `unavailable` attribute is not supproted, just mark the method deleted, which at least +// If the `unavailable` attribute is not supported, just mark the method deleted, which at least // makes it a compile-time error to try to call it. Note that on Clang, marking a method deleted // *and* unavailable unfortunately defeats the purpose of the unavailable annotation, as the // generic "deleted" error is reported instead. @@ -1771,7 +1771,7 @@ class ArrayPtr: public DisallowConstCopyIfNotConst { // ArrayPtr ptr = { 1, 2, 3 }; // foo(ptr[1]); // undefined behavior! // Any KJ programmer should be able to recognize that this is UB, because an ArrayPtr does not own -// its content. That's not what this constructor is for, tohugh. This constructor is meant to allow +// its content. That's not what this constructor is for, though. This constructor is meant to allow // code like this: // int foo(ArrayPtr p); // // ... later ... @@ -1927,6 +1927,11 @@ class ArrayPtr: public DisallowConstCopyIfNotConst { // // You must include kj/array.h to call this. + template + inline auto as() { return U::from(this); } + // Syntax sugar for invoking U::from. + // Used to chain conversion calls rather than wrap with function. + private: T* ptr; size_t size_; @@ -2093,6 +2098,49 @@ _::Deferred defer(Func&& func) { #define KJ_DEFER(code) auto KJ_UNIQUE_NAME(_kjDefer) = ::kj::defer([&](){code;}) // Run the given code when the function exits, whether by return or exception. +// ======================================================================================= +// IsDisallowedInCoroutine + +namespace _ { + +template +struct IsDisallowedInCoroutine { + static constexpr bool value = false; +}; + +template +struct IsDisallowedInCoroutine::_kj_DissalowedInCoroutine> { + static constexpr bool value = true; +}; + +template +struct IsDisallowedInCoroutine { + static constexpr bool value = true; +}; + +template +constexpr bool isDisallowedInCoroutine() { + return IsDisallowedInCoroutine::value; +} + +} // namespace _ + +#define KJ_DISALLOW_AS_COROUTINE_PARAM \ + using _kj_DissalowedInCoroutine = void; \ + template friend constexpr bool kj::_::isDisallowedInCoroutine(); \ + template friend struct kj::_::IsDisallowedInCoroutine +// Place in the body of a class or struct to indicate that an instance of or reference to this +// type cannot be passed as the parameter to a KJ coroutine. This makes sense, for example, for +// mutex locks, or other types which should never be held across a co_await. +// +// (Types annotated with this likely also should not be used as local variables inside coroutines, +// but there is no way for us to enforce that.) +// +// struct Foo { +// KJ_DISALLOW_AS_COROUTINE_PARAM; +// } +// + } // namespace kj KJ_END_HEADER diff --git a/c++/src/kj/compat/http-test.c++ b/c++/src/kj/compat/http-test.c++ index 148d035e1d..a4d85179f3 100644 --- a/c++/src/kj/compat/http-test.c++ +++ b/c++/src/kj/compat/http-test.c++ @@ -43,8 +43,8 @@ #else // Run the test using in-process two-way pipes. #define KJ_HTTP_TEST_SETUP_IO \ - kj::EventLoop eventLoop; \ - kj::WaitScope waitScope(eventLoop) + auto io = kj::setupAsyncIo(); \ + auto& waitScope KJ_UNUSED = io.waitScope #define KJ_HTTP_TEST_SETUP_LOOPBACK_LISTENER_AND_ADDR \ auto capPipe = newCapabilityPipe(); \ auto listener = kj::heap(*capPipe.ends[0]); \ @@ -1879,6 +1879,33 @@ public: } }; +void assertContainsWebSocketClose(kj::ArrayPtr data, uint16_t code, kj::Maybe messageSubstr) { + KJ_ASSERT(data.size() >= 2); // The smallest possible Close frame has size 2. + KJ_ASSERT(data.size() <= 127); // Maximum size for control frames. + KJ_ASSERT((data[0] & 0xf0) == 0x80); // Only the FIN flag is set. + KJ_ASSERT((data[0] & 0x0f) == 8); // OPCODE_CLOSE + + size_t payloadSize = data[1] & 0x7f; + + if (payloadSize == 0) { + // A Close frame with no body has no status code and no reason. + KJ_ASSERT(code == 1005); + KJ_ASSERT(messageSubstr == kj::none); + } else { + KJ_ASSERT(code != 1005); + } + auto payload = data.slice(2); + + KJ_ASSERT(payload.size() >= 2); // The first two bytes are the status code, so we better have at least two bytes. + uint16_t gotCode = (payload[0] << 8) | payload[1]; + KJ_ASSERT(gotCode == code); + + KJ_IF_SOME(needle, messageSubstr) { + auto reason = kj::str(payload.asChars().slice(2)); + KJ_ASSERT(reason.contains(needle), reason, needle); + } +} + KJ_TEST("WebSocket unexpected RSV bits") { KJ_HTTP_TEST_SETUP_IO; auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; @@ -1893,7 +1920,10 @@ KJ_TEST("WebSocket unexpected RSV bits") { 0xF0, 0x05, 'w', 'o', 'r', 'l', 'd' // all RSV bits set, plus FIN }; - auto clientTask = client->write(DATA, sizeof(DATA)); + auto rawCloseMessage = kj::heapArray(129); + auto clientTask = client->write(DATA, sizeof(DATA)).then([&]() { + return client->tryRead(rawCloseMessage.begin(), 2, rawCloseMessage.size()); + }); { bool gotException = false; @@ -1904,7 +1934,8 @@ KJ_TEST("WebSocket unexpected RSV bits") { KJ_ASSERT(errorCatcher.errors[0].statusCode == 1002); } - clientTask.wait(waitScope); + auto nread = clientTask.wait(waitScope); + assertContainsWebSocketClose(rawCloseMessage.slice(0, nread), 1002, "RSV bits"_kjc); } KJ_TEST("WebSocket unexpected continuation frame") { @@ -1919,7 +1950,10 @@ KJ_TEST("WebSocket unexpected continuation frame") { 0x80, 0x06, 'h', 'e', 'l', 'l', 'o', ' ', // Continuation frame with no start frame, plus FIN }; - auto clientTask = client->write(DATA, sizeof(DATA)); + auto rawCloseMessage = kj::heapArray(129); + auto clientTask = client->write(DATA, sizeof(DATA)).then([&]() { + return client->tryRead(rawCloseMessage.begin(), 2, rawCloseMessage.size()); + }); { bool gotException = false; @@ -1930,7 +1964,8 @@ KJ_TEST("WebSocket unexpected continuation frame") { KJ_ASSERT(errorCatcher.errors[0].statusCode == 1002); } - clientTask.wait(waitScope); + auto nread = clientTask.wait(waitScope); + assertContainsWebSocketClose(rawCloseMessage.slice(0, nread), 1002, "Unexpected continuation frame"_kjc); } KJ_TEST("WebSocket missing continuation frame") { @@ -1946,7 +1981,10 @@ KJ_TEST("WebSocket missing continuation frame") { 0x01, 0x06, 'w', 'o', 'r', 'l', 'd', '!', // Another start frame }; - auto clientTask = client->write(DATA, sizeof(DATA)); + auto rawCloseMessage = kj::heapArray(129); + auto clientTask = client->write(DATA, sizeof(DATA)).then([&]() { + return client->tryRead(rawCloseMessage.begin(), 2, rawCloseMessage.size()); + }); { bool gotException = false; @@ -1956,7 +1994,8 @@ KJ_TEST("WebSocket missing continuation frame") { KJ_ASSERT(errorCatcher.errors.size() == 1); } - clientTask.wait(waitScope); + auto nread = clientTask.wait(waitScope); + assertContainsWebSocketClose(rawCloseMessage.slice(0, nread), 1002, "Missing continuation frame"_kjc); } KJ_TEST("WebSocket fragmented control frame") { @@ -1971,7 +2010,10 @@ KJ_TEST("WebSocket fragmented control frame") { 0x09, 0x04, 'd', 'a', 't', 'a' // Fragmented ping frame }; - auto clientTask = client->write(DATA, sizeof(DATA)); + auto rawCloseMessage = kj::heapArray(129); + auto clientTask = client->write(DATA, sizeof(DATA)).then([&]() { + return client->tryRead(rawCloseMessage.begin(), 2, rawCloseMessage.size()); + }); { bool gotException = false; @@ -1982,7 +2024,8 @@ KJ_TEST("WebSocket fragmented control frame") { KJ_ASSERT(errorCatcher.errors[0].statusCode == 1002); } - clientTask.wait(waitScope); + auto nread = clientTask.wait(waitScope); + assertContainsWebSocketClose(rawCloseMessage.slice(0, nread), 1002, "Received fragmented control frame"_kjc); } KJ_TEST("WebSocket unknown opcode") { @@ -1997,7 +2040,10 @@ KJ_TEST("WebSocket unknown opcode") { 0x85, 0x04, 'd', 'a', 't', 'a' // 5 is a reserved opcode }; - auto clientTask = client->write(DATA, sizeof(DATA)); + auto rawCloseMessage = kj::heapArray(129); + auto clientTask = client->write(DATA, sizeof(DATA)).then([&]() { + return client->tryRead(rawCloseMessage.begin(), 2, rawCloseMessage.size()); + }); { bool gotException = false; @@ -2008,7 +2054,8 @@ KJ_TEST("WebSocket unknown opcode") { KJ_ASSERT(errorCatcher.errors[0].statusCode == 1002); } - clientTask.wait(waitScope); + auto nread = clientTask.wait(waitScope); + assertContainsWebSocketClose(rawCloseMessage.slice(0, nread), 1002, "Unknown opcode 5"_kjc); } KJ_TEST("WebSocket unsolicited pong") { @@ -2387,6 +2434,7 @@ KJ_TEST("WebSocket maximum message size") { WebSocketErrorCatcher errorCatcher; FakeEntropySource maskGenerator; + auto* rawClient = pipe.ends[0].get(); auto client = newWebSocket(kj::mv(pipe.ends[0]), maskGenerator); auto server = newWebSocket(kj::mv(pipe.ends[1]), kj::none, kj::none, errorCatcher); @@ -2394,9 +2442,12 @@ KJ_TEST("WebSocket maximum message size") { auto biggestAllowedString = kj::strArray(kj::repeat(kj::StringPtr("A"), maxSize), ""); auto tooBigString = kj::strArray(kj::repeat(kj::StringPtr("B"), maxSize + 1), ""); + auto rawCloseMessage = kj::heapArray(129); auto clientTask = client->send(biggestAllowedString) .then([&]() { return client->send(tooBigString); }) - .then([&]() { return client->close(1234, "done"); }); + .then([&]() { + return rawClient->tryRead(rawCloseMessage.begin(), 2, rawCloseMessage.size()); + }); { auto message = server->receive(maxSize).wait(waitScope); @@ -2410,6 +2461,9 @@ KJ_TEST("WebSocket maximum message size") { KJ_ASSERT(errorCatcher.errors.size() == 1); KJ_ASSERT(errorCatcher.errors[0].statusCode == 1009); } + + auto nread = clientTask.wait(waitScope); + assertContainsWebSocketClose(rawCloseMessage.slice(0, nread), 1009, "too large"_kjc); } class TestWebSocketService final: public HttpService, private kj::TaskSet::ErrorHandler { @@ -4104,60 +4158,103 @@ KJ_TEST("HttpServer threw exception") { KJ_EXPECT(text.startsWith("HTTP/1.1 500 Internal Server Error"), text); } -KJ_TEST("HttpServer bad request") { - KJ_HTTP_TEST_SETUP_IO; - kj::TimerImpl timer(kj::origin()); - auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; - - HttpHeaderTable table; - BrokenHttpService service; - HttpServer server(timer, table, service); - - auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); - - static constexpr auto request = "GET / HTTP/1.1\r\nbad request\r\n\r\n"_kj; - auto writePromise = pipe.ends[1]->write(request.begin(), request.size()); - auto response = pipe.ends[1]->readAllText().wait(waitScope); - KJ_EXPECT(writePromise.poll(waitScope)); - writePromise.wait(waitScope); +KJ_TEST("HttpServer bad requests") { + struct TestCase { + kj::StringPtr request; + kj::StringPtr expectedResponse; + bool expectWriteError; + }; - static constexpr auto expectedResponse = - "HTTP/1.1 400 Bad Request\r\n" - "Connection: close\r\n" - "Content-Length: 53\r\n" - "Content-Type: text/plain\r\n" - "\r\n" - "ERROR: The headers sent by your client are not valid."_kj; + static auto hugeHeaderRequest = kj::str( + "GET /foo/bar HTTP/1.1\r\n", + "Host: ", kj::strArray(kj::repeat("0", 1024 * 1024), ""), "\r\n", + "\r\n"); - KJ_EXPECT(expectedResponse == response, expectedResponse, response); -} + static TestCase testCases[] { + { + // bad request + .request = "GET / HTTP/1.1\r\nbad request\r\n\r\n"_kj, + .expectedResponse = + "HTTP/1.1 400 Bad Request\r\n" + "Connection: close\r\n" + "Content-Length: 53\r\n" + "Content-Type: text/plain\r\n" + "\r\n" + "ERROR: The headers sent by your client are not valid."_kj + }, + { + // invalid method + .request = "bad request\r\n\r\n"_kj, + .expectedResponse = + "HTTP/1.1 501 Not Implemented\r\n" + "Connection: close\r\n" + "Content-Length: 35\r\n" + "Content-Type: text/plain\r\n" + "\r\n" + "ERROR: Unrecognized request method."_kj + }, + { + // broken service generates 5000 + .request = + "GET /foo/bar HTTP/1.1\r\n" + "Host: example.com\r\n" + "\r\n"_kj, + .expectedResponse = + "HTTP/1.1 500 Internal Server Error\r\n" + "Connection: close\r\n" + "Content-Length: 51\r\n" + "Content-Type: text/plain\r\n" + "\r\n" + "ERROR: The HttpService did not generate a response."_kj, + }, + { + // huge header shouldn't break the server + .request = hugeHeaderRequest, + .expectedResponse = + "HTTP/1.1 431 Request Header Fields Too Large\r\n" + "Connection: close\r\n" + "Content-Length: 24\r\n" + "Content-Type: text/plain\r\n" + "\r\n" + "ERROR: header too large."_kj, + .expectWriteError = true, + }, + }; -KJ_TEST("HttpServer invalid method") { KJ_HTTP_TEST_SETUP_IO; - kj::TimerImpl timer(kj::origin()); - auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; - - HttpHeaderTable table; - BrokenHttpService service; - HttpServer server(timer, table, service); + // we need a real timer to test http server grace behavior. + auto& timer = io.provider->getTimer(); - auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + for (auto testCase : testCases) { + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; - static constexpr auto request = "bad request\r\n\r\n"_kj; - auto writePromise = pipe.ends[1]->write(request.begin(), request.size()); - auto response = pipe.ends[1]->readAllText().wait(waitScope); - KJ_EXPECT(writePromise.poll(waitScope)); - writePromise.wait(waitScope); + HttpHeaderTable table; + BrokenHttpService service; + HttpServer server(timer, table, service, { + .canceledUploadGraceBytes = 1024 * 1024, + }); - static constexpr auto expectedResponse = - "HTTP/1.1 501 Not Implemented\r\n" - "Connection: close\r\n" - "Content-Length: 35\r\n" - "Content-Type: text/plain\r\n" - "\r\n" - "ERROR: Unrecognized request method."_kj; + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + auto request = testCase.request; + auto writePromise = pipe.ends[1]->write(request.begin(), request.size()); + try { + auto response = pipe.ends[1]->readAllText().wait(waitScope); + auto expectedResponse = testCase.expectedResponse; + KJ_EXPECT(expectedResponse == response, expectedResponse, response); + } catch (...) { + auto ex = kj::getCaughtExceptionAsKj(); + KJ_FAIL_REQUIRE("not supposed to happen", ex); + } - KJ_EXPECT(expectedResponse == response, expectedResponse, response); + // write promise should have been resolved already + KJ_EXPECT(writePromise.poll(waitScope)); + try { + writePromise.wait(waitScope); + } catch (...) { + KJ_EXPECT(testCase.expectWriteError, "write error wasn't expected"); + } + } } // Ensure that HttpServerSettings can continue to be constexpr. @@ -5002,6 +5099,72 @@ KJ_TEST("newHttpService from HttpClient WebSockets") { writeResponsesPromise.wait(waitScope); } +KJ_TEST("HttpClient WebSocket: client can have a custom WebSocket error handler") { + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + // These are WEBSOCKET_REQUEST_HANDSHAKE and WEBSOCKET_RESPONSE_HANDSHAKE but without the "My-Header" header. + // This test isn't about the HTTP handshake, so the headers are just noise. + const char wsRequestHandshake[] = + " HTTP/1.1\r\n" + "Connection: Upgrade\r\n" + "Upgrade: websocket\r\n" + "Sec-WebSocket-Key: DCI4TgwiOE4MIjhODCI4Tg==\r\n" + "Sec-WebSocket-Version: 13\r\n" + "\r\n"; + const char wsResponseHandshake[] = + "HTTP/1.1 101 Switching Protocols\r\n" + "Connection: Upgrade\r\n" + "Upgrade: websocket\r\n" + "Sec-WebSocket-Accept: pShtIFKT0s8RYZvnWY/CrjQD8CM=\r\n" + "\r\n"; + + const byte badFrame[] = { + 0xF0, 0x02, 'y', 'o' // all RSV bits set, plus FIN + }; + const byte closeFrame[] = { + 0x88, 0xa8, 0xC, 0x22, 0x38, 0x4e, 0x3, 0xea, // FIN, opcode=Close, code=1009 + 'R', 'e', 'c', 'e', 'i', 'v', 'e', 'd', ' ', + 'f', 'r', 'a', 'm', 'e', ' ', + 'h', 'a', 'd', ' ', + 'R', 'S', 'V', ' ', + 'b', 'i', 't', 's', ' ', + '2', ' ', + 'o', 'r', ' ', + '3', ' ', + 's', 'e', 't', + }; + + auto request = kj::str("GET /websocket", wsRequestHandshake); + auto serverPromise = expectRead(*pipe.ends[1], request) + .then([&]() { return writeA(*pipe.ends[1], asBytes(wsResponseHandshake)); }) + .then([&]() { return writeA(*pipe.ends[1], badFrame); }) + .then([&]() { return expectRead(*pipe.ends[1], closeFrame); }) + .eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); }); + + { + HttpHeaderTable table; + FakeEntropySource entropySource; + HttpClientSettings clientSettings; + WebSocketErrorCatcher errorCatcher; + clientSettings.entropySource = entropySource; + clientSettings.webSocketErrorHandler = errorCatcher; + + auto clientStream = kj::mv(pipe.ends[0]); + auto httpClient = newHttpClient(table, *clientStream, clientSettings); + auto wsClientPromise = httpClient->openWebSocket("/websocket", HttpHeaders(table)) + .then([&](kj::HttpClient::WebSocketResponse resp) { return kj::mv(resp.webSocketOrBody.get>()); }) + .then([](kj::Own webSocket) -> kj::Promise { return webSocket->receive().attach(kj::mv(webSocket)); }) + .eagerlyEvaluate([](kj::Exception e) -> kj::WebSocket::Message { return kj::str("irrelevant value"); }); + + wsClientPromise.wait(waitScope); + KJ_EXPECT(errorCatcher.errors.size() == 1); + } + + serverPromise.wait(waitScope); +} + KJ_TEST("newHttpService from HttpClient WebSockets disconnect") { KJ_HTTP_TEST_SETUP_IO; kj::TimerImpl timer(kj::origin()); @@ -5791,6 +5954,44 @@ KJ_TEST("HttpClient concurrency limiting") { KJ_EXPECT(callbackEvents == kj::ArrayPtr({ {0, 0} })); } +KJ_TEST("HttpClientImpl connect()") { + KJ_HTTP_TEST_SETUP_IO; + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + HttpHeaderTable headerTable; + auto client = newHttpClient(headerTable, *pipe.ends[0]); + + auto req = client->connect("foo:123", HttpHeaders(headerTable), {}); + + char buffer[16]; + auto readPromise = req.connection->tryRead(buffer, 16, 16); + + expectRead(*pipe.ends[1], "CONNECT foo:123 HTTP/1.1\r\n\r\n").wait(waitScope); + + { + kj::StringPtr msg = "HTTP/1.1 200 OK\r\n\r\nthis is the"_kj; + pipe.ends[1]->write(msg.begin(), msg.size()).wait(waitScope); + } + + KJ_EXPECT(!readPromise.poll(waitScope)); + + kj::Promise writePromise = nullptr; + { + kj::StringPtr msg = " connection content!!"_kj; + writePromise = pipe.ends[1]->write(msg.begin(), msg.size()); + } + + KJ_ASSERT(readPromise.poll(waitScope)); + KJ_ASSERT(readPromise.wait(waitScope) == 16); + KJ_EXPECT(kj::str(kj::ArrayPtr(buffer)) == "this is the conn"_kj); + + KJ_EXPECT(req.connection->tryRead(buffer, 16, 16).wait(waitScope) == 16); + KJ_EXPECT(kj::str(kj::ArrayPtr(buffer)) == "ection content!!"_kj); + + KJ_ASSERT(writePromise.poll(waitScope)); + writePromise.wait(waitScope); +} + #if KJ_HTTP_TEST_USE_OS_PIPE // This test relies on access to the network. KJ_TEST("NetworkHttpClient connect impl") { diff --git a/c++/src/kj/compat/http.c++ b/c++/src/kj/compat/http.c++ index 29970d6af5..3e655ed11b 100644 --- a/c++/src/kj/compat/http.c++ +++ b/c++/src/kj/compat/http.c++ @@ -1453,7 +1453,9 @@ public: } kj::Promise readMessage() override { - auto text = co_await readMessageHeaders(); + auto textOrError = co_await readMessageHeaders(); + KJ_REQUIRE(textOrError.is>(), "bad message"); + auto text = textOrError.get>(); headers.clear(); KJ_REQUIRE(headers.tryParse(text), "bad message"); auto body = getEntityBody(HttpInputStreamImpl::RESPONSE, HttpMethod::GET, 0, headers); @@ -1528,7 +1530,7 @@ public: return !lineBreakBeforeNextHeader && leftover == nullptr; } - kj::Promise> readMessageHeaders() { + kj::Promise, HttpHeaders::ProtocolError>> readMessageHeaders() { ++pendingMessageCount; auto paf = kj::newPromiseAndFulfiller(); @@ -1541,28 +1543,38 @@ public: co_return co_await readHeader(HeaderType::MESSAGE, 0, 0); } - kj::Promise readChunkHeader() { + kj::Promise> readChunkHeader() { KJ_REQUIRE(onMessageDone != kj::none); // We use the portion of the header after the end of message headers. - auto text = co_await readHeader(HeaderType::CHUNK, messageHeaderEnd, messageHeaderEnd); - KJ_REQUIRE(text.size() > 0) { break; } - - uint64_t value = 0; - for (char c: text) { - if ('0' <= c && c <= '9') { - value = value * 16 + (c - '0'); - } else if ('a' <= c && c <= 'f') { - value = value * 16 + (c - 'a' + 10); - } else if ('A' <= c && c <= 'F') { - value = value * 16 + (c - 'A' + 10); - } else { - KJ_FAIL_REQUIRE("invalid HTTP chunk size", text, text.asBytes()) { break; } + auto textOrError = co_await readHeader(HeaderType::CHUNK, messageHeaderEnd, messageHeaderEnd); + + KJ_SWITCH_ONEOF(textOrError) { + KJ_CASE_ONEOF(protocolError, HttpHeaders::ProtocolError) { + co_return protocolError; + } + KJ_CASE_ONEOF(text, kj::ArrayPtr) { + KJ_REQUIRE(text.size() > 0) { break; } + + uint64_t value = 0; + for (char c: text) { + if ('0' <= c && c <= '9') { + value = value * 16 + (c - '0'); + } else if ('a' <= c && c <= 'f') { + value = value * 16 + (c - 'a' + 10); + } else if ('A' <= c && c <= 'F') { + value = value * 16 + (c - 'A' + 10); + } else { + KJ_FAIL_REQUIRE("invalid HTTP chunk size", text, text.asBytes()) { break; } + co_return value; + } + } + co_return value; } } - co_return value; + KJ_UNREACHABLE; } inline kj::Promise readRequestHeaders() { @@ -1571,18 +1583,36 @@ public: co_return HttpHeaders::RequestConnectOrProtocolError(resuming); } - auto text = co_await readMessageHeaders(); - headers.clear(); - co_return headers.tryParseRequestOrConnect(text); + auto textOrError = co_await readMessageHeaders(); + KJ_SWITCH_ONEOF(textOrError) { + KJ_CASE_ONEOF(protocolError, HttpHeaders::ProtocolError) { + co_return protocolError; + } + KJ_CASE_ONEOF(text, kj::ArrayPtr) { + headers.clear(); + co_return headers.tryParseRequestOrConnect(text); + } + } + + KJ_UNREACHABLE; } inline kj::Promise readResponseHeaders() { // Note: readResponseHeaders() could be called multiple times concurrently when pipelining // requests. readMessageHeaders() will serialize these, but it's important not to mess with // state (like calling headers.clear()) before said serialization has taken place. - auto text = co_await readMessageHeaders(); - headers.clear(); - co_return headers.tryParseResponse(text); + auto headersOrError = co_await readMessageHeaders(); + KJ_SWITCH_ONEOF(headersOrError) { + KJ_CASE_ONEOF(protocolError, HttpHeaders::ProtocolError) { + co_return protocolError; + } + KJ_CASE_ONEOF(text, kj::ArrayPtr) { + headers.clear(); + co_return headers.tryParseResponse(text); + } + } + + KJ_UNREACHABLE; } inline const HttpHeaders& getHeaders() const { return headers; } @@ -1637,6 +1667,11 @@ public: return { headerBuffer.releaseAsBytes(), leftover.asBytes() }; } + kj::Promise discard(AsyncOutputStream &output, size_t maxBytes) { + // Used to read and discard the input during error handling. + return inner.pumpTo(output, maxBytes).ignoreResult(); + } + private: AsyncInputStream& inner; kj::Array headerBuffer; @@ -1685,7 +1720,7 @@ private: CHUNK }; - kj::Promise> readHeader( + kj::Promise, HttpHeaders::ProtocolError>> readHeader( HeaderType type, size_t bufferStart, size_t bufferEnd) { // Reads the HTTP message header or a chunk header (as in transfer-encoding chunked) and // returns the buffer slice containing it. @@ -1731,7 +1766,12 @@ private: // Can't grow because we'd invalidate the HTTP headers. kj::throwFatalException(KJ_EXCEPTION(FAILED, "invalid HTTP chunk size")); } - KJ_REQUIRE(headerBuffer.size() < MAX_BUFFER, "request headers too large"); + if (headerBuffer.size() >= MAX_BUFFER) { + co_return HttpHeaders::ProtocolError { + .statusCode = 431, + .statusMessage = "Request Header Fields Too Large", + .description = "header too large." }; + } auto newBuffer = kj::heapArray(headerBuffer.size() * 2); memcpy(newBuffer.begin(), headerBuffer.begin(), headerBuffer.size()); headerBuffer = kj::mv(newBuffer); @@ -2013,7 +2053,9 @@ public: co_return alreadyRead; } else if (chunkSize == 0) { // Read next chunk header. - auto nextChunkSize = co_await getInner().readChunkHeader(); + auto nextChunkSizeOrError = co_await getInner().readChunkHeader(); + KJ_REQUIRE(nextChunkSizeOrError.is(), "bad header"); + auto nextChunkSize = nextChunkSizeOrError.get(); if (nextChunkSize == 0) { doneReading(); } @@ -2593,7 +2635,7 @@ public: : stream(kj::mv(stream)), maskKeyGenerator(maskKeyGenerator), compressionConfig(kj::mv(compressionConfigParam)), errorHandler(errorHandler.orDefault(*this)), - sendingPong(kj::mv(waitBeforeSend)), + sendingControlMessage(kj::mv(waitBeforeSend)), recvBuffer(kj::mv(buffer)), recvData(leftover) { #if KJ_HAS_ZLIB KJ_IF_SOME(config, compressionConfig) { @@ -2615,18 +2657,7 @@ public: } kj::Promise close(uint16_t code, kj::StringPtr reason) override { - kj::Array payload; - if (code == 1005) { - KJ_REQUIRE(reason.size() == 0, "WebSocket close code 1005 cannot have a reason"); - - // code 1005 -- leave payload empty - } else { - payload = heapArray(reason.size() + 2); - payload[0] = code >> 8; - payload[1] = code; - memcpy(payload.begin() + 2, reason.begin(), reason.size()); - } - + kj::Array payload = serializeClose(code, reason); auto promise = sendImpl(OPCODE_CLOSE, payload); return promise.attach(kj::mv(payload)); } @@ -2634,14 +2665,14 @@ public: kj::Promise disconnect() override { KJ_REQUIRE(!currentlySending, "another message send is already in progress"); - KJ_IF_SOME(p, sendingPong) { - // We recently sent a pong, make sure it's finished before proceeding. + KJ_IF_SOME(p, sendingControlMessage) { + // We recently sent a control message; make sure it's finished before proceeding. currentlySending = true; auto promise = p.then([this]() { currentlySending = false; return disconnect(); }); - sendingPong = kj::none; + sendingControlMessage = kj::none; return promise; } @@ -2652,8 +2683,8 @@ public: } void abort() override { - queuedPong = kj::none; - sendingPong = kj::none; + queuedControlMessage = kj::none; + sendingControlMessage = kj::none; disconnected = true; stream->abortRead(); stream->shutdownWrite(); @@ -2664,6 +2695,10 @@ public: } kj::Promise receive(size_t maxSize) override { + KJ_IF_SOME(ex, receiveException) { + return kj::cp(ex); + } + size_t headerSize = Header::headerSize(recvData.begin(), recvData.size()); if (headerSize > recvData.size()) { @@ -2695,35 +2730,28 @@ public: auto& recvHeader = *reinterpret_cast(recvData.begin()); if (recvHeader.hasRsv2or3()) { - return errorHandler.handleWebSocketProtocolError({ - 1002, "Received frame had RSV bits 2 or 3 set", - }); + return sendCloseDueToError(1002, "Received frame had RSV bits 2 or 3 set"); } recvData = recvData.slice(headerSize, recvData.size()); size_t payloadLen = recvHeader.getPayloadLen(); if (payloadLen > maxSize) { - return errorHandler.handleWebSocketProtocolError({ - 1009, kj::str("Message is too large: ", payloadLen, " > ", maxSize) - }); + auto description = kj::str("Message is too large: ", payloadLen, " > ", maxSize); + return sendCloseDueToError(1009, description.asPtr()).attach(kj::mv(description)); } auto opcode = recvHeader.getOpcode(); bool isData = opcode < OPCODE_FIRST_CONTROL; if (opcode == OPCODE_CONTINUATION) { if (fragments.empty()) { - return errorHandler.handleWebSocketProtocolError({ - 1002, "Unexpected continuation frame" - }); + return sendCloseDueToError(1002, "Unexpected continuation frame"); } opcode = fragmentOpcode; } else if (isData) { if (!fragments.empty()) { - return errorHandler.handleWebSocketProtocolError({ - 1002, "Missing continuation frame" - }); + return sendCloseDueToError(1002, "Missing continuation frame"); } } @@ -2771,9 +2799,7 @@ public: } else { // Fragmented message, and this isn't the final fragment. if (!isData) { - return errorHandler.handleWebSocketProtocolError({ - 1002, "Received fragmented control frame" - }); + return sendCloseDueToError(1002, "Received fragmented control frame"); } message = kj::heapArray(payloadLen); @@ -2804,11 +2830,10 @@ public: // Provide a reasonable error if a compressed frame is received without compression enabled. if (isCompressed && compressionConfig == kj::none) { - return errorHandler.handleWebSocketProtocolError({ - 1002, kj::str( - "Received a WebSocket frame whose compression bit was set, but the compression " - "extension was not negotiated for this connection.") - }); + return sendCloseDueToError( + 1002, + "Received a WebSocket frame whose compression bit was set, but the compression " + "extension was not negotiated for this connection."); } switch (opcode) { @@ -2881,9 +2906,10 @@ public: // Unsolicited pong. Ignore. return receive(maxSize); default: - return errorHandler.handleWebSocketProtocolError({ - 1002, kj::str("Unknown opcode ", opcode) - }); + { + auto description = kj::str("Unknown opcode ", opcode); + return sendCloseDueToError(1002, description.asPtr()).attach(kj::mv(description)); + } } }; @@ -3383,6 +3409,7 @@ private: static constexpr byte OPCODE_PONG = 10; static constexpr byte OPCODE_FIRST_CONTROL = 8; + static constexpr byte OPCODE_MAX = 15; // --------------------------------------------------------------------------- @@ -3399,20 +3426,41 @@ private: bool disconnected = false; bool currentlySending = false; Header sendHeader; - kj::ArrayPtr sendParts[2]; - kj::Maybe> queuedPong; - // queuedPong holds the body of the next pong to write, cleared when the pong is written. If a - // more recent ping arrives before the pong is actually written, we can update this value to - // instead respond to the more recent ping. + struct ControlMessage { + byte opcode; + kj::Array payload; + kj::Maybe>> fulfiller; + + ControlMessage( + byte opcodeParam, + kj::Array payloadParam, + kj::Maybe>> fulfillerParam) + : opcode(opcodeParam), payload(kj::mv(payloadParam)), fulfiller(kj::mv(fulfillerParam)) { + KJ_REQUIRE(opcode <= OPCODE_MAX); + } + }; + + kj::Maybe receiveException; + // If set, all future calls to receive() will throw this exception. - kj::Maybe> sendingPong; - // If a Pong is being sent asynchronously in response to a Ping, this is a promise for the - // completion of that send. + kj::Maybe queuedControlMessage; + // queuedControlMessage holds the body of the next control message to write; it is cleared when the message is + // written. + // + // It may be overwritten; for example, if a more recent ping arrives before the pong is actually written, we can + // update this value to instead respond to the more recent ping. If a bad frame shows up, we can overwrite any + // queued pong with a Close message. + // + // Currently, this holds either a Close or a Pong. + + kj::Maybe> sendingControlMessage; + // If a control message is being sent asynchronously (e.g., a Pong in response to a Ping), this is a + // promise for the completion of that send. // // Additionally, this member is used if we need to block our first send on WebSocket startup, // e.g. because we need to wait for HTTP handshake writes to flush before we can start sending - // WebSocket data. `sendingPong` was overloaded for this use case because the logic is the same. + // WebSocket data. `sendingControlMessage` was overloaded for this use case because the logic is the same. // Perhaps it should be renamed to `blockSend` or `writeQueue`. uint fragmentOpcode = 0; @@ -3435,14 +3483,18 @@ private: currentlySending = true; - KJ_IF_SOME(p, sendingPong) { - // We recently sent a pong, make sure it's finished before proceeding. - auto promise = p.then([this, opcode, message]() { - currentlySending = false; - return sendImpl(opcode, message); - }); - sendingPong = kj::none; - return promise; + for (;;) { + KJ_IF_SOME(p, sendingControlMessage) { + // Re-check in case of disconnect on a previous loop iteration. + KJ_REQUIRE(!disconnected, "WebSocket can't send after disconnect()"); + + // We recently sent a control message; make sure it's finished before proceeding. + auto localPromise = kj::mv(p); + sendingControlMessage = kj::none; + co_await localPromise; + } else { + break; + } } // We don't stop the application from sending further messages after close() -- this is the @@ -3488,81 +3540,133 @@ private: message = ownMessage; } + kj::ArrayPtr sendParts[2]; sendParts[0] = sendHeader.compose(true, useCompression, opcode, message.size(), mask); sendParts[1] = message; KJ_ASSERT(!sendHeader.hasRsv2or3(), "RSV bits 2 and 3 must be 0, as we do not currently " - "support an extension that would set these bits"); + "support an extension that would set these bits"); - auto promise = stream->write(sendParts).attach(kj::mv(compressedMessage)); - if (!mask.isZero()) { - promise = promise.attach(kj::mv(ownMessage)); + co_await stream->write(sendParts); + currentlySending = false; + + // Send queued control message if needed. + if (queuedControlMessage != kj::none) { + setUpSendingControlMessage(); + }; + sentBytes += sendParts[0].size() + sendParts[1].size();; + } + + void queueClose(uint16_t code, kj::StringPtr reason, kj::Own> fulfiller) { + bool alreadyWaiting = (queuedControlMessage != kj::none); + + // Overwrite any previously-queued message. If there is one, it's just a Pong, and this Close supersedes it. + auto payload = serializeClose(code, reason); + queuedControlMessage = ControlMessage(OPCODE_CLOSE, kj::mv(payload), kj::mv(fulfiller)); + + if (!alreadyWaiting) { + setUpSendingControlMessage(); } - return promise.then([this, size = sendParts[0].size() + sendParts[1].size()]() { - currentlySending = false; + } - // Send queued pong if needed. - if (queuedPong != kj::none) { - setUpSendingPong(); - } - sentBytes += size; - }); + kj::Array serializeClose(uint16_t code, kj::StringPtr reason) { + kj::Array payload; + if (code == 1005) { + KJ_REQUIRE(reason.size() == 0, "WebSocket close code 1005 cannot have a reason"); + + // code 1005 -- leave payload empty + } else { + payload = heapArray(reason.size() + 2); + payload[0] = code >> 8; + payload[1] = code; + memcpy(payload.begin() + 2, reason.begin(), reason.size()); + } + return kj::mv(payload); + } + + kj::Promise sendCloseDueToError(uint16_t code, kj::StringPtr reason){ + auto paf = newPromiseAndFulfiller(); + queueClose(code, reason, kj::mv(paf.fulfiller)); + + return paf.promise.then([this, code, reason]() -> kj::Promise { + return errorHandler.handleWebSocketProtocolError({ + code, reason + }); + }); } void queuePong(kj::Array payload) { - bool alreadyWaitingForPongWrite = (queuedPong != kj::none); + bool alreadyWaitingForPongWrite = false; + + KJ_IF_SOME(controlMessage, queuedControlMessage) { + if (controlMessage.opcode == OPCODE_CLOSE) { + // We're currently sending a Close message, which we only do (at least via queuedControlMessage) when we're + // closing the connection due to error. There's no point queueing a Pong that'll never be sent. + return; + } else { + KJ_ASSERT(controlMessage.opcode == OPCODE_PONG); + alreadyWaitingForPongWrite = true; + } + } // Note: According to spec, if the server receives a second ping before responding to the // previous one, it can opt to respond only to the last ping. So we don't have to check if - // queuedPong is already non-null. - queuedPong = kj::mv(payload); + // queuedControlMessage is already non-null. + queuedControlMessage = ControlMessage(OPCODE_PONG, kj::mv(payload), kj::none); if (currentlySending) { // There is a message-send in progress, so we cannot write to the stream now. We will set - // up the pong write at the end of the message-send. + // up the control message write at the end of the message-send. return; } if (alreadyWaitingForPongWrite) { // We were already waiting for a pong to be written; don't need to queue another write. return; } - setUpSendingPong(); + setUpSendingControlMessage(); } - void setUpSendingPong() { - KJ_IF_SOME(promise, sendingPong) { - sendingPong = promise.then([this]() mutable { - return writeQueuedPong(); + void setUpSendingControlMessage() { + KJ_IF_SOME(promise, sendingControlMessage) { + sendingControlMessage = promise.then([this]() mutable { + return writeQueuedControlMessage(); }); } else { - sendingPong = writeQueuedPong(); + sendingControlMessage = writeQueuedControlMessage(); } } - kj::Promise writeQueuedPong() { - KJ_IF_SOME(q, queuedPong) { - kj::Array payload = kj::mv(q); - queuedPong = kj::none; + kj::Promise writeQueuedControlMessage() { + KJ_IF_SOME(q, queuedControlMessage) { + byte opcode = q.opcode; + kj::Array payload = kj::mv(q.payload); + auto maybeFulfiller = kj::mv(q.fulfiller); + queuedControlMessage = kj::none; if (hasSentClose || disconnected) { - return kj::READY_NOW; + KJ_IF_SOME(fulfiller, maybeFulfiller) { + fulfiller->fulfill(); + } + co_return; } - sendParts[0] = sendHeader.compose(true, false, OPCODE_PONG, + kj::ArrayPtr sendParts[2]; + sendParts[0] = sendHeader.compose(true, false, opcode, payload.size(), Mask(maskKeyGenerator)); sendParts[1] = payload; - return stream->write(sendParts).attach(kj::mv(payload)); - } else { - return kj::READY_NOW; + co_await stream->write(sendParts); + KJ_IF_SOME(fulfiller, maybeFulfiller) { + fulfiller->fulfill(); + } } } kj::Promise optimizedPumpTo(WebSocketImpl& other) { - KJ_IF_SOME(p, other.sendingPong) { - // We recently sent a pong, make sure it's finished before proceeding. + KJ_IF_SOME(p, other.sendingControlMessage) { + // We recently sent a control message; make sure it's finished before proceeding. auto promise = p.then([this, &other]() { return optimizedPumpTo(other); }); - other.sendingPong = kj::none; + other.sendingControlMessage = kj::none; return promise; } @@ -3797,6 +3901,19 @@ public: return transferredBytes; } + kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { + KJ_UNREACHABLE; + }; + + kj::Maybe destinationPumpingTo; + kj::Maybe destinationPumpingFrom; + // Tracks the outstanding pumpTo() and tryPumpFrom() calls currently running on the + // WebSocketPipeEnd, which is the destination side of this WebSocketPipeImpl. This is used by + // the source end to implement getPreferredExtensions(). + // + // getPreferredExtensions() does not fit into the model used by all the other methods because it + // is not directional (not a read nor a write call). + private: kj::Maybe state; // Object-oriented state! If any method call is blocked waiting on activity from the other end, @@ -3915,6 +4032,10 @@ private: KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); } + kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { + KJ_UNREACHABLE; + }; + private: kj::PromiseFulfiller& fulfiller; WebSocketPipeImpl& pipe; @@ -4000,6 +4121,10 @@ private: KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); } + kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { + KJ_UNREACHABLE; + }; + private: kj::PromiseFulfiller& fulfiller; WebSocketPipeImpl& pipe; @@ -4084,6 +4209,10 @@ private: KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); } + kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { + KJ_UNREACHABLE; + }; + private: kj::PromiseFulfiller& fulfiller; WebSocketPipeImpl& pipe; @@ -4180,6 +4309,10 @@ private: KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); } + kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { + KJ_UNREACHABLE; + }; + private: kj::PromiseFulfiller& fulfiller; WebSocketPipeImpl& pipe; @@ -4226,6 +4359,9 @@ private: KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); } + kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { + KJ_UNREACHABLE; + }; }; class Aborted final: public WebSocket { @@ -4267,12 +4403,15 @@ private: uint64_t receivedByteCount() override { KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); } + kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { + KJ_UNREACHABLE; + }; }; }; class WebSocketPipeEnd final: public WebSocket { public: - WebSocketPipeEnd(kj::Own in, kj::Own out) + WebSocketPipeEnd(kj::Rc&& in, kj::Rc&& out) : in(kj::mv(in)), out(kj::mv(out)) {} ~WebSocketPipeEnd() noexcept(false) { in->abort(); @@ -4299,31 +4438,62 @@ public: return out->whenAborted(); } kj::Maybe> tryPumpFrom(WebSocket& other) override { - return out->tryPumpFrom(other); + KJ_REQUIRE(in->destinationPumpingFrom == kj::none, "can only call tryPumpFrom() once at a time"); + // By convention, we store the WebSocket reference on `in`. + in->destinationPumpingFrom = other; + auto deferredUnregister = kj::defer([this]() { in->destinationPumpingFrom = kj::none; }); + KJ_IF_SOME(p, out->tryPumpFrom(other)) { + return p.attach(kj::mv(deferredUnregister)); + } else { + return kj::none; + } } kj::Promise receive(size_t maxSize) override { return in->receive(maxSize); } kj::Promise pumpTo(WebSocket& other) override { - return in->pumpTo(other); + KJ_REQUIRE(in->destinationPumpingTo == kj::none, "can only call pumpTo() once at a time"); + // By convention, we store the WebSocket reference on `in`. + in->destinationPumpingTo = other; + auto deferredUnregister = kj::defer([this]() { in->destinationPumpingTo = kj::none; }); + return in->pumpTo(other).attach(kj::mv(deferredUnregister)); } uint64_t sentByteCount() override { return out->sentByteCount(); } uint64_t receivedByteCount() override { return in->sentByteCount(); } + kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { + // We want to forward this call to whatever WebSocket the other end of the pipe is pumping + // to/from, if any. We'll check them in an arbitrary order and take the first one we see. + // But really, the hope is that both destinationPumpingTo and destinationPumpingFrom are in fact + // the same object. If they aren't the same, then it's not really clear whose extensions we + // should prefer; the choice here is arbitrary. + KJ_IF_SOME(ws, out->destinationPumpingTo) { + KJ_IF_SOME(result, ws.getPreferredExtensions(ctx)) { + return kj::mv(result); + } + } + KJ_IF_SOME(ws, out->destinationPumpingFrom) { + KJ_IF_SOME(result, ws.getPreferredExtensions(ctx)) { + return kj::mv(result); + } + } + return kj::none; + }; + private: - kj::Own in; - kj::Own out; + kj::Rc in; + kj::Rc out; }; } // namespace WebSocketPipe newWebSocketPipe() { - auto pipe1 = kj::refcounted(); - auto pipe2 = kj::refcounted(); + auto pipe1 = kj::rc(); + auto pipe2 = kj::rc(); - auto end1 = kj::heap(kj::addRef(*pipe1), kj::addRef(*pipe2)); + auto end1 = kj::heap(pipe1.addRef(), pipe2.addRef()); auto end2 = kj::heap(kj::mv(pipe2), kj::mv(pipe1)); return { { kj::mv(end1), kj::mv(end2) } }; @@ -4379,6 +4549,7 @@ public: if (bytesToCopy > 0) { memcpy(destination, leftover.begin(), bytesToCopy); + leftover = nullptr; leftoverBackingBuffer = nullptr; minBytes -= bytesToCopy; maxBytes -= bytesToCopy; @@ -5106,14 +5277,21 @@ kj::OneOf tryParseExtensionAgreement( "an invalid value.")); return kj::mv(e); } + } // namespace _ (private) + namespace { -class NullInputStream final: public kj::AsyncInputStream { + +class HeadResponseStream final: public kj::AsyncInputStream { + // An input stream which returns no data, but `tryGetLength()` returns a specified value. Used + // for HEAD responses, where the size is known but the body content is not sent. public: - NullInputStream(kj::Maybe expectedLength = size_t(0)) + HeadResponseStream(kj::Maybe expectedLength) : expectedLength(expectedLength) {} kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + // TODO(someday): Maybe this should throw? We should not be trying to read the body of a + // HEAD response. return constPromise(); } @@ -5129,48 +5307,6 @@ private: kj::Maybe expectedLength; }; -class NullOutputStream final: public kj::AsyncOutputStream { -public: - Promise write(const void* buffer, size_t size) override { - return kj::READY_NOW; - } - Promise write(ArrayPtr> pieces) override { - return kj::READY_NOW; - } - Promise whenWriteDisconnected() override { - return kj::NEVER_DONE; - } - - // We can't really optimize tryPumpFrom() unless AsyncInputStream grows a skip() method. -}; - -class NullIoStream final: public kj::AsyncIoStream { -public: - void shutdownWrite() override {} - - Promise write(const void* buffer, size_t size) override { - return kj::READY_NOW; - } - Promise write(ArrayPtr> pieces) override { - return kj::READY_NOW; - } - Promise whenWriteDisconnected() override { - return kj::NEVER_DONE; - } - - kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { - return constPromise(); - } - - kj::Maybe tryGetLength() override { - return kj::Maybe((uint64_t)0); - } - - kj::Promise pumpTo(AsyncOutputStream& output, uint64_t amount) override { - return constPromise(); - } -}; - class HttpClientImpl final: public HttpClient, private HttpClientErrorHandler { public: @@ -5407,7 +5543,7 @@ public: response.statusText, &httpInput.getHeaders(), upgradeToWebSocket(kj::mv(ownStream), httpInput, httpOutput, settings.entropySource, - kj::mv(compressionParameters)), + kj::mv(compressionParameters), settings.webSocketErrorHandler), }; } else { upgraded = false; @@ -6316,14 +6452,11 @@ public: maxConcurrentRequests(maxConcurrentRequests), countChangedCallback(kj::mv(countChangedCallback)) {} - ~ConcurrencyLimitingHttpClient() noexcept(false) { - if (concurrentRequests > 0) { - static bool logOnce KJ_UNUSED = ([&] { - KJ_LOG(ERROR, "ConcurrencyLimitingHttpClient getting destroyed when concurrent requests " - "are still active", concurrentRequests); - return true; - })(); - } + ~ConcurrencyLimitingHttpClient() noexcept { + // Crash in this case because otherwise we'll have UAF later on. + KJ_ASSERT(concurrentRequests == 0, + "ConcurrencyLimitingHttpClient getting destroyed when concurrent requests " + "are still active"); } Request request(HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, @@ -6578,7 +6711,7 @@ public: auto requestPaf = kj::newPromiseAndFulfiller>(); responder->setPromise(kj::mv(requestPaf.promise)); - auto in = kj::heap(); + auto in = kj::heap(); auto promise = service.request(HttpMethod::GET, urlCopy, *headersCopy, *in, *responder) .attach(kj::mv(in), kj::mv(urlCopy), kj::mv(headersCopy)); requestPaf.fulfiller->fulfill(kj::mv(promise)); @@ -6742,11 +6875,11 @@ private: headersCopy=kj::mv(headersCopy),expectedBodySize]() mutable { fulfiller->fulfill({ statusCode, statusTextCopy, headersCopy.get(), - kj::heap(expectedBodySize) + kj::heap(expectedBodySize) .attach(kj::mv(statusTextCopy), kj::mv(headersCopy)) }); }).eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); }); - return kj::heap(); + return kj::heap(); } else { auto pipe = newOneWayPipe(expectedBodySize); @@ -6826,6 +6959,10 @@ private: uint64_t sentByteCount() override { return inner->sentByteCount(); } uint64_t receivedByteCount() override { return inner->receivedByteCount(); } + kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { + return inner->getPreferredExtensions(ctx); + }; + private: kj::Own inner; kj::Maybe> completionTask; @@ -6893,11 +7030,11 @@ private: headersCopy=kj::mv(headersCopy),expectedBodySize]() mutable { fulfiller->fulfill({ statusCode, statusTextCopy, headersCopy.get(), - kj::Own(kj::heap(expectedBodySize) + kj::Own(kj::heap(expectedBodySize) .attach(kj::mv(statusTextCopy), kj::mv(headersCopy))) }); }).eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); }); - return kj::heap(); + return kj::heap(); } else { auto pipe = newOneWayPipe(expectedBodySize); @@ -7469,6 +7606,21 @@ private: KJ_CASE_ONEOF(protocolError, HttpHeaders::ProtocolError) { // Bad request. + auto needClientGrace = protocolError.statusCode == 431; + if (needClientGrace) { + // We're going to reply with an error and close the connection. + // The client might not be able to read the error back. Read some data and wait + // a bit to give client a chance to finish writing. + + auto dummy = kj::heap(); + auto lengthGrace = kj::evalNow([&]() { + return httpInput.discard(*dummy, server.settings.canceledUploadGraceBytes); + }).catch_([](kj::Exception&& e) -> void { }) + .attach(kj::mv(dummy)); + auto timeGrace = server.timer.afterDelay(server.settings.canceledUploadGracePeriod); + co_await lengthGrace.exclusiveJoin(kj::mv(timeGrace)); + } + // sendError() uses Response::send(), which requires that we have a currentMethod, but we // never read one. GET seems like the correct choice here. currentMethod = HttpMethod::GET; @@ -7878,7 +8030,12 @@ private: } kj::Own sendWebSocketError(StringPtr errorMessage) { - kj::Exception exception = KJ_EXCEPTION(FAILED, + // The client committed a protocol error during a WebSocket handshake. We will send an error + // response back to them, and throw an exception from `acceptWebSocket()` to our app. We'll + // label this as a DISCONNECTED exception, as if the client had simply closed the connection + // rather than commiting a protocol error. This is intended to let the server know that it + // wasn't an error on the server's part. (This is a big of a hack...) + kj::Exception exception = KJ_EXCEPTION(DISCONNECTED, "received bad WebSocket handshake", errorMessage); webSocketError = sendError( HttpHeaders::ProtocolError { 400, "Bad Request", errorMessage, nullptr }); @@ -7914,6 +8071,10 @@ private: uint64_t sentByteCount() override { KJ_FAIL_ASSERT("received bad WebSocket handshake"); } uint64_t receivedByteCount() override { KJ_FAIL_ASSERT("received bad WebSocket handshake"); } + kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { + KJ_FAIL_ASSERT(kj::cp(exception)); + }; + private: kj::Exception exception; }; diff --git a/c++/src/kj/compat/http.h b/c++/src/kj/compat/http.h index d77552d65c..abbd969a28 100644 --- a/c++/src/kj/compat/http.h +++ b/c++/src/kj/compat/http.h @@ -689,7 +689,7 @@ class WebSocket { REQUEST, RESPONSE }; - virtual kj::Maybe getPreferredExtensions(ExtensionsContext ctx) { return kj::none; } + virtual kj::Maybe getPreferredExtensions(ExtensionsContext ctx) = 0; // If pumpTo() / tryPumpFrom() is able to be optimized only if the other WebSocket is using // certain extensions (e.g. compression settings), then this method returns what those extensions // are. For example, matching extensions between standard WebSockets allows pumping to be @@ -1021,6 +1021,19 @@ class HttpClientErrorHandler { // little reason to override this. }; +class WebSocketErrorHandler { +public: + virtual kj::Exception handleWebSocketProtocolError(WebSocket::ProtocolError protocolError); + // Handles low-level protocol errors in received WebSocket data. + // + // This is called when the WebSocket peer sends us bad data *after* a successful WebSocket + // upgrade, e.g. a continuation frame without a preceding start frame, a frame with an unknown + // opcode, or similar. + // + // You would override this method in order to customize the exception. You cannot prevent the + // exception from being thrown. +}; + struct HttpClientSettings { kj::Duration idleTimeout = 5 * kj::SECONDS; // For clients which automatically create new connections, any connection idle for at least this @@ -1046,23 +1059,13 @@ struct HttpClientSettings { }; WebSocketCompressionMode webSocketCompressionMode = NO_COMPRESSION; + kj::Maybe webSocketErrorHandler = kj::none; + // Customize exceptions thrown on WebSocket protocol errors. + kj::Maybe tlsContext; // A reference to a TLS context that will be used when tlsStarter is invoked. }; -class WebSocketErrorHandler { -public: - virtual kj::Exception handleWebSocketProtocolError(WebSocket::ProtocolError protocolError); - // Handles low-level protocol errors in received WebSocket data. - // - // This is called when the WebSocket peer sends us bad data *after* a successful WebSocket - // upgrade, e.g. a continuation frame without a preceding start frame, a frame with an unknown - // opcode, or similar. - // - // You would override this method in order to customize the exception. You cannot prevent the - // exception from being thrown. -}; - kj::Own newHttpClient(kj::Timer& timer, const HttpHeaderTable& responseHeaderTable, kj::Network& network, kj::Maybe tlsNetwork, HttpClientSettings settings = HttpClientSettings()); diff --git a/c++/src/kj/exception.c++ b/c++/src/kj/exception.c++ index be582f3d56..2443c284df 100644 --- a/c++/src/kj/exception.c++ +++ b/c++/src/kj/exception.c++ @@ -37,7 +37,6 @@ #include "exception.h" #include "string.h" #include "debug.h" -#include "threadlocal.h" #include "miniposix.h" #include "function.h" #include "main.h" @@ -256,7 +255,7 @@ ArrayPtr getStackTrace(ArrayPtr space, uint ignoreCount) { #if (__GNUC__ && !_WIN32) || __clang__ // Allow dependents to override the implementation of stack symbolication by making it a weak -// symbol. We prefer weak symbols over some sort of callback registration mechanism becasue this +// symbol. We prefer weak symbols over some sort of callback registration mechanism because this // allows an alternate symbolication library to be easily linked into tests without changing the // code of the test. __attribute__((weak)) @@ -919,7 +918,7 @@ void Exception::addTraceHere() { namespace { -KJ_THREADLOCAL_PTR(ExceptionImpl) currentException = nullptr; +thread_local ExceptionImpl* currentException = nullptr; void validateExceptionPointer(const ExceptionImpl* e) noexcept { // Occasionally in production we are seeing `currentException` have the value 1. Try to figure @@ -1013,7 +1012,7 @@ kj::Exception getDestructionReason(void* traceSeparator, kj::Exception::Type def namespace { -KJ_THREADLOCAL_PTR(ExceptionCallback) threadLocalCallback = nullptr; +thread_local ExceptionCallback* threadLocalCallback = nullptr; } // namespace diff --git a/c++/src/kj/filesystem-test.c++ b/c++/src/kj/filesystem-test.c++ index f3eae2fe79..b1b3166de2 100644 --- a/c++/src/kj/filesystem-test.c++ +++ b/c++/src/kj/filesystem-test.c++ @@ -23,6 +23,10 @@ #include "test.h" #include +#if __linux__ +#include +#endif // __linux__ + namespace kj { namespace { @@ -454,6 +458,7 @@ KJ_TEST("InMemoryDirectory") { { auto file = dir->openFile(Path("foo"), WriteMode::CREATE); + KJ_EXPECT(file->getFd() == kj::none); clock.expectChanged(*dir); file->writeAll("foobar"); clock.expectUnchanged(*dir); @@ -747,6 +752,23 @@ KJ_TEST("InMemoryDirectory move") { KJ_EXPECT(dst->openFile(Path({"link", "baz", "qux"}))->readAllText() == "bazqux"); } +KJ_TEST("InMemoryDirectory transfer from self") { + TestClock clock; + + auto dir = newInMemoryDirectory(clock); + + auto file = dir->openFile(Path({"foo"}), WriteMode::CREATE); + + dir->transfer(Path({"bar"}), WriteMode::CREATE, Path({"foo"}), TransferMode::MOVE); + + auto list = dir->listNames(); + KJ_EXPECT(list.size() == 1); + KJ_EXPECT(list[0] == "bar"); + + auto file2 = dir->openFile(Path({"bar"})); + KJ_EXPECT(file.get() == file2.get()); +} + KJ_TEST("InMemoryDirectory createTemporary") { TestClock clock; @@ -755,7 +777,39 @@ KJ_TEST("InMemoryDirectory createTemporary") { file->writeAll("foobar"); KJ_EXPECT(file->readAllText() == "foobar"); KJ_EXPECT(dir->listNames() == nullptr); + KJ_EXPECT(file->getFd() == kj::none); +} + +#if __linux__ + +KJ_TEST("InMemoryDirectory backed my memfd") { + // Test memfd-backed in-memory directory. We're not going to test all functionality here, since + // we assume filesystem-disk-test covers fd-backed files in depth. + + TestClock clock; + auto dir = newInMemoryDirectory(clock, memfdInMemoryFileFactory()); + auto file = dir->openFile(Path({"foo", "bar"}), WriteMode::CREATE | WriteMode::CREATE_PARENT); + + // Write directly to the FD, verify it is reflected in the file object. + int fd = KJ_ASSERT_NONNULL(file->getFd()); + ssize_t n; + KJ_SYSCALL(n = write(fd, "foo", 3)); + KJ_EXPECT(n == 3); + + KJ_EXPECT(file->readAllText() == "foo"_kj); + + // Re-opening the same file produces an alias of the same memfd. + auto file2 = dir->openFile(Path({"foo", "bar"})); + KJ_EXPECT(file2->readAllText() == "foo"_kj); + file->writeAll("bar"_kj); + KJ_EXPECT(file2->readAllText() == "bar"_kj); + KJ_EXPECT(file2->getFd() != kj::none); + KJ_EXPECT(file->stat().hashCode == file2->stat().hashCode); + + KJ_EXPECT(dir->createTemporary()->getFd() != kj::none); } +#endif // __linux__ + } // namespace } // namespace kj diff --git a/c++/src/kj/filesystem.c++ b/c++/src/kj/filesystem.c++ index 8ab1f941bc..367895e16d 100644 --- a/c++/src/kj/filesystem.c++ +++ b/c++/src/kj/filesystem.c++ @@ -28,6 +28,10 @@ #include "mutex.h" #include +#if __linux__ +#include // for memfd_create() +#endif // __linux__ + namespace kj { Path::Path(StringPtr name): Path(heapString(name)) {} @@ -271,7 +275,7 @@ String PathPtr::toWin32StringImpl(bool absolute, bool forApi) const { // False alarm: this is the drive letter. } else { KJ_FAIL_REQUIRE( - "colons are prohibited in win32 paths to avoid triggering alterante data streams", + "colons are prohibited in win32 paths to avoid triggering alternate data streams", result) { // Recover by using a different character which we know Win32 syscalls will reject. result[i] = '|'; @@ -978,7 +982,11 @@ private: class InMemoryDirectory final: public Directory, public AtomicRefcounted { public: - InMemoryDirectory(const Clock& clock): impl(clock) {} + InMemoryDirectory(const Clock& clock, const InMemoryFileFactory& fileFactory) + : impl(clock, fileFactory) {} + InMemoryDirectory(const Clock& clock, const InMemoryFileFactory& fileFactory, + const Directory& copyFrom, bool copyFiles) + : impl(clock, fileFactory, copyFrom, copyFiles) {} Own cloneFsNode() const override { return atomicAddRef(*this); @@ -1154,15 +1162,15 @@ public: if (path.size() == 0) { KJ_FAIL_REQUIRE("can't replace self") { break; } } else if (path.size() == 1) { - // don't need lock just to read the clock ref + // don't need lock just to construct a file return heap>(*this, path[0], - newInMemoryFile(impl.getWithoutLock().clock), mode); + impl.getWithoutLock().newFile(), mode); } else { KJ_IF_SOME(child, tryGetParent(path[0], mode)) { return child->replaceFile(path.slice(1, path.size()), mode); } } - return heap>(newInMemoryFile(impl.getWithoutLock().clock)); + return heap>(impl.getWithoutLock().newFile()); } Maybe> tryOpenSubdir(PathPtr path, WriteMode mode) const override { @@ -1194,15 +1202,15 @@ public: if (path.size() == 0) { KJ_FAIL_REQUIRE("can't replace self") { break; } } else if (path.size() == 1) { - // don't need lock just to read the clock ref + // don't need lock just to construct a directory return heap>(*this, path[0], - newInMemoryDirectory(impl.getWithoutLock().clock), mode); + impl.getWithoutLock().newDirectory(), mode); } else { KJ_IF_SOME(child, tryGetParent(path[0], mode)) { return child->replaceSubdir(path.slice(1, path.size()), mode); } } - return heap>(newInMemoryDirectory(impl.getWithoutLock().clock)); + return heap>(impl.getWithoutLock().newDirectory()); } Maybe> tryAppendFile(PathPtr path, WriteMode mode) const override { @@ -1256,8 +1264,8 @@ public: } Own createTemporary() const override { - // Don't need lock just to read the clock ref. - return newInMemoryFile(impl.getWithoutLock().clock); + // Don't need lock just to construct a file. + return impl.getWithoutLock().newFile(); } bool tryTransfer(PathPtr toPath, WriteMode toMode, @@ -1270,31 +1278,104 @@ public: KJ_FAIL_REQUIRE("can't replace self") { return false; } } } else if (toPath.size() == 1) { - // tryTransferChild() needs to at least know the node type, so do an lstat. - KJ_IF_SOME(meta, fromDirectory.tryLstat(fromPath)) { - auto lock = impl.lockExclusive(); - KJ_IF_SOME(entry, lock->openEntry(toPath[0], toMode)) { - // Make sure if we just cerated a new entry, and we don't successfully transfer to it, we - // remove the entry before returning. - bool needRollback = entry.node == nullptr; - KJ_DEFER(if (needRollback) { lock->entries.erase(toPath[0]); }); - - if (lock->tryTransferChild(entry, meta.type, meta.lastModified, meta.size, - fromDirectory, fromPath, mode)) { - lock->modified(); - needRollback = false; - return true; - } else { - KJ_FAIL_REQUIRE("InMemoryDirectory can't link an inode of this type", fromPath) { - return false; - } - } - } else { + if (!has(toMode, WriteMode::MODIFY)) { + // Replacement is not allowed, so we'll have to check upfront if the target path exists. + // Unfortunately we have to take a lock and then drop it immediately since we can't keep + // the lock held while accessing `fromDirectory`. + if (impl.lockShared()->tryGetEntry(toPath[0]) != kj::none) { return false; } + } + + OneOf newNode; + FsNode::Metadata meta; + KJ_IF_SOME(m, fromDirectory.tryLstat(fromPath)) { + meta = m; } else { return false; } + + switch (meta.type) { + case FsNode::Type::FILE: { + auto file = KJ_ASSERT_NONNULL( + fromDirectory.tryOpenFile(fromPath, WriteMode::MODIFY), + "source node deleted concurrently during transfer", fromPath); + + if (mode == TransferMode::COPY) { + auto copy = impl.getWithoutLock().newFile(); + copy->copy(0, *file, 0, meta.size); + file = kj::mv(copy); + } + + newNode = FileNode { kj::mv(file) }; + break; + } + case FsNode::Type::DIRECTORY: { + auto subdir = KJ_ASSERT_NONNULL( + fromDirectory.tryOpenSubdir(fromPath, WriteMode::MODIFY), + "source node deleted concurrently during transfer", fromPath); + + switch (mode) { + case TransferMode::COPY: + // Copying is straightforward: Make a deep copy of the entire directory tree, + // including file contents. + subdir = impl.getWithoutLock().copyDirectory(*subdir, /* copyFiles = */ true); + break; + + case TransferMode::LINK: + // To "link", we can safely just place `subdir` directly into our own tree. + break; + + case TransferMode::MOVE: + // Moving may be tricky: + // + // If `fromDirectory` is an `InMemoryDirectory`, then we know that removing the + // subdir just unlinks the object without modifying it, so we can safely just link it + // into our own tree. + // + // However, if `fromDirectory` is a disk directory, then removing the subdir will + // likely perform a recursive delete, thus leaving `subdir` pointing to an empty + // directory. If we link that into our tree, it's useless. So, instead, perform a + // deep copy of the directory tree upfront, into an InMemoryDirectory. However, file + // content need not be copied, since unlinked files keep their contents until closed. + if (kj::dynamicDowncastIfAvailable(fromDirectory) == + kj::none) { + subdir = impl.getWithoutLock().copyDirectory(*subdir, /* copyFiles = */ false); + } + break; + } + + newNode = DirectoryNode { kj::mv(subdir) }; + break; + } + case FsNode::Type::SYMLINK: { + auto link = KJ_ASSERT_NONNULL(fromDirectory.tryReadlink(fromPath), + "source node deleted concurrently during transfer", fromPath); + + newNode = SymlinkNode {meta.lastModified, kj::mv(link)}; + break; + } + default: + KJ_FAIL_REQUIRE("InMemoryDirectory can't link an inode of this type", fromPath); + } + + if (mode == TransferMode::MOVE) { + KJ_ASSERT(fromDirectory.tryRemove(fromPath), "couldn't move node", fromPath); + } + + // Take the lock to insert the entry into our map. Remember that it's important we do not + // manipulate `fromDirectory` while the lock is held, since it could be the same directory. + { + auto lock = impl.lockExclusive(); + KJ_IF_SOME(targetEntry, lock->openEntry(toPath[0], toMode)) { + targetEntry.init(kj::mv(newNode));; + } else { + return false; + } + lock->modified(); + } + + return true; } else { // TODO(someday): Ideally we wouldn't create parent directories if fromPath doesn't exist. // This requires a different approach to the code here, though. @@ -1444,6 +1525,7 @@ private: struct Impl { const Clock& clock; + const InMemoryFileFactory& fileFactory; std::map entries; // Note: If this changes to a non-sorted map, listNames() and listEntries() must be updated to @@ -1451,7 +1533,85 @@ private: Date lastModified; - Impl(const Clock& clock): clock(clock), lastModified(clock.now()) {} + Impl(const Clock& clock, const InMemoryFileFactory& fileFactory) + : clock(clock), fileFactory(fileFactory), lastModified(clock.now()) {} + + Impl(const Clock& clock, const InMemoryFileFactory& fileFactory, + const Directory& copyFrom, bool copyFiles) + : clock(clock), fileFactory(fileFactory), lastModified(clock.now()) { + // Implements copyDirectory() (see below). + for (auto& fromEntry: copyFrom.listEntries()) { + kj::Path filename({kj::mv(fromEntry.name)}); + OneOf newNode; + switch (fromEntry.type) { + case FsNode::Type::FILE: { + KJ_IF_SOME(file, copyFrom.tryOpenFile(filename, WriteMode::MODIFY)) { + if (copyFiles) { + auto copy = newFile(); + copy->copy(0, *file, 0, kj::maxValue); + file = kj::mv(copy); + } + + newNode = FileNode { kj::mv(file) }; + break; + } else { + continue; + } + } + + case FsNode::Type::DIRECTORY: { + KJ_IF_SOME(subdir, copyFrom.tryOpenSubdir(filename, WriteMode::MODIFY)) { + subdir = copyDirectory(*subdir, copyFiles); + newNode = DirectoryNode { kj::mv(subdir) }; + break; + } else { + continue; + } + } + + case FsNode::Type::SYMLINK: { + KJ_IF_SOME(link, copyFrom.tryReadlink(filename)) { + KJ_IF_SOME(metadata, copyFrom.tryLstat(filename)) { + newNode = SymlinkNode { metadata.lastModified, kj::mv(link) }; + break; + } else { + continue; + } + } else { + continue; + } + } + + default: + KJ_LOG(ERROR, "couldn't copy node of type not supported by InMemoryDirectory", + filename); + continue; + } + + KJ_ASSERT(newNode != nullptr); + + EntryImpl entry(kj::mv(filename)[0]); + StringPtr nameRef = entry.name; + entry.init(kj::mv(newNode)); + KJ_ASSERT(entries.insert(std::make_pair(nameRef, kj::mv(entry))).second); + } + } + + Own newFile() const { + // Construct a new empty file. Note: This function is expected to work without the lock held. + return fileFactory.create(clock); + } + Own newDirectory() const { + // Construct a new empty directory. Note: This function is expected to work without the lock + // held. + return newInMemoryDirectory(clock, fileFactory); + } + + Own copyDirectory(const Directory& other, bool copyFiles) const { + // Creates an in-memory deep copy of the given directory object. If `copyFiles` is true, then + // file contents are copied too, otherwise they are just linked. + return kj::atomicRefcounted(clock, fileFactory, other, copyFiles); + } Maybe openEntry(kj::StringPtr name, WriteMode mode) { // TODO(perf): We could avoid a copy if the entry exists, at the expense of a double-lookup @@ -1500,82 +1660,6 @@ private: void modified() { lastModified = clock.now(); } - - bool tryTransferChild(EntryImpl& entry, const FsNode::Type type, kj::Maybe lastModified, - kj::Maybe size, const Directory& fromDirectory, - PathPtr fromPath, TransferMode mode) { - switch (type) { - case FsNode::Type::FILE: - KJ_IF_SOME(file, fromDirectory.tryOpenFile(fromPath, WriteMode::MODIFY)) { - if (mode == TransferMode::COPY) { - auto copy = newInMemoryFile(clock); - copy->copy(0, *file, 0, size.orDefault(kj::maxValue)); - entry.set(kj::mv(copy)); - } else { - if (mode == TransferMode::MOVE) { - KJ_ASSERT(fromDirectory.tryRemove(fromPath), "couldn't move node", fromPath) { - return false; - } - } - entry.set(kj::mv(file)); - } - return true; - } else { - KJ_FAIL_ASSERT("source node deleted concurrently during transfer", fromPath) { - return false; - } - } - case FsNode::Type::DIRECTORY: - KJ_IF_SOME(subdir, fromDirectory.tryOpenSubdir(fromPath, WriteMode::MODIFY)) { - if (mode == TransferMode::COPY) { - auto copy = atomicRefcounted(clock); - auto& cpim = copy->impl.getWithoutLock(); // safe because just-created - for (auto& subEntry: subdir->listEntries()) { - EntryImpl newEntry(kj::mv(subEntry.name)); - Path filename(newEntry.name); - if (!cpim.tryTransferChild(newEntry, subEntry.type, kj::none, kj::none, *subdir, - filename, TransferMode::COPY)) { - KJ_LOG(ERROR, "couldn't copy node of type not supported by InMemoryDirectory", - filename); - } else { - StringPtr nameRef = newEntry.name; - cpim.entries.insert(std::make_pair(nameRef, kj::mv(newEntry))); - } - } - entry.set(kj::mv(copy)); - } else { - if (mode == TransferMode::MOVE) { - KJ_ASSERT(fromDirectory.tryRemove(fromPath), "couldn't move node", fromPath) { - return false; - } - } - entry.set(kj::mv(subdir)); - } - return true; - } else { - KJ_FAIL_ASSERT("source node deleted concurrently during transfer", fromPath) { - return false; - } - } - case FsNode::Type::SYMLINK: - KJ_IF_SOME(content, fromDirectory.tryReadlink(fromPath)) { - // Since symlinks are immutable, we can implement LINK the same as COPY. - entry.init(SymlinkNode { lastModified.orDefault(clock.now()), kj::mv(content) }); - if (mode == TransferMode::MOVE) { - KJ_ASSERT(fromDirectory.tryRemove(fromPath), "couldn't move node", fromPath) { - return false; - } - } - return true; - } else { - KJ_FAIL_ASSERT("source node deleted concurrently during transfer", fromPath) { - return false; - } - } - default: - return false; - } - } }; kj::MutexGuarded impl; @@ -1633,7 +1717,7 @@ private: } else if (entry.node == nullptr) { KJ_ASSERT(has(mode, WriteMode::CREATE)); lock->modified(); - return entry.init(FileNode { newInMemoryFile(lock->clock) }); + return entry.init(FileNode { lock->newFile() }); } else { KJ_FAIL_REQUIRE("not a file") { return kj::none; } } @@ -1651,7 +1735,7 @@ private: } else if (entry.node == nullptr) { KJ_ASSERT(has(mode, WriteMode::CREATE)); lock->modified(); - return entry.init(DirectoryNode { newInMemoryDirectory(lock->clock) }); + return entry.init(DirectoryNode { lock->newDirectory() }); } else { KJ_FAIL_REQUIRE("not a directory") { return kj::none; } } @@ -1682,7 +1766,7 @@ private: return entry.node.get().directory->clone(); } else if (entry.node == nullptr) { lock->modified(); - return entry.init(DirectoryNode { newInMemoryDirectory(lock->clock) }); + return entry.init(DirectoryNode { lock->newDirectory() }); } // Continue on. } @@ -1733,11 +1817,43 @@ private: Own newInMemoryFile(const Clock& clock) { return atomicRefcounted(clock); } -Own newInMemoryDirectory(const Clock& clock) { - return atomicRefcounted(clock); +Own newInMemoryDirectory(const Clock& clock, const InMemoryFileFactory& fileFactory) { + return atomicRefcounted(clock, fileFactory); } Own newFileAppender(Own inner) { return heap(kj::mv(inner)); } +const InMemoryFileFactory& defaultInMemoryFileFactory() { + class FactoryImpl: public InMemoryFileFactory { + public: + kj::Own create(const Clock& clock) const override { + return newInMemoryFile(clock); + } + }; + static const FactoryImpl instance; + return instance; +} + +#if __linux__ + +Own newMemfdFile(uint flags) { + int fd; + KJ_SYSCALL(fd = memfd_create("kj-memfd", flags | MFD_CLOEXEC)); + return newDiskFile(AutoCloseFd(fd)); +} + +const InMemoryFileFactory& memfdInMemoryFileFactory() { + class FactoryImpl: public InMemoryFileFactory { + public: + kj::Own create(const Clock& clock) const override { + return newMemfdFile(0); + } + }; + static const FactoryImpl instance; + return instance; +} + +#endif // __linux__ + } // namespace kj diff --git a/c++/src/kj/filesystem.h b/c++/src/kj/filesystem.h index 8b98226727..28c0ae1908 100644 --- a/c++/src/kj/filesystem.h +++ b/c++/src/kj/filesystem.h @@ -796,7 +796,7 @@ class Directory: public ReadableDirectory { virtual Own> replaceFile(PathPtr path, WriteMode mode) const = 0; // Construct a file which, when ready, will be atomically moved to `path`, replacing whatever - // is there already. See `Replacer` for detalis. + // is there already. See `Replacer` for details. // // The `CREATE` and `MODIFY` bits of `mode` are not enforced until commit time, hence // `replaceFile()` has no "try" variant. @@ -819,7 +819,7 @@ class Directory: public ReadableDirectory { virtual Own> replaceSubdir(PathPtr path, WriteMode mode) const = 0; // Construct a directory which, when ready, will be atomically moved to `path`, replacing - // whatever is there already. See `Replacer` for detalis. + // whatever is there already. See `Replacer` for details. // // The `CREATE` and `MODIFY` bits of `mode` are not enforced until commit time, hence // `replaceSubdir()` has no "try" variant. @@ -927,8 +927,32 @@ class Filesystem { // ======================================================================================= +class InMemoryFileFactory { + // Used to customize the File implementation used by InMemoryDirectory. +public: + virtual kj::Own create(const Clock& clock) const = 0; +}; + +const InMemoryFileFactory& defaultInMemoryFileFactory(); +// Creates files using `newInMemoryFile()`. + +#if __linux__ + +Own newMemfdFile(uint flags = 0); +// Creates a `File` backed by a Linux memfd. This creates an in-memory file which behaves more +// closely to a disk file, compared to newInMemoryFile(). In particular, the file has a backing +// FD, and memory mapping doesn't have weird quirks. +// +// `flags` will be passed to `memfd_create()`. (The MFD_CLOEXEC flag is always added.) + +const InMemoryFileFactory& memfdInMemoryFileFactory(); +// Creates files using `newMemfdFile()`. + +#endif // __linux__ + Own newInMemoryFile(const Clock& clock); -Own newInMemoryDirectory(const Clock& clock); +Own newInMemoryDirectory(const Clock& clock, + const InMemoryFileFactory& fileFactory = defaultInMemoryFileFactory()); // Construct file and directory objects which reside in-memory. // // InMemoryFile has the following special properties: @@ -942,6 +966,11 @@ Own newInMemoryDirectory(const Clock& clock); // - link() can link directory nodes in addition to files. // - link() and rename() accept any kind of Directory as `fromDirectory` -- it doesn't need to be // another InMemoryDirectory. However, for rename(), the from path must be a directory. +// +// `fileFactory` can be customized in order to control the implementation of `File` objects created +// using this `InMemoryDirectory`. This is particularly useful for testing where the application +// expects files to have backing file descriptors or implement memory mapping fully correctly, but +// doesn't care as much about directory behavior. Own newFileAppender(Own inner); // Creates an AppendableFile by wrapping a File. Note that this implementation assumes it is the diff --git a/c++/src/kj/glob-filter-test.c++ b/c++/src/kj/glob-filter-test.c++ new file mode 100644 index 0000000000..f00218c1fa --- /dev/null +++ b/c++/src/kj/glob-filter-test.c++ @@ -0,0 +1,84 @@ +// Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#include "common.h" +#include "test.h" +#include "glob-filter.h" + +namespace kj { +namespace _ { +namespace { + +KJ_TEST("GlobFilter") { + { + GlobFilter filter("foo"); + + KJ_EXPECT(filter.matches("foo")); + KJ_EXPECT(!filter.matches("bar")); + KJ_EXPECT(!filter.matches("foob")); + KJ_EXPECT(!filter.matches("foobbb")); + KJ_EXPECT(!filter.matches("fobbbb")); + KJ_EXPECT(!filter.matches("bfoo")); + KJ_EXPECT(!filter.matches("bbbbbfoo")); + KJ_EXPECT(filter.matches("bbbbb/foo")); + KJ_EXPECT(filter.matches("bar/baz/foo")); + } + + { + GlobFilter filter("foo*"); + + KJ_EXPECT(filter.matches("foo")); + KJ_EXPECT(!filter.matches("bar")); + KJ_EXPECT(filter.matches("foob")); + KJ_EXPECT(filter.matches("foobbb")); + KJ_EXPECT(!filter.matches("fobbbb")); + KJ_EXPECT(!filter.matches("bfoo")); + KJ_EXPECT(!filter.matches("bbbbbfoo")); + KJ_EXPECT(filter.matches("bbbbb/foo")); + KJ_EXPECT(filter.matches("bar/baz/foo")); + } + + { + GlobFilter filter("foo*bar"); + + KJ_EXPECT(filter.matches("foobar")); + KJ_EXPECT(filter.matches("fooxbar")); + KJ_EXPECT(filter.matches("fooxxxbar")); + KJ_EXPECT(!filter.matches("foo/bar")); + KJ_EXPECT(filter.matches("blah/fooxxxbar")); + KJ_EXPECT(!filter.matches("blah/xxfooxxxbar")); + } + + { + GlobFilter filter("foo?bar"); + + KJ_EXPECT(!filter.matches("foobar")); + KJ_EXPECT(filter.matches("fooxbar")); + KJ_EXPECT(!filter.matches("fooxxxbar")); + KJ_EXPECT(!filter.matches("foo/bar")); + KJ_EXPECT(filter.matches("blah/fooxbar")); + KJ_EXPECT(!filter.matches("blah/xxfooxbar")); + } +} + +} // namespace +} // namespace _ +} // namespace kj diff --git a/c++/src/kj/glob-filter.c++ b/c++/src/kj/glob-filter.c++ new file mode 100644 index 0000000000..d8ec620d0c --- /dev/null +++ b/c++/src/kj/glob-filter.c++ @@ -0,0 +1,109 @@ +// Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#include "glob-filter.h" + +namespace kj { + +GlobFilter::GlobFilter(const char* pattern): pattern(heapString(pattern)) {} +GlobFilter::GlobFilter(ArrayPtr pattern): pattern(heapString(pattern)) {} + +bool GlobFilter::matches(StringPtr name) { + // Get out your computer science books. We're implementing a non-deterministic finite automaton. + // + // Our NDFA has one "state" corresponding to each character in the pattern. + // + // As you may recall, an NDFA can be transformed into a DFA where every state in the DFA + // represents some combination of states in the NDFA. Therefore, we actually have to store a + // list of states here. (Actually, what we really want is a set of states, but because our + // patterns are mostly non-cyclic a list of states should work fine and be a bit more efficient.) + + // Our state list starts out pointing only at the start of the pattern. + states.resize(0); + states.add(0); + + Vector scratch; + + // Iterate through each character in the name. + for (char c: name) { + // Pull the current set of states off to the side, so that we can populate `states` with the + // new set of states. + Vector oldStates = kj::mv(states); + states = kj::mv(scratch); + states.resize(0); + + // The pattern can omit a leading path. So if we're at a '/' then enter the state machine at + // the beginning on the next char. + if (c == '/' || c == '\\') { + states.add(0); + } + + // Process each state. + for (uint state: oldStates) { + applyState(c, state); + } + + // Store the previous state vector for reuse. + scratch = kj::mv(oldStates); + } + + // If any one state is at the end of the pattern (or at a wildcard just before the end of the + // pattern), we have a match. + for (uint state: states) { + while (state < pattern.size() && pattern[state] == '*') { + ++state; + } + if (state == pattern.size()) { + return true; + } + } + return false; +} + +void GlobFilter::applyState(char c, int state) { + if (state < pattern.size()) { + switch (pattern[state]) { + case '*': + // At a '*', we both re-add the current state and attempt to match the *next* state. + if (c != '/' && c != '\\') { // '*' doesn't match '/'. + states.add(state); + } + applyState(c, state + 1); + break; + + case '?': + // A '?' matches one character (never a '/'). + if (c != '/' && c != '\\') { + states.add(state + 1); + } + break; + + default: + // Any other character matches only itself. + if (c == pattern[state]) { + states.add(state + 1); + } + break; + } + } +} + +} // namespace kj diff --git a/c++/src/kj/threadlocal-test.c++ b/c++/src/kj/glob-filter.h similarity index 59% rename from c++/src/kj/threadlocal-test.c++ rename to c++/src/kj/glob-filter.h index 7d409912e3..583db6f0c1 100644 --- a/c++/src/kj/threadlocal-test.c++ +++ b/c++/src/kj/glob-filter.h @@ -19,51 +19,27 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -#include "threadlocal.h" -#include "debug.h" -#include "thread.h" -#include +#pragma once -namespace kj { -namespace { - -KJ_THREADLOCAL_PTR(uint) tls1 = nullptr; -KJ_THREADLOCAL_PTR(uint) tls2; +#include +#include -TEST(ThreadLocal, Basic) { - // Verify that both started out null. - uint* p = tls1; - EXPECT_EQ(nullptr, p); - p = tls2; - EXPECT_EQ(nullptr, p); +namespace kj { - // Set tls1, then verify that only tls1 changed, not tls2. - uint i = 123; - tls1 = &i; +class GlobFilter { + // Implements glob filters for the --filter flag. - p = tls1; - EXPECT_EQ(&i, p); - p = tls2; - EXPECT_EQ(nullptr, p); +public: + explicit GlobFilter(const char* pattern); + explicit GlobFilter(ArrayPtr pattern); - // Check that in another thread, tls1 starts null but can be changed. - uint j = 456; - bool threadDone = false; - Thread([&]() { - p = tls1; - EXPECT_EQ(nullptr, p); - tls1 = &j; + bool matches(StringPtr name); - p = tls1; - EXPECT_EQ(&j, p); - threadDone = true; - }); - EXPECT_TRUE(threadDone); +private: + String pattern; + Vector states; - // tls1 didn't change in this thread. - p = tls1; - EXPECT_EQ(&i, p); -} + void applyState(char c, int state); +}; -} // namespace } // namespace kj diff --git a/c++/src/kj/refcount-test.c++ b/c++/src/kj/refcount-test.c++ index 6666615567..c2254989f8 100644 --- a/c++/src/kj/refcount-test.c++ +++ b/c++/src/kj/refcount-test.c++ @@ -24,10 +24,12 @@ namespace kj { -struct SetTrueInDestructor: public Refcounted { +struct SetTrueInDestructor: public Refcounted, EnableAddRefToThis { SetTrueInDestructor(bool* ptr): ptr(ptr) {} ~SetTrueInDestructor() { *ptr = true; } + kj::Rc newRef() { return addRefToThis(); } + bool* ptr; }; @@ -57,6 +59,99 @@ TEST(Refcount, Basic) { #endif } +KJ_TEST("Rc") { + bool b = false; + + Rc ref1 = kj::rc(&b); + EXPECT_FALSE(ref1->isShared()); + EXPECT_TRUE(ref1 != nullptr); + EXPECT_FALSE(ref1 == nullptr); + + Rc ref2 = ref1.addRef(); + EXPECT_TRUE(ref1->isShared()); + EXPECT_TRUE(ref1 == ref2); + + { + Rc ref3 = ref2.addRef(); + EXPECT_TRUE(ref3->isShared()); + // ref3 is dropped + } + + EXPECT_FALSE(b); + + // start dropping references one by one + + EXPECT_TRUE(ref2->isShared()); + ref1 = nullptr; + EXPECT_TRUE(ref1 == nullptr); + EXPECT_FALSE(ref2->isShared()); + EXPECT_FALSE(b); + EXPECT_FALSE(ref1 == ref2); + + ref2 = nullptr; + EXPECT_TRUE(ref1 == ref2); + + // last reference dropped, SetTrueInDestructor destructor should execute + EXPECT_TRUE(b); +} + +KJ_TEST("Rc Own interop") { + bool b = false; + + Rc ref1 = kj::rc(&b); + + EXPECT_FALSE(b); + auto own = ref1.toOwn(); + EXPECT_TRUE(ref1 == nullptr); + EXPECT_TRUE(own.get() != nullptr); + + EXPECT_FALSE(b); + own = nullptr; + EXPECT_TRUE(b); +} + +struct Child: public SetTrueInDestructor { + Child(bool* ptr): SetTrueInDestructor(ptr) {} +}; + +KJ_TEST("Rc inheritance") { + bool b = false; + + auto child = kj::rc(&b); + + // up casting works automatically + kj::Rc parent = child.addRef(); + + auto down = parent.downcast(); + EXPECT_TRUE(parent == nullptr); + EXPECT_TRUE(down != nullptr); + + EXPECT_FALSE(b); + child = nullptr; + EXPECT_FALSE(b); + down = nullptr; + EXPECT_TRUE(b); +} + +KJ_TEST("Refcounted::EnableAddRefToThis") { + bool b = false; + + auto ref1 = kj::rc(&b); + EXPECT_FALSE(ref1->isShared()); + + auto ref2 = ref1->newRef(); + EXPECT_TRUE(ref2->isShared()); + EXPECT_TRUE(ref1->isShared()); + EXPECT_FALSE(b); + + ref1 = nullptr; + EXPECT_FALSE(ref2->isShared()); + EXPECT_FALSE(b); + + ref2 = nullptr; + EXPECT_TRUE(b); +} + struct SetTrueInDestructor2 { // Like above but doesn't inherit Refcounted. @@ -134,4 +229,78 @@ KJ_TEST("RefcountedWrapper") { } } + +struct AtomicSetTrueInDestructor: public AtomicRefcounted, + EnableAddRefToThis { + + AtomicSetTrueInDestructor(bool* ptr): ptr(ptr) {} + ~AtomicSetTrueInDestructor() { *ptr = true; } + + kj::Arc newRef() { return addRefToThis(); } + + bool* ptr; +}; + +KJ_TEST("Arc") { + bool b = false; + + kj::Arc ref1 = kj::arc(&b); + EXPECT_FALSE(ref1->isShared()); + EXPECT_TRUE(ref1 != nullptr); + EXPECT_FALSE(ref1 == nullptr); + + kj::Arc ref2 = ref1.addRef(); + + // can be always cast to Arc + kj::Arc ref3 = ref1.addRef(); + + // addRef works for const references too + kj::Arc ref4 = ref3.addRef(); + + ref1 = nullptr; + EXPECT_TRUE(ref1 == nullptr); + ref2 = nullptr; + EXPECT_TRUE(ref2 == nullptr); + ref3 = nullptr; + EXPECT_TRUE(ref3 == nullptr); + + EXPECT_FALSE(b); + ref4 = nullptr; + EXPECT_TRUE(b); +} + +KJ_TEST("AtomicRefcounted::EnableAddRefToThis") { + bool b = false; + + kj::Arc ref1 = kj::arc(&b); + EXPECT_FALSE(ref1->isShared()); + + kj::Arc ref2 = ref1->newRef(); + EXPECT_TRUE(ref2->isShared()); + EXPECT_TRUE(ref1->isShared()); + EXPECT_FALSE(b); + + ref1 = nullptr; + EXPECT_FALSE(ref2->isShared()); + EXPECT_FALSE(b); + + ref2 = nullptr; + EXPECT_TRUE(b); +} + +KJ_TEST("Arc Own interop") { + bool b = false; + + kj::Arc ref1 = kj::arc(&b); + + EXPECT_FALSE(b); + auto own = ref1.toOwn(); + EXPECT_TRUE(ref1 == nullptr); + EXPECT_TRUE(own.get() != nullptr); + + EXPECT_FALSE(b); + own = nullptr; + EXPECT_TRUE(b); +} + } // namespace kj diff --git a/c++/src/kj/refcount.h b/c++/src/kj/refcount.h index 03b5234d8d..7b1502663c 100644 --- a/c++/src/kj/refcount.h +++ b/c++/src/kj/refcount.h @@ -38,6 +38,12 @@ namespace kj { // ======================================================================================= // Non-atomic (thread-unsafe) refcounting +template +class Rc; + +template +class EnableAddRefToThis; + class Refcounted: private Disposer { // Subclass this to create a class that contains a reference count. Then, use // `kj::refcounted()` to allocate a new refcounted pointer. @@ -78,9 +84,13 @@ class Refcounted: private Disposer { // "mutable" because disposeImpl() is const. Bleh. void disposeImpl(void* pointer) const override; + template static Own addRefInternal(T* object); + template + static Rc addRcRefInternal(T* object); + template friend Own addRef(T& object); template @@ -88,6 +98,15 @@ class Refcounted: private Disposer { template friend class RefcountedWrapper; + + template + friend Rc rc(Params&&... params); + + template + friend class Rc; + + template + friend class EnableAddRefToThis; }; template @@ -98,6 +117,14 @@ inline Own refcounted(Params&&... params) { return Refcounted::addRefInternal(new T(kj::fwd(params)...)); } +template +inline Rc rc(Params&&... params) { + // Allocate a new refcounted instance of T, passing `params` to its constructor. + // Returns smart pointer that can be used to manage references. + + return Refcounted::addRcRefInternal(new T(kj::fwd(params)...)); +} + template Own addRef(T& object) { // Return a new reference to `object`, which must subclass Refcounted and have been allocated @@ -115,6 +142,113 @@ Own Refcounted::addRefInternal(T* object) { return Own(object, *refcounted); } +template +Rc Refcounted::addRcRefInternal(T* object) { + static_assert(kj::canConvert()); + Refcounted* refcounted = object; + ++refcounted->refcount; + return Rc(object); +} + +template +class Rc { + // Smart pointer for reference counted objects. + // + // There are only three ways to obtain new Rc instances: + // - use kj::rc(...) function to create new T. + // - use kj::Rc::addRef() and the existing Rc instance. + // - use EnableAddRefToThis to allow T instance to add new references to itself. + // + // Suggested usage patterns are: + // - return kj::Rc as value from factory functions: + // kj::Rc createMyService(); + // - pass kj::Rc as rvalue to functions that need to extend T's lifetime: + // void setMyService(kj::Rc&& service) + // - store kj::Rc as data member: + // struct MyComputation { kj::Rc service; }; + // - use toOwn to convert kj::Rc instance to kj::Own and use it + // without being concerned of reference counting behavior. + // To improve the transparency of the code, kj::Own shouldn't be used + // to call addRef() without kj::Rc. + +public: + KJ_DISALLOW_COPY(Rc); + Rc() { } + Rc(decltype(nullptr)) { } + inline Rc(Rc&& other) noexcept = default; + + template ()>> + inline Rc(Rc&& other) noexcept : own(kj::mv(other.own)) { } + + kj::Own toOwn() { + // Convert Rc to Own. + // Nullifies the original Rc. + return kj::mv(own); + } + + kj::Rc addRef() { + T* refcounted = own.get(); + if (refcounted != nullptr) { + return Refcounted::addRcRefInternal(refcounted); + } else { + return kj::Rc(); + } + } + + Rc& operator=(decltype(nullptr)) { + own = nullptr; + return *this; + } + + Rc& operator=(Rc&& other) = default; + + template + Rc downcast() { + return Rc(own.template downcast()); + } + + inline bool operator==(const Rc& other) const { return own.get() == other.own.get(); } + inline bool operator==(decltype(nullptr)) const { return own.get() == nullptr; } + inline bool operator!=(decltype(nullptr)) const { return own.get() != nullptr; } + + inline T* operator->() { return own.get(); } + inline const T* operator->() const { return own.get(); } + + // do not expose * to avoid dangling references + +private: + Rc(T* t) : own(t, *t) { } + Rc(Own&& t) : own(kj::mv(t)) { } + + Own own; + + friend class Refcounted; + + template + friend class Rc; + + template + friend class EnableAddRefToThis; +}; + +template +class EnableAddRefToThis { + // Exposes addRefToThis member function for objects to add + // references to themselves. + // Can be used both with Refcounted and AtomicRefcounted objects. + +protected: + auto addRefToThis() const { + const Self* self = static_cast(this); + return Self::addRcRefInternal(self); + } + + auto addRefToThis() { + Self* self = static_cast(this); + return Self::addRcRefInternal(self); + } +}; + template class RefcountedWrapper: public Refcounted { // Adds refcounting as a wrapper around an existing type, allowing you to construct references @@ -181,6 +315,9 @@ Own>> refcountedWrapper(Own&& wrapped) { #endif #endif +template +class Arc; + class AtomicRefcounted: private kj::Disposer { public: AtomicRefcounted() = default; @@ -218,6 +355,19 @@ class AtomicRefcounted: private kj::Disposer { friend kj::Maybe> atomicAddRefWeak(const T& object); template friend kj::Own atomicRefcounted(Params&&... params); + + template + static kj::Arc addRcRefInternal(T* object); + template + static kj::Arc addRcRefInternal(const T* object); + + template + friend class Arc; + template + friend kj::Arc arc(Params&&... params); + + template + friend class EnableAddRefToThis; }; template @@ -225,6 +375,11 @@ inline kj::Own atomicRefcounted(Params&&... params) { return AtomicRefcounted::addRefInternal(new T(kj::fwd(params)...)); } +template +inline kj::Arc arc(Params&&... params) { + return AtomicRefcounted::addRcRefInternal(new T(kj::fwd(params)...)); +} + template kj::Own atomicAddRef(T& object) { KJ_IREQUIRE(object.AtomicRefcounted::refcount > 0, @@ -278,6 +433,85 @@ kj::Own AtomicRefcounted::addRefInternal(const T* object) { return kj::Own(object, *refcounted); } +template +kj::Arc AtomicRefcounted::addRcRefInternal(T* object) { + static_assert(kj::canConvert()); + return kj::Arc(addRefInternal(object)); +} + +template +kj::Arc AtomicRefcounted::addRcRefInternal(const T* object) { + static_assert(kj::canConvert()); + return kj::Arc(addRefInternal(object)); +} + +template +class Arc { + // Smart pointer for atomic reference counted objects. + // + // Usage is similar to kj::Rc. + +public: + KJ_DISALLOW_COPY(Arc); + Arc() { } + Arc(decltype(nullptr)) { } + inline Arc(Arc&& other) noexcept = default; + + template ()>> + inline Arc(Arc&& other) noexcept : own(kj::mv(other.own)) { } + + kj::Own toOwn() { + // Convert Arc to Own. + // Nullifies the original Arc. + return kj::mv(own); + } + + kj::Arc addRef() { + T* refcounted = own.get(); + if (refcounted != nullptr) { + return AtomicRefcounted::addRcRefInternal(refcounted); + } else { + return kj::Arc(); + } + } + + Arc& operator=(decltype(nullptr)) { + own = nullptr; + return *this; + } + + Arc& operator=(Arc&& other) = default; + + template + Arc downcast() { + return Arc(own.template downcast()); + } + + inline bool operator==(const Arc& other) const { return own.get() == other.own.get(); } + inline bool operator==(decltype(nullptr)) const { return own.get() == nullptr; } + inline bool operator!=(decltype(nullptr)) const { return own.get() != nullptr; } + + inline T* operator->() { return own.get(); } + inline const T* operator->() const { return own.get(); } + + // do not expose * to avoid dangling references + +private: + Arc(T* t) : own(t, *t) { } + Arc(Own&& t) : own(kj::mv(t)) { } + + Own own; + + friend class AtomicRefcounted; + + template + friend class Arc; + + template + friend class EnableAddRefToThis; +}; + + } // namespace kj KJ_END_HEADER diff --git a/c++/src/kj/string-test.c++ b/c++/src/kj/string-test.c++ index 9c6da1c31d..a1e7167255 100644 --- a/c++/src/kj/string-test.c++ +++ b/c++/src/kj/string-test.c++ @@ -435,6 +435,26 @@ KJ_TEST("StringPtr contains") { KJ_EXPECT(foobar.slice(2).contains(foobar.slice(1)) == false); } +struct Std { + static std::string from(const String* str) { + return std::string(str->cStr()); + } + + static std::string from(const StringPtr* str) { + return std::string(str->cStr()); + } +}; + +KJ_TEST("as") { + String str = kj::str("foo"_kj); + std::string stdStr = str.as(); + KJ_EXPECT(stdStr == "foo"); + + StringPtr ptr = "bar"_kj; + std::string stdPtr = ptr.as(); + KJ_EXPECT(stdPtr == "bar"); +} + } // namespace } // namespace _ (private) } // namespace kj diff --git a/c++/src/kj/string.h b/c++/src/kj/string.h index f47c45f172..84a82485b2 100644 --- a/c++/src/kj/string.h +++ b/c++/src/kj/string.h @@ -176,6 +176,11 @@ class StringPtr { // Like ArrayPtr::attach(), but instead promotes a StringPtr into a ConstString. Generally the // attachment should be an object that somehow owns the String that the StringPtr is pointing at. + template + inline auto as() { return T::from(this); } + // Syntax sugar for invoking T::from. + // Used to chain conversion calls rather than wrap with function. + private: inline explicit constexpr StringPtr(ArrayPtr content): content(content) {} friend constexpr StringPtr (::operator "" _kj)(const char* str, size_t n); @@ -322,6 +327,11 @@ class String { template Maybe tryParseAs() const { return StringPtr(*this).tryParseAs(); } + template + inline auto as() { return T::from(this); } + // Syntax sugar for invoking T::from. + // Used to chain conversion calls rather than wrap with function. + private: Array content; }; diff --git a/c++/src/kj/table.h b/c++/src/kj/table.h index 6f670a273b..2a4a83c76c 100644 --- a/c++/src/kj/table.h +++ b/c++/src/kj/table.h @@ -1164,7 +1164,7 @@ class BTreeImpl::MaybeUint { // A nullable uint, using the value zero to mean null and shifting all other values up by 1. public: MaybeUint() = default; - inline MaybeUint(uint i): i(i - 1) {} + inline MaybeUint(uint i): i(i + 1) {} inline MaybeUint(decltype(nullptr)): i(0) {} inline bool operator==(decltype(nullptr)) const { return i == 0; } diff --git a/c++/src/kj/test-test.c++ b/c++/src/kj/test-test.c++ index 379c77d45f..1cfa9a31f0 100644 --- a/c++/src/kj/test-test.c++ +++ b/c++/src/kj/test-test.c++ @@ -33,58 +33,6 @@ namespace kj { namespace _ { namespace { -KJ_TEST("GlobFilter") { - { - GlobFilter filter("foo"); - - KJ_EXPECT(filter.matches("foo")); - KJ_EXPECT(!filter.matches("bar")); - KJ_EXPECT(!filter.matches("foob")); - KJ_EXPECT(!filter.matches("foobbb")); - KJ_EXPECT(!filter.matches("fobbbb")); - KJ_EXPECT(!filter.matches("bfoo")); - KJ_EXPECT(!filter.matches("bbbbbfoo")); - KJ_EXPECT(filter.matches("bbbbb/foo")); - KJ_EXPECT(filter.matches("bar/baz/foo")); - } - - { - GlobFilter filter("foo*"); - - KJ_EXPECT(filter.matches("foo")); - KJ_EXPECT(!filter.matches("bar")); - KJ_EXPECT(filter.matches("foob")); - KJ_EXPECT(filter.matches("foobbb")); - KJ_EXPECT(!filter.matches("fobbbb")); - KJ_EXPECT(!filter.matches("bfoo")); - KJ_EXPECT(!filter.matches("bbbbbfoo")); - KJ_EXPECT(filter.matches("bbbbb/foo")); - KJ_EXPECT(filter.matches("bar/baz/foo")); - } - - { - GlobFilter filter("foo*bar"); - - KJ_EXPECT(filter.matches("foobar")); - KJ_EXPECT(filter.matches("fooxbar")); - KJ_EXPECT(filter.matches("fooxxxbar")); - KJ_EXPECT(!filter.matches("foo/bar")); - KJ_EXPECT(filter.matches("blah/fooxxxbar")); - KJ_EXPECT(!filter.matches("blah/xxfooxxxbar")); - } - - { - GlobFilter filter("foo?bar"); - - KJ_EXPECT(!filter.matches("foobar")); - KJ_EXPECT(filter.matches("fooxbar")); - KJ_EXPECT(!filter.matches("fooxxxbar")); - KJ_EXPECT(!filter.matches("foo/bar")); - KJ_EXPECT(filter.matches("blah/fooxbar")); - KJ_EXPECT(!filter.matches("blah/xxfooxbar")); - } -} - KJ_TEST("expect exit from exit") { KJ_EXPECT_EXIT(42, _exit(42)); KJ_EXPECT_EXIT(kj::none, _exit(42)); diff --git a/c++/src/kj/test.c++ b/c++/src/kj/test.c++ index cb07d5675b..8b977bc7fa 100644 --- a/c++/src/kj/test.c++ +++ b/c++/src/kj/test.c++ @@ -68,95 +68,6 @@ size_t TestCase::iterCount() { // ======================================================================================= -namespace _ { // private - -GlobFilter::GlobFilter(const char* pattern): pattern(heapString(pattern)) {} -GlobFilter::GlobFilter(ArrayPtr pattern): pattern(heapString(pattern)) {} - -bool GlobFilter::matches(StringPtr name) { - // Get out your computer science books. We're implementing a non-deterministic finite automaton. - // - // Our NDFA has one "state" corresponding to each character in the pattern. - // - // As you may recall, an NDFA can be transformed into a DFA where every state in the DFA - // represents some combination of states in the NDFA. Therefore, we actually have to store a - // list of states here. (Actually, what we really want is a set of states, but because our - // patterns are mostly non-cyclic a list of states should work fine and be a bit more efficient.) - - // Our state list starts out pointing only at the start of the pattern. - states.resize(0); - states.add(0); - - Vector scratch; - - // Iterate through each character in the name. - for (char c: name) { - // Pull the current set of states off to the side, so that we can populate `states` with the - // new set of states. - Vector oldStates = kj::mv(states); - states = kj::mv(scratch); - states.resize(0); - - // The pattern can omit a leading path. So if we're at a '/' then enter the state machine at - // the beginning on the next char. - if (c == '/' || c == '\\') { - states.add(0); - } - - // Process each state. - for (uint state: oldStates) { - applyState(c, state); - } - - // Store the previous state vector for reuse. - scratch = kj::mv(oldStates); - } - - // If any one state is at the end of the pattern (or at a wildcard just before the end of the - // pattern), we have a match. - for (uint state: states) { - while (state < pattern.size() && pattern[state] == '*') { - ++state; - } - if (state == pattern.size()) { - return true; - } - } - return false; -} - -void GlobFilter::applyState(char c, int state) { - if (state < pattern.size()) { - switch (pattern[state]) { - case '*': - // At a '*', we both re-add the current state and attempt to match the *next* state. - if (c != '/' && c != '\\') { // '*' doesn't match '/'. - states.add(state); - } - applyState(c, state + 1); - break; - - case '?': - // A '?' matches one character (never a '/'). - if (c != '/' && c != '\\') { - states.add(state + 1); - } - break; - - default: - // Any other character matches only itself. - if (c == pattern[state]) { - states.add(state + 1); - } - break; - } - } -} - -} // namespace _ (private) - -// ======================================================================================= - namespace { class TestExceptionCallback: public ExceptionCallback { @@ -255,7 +166,7 @@ public: } } - _::GlobFilter filter(filePattern); + kj::GlobFilter filter(filePattern); for (TestCase* testCase = testCasesHead; testCase != nullptr; testCase = testCase->next) { if (!testCase->matchedFilter && filter.matches(testCase->file) && diff --git a/c++/src/kj/test.h b/c++/src/kj/test.h index e3287b886a..e40bfbc96b 100644 --- a/c++/src/kj/test.h +++ b/c++/src/kj/test.h @@ -21,10 +21,11 @@ #pragma once -#include "debug.h" -#include "vector.h" -#include "function.h" -#include "windows-sanity.h" // work-around macro conflict with `ERROR` +#include +#include +#include +#include +#include // work-around macro conflict with `ERROR` KJ_BEGIN_HEADER @@ -186,24 +187,6 @@ class LogExpectation: public ExceptionCallback { UnwindDetector unwindDetector; }; -class GlobFilter { - // Implements glob filters for the --filter flag. - // - // Exposed in header only for testing. - -public: - explicit GlobFilter(const char* pattern); - explicit GlobFilter(ArrayPtr pattern); - - bool matches(StringPtr name); - -private: - String pattern; - Vector states; - - void applyState(char c, int state); -}; - } // namespace _ (private) } // namespace kj diff --git a/c++/src/kj/threadlocal.h b/c++/src/kj/threadlocal.h deleted file mode 100644 index 613b96e788..0000000000 --- a/c++/src/kj/threadlocal.h +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright (c) 2014, Jason Choy -// Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors -// Licensed under the MIT License: -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -// THE SOFTWARE. - -#pragma once - -// This file declares a macro `KJ_THREADLOCAL_PTR` for declaring thread-local pointer-typed -// variables. Use like: -// KJ_THREADLOCAL_PTR(MyType) foo = nullptr; -// This is equivalent to: -// thread_local MyType* foo = nullptr; -// This can only be used at the global scope. -// -// AVOID USING THIS. Use of thread-locals is discouraged because they often have many of the same -// properties as singletons: http://www.object-oriented-security.org/lets-argue/singletons -// -// Also, thread-locals tend to be hostile to event-driven code, which can be particularly -// surprising when using fibers (all fibers in the same thread will share the same threadlocals, -// even though they do not share a stack). -// -// That said, thread-locals are sometimes needed for runtime logistics in the KJ framework. For -// example, the current exception callback and current EventLoop are stored as thread-local -// pointers. Since KJ only ever needs to store pointers, not values, we avoid the question of -// whether these values' destructors need to be run, and we avoid the need for heap allocation. - -#include "common.h" - -KJ_BEGIN_HEADER - -namespace kj { - -#if __GNUC__ - -#define KJ_THREADLOCAL_PTR(type) static __thread type* -// GCC's __thread is lighter-weight than thread_local and is good enough for our purposes. -// -// TODO(cleanup): The above comment was written many years ago. Is it still true? Shouldn't the -// compiler be smart enough to optimize a thread_local of POD type? - -#else - -#define KJ_THREADLOCAL_PTR(type) static thread_local type* - -#endif // KJ_USE_PTHREAD_TLS - -} // namespace kj - -KJ_END_HEADER diff --git a/c++/src/kj/timer.c++ b/c++/src/kj/timer.c++ index 5aec7a9a6f..9ce5a3b93f 100644 --- a/c++/src/kj/timer.c++ +++ b/c++/src/kj/timer.c++ @@ -40,28 +40,48 @@ struct TimerImpl::Impl { class TimerImpl::TimerPromiseAdapter { public: - TimerPromiseAdapter(PromiseFulfiller& fulfiller, TimerImpl::Impl& impl, TimePoint time) - : time(time), fulfiller(fulfiller), impl(impl) { - pos = impl.timers.insert(this); + TimerPromiseAdapter(PromiseFulfiller& fulfiller, TimerImpl& parent, TimePoint time) + : time(time), fulfiller(fulfiller), parent(parent) { + pos = parent.impl->timers.insert(this); + + KJ_IF_SOME(h, parent.sleepHooks) { + if (pos == parent.impl->timers.begin()) { + h.updateNextTimerEvent(time); + } + } } ~TimerPromiseAdapter() { - if (pos != impl.timers.end()) { - impl.timers.erase(pos); + if (pos != parent.impl->timers.end()) { + KJ_IF_SOME(h, parent.sleepHooks) { + bool isFirst = pos == parent.impl->timers.begin(); + + parent.impl->timers.erase(pos); + + if (isFirst) { + if (parent.impl->timers.empty()) { + h.updateNextTimerEvent(kj::none); + } else { + h.updateNextTimerEvent((*parent.impl->timers.begin())->time); + } + } + } else { + parent.impl->timers.erase(pos); + } } } void fulfill() { fulfiller.fulfill(); - impl.timers.erase(pos); - pos = impl.timers.end(); + parent.impl->timers.erase(pos); + pos = parent.impl->timers.end(); } const TimePoint time; private: PromiseFulfiller& fulfiller; - TimerImpl::Impl& impl; + TimerImpl& parent; Impl::Timers::const_iterator pos; }; @@ -70,12 +90,21 @@ inline bool TimerImpl::Impl::TimerBefore::operator()( return lhs->time < rhs->time; } +TimePoint TimerImpl::now() const { + KJ_IF_SOME(h, sleepHooks) { + return h.getTimeWhileSleeping(); + } else { + return time; + } +} + Promise TimerImpl::atTime(TimePoint time) { - return newAdaptedPromise(*impl, time); + auto result = newAdaptedPromise(*this, time); + return result; } Promise TimerImpl::afterDelay(Duration delay) { - return newAdaptedPromise(*impl, time + delay); + return newAdaptedPromise(*this, now() + delay); } TimerImpl::TimerImpl(TimePoint startTime) @@ -110,6 +139,8 @@ Maybe TimerImpl::timeoutToNextEvent(TimePoint start, Duration unit, ui } void TimerImpl::advanceTo(TimePoint newTime) { + sleepHooks = nullptr; + // On Macs running an Intel processor, it has been observed that clock_gettime // may return non monotonic time, even when CLOCK_MONOTONIC is used. // This workaround is to avoid the assert triggering on these machines. diff --git a/c++/src/kj/timer.h b/c++/src/kj/timer.h index eb9443c23b..2ef0c8fc42 100644 --- a/c++/src/kj/timer.h +++ b/c++/src/kj/timer.h @@ -108,6 +108,25 @@ class TimerImpl final: public Timer { void advanceTo(TimePoint newTime); // Set the time to `time` and fire any at() events that have been passed. + class SleepHooks { + public: + virtual void updateNextTimerEvent(kj::Maybe time) = 0; + // Called whenever the value returned by `nextEvent()` changes. + + virtual kj::TimePoint getTimeWhileSleeping() = 0; + // Get the current time. While sleeping, we can't lock time in place and advance it on each + // poll of the event queue, because arbitrary time might have passed outside the control of + // the KJ event loop. + }; + + void setSleeping(SleepHooks& hooks) { sleepHooks = hooks; } + // Hooks needed by UnixEventPort::preparePollableFdForSleep(). When the loop is sleeping, we + // would like for the application to be able to invoke the kj::Timer and for it to basically work + // correctly. This requires that we make some callbacks to the UnixEventPort to keep things + // consistent, since we can't assume the UnixEventPort will be actively polling the TimerImpl. + // + // The sleep hooks are automatically cleared when advanceTo() is next called. + // implements Timer ---------------------------------------------------------- TimePoint now() const override; Promise atTime(TimePoint time) override; @@ -118,6 +137,7 @@ class TimerImpl final: public Timer { class TimerPromiseAdapter; TimePoint time; Own impl; + kj::Maybe sleepHooks; }; // ======================================================================================= @@ -137,8 +157,6 @@ Promise Timer::timeoutAfter(Duration delay, Promise&& promise) { })); } -inline TimePoint TimerImpl::now() const { return time; } - } // namespace kj KJ_END_HEADER