From ff92c499286b610abf6bb090f0621ffa452b17d6 Mon Sep 17 00:00:00 2001 From: Kenton Varda Date: Sat, 9 Dec 2023 22:41:39 -0600 Subject: [PATCH 01/58] Fix leftover handling in HTTP CONNECT implementation. A buffer was read after being freed when two unusual things happened together: 1) Some of the initial connection content arrived immediately with the HTTP response headers, causing there to be a `leftover` buffer. This is rare because the response headers would normally be returned when the connection was opened, but the nature of TCP doesn't allow the server to return any bytes until an additional network round trip has occurred. 2) The application tried to read with a `minBytes` greater than the size of this leftover. (This is unusual because most use cases set `minBytes` to 1, which can't possibly be larger than a non-empty leftover size.) In this case, the leftover's backing buffer was destroyed, but the ArrayPtr pointing into it was not reset, so the next read would try to read it again. I have not observed this problem happening in practice. I just noticed the bug while reading the code. --- c++/src/kj/compat/http-test.c++ | 54 ++++++++++++++++++++++++++++----- c++/src/kj/compat/http.c++ | 1 + 2 files changed, 47 insertions(+), 8 deletions(-) diff --git a/c++/src/kj/compat/http-test.c++ b/c++/src/kj/compat/http-test.c++ index 7f7f520d9a..ebf0215b7f 100644 --- a/c++/src/kj/compat/http-test.c++ +++ b/c++/src/kj/compat/http-test.c++ @@ -5643,6 +5643,44 @@ KJ_TEST("HttpClient concurrency limiting") { KJ_EXPECT(callbackEvents == kj::ArrayPtr({ {0, 0} })); } +KJ_TEST("HttpClientImpl connect()") { + KJ_HTTP_TEST_SETUP_IO; + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + HttpHeaderTable headerTable; + auto client = newHttpClient(headerTable, *pipe.ends[0]); + + auto req = client->connect("foo:123", HttpHeaders(headerTable), {}); + + char buffer[16]; + auto readPromise = req.connection->tryRead(buffer, 16, 16); + + expectRead(*pipe.ends[1], "CONNECT foo:123 HTTP/1.1\r\n\r\n").wait(waitScope); + + { + kj::StringPtr msg = "HTTP/1.1 200 OK\r\n\r\nthis is the"_kj; + pipe.ends[1]->write(msg.begin(), msg.size()).wait(waitScope); + } + + KJ_EXPECT(!readPromise.poll(waitScope)); + + kj::Promise writePromise = nullptr; + { + kj::StringPtr msg = " connection content!!"_kj; + writePromise = pipe.ends[1]->write(msg.begin(), msg.size()); + } + + KJ_ASSERT(readPromise.poll(waitScope)); + KJ_ASSERT(readPromise.wait(waitScope) == 16); + KJ_EXPECT(kj::str(kj::ArrayPtr(buffer)) == "this is the conn"_kj); + + KJ_EXPECT(req.connection->tryRead(buffer, 16, 16).wait(waitScope) == 16); + KJ_EXPECT(kj::str(kj::ArrayPtr(buffer)) == "ection content!!"_kj); + + KJ_ASSERT(writePromise.poll(waitScope)); + writePromise.wait(waitScope); +} + #if KJ_HTTP_TEST_USE_OS_PIPE // This test relies on access to the network. KJ_TEST("NetworkHttpClient connect impl") { @@ -7154,9 +7192,9 @@ struct HttpRangeTestCase { kj::OneOf, HttpEverythingRange, HttpUnsatisfiableRange> expected; HttpRangeTestCase(kj::StringPtr value, uint64_t contentLength) : - value(value), contentLength(contentLength), expected(HttpUnsatisfiableRange {}) {} + value(value), contentLength(contentLength), expected(HttpUnsatisfiableRange {}) {} HttpRangeTestCase(kj::StringPtr value, uint64_t contentLength, HttpEverythingRange expected) : - value(value), contentLength(contentLength), expected(expected) {} + value(value), contentLength(contentLength), expected(expected) {} HttpRangeTestCase(kj::StringPtr value, uint64_t contentLength, InitializeableArray expected) : value(value), contentLength(contentLength), expected(kj::mv(expected)) {} }; @@ -7171,7 +7209,7 @@ KJ_TEST("Range header parsing") { {" Bytes =0-1"_kjc, 2, HttpEverythingRange {}}, // Check fails with other units {"nibbles=0-1"_kjc, 2}, - + // ===== Interval ===== // Check valid ranges accepted {"bytes=0-1"_kjc, 8, {{0,1}}}, @@ -7191,7 +7229,7 @@ KJ_TEST("Range header parsing") { {"bytes=0-2,1-3"_kjc, 5, {{0,2},{1,3}}}, // Check unsatisfiable ranges ignored {"bytes=1-2,7-8"_kjc, 5, {{1,2}}}, - + // ===== Prefix ===== // Check valid ranges accepted {"bytes=2-"_kjc, 8, {{2,7}}}, @@ -7201,7 +7239,7 @@ KJ_TEST("Range header parsing") { {"bytes=5-"_kjc, 2}, // Check multiple valid ranges accepted {"bytes= 1- ,6-, 10-11 "_kjc, 12, {{1,11},{6,11},{10,11}}}, - + // ===== Suffix ===== // Check valid ranges accepted {"bytes=-2"_kjc, 8, {{6,7}}}, @@ -7214,7 +7252,7 @@ KJ_TEST("Range header parsing") { // Check unsatisfiable empty range ignored {"bytes=-0"_kjc, 2}, {"bytes=0-1,-0,2-3"_kjc, 4, {{0,1},{2,3}}}, - + // ===== Invalid ===== // Check range with no start or end rejected {"bytes=-"_kjc, 2}, @@ -7226,7 +7264,7 @@ KJ_TEST("Range header parsing") { {"bytes="_kjc, 2}, {"bytes"_kjc, 2}, }; - + for (auto& testCase : RANGE_TEST_CASES) { auto ranges = tryParseHttpRangeHeader(testCase.value, testCase.contentLength); KJ_SWITCH_ONEOF(testCase.expected) { @@ -7255,7 +7293,7 @@ KJ_TEST("Range header parsing") { KJ_FAIL_ASSERT("Expected ", testCase.value, testCase.contentLength, "to be unsatisfiable"); } } - } + } } } } diff --git a/c++/src/kj/compat/http.c++ b/c++/src/kj/compat/http.c++ index 664941d91e..a9ccf6f285 100644 --- a/c++/src/kj/compat/http.c++ +++ b/c++/src/kj/compat/http.c++ @@ -4337,6 +4337,7 @@ public: if (bytesToCopy > 0) { memcpy(destination, leftover.begin(), bytesToCopy); + leftover = nullptr; leftoverBackingBuffer = nullptr; minBytes -= bytesToCopy; maxBytes -= bytesToCopy; From f234612afd4ac5e18d36743f02a80c81b96ce93b Mon Sep 17 00:00:00 2001 From: Felix Hanau Date: Thu, 14 Dec 2023 04:07:55 +0000 Subject: [PATCH 02/58] Add support for bazel 7.0 and clang 17 --- c++/.bazelrc | 10 ++++++---- c++/.bazelversion | 2 +- c++/WORKSPACE | 1 + 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/c++/.bazelrc b/c++/.bazelrc index 39ec5624d4..ba80b81903 100644 --- a/c++/.bazelrc +++ b/c++/.bazelrc @@ -1,3 +1,4 @@ +common --noenable_bzlmod common --enable_platform_specific_config build:unix --cxxopt='-std=c++20' --host_cxxopt='-std=c++20' --force_pic --verbose_failures @@ -6,14 +7,15 @@ build:unix --cxxopt='-Wextra' --host_cxxopt='-Wextra' build:unix --cxxopt='-Wno-strict-aliasing' --host_cxxopt='-Wno-strict-aliasing' build:unix --cxxopt='-Wno-sign-compare' --host_cxxopt='-Wno-sign-compare' build:unix --cxxopt='-Wno-unused-parameter' --host_cxxopt='-Wno-unused-parameter' +build:unix --cxxopt='-Wno-deprecated-this-capture' --host_cxxopt='-Wno-deprecated-this-capture' # I needed these magic spells to build locally with clang-11 and clang-12 on Ubuntu. clang-13 and up # work out-of-the-box. # TODO(2.0): Remove this when we support g++ again. -build:unix --action_env=CXXFLAGS=-stdlib=libc++ -build:unix --action_env=LDFLAGS=-stdlib=libc++ -build:unix --action_env=BAZEL_CXXOPTS=-stdlib=libc++ -build:unix --action_env=BAZEL_LINKOPTS=-lc++:-lm +build:linux --action_env=CXXFLAGS=-stdlib=libc++ +build:linux --action_env=LDFLAGS=-stdlib=libc++ +build:linux --action_env=BAZEL_CXXOPTS=-stdlib=libc++ +build:linux --action_env=BAZEL_LINKOPTS=-lc++:-lm build:linux --config=unix build:macos --config=unix diff --git a/c++/.bazelversion b/c++/.bazelversion index 5e3254243a..66ce77b7ea 100644 --- a/c++/.bazelversion +++ b/c++/.bazelversion @@ -1 +1 @@ -6.1.2 +7.0.0 diff --git a/c++/WORKSPACE b/c++/WORKSPACE index 3e5bfe595e..efc25f9167 100644 --- a/c++/WORKSPACE +++ b/c++/WORKSPACE @@ -38,6 +38,7 @@ cc_library( "-Dverbose=-1", ] + select({ "@platforms//os:macos": [ "-Wno-implicit-function-declaration" ], + "@platforms//os:linux": [ "-Wno-implicit-function-declaration" ], "//conditions:default": [], }), visibility = ["//visibility:public"], From a89c8685779231f8e7bebcd0592c835cb56508ba Mon Sep 17 00:00:00 2001 From: Mike Aizatsky Date: Wed, 20 Dec 2023 08:51:42 -0800 Subject: [PATCH 03/58] as() syntax sugar methods for primitive types (#1882) String, StringPtr, Array, ArrayPtr gain as() method which is a syntax sugar for T::from(this). This enables defining conversion domains (such as Std in tests) and use `.as()` chained calls for more fluid expressions. --- c++/src/kj/array-test.c++ | 14 ++++++++++++++ c++/src/kj/array.h | 5 +++++ c++/src/kj/common-test.c++ | 15 +++++++++++++++ c++/src/kj/common.h | 5 +++++ c++/src/kj/string-test.c++ | 20 ++++++++++++++++++++ c++/src/kj/string.h | 10 ++++++++++ 6 files changed, 69 insertions(+) diff --git a/c++/src/kj/array-test.c++ b/c++/src/kj/array-test.c++ index 8693d2e4b0..50fd727f83 100644 --- a/c++/src/kj/array-test.c++ +++ b/c++/src/kj/array-test.c++ @@ -24,6 +24,7 @@ #include #include #include +#include namespace kj { namespace { @@ -509,5 +510,18 @@ TEST(Array, AttachFromArrayPtr) { KJ_EXPECT(destroyed1 == 3, destroyed1); } +struct Std { + template + static std::span from(Array* arr) { + return std::span(arr->begin(), arr->size()); + } +}; + +KJ_TEST("Array::as") { + kj::Array arr = kj::arr(1, 2, 4); + std::span stdArr = arr.as(); + KJ_EXPECT(stdArr.size() == 3); +} + } // namespace } // namespace kj diff --git a/c++/src/kj/array.h b/c++/src/kj/array.h index 52e83bd696..4db3934de8 100644 --- a/c++/src/kj/array.h +++ b/c++/src/kj/array.h @@ -237,6 +237,11 @@ class Array { Array attach(Attachments&&... attachments) KJ_WARN_UNUSED_RESULT; // Like Own::attach(), but attaches to an Array. + template + inline auto as() { return U::from(this); } + // Syntax sugar for invoking U::from. + // Used to chain conversion calls rather than wrap with function. + private: T* ptr; size_t size_; diff --git a/c++/src/kj/common-test.c++ b/c++/src/kj/common-test.c++ index 248a169e1d..af5ad74caf 100644 --- a/c++/src/kj/common-test.c++ +++ b/c++/src/kj/common-test.c++ @@ -23,6 +23,7 @@ #include "test.h" #include #include +#include namespace kj { namespace { @@ -925,5 +926,19 @@ KJ_TEST("kj::ArrayPtr startsWith / endsWith / findFirst / findLast") { KJ_EXPECT(arr.findLast(78).orDefault(100) == 100); } +struct Std { + template + static std::span from(ArrayPtr* arr) { + return std::span(arr->begin(), arr->size()); + } +}; + +KJ_TEST("ArrayPtr::as") { + int rawArray[] = {12, 34, 56, 34, 12}; + ArrayPtr arr(rawArray); + std::span stdPtr = arr.as(); + KJ_EXPECT(stdPtr.size() == 5); +} + } // namespace } // namespace kj diff --git a/c++/src/kj/common.h b/c++/src/kj/common.h index 10555fe39f..ac4d906e9f 100644 --- a/c++/src/kj/common.h +++ b/c++/src/kj/common.h @@ -1927,6 +1927,11 @@ class ArrayPtr: public DisallowConstCopyIfNotConst { // // You must include kj/array.h to call this. + template + inline auto as() { return U::from(this); } + // Syntax sugar for invoking U::from. + // Used to chain conversion calls rather than wrap with function. + private: T* ptr; size_t size_; diff --git a/c++/src/kj/string-test.c++ b/c++/src/kj/string-test.c++ index 9c6da1c31d..a1e7167255 100644 --- a/c++/src/kj/string-test.c++ +++ b/c++/src/kj/string-test.c++ @@ -435,6 +435,26 @@ KJ_TEST("StringPtr contains") { KJ_EXPECT(foobar.slice(2).contains(foobar.slice(1)) == false); } +struct Std { + static std::string from(const String* str) { + return std::string(str->cStr()); + } + + static std::string from(const StringPtr* str) { + return std::string(str->cStr()); + } +}; + +KJ_TEST("as") { + String str = kj::str("foo"_kj); + std::string stdStr = str.as(); + KJ_EXPECT(stdStr == "foo"); + + StringPtr ptr = "bar"_kj; + std::string stdPtr = ptr.as(); + KJ_EXPECT(stdPtr == "bar"); +} + } // namespace } // namespace _ (private) } // namespace kj diff --git a/c++/src/kj/string.h b/c++/src/kj/string.h index f47c45f172..84a82485b2 100644 --- a/c++/src/kj/string.h +++ b/c++/src/kj/string.h @@ -176,6 +176,11 @@ class StringPtr { // Like ArrayPtr::attach(), but instead promotes a StringPtr into a ConstString. Generally the // attachment should be an object that somehow owns the String that the StringPtr is pointing at. + template + inline auto as() { return T::from(this); } + // Syntax sugar for invoking T::from. + // Used to chain conversion calls rather than wrap with function. + private: inline explicit constexpr StringPtr(ArrayPtr content): content(content) {} friend constexpr StringPtr (::operator "" _kj)(const char* str, size_t n); @@ -322,6 +327,11 @@ class String { template Maybe tryParseAs() const { return StringPtr(*this).tryParseAs(); } + template + inline auto as() { return T::from(this); } + // Syntax sugar for invoking T::from. + // Used to chain conversion calls rather than wrap with function. + private: Array content; }; From cc1692129ed21272533e5bca888c71c1d293a92a Mon Sep 17 00:00:00 2001 From: Kenton Varda Date: Sat, 23 Dec 2023 13:57:57 -0600 Subject: [PATCH 04/58] Allow InMemoryDirectory to be backed by memfds. In this mode, InMemoryDirectory works as normal, but when it creates a File, instead of using InMemoryFile, it uses DiskFile wrapping a memfd. This creates a nice compromise for writing tests of code that depends on files being backed by real file descriptors, but does not need the same for directories. (If we wanted similar functionality on other operating systems, we could create an InMemoryFileFactory that backs files with anonymous temporary files on disk...) --- c++/src/kj/filesystem-test.c++ | 37 +++++++++++++++ c++/src/kj/filesystem.c++ | 83 +++++++++++++++++++++++++++------- c++/src/kj/filesystem.h | 31 ++++++++++++- 3 files changed, 133 insertions(+), 18 deletions(-) diff --git a/c++/src/kj/filesystem-test.c++ b/c++/src/kj/filesystem-test.c++ index f3eae2fe79..d1c3bc377f 100644 --- a/c++/src/kj/filesystem-test.c++ +++ b/c++/src/kj/filesystem-test.c++ @@ -23,6 +23,10 @@ #include "test.h" #include +#if __linux__ +#include +#endif // __linux__ + namespace kj { namespace { @@ -454,6 +458,7 @@ KJ_TEST("InMemoryDirectory") { { auto file = dir->openFile(Path("foo"), WriteMode::CREATE); + KJ_EXPECT(file->getFd() == kj::none); clock.expectChanged(*dir); file->writeAll("foobar"); clock.expectUnchanged(*dir); @@ -755,7 +760,39 @@ KJ_TEST("InMemoryDirectory createTemporary") { file->writeAll("foobar"); KJ_EXPECT(file->readAllText() == "foobar"); KJ_EXPECT(dir->listNames() == nullptr); + KJ_EXPECT(file->getFd() == kj::none); } +#if __linux__ + +KJ_TEST("InMemoryDirectory backed my memfd") { + // Test memfd-backed in-memory directory. We're not going to test all functionality here, since + // we assume filesystem-disk-test covers fd-backed files in depth. + + TestClock clock; + auto dir = newInMemoryDirectory(clock, memfdInMemoryFileFactory()); + auto file = dir->openFile(Path({"foo", "bar"}), WriteMode::CREATE | WriteMode::CREATE_PARENT); + + // Write directly to the FD, verify it is reflected in the file object. + int fd = KJ_ASSERT_NONNULL(file->getFd()); + ssize_t n; + KJ_SYSCALL(n = write(fd, "foo", 3)); + KJ_EXPECT(n == 3); + + KJ_EXPECT(file->readAllText() == "foo"_kj); + + // Re-opening the same file produces an alias of the same memfd. + auto file2 = dir->openFile(Path({"foo", "bar"})); + KJ_EXPECT(file2->readAllText() == "foo"_kj); + file->writeAll("bar"_kj); + KJ_EXPECT(file2->readAllText() == "bar"_kj); + KJ_EXPECT(file2->getFd() != kj::none); + KJ_EXPECT(file->stat().hashCode == file2->stat().hashCode); + + KJ_EXPECT(dir->createTemporary()->getFd() != kj::none); +} + +#endif // __linux__ + } // namespace } // namespace kj diff --git a/c++/src/kj/filesystem.c++ b/c++/src/kj/filesystem.c++ index 8ab1f941bc..d43d5dbf28 100644 --- a/c++/src/kj/filesystem.c++ +++ b/c++/src/kj/filesystem.c++ @@ -28,6 +28,10 @@ #include "mutex.h" #include +#if __linux__ +#include // for memfd_create() +#endif // __linux__ + namespace kj { Path::Path(StringPtr name): Path(heapString(name)) {} @@ -978,7 +982,8 @@ private: class InMemoryDirectory final: public Directory, public AtomicRefcounted { public: - InMemoryDirectory(const Clock& clock): impl(clock) {} + InMemoryDirectory(const Clock& clock, const InMemoryFileFactory& fileFactory) + : impl(clock, fileFactory) {} Own cloneFsNode() const override { return atomicAddRef(*this); @@ -1154,15 +1159,15 @@ public: if (path.size() == 0) { KJ_FAIL_REQUIRE("can't replace self") { break; } } else if (path.size() == 1) { - // don't need lock just to read the clock ref + // don't need lock just to construct a file return heap>(*this, path[0], - newInMemoryFile(impl.getWithoutLock().clock), mode); + impl.getWithoutLock().newFile(), mode); } else { KJ_IF_SOME(child, tryGetParent(path[0], mode)) { return child->replaceFile(path.slice(1, path.size()), mode); } } - return heap>(newInMemoryFile(impl.getWithoutLock().clock)); + return heap>(impl.getWithoutLock().newFile()); } Maybe> tryOpenSubdir(PathPtr path, WriteMode mode) const override { @@ -1194,15 +1199,15 @@ public: if (path.size() == 0) { KJ_FAIL_REQUIRE("can't replace self") { break; } } else if (path.size() == 1) { - // don't need lock just to read the clock ref + // don't need lock just to construct a directory return heap>(*this, path[0], - newInMemoryDirectory(impl.getWithoutLock().clock), mode); + impl.getWithoutLock().newDirectory(), mode); } else { KJ_IF_SOME(child, tryGetParent(path[0], mode)) { return child->replaceSubdir(path.slice(1, path.size()), mode); } } - return heap>(newInMemoryDirectory(impl.getWithoutLock().clock)); + return heap>(impl.getWithoutLock().newDirectory()); } Maybe> tryAppendFile(PathPtr path, WriteMode mode) const override { @@ -1256,8 +1261,8 @@ public: } Own createTemporary() const override { - // Don't need lock just to read the clock ref. - return newInMemoryFile(impl.getWithoutLock().clock); + // Don't need lock just to construct a file. + return impl.getWithoutLock().newFile(); } bool tryTransfer(PathPtr toPath, WriteMode toMode, @@ -1444,6 +1449,7 @@ private: struct Impl { const Clock& clock; + const InMemoryFileFactory& fileFactory; std::map entries; // Note: If this changes to a non-sorted map, listNames() and listEntries() must be updated to @@ -1451,7 +1457,18 @@ private: Date lastModified; - Impl(const Clock& clock): clock(clock), lastModified(clock.now()) {} + Impl(const Clock& clock, const InMemoryFileFactory& fileFactory) + : clock(clock), fileFactory(fileFactory), lastModified(clock.now()) {} + + Own newFile() const { + // Consturct a new empty file. Note: This function is expected to work without the lock held. + return fileFactory.create(clock); + } + Own newDirectory() const { + // Consturct a new empty directory. Note: This function is expected to work without the lock + // held. + return newInMemoryDirectory(clock, fileFactory); + } Maybe openEntry(kj::StringPtr name, WriteMode mode) { // TODO(perf): We could avoid a copy if the entry exists, at the expense of a double-lookup @@ -1508,7 +1525,7 @@ private: case FsNode::Type::FILE: KJ_IF_SOME(file, fromDirectory.tryOpenFile(fromPath, WriteMode::MODIFY)) { if (mode == TransferMode::COPY) { - auto copy = newInMemoryFile(clock); + auto copy = newFile(); copy->copy(0, *file, 0, size.orDefault(kj::maxValue)); entry.set(kj::mv(copy)); } else { @@ -1528,7 +1545,7 @@ private: case FsNode::Type::DIRECTORY: KJ_IF_SOME(subdir, fromDirectory.tryOpenSubdir(fromPath, WriteMode::MODIFY)) { if (mode == TransferMode::COPY) { - auto copy = atomicRefcounted(clock); + auto copy = atomicRefcounted(clock, fileFactory); auto& cpim = copy->impl.getWithoutLock(); // safe because just-created for (auto& subEntry: subdir->listEntries()) { EntryImpl newEntry(kj::mv(subEntry.name)); @@ -1633,7 +1650,7 @@ private: } else if (entry.node == nullptr) { KJ_ASSERT(has(mode, WriteMode::CREATE)); lock->modified(); - return entry.init(FileNode { newInMemoryFile(lock->clock) }); + return entry.init(FileNode { lock->newFile() }); } else { KJ_FAIL_REQUIRE("not a file") { return kj::none; } } @@ -1651,7 +1668,7 @@ private: } else if (entry.node == nullptr) { KJ_ASSERT(has(mode, WriteMode::CREATE)); lock->modified(); - return entry.init(DirectoryNode { newInMemoryDirectory(lock->clock) }); + return entry.init(DirectoryNode { lock->newDirectory() }); } else { KJ_FAIL_REQUIRE("not a directory") { return kj::none; } } @@ -1682,7 +1699,7 @@ private: return entry.node.get().directory->clone(); } else if (entry.node == nullptr) { lock->modified(); - return entry.init(DirectoryNode { newInMemoryDirectory(lock->clock) }); + return entry.init(DirectoryNode { lock->newDirectory() }); } // Continue on. } @@ -1733,11 +1750,43 @@ private: Own newInMemoryFile(const Clock& clock) { return atomicRefcounted(clock); } -Own newInMemoryDirectory(const Clock& clock) { - return atomicRefcounted(clock); +Own newInMemoryDirectory(const Clock& clock, const InMemoryFileFactory& fileFactory) { + return atomicRefcounted(clock, fileFactory); } Own newFileAppender(Own inner) { return heap(kj::mv(inner)); } +const InMemoryFileFactory& defaultInMemoryFileFactory() { + class FactoryImpl: public InMemoryFileFactory { + public: + kj::Own create(const Clock& clock) const override { + return newInMemoryFile(clock); + } + }; + static const FactoryImpl instance; + return instance; +} + +#if __linux__ + +Own newMemfdFile(uint flags) { + int fd; + KJ_SYSCALL(fd = memfd_create("kj-memfd", flags | MFD_CLOEXEC)); + return newDiskFile(AutoCloseFd(fd)); +} + +const InMemoryFileFactory& memfdInMemoryFileFactory() { + class FactoryImpl: public InMemoryFileFactory { + public: + kj::Own create(const Clock& clock) const override { + return newMemfdFile(0); + } + }; + static const FactoryImpl instance; + return instance; +} + +#endif // __linux__ + } // namespace kj diff --git a/c++/src/kj/filesystem.h b/c++/src/kj/filesystem.h index 8b98226727..1c5ae28389 100644 --- a/c++/src/kj/filesystem.h +++ b/c++/src/kj/filesystem.h @@ -927,8 +927,32 @@ class Filesystem { // ======================================================================================= +class InMemoryFileFactory { + // Used to customize the File implementation used by InMemoryDirectory. +public: + virtual kj::Own create(const Clock& clock) const = 0; +}; + +const InMemoryFileFactory& defaultInMemoryFileFactory(); +// Creates files using `newInMemoryFile()`. + +#if __linux__ + +Own newMemfdFile(uint flags = 0); +// Creates a `File` backed by a Linux memfd. This creates an in-memory file which behaves more +// closely to a disk file, compared to newInMemoryFile(). In particular, the file has a backing +// FD, and memory mapping doesn't have weird quirks. +// +// `flags` will be passed to `memfd_create()`. (The MFD_CLOEXEC flag is always added.) + +const InMemoryFileFactory& memfdInMemoryFileFactory(); +// Creates files using `newMemfdFile()`. + +#endif // __linux__ + Own newInMemoryFile(const Clock& clock); -Own newInMemoryDirectory(const Clock& clock); +Own newInMemoryDirectory(const Clock& clock, + const InMemoryFileFactory& fileFactory = defaultInMemoryFileFactory()); // Construct file and directory objects which reside in-memory. // // InMemoryFile has the following special properties: @@ -942,6 +966,11 @@ Own newInMemoryDirectory(const Clock& clock); // - link() can link directory nodes in addition to files. // - link() and rename() accept any kind of Directory as `fromDirectory` -- it doesn't need to be // another InMemoryDirectory. However, for rename(), the from path must be a directory. +// +// `fileFactory` can be customized in order to control the implementation of `File` objects created +// using this `InMemoryDirectory`. This is particularly useful for testing where the application +// expects files to have backing file descriptors or implement memory mapping fully correctly, but +// doesn't care as much about directory behavior. Own newFileAppender(Own inner); // Creates an AppendableFile by wrapping a File. Note that this implementation assumes it is the From d9c5f425ed04b0952adce8c2c1ce5e565b1ebcf6 Mon Sep 17 00:00:00 2001 From: Felix Hanau Date: Sun, 24 Dec 2023 22:44:59 -0500 Subject: [PATCH 05/58] Add json-rpc-test in bazel json-rpc and json-rpc-test were not included in bazel so far. Adding the test to bazel resolves a long-standing internal issue. --- c++/src/capnp/compat/BUILD.bazel | 39 +++++++++++++++++++++++++++-- c++/src/capnp/compat/json-rpc.c++ | 2 +- c++/src/capnp/compat/json-rpc.capnp | 2 +- 3 files changed, 39 insertions(+), 4 deletions(-) diff --git a/c++/src/capnp/compat/BUILD.bazel b/c++/src/capnp/compat/BUILD.bazel index 951c9ddc2c..7cb67279ca 100644 --- a/c++/src/capnp/compat/BUILD.bazel +++ b/c++/src/capnp/compat/BUILD.bazel @@ -23,6 +23,32 @@ cc_library( ], ) +cc_library( + name = "json-rpc", + srcs = [ + "json-rpc.c++", + ], + hdrs = [ + "json-rpc.h", + ], + include_prefix = "capnp/compat", + visibility = ["//visibility:public"], + deps = [ + ":json-rpc_capnp", + "//src/kj/compat:kj-http", + ], +) + +cc_capnp_library( + name = "json-rpc_capnp", + srcs = [ + "json-rpc.capnp", + ], + include_prefix = "capnp/compat", + src_prefix = "src", + visibility = ["//visibility:public"], +) + cc_capnp_library( name = "json-test_capnp", srcs = [ @@ -93,13 +119,22 @@ cc_library( "websocket-rpc-test.c++", ]] +cc_test( + name = "json-rpc-test", + srcs = ["json-rpc-test.c++"], + deps = [ + ":json-rpc", + "//src/capnp:capnp-test", + ], +) + cc_test( name = "json-test", srcs = ["json-test.c++"], deps = [ - "//src/capnp:capnp-test", ":json", - ":json-test_capnp" + ":json-test_capnp", + "//src/capnp:capnp-test", ], ) diff --git a/c++/src/capnp/compat/json-rpc.c++ b/c++/src/capnp/compat/json-rpc.c++ index 2a8b88d449..92e1ad46f9 100644 --- a/c++/src/capnp/compat/json-rpc.c++ +++ b/c++/src/capnp/compat/json-rpc.c++ @@ -161,7 +161,7 @@ void JsonRpc::queueError(kj::Maybe id, int code, kj::String error.setMessage(message); // OK to discard result of queueWrite() since it's just one branch of a fork. - queueWrite(codec.encode(jsonResponse)); + (void)queueWrite(codec.encode(jsonResponse)); } kj::Promise JsonRpc::readLoop() { diff --git a/c++/src/capnp/compat/json-rpc.capnp b/c++/src/capnp/compat/json-rpc.capnp index 9380788cd7..33e47e0786 100644 --- a/c++/src/capnp/compat/json-rpc.capnp +++ b/c++/src/capnp/compat/json-rpc.capnp @@ -2,7 +2,7 @@ $import "/capnp/c++.capnp".namespace("capnp::json"); -using Json = import "json.capnp"; +using Json = import "/capnp/compat/json.capnp"; struct RpcMessage { jsonrpc @0 :Text; From 3e17170a4f575f87989b565fdb1f95efa69142ae Mon Sep 17 00:00:00 2001 From: Felix Hanau Date: Wed, 20 Dec 2023 21:54:32 -0500 Subject: [PATCH 06/58] Update brotli dependency, move it to WORKSPACE --- c++/WORKSPACE | 11 ++++++++--- c++/build/load_br.bzl | 12 ------------ 2 files changed, 8 insertions(+), 15 deletions(-) delete mode 100644 c++/build/load_br.bzl diff --git a/c++/WORKSPACE b/c++/WORKSPACE index efc25f9167..d3012da069 100644 --- a/c++/WORKSPACE +++ b/c++/WORKSPACE @@ -1,7 +1,6 @@ workspace(name = "capnp-cpp") load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") -load("//:build/load_br.bzl", "load_brotli") http_archive( name = "bazel_skylib", @@ -16,6 +15,14 @@ load("@bazel_skylib//:workspace.bzl", "bazel_skylib_workspace") bazel_skylib_workspace() +http_archive( + name = "brotli", + sha256 = "e720a6ca29428b803f4ad165371771f5398faba397edf6778837a18599ea13ff", + strip_prefix = "brotli-1.1.0", + type = "tgz", + urls = ["https://github.com/google/brotli/archive/refs/tags/v1.1.0.tar.gz"], +) + http_archive( name = "ssl", sha256 = "873ec711658f65192e9c58554ce058d1cfa4e57e13ab5366ee16f76d1c757efc", @@ -52,5 +59,3 @@ http_archive( strip_prefix = "zlib-1.3", urls = ["https://zlib.net/zlib-1.3.tar.xz"], ) - -load_brotli() diff --git a/c++/build/load_br.bzl b/c++/build/load_br.bzl deleted file mode 100644 index fe6fdfedd6..0000000000 --- a/c++/build/load_br.bzl +++ /dev/null @@ -1,12 +0,0 @@ -load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") - -# Defined in a bzl file to allow dependents to pull in brotli via capnproto. Using latest brotli -# commit due to macOS compile issues with v1.0.9, switch to a release version later -def load_brotli(): - http_archive( - name = "brotli", - sha256 = "e33f397d86aaa7f3e786bdf01a7b5cff4101cfb20041c04b313b149d34332f64", - strip_prefix = "google-brotli-ed1995b", - type = "tgz", - urls = ["https://github.com/google/brotli/tarball/ed1995b6bda19244070ab5d331111f16f67c8054"], - ) From 9084594dbb27b5e855ea4c684babe188eafaa900 Mon Sep 17 00:00:00 2001 From: Felix Hanau Date: Wed, 20 Dec 2023 22:05:03 -0500 Subject: [PATCH 07/58] [nfc] Selectively compile platform-specific code, move test helpers out of libkj --- c++/src/kj/BUILD.bazel | 56 +++++++++++++++++++++++++----------------- c++/src/kj/test.h | 8 +++--- 2 files changed, 38 insertions(+), 26 deletions(-) diff --git a/c++/src/kj/BUILD.bazel b/c++/src/kj/BUILD.bazel index e492992c61..db113fa4dd 100644 --- a/c++/src/kj/BUILD.bazel +++ b/c++/src/kj/BUILD.bazel @@ -13,8 +13,6 @@ cc_library( "encoding.c++", "exception.c++", "filesystem.c++", - "filesystem-disk-unix.c++", - "filesystem-disk-win32.c++", "hash.c++", "io.c++", "list.c++", @@ -27,11 +25,13 @@ cc_library( "string.c++", "string-tree.c++", "table.c++", - "test-helpers.c++", "thread.c++", "time.c++", "units.c++", - ], + ] + select({ + "@platforms//os:windows": ["filesystem-disk-win32.c++"], + "//conditions:default": ["filesystem-disk-unix.c++"], + }), hdrs = [ "arena.h", "array.h", @@ -87,12 +87,17 @@ cc_library( srcs = [ "async.c++", "async-io.c++", - "async-io-unix.c++", - "async-io-win32.c++", - "async-unix.c++", - "async-win32.c++", "timer.c++", - ], + ] + select({ + "@platforms//os:windows": [ + "async-io-win32.c++", + "async-win32.c++", + ], + "//conditions:default": [ + "async-io-unix.c++", + "async-unix.c++", + ], + }), hdrs = [ "async.h", "async-inl.h", @@ -100,10 +105,11 @@ cc_library( "async-io-internal.h", "async-prelude.h", "async-queue.h", - "async-unix.h", - "async-win32.h", "timer.h", - ], + ] + select({ + "@platforms//os:windows": ["async-win32.h"], + "//conditions:default": ["async-unix.h"], + }), include_prefix = "kj", linkopts = select({ "@platforms//os:windows": [ @@ -120,6 +126,10 @@ cc_library( name = "kj-test", srcs = [ "test.c++", + "test-helpers.c++", + ], + hdrs = [ + "test.h", ], include_prefix = "kj", visibility = ["//visibility:public"], @@ -191,25 +201,25 @@ cc_library( cc_test( name = "filesystem-disk-generic-test", srcs = ["filesystem-disk-generic-test.c++"], + target_compatible_with = [ + "@platforms//os:linux", + ], deps = [ ":filesystem-disk-test-base", ":kj-test", ], - target_compatible_with = [ - "@platforms//os:linux", - ], ) cc_test( name = "filesystem-disk-old-kernel-test", srcs = ["filesystem-disk-old-kernel-test.c++"], + target_compatible_with = [ + "@platforms//os:linux", + ], deps = [ ":filesystem-disk-test-base", ":kj-test", ], - target_compatible_with = [ - "@platforms//os:linux", - ], ) cc_test( @@ -246,12 +256,14 @@ cc_test( cc_test( name = "exception-override-symbolizer-test", srcs = ["exception-override-symbolizer-test.c++"], + linkstatic = True, + target_compatible_with = select({ + "@platforms//os:linux": [], + "@platforms//os:macos": [], + "//conditions:default": ["@platforms//:incompatible"], + }), deps = [ ":kj", ":kj-test", ], - linkstatic = True, - target_compatible_with = [ - "@platforms//os:linux", - ], ) diff --git a/c++/src/kj/test.h b/c++/src/kj/test.h index e3287b886a..1d845c1452 100644 --- a/c++/src/kj/test.h +++ b/c++/src/kj/test.h @@ -21,10 +21,10 @@ #pragma once -#include "debug.h" -#include "vector.h" -#include "function.h" -#include "windows-sanity.h" // work-around macro conflict with `ERROR` +#include +#include +#include +#include // work-around macro conflict with `ERROR` KJ_BEGIN_HEADER From b021ad21d02c989f99bfe495ff9c5d26312dcc3a Mon Sep 17 00:00:00 2001 From: Kenton Varda Date: Wed, 27 Dec 2023 12:21:10 -0600 Subject: [PATCH 08/58] Cleanup: Remove ez-rpc. I've long considered this library a mistake. It hides too much of the internals, which leads people to get stuck quickly as soon as they actually need those internals. It's much better to use `kj::setupAsyncIo()` to set up the KJ event loop / networking, and then `capnp::TwoPartyServer` or `capnp::TwoPartyClient` to set up RPC. --- c++/Makefile.am | 7 +- c++/src/capnp/BUILD.bazel | 3 - c++/src/capnp/CMakeLists.txt | 3 - c++/src/capnp/ez-rpc-test.c++ | 75 ------- c++/src/capnp/ez-rpc.c++ | 368 ---------------------------------- c++/src/capnp/ez-rpc.h | 251 ----------------------- c++/src/capnp/rpc-twoparty.h | 3 - c++/src/capnp/rpc.h | 12 +- 8 files changed, 3 insertions(+), 719 deletions(-) delete mode 100644 c++/src/capnp/ez-rpc-test.c++ delete mode 100644 c++/src/capnp/ez-rpc.c++ delete mode 100644 c++/src/capnp/ez-rpc.h diff --git a/c++/Makefile.am b/c++/Makefile.am index 1567491d4d..ab928d59df 100644 --- a/c++/Makefile.am +++ b/c++/Makefile.am @@ -229,8 +229,7 @@ includecapnp_HEADERS = \ src/capnp/rpc-twoparty.h \ src/capnp/rpc.capnp.h \ src/capnp/rpc-twoparty.capnp.h \ - src/capnp/persistent.capnp.h \ - src/capnp/ez-rpc.h + src/capnp/persistent.capnp.h includecapnpcompat_HEADERS = \ src/capnp/compat/json.h \ @@ -374,8 +373,7 @@ libcapnp_rpc_la_SOURCES= \ src/capnp/rpc.capnp.c++ \ src/capnp/rpc-twoparty.c++ \ src/capnp/rpc-twoparty.capnp.c++ \ - src/capnp/persistent.capnp.c++ \ - src/capnp/ez-rpc.c++ + src/capnp/persistent.capnp.c++ libcapnp_json_la_LIBADD = libcapnp.la libkj.la $(PTHREAD_LIBS) libcapnp_json_la_LDFLAGS = -release $(SO_VERSION) -no-undefined @@ -542,7 +540,6 @@ heavy_tests = \ src/capnp/serialize-text-test.c++ \ src/capnp/rpc-test.c++ \ src/capnp/rpc-twoparty-test.c++ \ - src/capnp/ez-rpc-test.c++ \ src/capnp/compat/json-test.c++ \ src/capnp/compat/websocket-rpc-test.c++ \ src/capnp/compiler/lexer-test.c++ \ diff --git a/c++/src/capnp/BUILD.bazel b/c++/src/capnp/BUILD.bazel index eea8349985..ec0833cef2 100644 --- a/c++/src/capnp/BUILD.bazel +++ b/c++/src/capnp/BUILD.bazel @@ -60,7 +60,6 @@ cc_library( srcs = [ "capability.c++", "dynamic-capability.c++", - "ez-rpc.c++", "membrane.c++", "persistent.capnp.c++", "reconnect.c++", @@ -71,7 +70,6 @@ cc_library( "serialize-async.c++", ], hdrs = [ - "ez-rpc.h", "persistent.capnp.h", "reconnect.h", "rpc.capnp.h", @@ -231,7 +229,6 @@ cc_library( "dynamic-test.c++", "encoding-test.c++", "endian-test.c++", - "ez-rpc-test.c++", "layout-test.c++", "membrane-test.c++", "message-test.c++", diff --git a/c++/src/capnp/CMakeLists.txt b/c++/src/capnp/CMakeLists.txt index 9980fde617..045cd98f8c 100644 --- a/c++/src/capnp/CMakeLists.txt +++ b/c++/src/capnp/CMakeLists.txt @@ -87,7 +87,6 @@ set(capnp-rpc_sources rpc-twoparty.c++ rpc-twoparty.capnp.c++ persistent.capnp.c++ - ez-rpc.c++ ) set(capnp-rpc_headers rpc-prelude.h @@ -96,7 +95,6 @@ set(capnp-rpc_headers rpc.capnp.h rpc-twoparty.capnp.h persistent.capnp.h - ez-rpc.h ) set(capnp-rpc_schemas rpc.capnp @@ -292,7 +290,6 @@ if(BUILD_TESTING) serialize-text-test.c++ rpc-test.c++ rpc-twoparty-test.c++ - ez-rpc-test.c++ compiler/lexer-test.c++ compiler/type-id-test.c++ test-util.c++ diff --git a/c++/src/capnp/ez-rpc-test.c++ b/c++/src/capnp/ez-rpc-test.c++ deleted file mode 100644 index 0cd2fdbb5e..0000000000 --- a/c++/src/capnp/ez-rpc-test.c++ +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors -// Licensed under the MIT License: -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -// THE SOFTWARE. - -#define CAPNP_TESTING_CAPNP 1 - -#include "ez-rpc.h" -#include "test-util.h" -#include - -namespace capnp { -namespace _ { -namespace { - -TEST(EzRpc, Basic) { - int callCount = 0; - EzRpcServer server(kj::heap(callCount), "localhost"); - - EzRpcClient client("localhost", server.getPort().wait(server.getWaitScope())); - - auto cap = client.getMain(); - auto request = cap.fooRequest(); - request.setI(123); - request.setJ(true); - - EXPECT_EQ(0, callCount); - auto response = request.send().wait(server.getWaitScope()); - EXPECT_EQ("foo", response.getX()); - EXPECT_EQ(1, callCount); -} - -TEST(EzRpc, DeprecatedNames) { - EzRpcServer server("localhost"); - int callCount = 0; - server.exportCap("cap1", kj::heap(callCount)); - server.exportCap("cap2", kj::heap()); - - EzRpcClient client("localhost", server.getPort().wait(server.getWaitScope())); - - auto cap = client.importCap("cap1"); - auto request = cap.fooRequest(); - request.setI(123); - request.setJ(true); - - EXPECT_EQ(0, callCount); - auto response = request.send().wait(server.getWaitScope()); - EXPECT_EQ("foo", response.getX()); - EXPECT_EQ(1, callCount); - - EXPECT_EQ(0, client.importCap("cap2").castAs() - .getCallSequenceRequest().send().wait(server.getWaitScope()).getN()); - EXPECT_EQ(1, client.importCap("cap2").castAs() - .getCallSequenceRequest().send().wait(server.getWaitScope()).getN()); -} - -} // namespace -} // namespace _ -} // namespace capnp diff --git a/c++/src/capnp/ez-rpc.c++ b/c++/src/capnp/ez-rpc.c++ deleted file mode 100644 index 6add953b6c..0000000000 --- a/c++/src/capnp/ez-rpc.c++ +++ /dev/null @@ -1,368 +0,0 @@ -// Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors -// Licensed under the MIT License: -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -// THE SOFTWARE. - -#include "ez-rpc.h" -#include "rpc-twoparty.h" -#include -#include -#include -#include -#include - -namespace capnp { - -KJ_THREADLOCAL_PTR(EzRpcContext) threadEzContext = nullptr; - -class EzRpcContext: public kj::Refcounted { -public: - EzRpcContext(): ioContext(kj::setupAsyncIo()) { - threadEzContext = this; - } - - ~EzRpcContext() noexcept(false) { - KJ_REQUIRE(threadEzContext == this, - "EzRpcContext destroyed from different thread than it was created.") { - return; - } - threadEzContext = nullptr; - } - - kj::WaitScope& getWaitScope() { - return ioContext.waitScope; - } - - kj::AsyncIoProvider& getIoProvider() { - return *ioContext.provider; - } - - kj::LowLevelAsyncIoProvider& getLowLevelIoProvider() { - return *ioContext.lowLevelProvider; - } - - static kj::Own getThreadLocal() { - EzRpcContext* existing = threadEzContext; - if (existing != nullptr) { - return kj::addRef(*existing); - } else { - return kj::refcounted(); - } - } - -private: - kj::AsyncIoContext ioContext; -}; - -// ======================================================================================= - -kj::Promise> connectAttach(kj::Own&& addr) { - return addr->connect().attach(kj::mv(addr)); -} - -struct EzRpcClient::Impl { - kj::Own context; - - struct ClientContext { - kj::Own stream; - TwoPartyVatNetwork network; - RpcSystem rpcSystem; - - ClientContext(kj::Own&& stream, ReaderOptions readerOpts) - : stream(kj::mv(stream)), - network(*this->stream, rpc::twoparty::Side::CLIENT, readerOpts), - rpcSystem(makeRpcClient(network)) {} - - Capability::Client getMain() { - word scratch[4]; - memset(scratch, 0, sizeof(scratch)); - MallocMessageBuilder message(scratch); - auto hostId = message.getRoot(); - hostId.setSide(rpc::twoparty::Side::SERVER); - return rpcSystem.bootstrap(hostId); - } - - Capability::Client restore(kj::StringPtr name) { - word scratch[64]; - memset(scratch, 0, sizeof(scratch)); - MallocMessageBuilder message(scratch); - - auto hostIdOrphan = message.getOrphanage().newOrphan(); - auto hostId = hostIdOrphan.get(); - hostId.setSide(rpc::twoparty::Side::SERVER); - - auto objectId = message.getRoot(); - objectId.setAs(name); -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wdeprecated-declarations" - return rpcSystem.restore(hostId, objectId); -#pragma GCC diagnostic pop - } - }; - - kj::ForkedPromise setupPromise; - - kj::Maybe> clientContext; - // Filled in before `setupPromise` resolves. - - Impl(kj::StringPtr serverAddress, uint defaultPort, - ReaderOptions readerOpts) - : context(EzRpcContext::getThreadLocal()), - setupPromise(context->getIoProvider().getNetwork() - .parseAddress(serverAddress, defaultPort) - .then([](kj::Own&& addr) { - return connectAttach(kj::mv(addr)); - }).then([this, readerOpts](kj::Own&& stream) { - clientContext = kj::heap(kj::mv(stream), - readerOpts); - }).fork()) {} - - Impl(const struct sockaddr* serverAddress, uint addrSize, - ReaderOptions readerOpts) - : context(EzRpcContext::getThreadLocal()), - setupPromise( - connectAttach(context->getIoProvider().getNetwork() - .getSockaddr(serverAddress, addrSize)) - .then([this, readerOpts](kj::Own&& stream) { - clientContext = kj::heap(kj::mv(stream), - readerOpts); - }).fork()) {} - - Impl(int socketFd, ReaderOptions readerOpts) - : context(EzRpcContext::getThreadLocal()), - setupPromise(kj::Promise(kj::READY_NOW).fork()), - clientContext(kj::heap( - context->getLowLevelIoProvider().wrapSocketFd(socketFd), - readerOpts)) {} -}; - -EzRpcClient::EzRpcClient(kj::StringPtr serverAddress, uint defaultPort, ReaderOptions readerOpts) - : impl(kj::heap(serverAddress, defaultPort, readerOpts)) {} - -EzRpcClient::EzRpcClient(const struct sockaddr* serverAddress, uint addrSize, ReaderOptions readerOpts) - : impl(kj::heap(serverAddress, addrSize, readerOpts)) {} - -EzRpcClient::EzRpcClient(int socketFd, ReaderOptions readerOpts) - : impl(kj::heap(socketFd, readerOpts)) {} - -EzRpcClient::~EzRpcClient() noexcept(false) {} - -Capability::Client EzRpcClient::getMain() { - KJ_IF_SOME(client, impl->clientContext) { - return client->getMain(); - } else { - return impl->setupPromise.addBranch().then([this]() { - return KJ_ASSERT_NONNULL(impl->clientContext)->getMain(); - }); - } -} - -Capability::Client EzRpcClient::importCap(kj::StringPtr name) { - KJ_IF_SOME(client, impl->clientContext) { - return client->restore(name); - } else { - return impl->setupPromise.addBranch().then( - [this,name=kj::heapString(name)]() { - return KJ_ASSERT_NONNULL(impl->clientContext)->restore(name); - }); - } -} - -kj::WaitScope& EzRpcClient::getWaitScope() { - return impl->context->getWaitScope(); -} - -kj::AsyncIoProvider& EzRpcClient::getIoProvider() { - return impl->context->getIoProvider(); -} - -kj::LowLevelAsyncIoProvider& EzRpcClient::getLowLevelIoProvider() { - return impl->context->getLowLevelIoProvider(); -} - -// ======================================================================================= - -namespace { - -class DummyFilter: public kj::LowLevelAsyncIoProvider::NetworkFilter { -public: - bool shouldAllow(const struct sockaddr* addr, uint addrlen) override { - return true; - } -}; - -static DummyFilter DUMMY_FILTER; - -} // namespace - -struct EzRpcServer::Impl final: public SturdyRefRestorer, - public kj::TaskSet::ErrorHandler { - Capability::Client mainInterface; - kj::Own context; - - struct ExportedCap { - kj::String name; - Capability::Client cap = nullptr; - - ExportedCap(kj::StringPtr name, Capability::Client cap) - : name(kj::heapString(name)), cap(cap) {} - - ExportedCap() = default; - ExportedCap(const ExportedCap&) = delete; - ExportedCap(ExportedCap&&) = default; - ExportedCap& operator=(const ExportedCap&) = delete; - ExportedCap& operator=(ExportedCap&&) = default; - // Make std::map happy... - }; - - std::map exportMap; - - kj::ForkedPromise portPromise; - - kj::TaskSet tasks; - - struct ServerContext { - kj::Own stream; - TwoPartyVatNetwork network; - RpcSystem rpcSystem; - -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wdeprecated-declarations" - ServerContext(kj::Own&& stream, SturdyRefRestorer& restorer, - ReaderOptions readerOpts) - : stream(kj::mv(stream)), - network(*this->stream, rpc::twoparty::Side::SERVER, readerOpts), - rpcSystem(makeRpcServer(network, restorer)) {} -#pragma GCC diagnostic pop - }; - - Impl(Capability::Client mainInterface, kj::StringPtr bindAddress, uint defaultPort, - ReaderOptions readerOpts) - : mainInterface(kj::mv(mainInterface)), - context(EzRpcContext::getThreadLocal()), portPromise(nullptr), tasks(*this) { - auto paf = kj::newPromiseAndFulfiller(); - portPromise = paf.promise.fork(); - - tasks.add(context->getIoProvider().getNetwork().parseAddress(bindAddress, defaultPort) - .then([this, portFulfiller=kj::mv(paf.fulfiller), readerOpts](kj::Own&& addr) mutable { - auto listener = addr->listen(); - portFulfiller->fulfill(listener->getPort()); - acceptLoop(kj::mv(listener), readerOpts); - })); - } - - Impl(Capability::Client mainInterface, struct sockaddr* bindAddress, uint addrSize, - ReaderOptions readerOpts) - : mainInterface(kj::mv(mainInterface)), - context(EzRpcContext::getThreadLocal()), portPromise(nullptr), tasks(*this) { - auto listener = context->getIoProvider().getNetwork() - .getSockaddr(bindAddress, addrSize)->listen(); - portPromise = kj::Promise(listener->getPort()).fork(); - acceptLoop(kj::mv(listener), readerOpts); - } - - Impl(Capability::Client mainInterface, int socketFd, uint port, ReaderOptions readerOpts) - : mainInterface(kj::mv(mainInterface)), - context(EzRpcContext::getThreadLocal()), - portPromise(kj::Promise(port).fork()), - tasks(*this) { - acceptLoop(context->getLowLevelIoProvider().wrapListenSocketFd(socketFd, DUMMY_FILTER), - readerOpts); - } - - void acceptLoop(kj::Own&& listener, ReaderOptions readerOpts) { - auto ptr = listener.get(); - tasks.add(ptr->accept().then([this, listener=kj::mv(listener), readerOpts](kj::Own&& connection) mutable { - acceptLoop(kj::mv(listener), readerOpts); - - auto server = kj::heap(kj::mv(connection), *this, readerOpts); - - // Arrange to destroy the server context when all references are gone, or when the - // EzRpcServer is destroyed (which will destroy the TaskSet). - tasks.add(server->network.onDisconnect().attach(kj::mv(server))); - })); - } - - Capability::Client restore(AnyPointer::Reader objectId) override { - if (objectId.isNull()) { - return mainInterface; - } else { - auto name = objectId.getAs(); - auto iter = exportMap.find(name); - if (iter == exportMap.end()) { - KJ_FAIL_REQUIRE("Server exports no such capability.", name) { break; } - return nullptr; - } else { - return iter->second.cap; - } - } - } - - void taskFailed(kj::Exception&& exception) override { - kj::throwFatalException(kj::mv(exception)); - } -}; - -EzRpcServer::EzRpcServer(Capability::Client mainInterface, kj::StringPtr bindAddress, - uint defaultPort, ReaderOptions readerOpts) - : impl(kj::heap(kj::mv(mainInterface), bindAddress, defaultPort, readerOpts)) {} - -EzRpcServer::EzRpcServer(Capability::Client mainInterface, struct sockaddr* bindAddress, - uint addrSize, ReaderOptions readerOpts) - : impl(kj::heap(kj::mv(mainInterface), bindAddress, addrSize, readerOpts)) {} - -EzRpcServer::EzRpcServer(Capability::Client mainInterface, int socketFd, uint port, - ReaderOptions readerOpts) - : impl(kj::heap(kj::mv(mainInterface), socketFd, port, readerOpts)) {} - -EzRpcServer::EzRpcServer(kj::StringPtr bindAddress, uint defaultPort, - ReaderOptions readerOpts) - : EzRpcServer(nullptr, bindAddress, defaultPort, readerOpts) {} - -EzRpcServer::EzRpcServer(struct sockaddr* bindAddress, uint addrSize, - ReaderOptions readerOpts) - : EzRpcServer(nullptr, bindAddress, addrSize, readerOpts) {} - -EzRpcServer::EzRpcServer(int socketFd, uint port, ReaderOptions readerOpts) - : EzRpcServer(nullptr, socketFd, port, readerOpts) {} - -EzRpcServer::~EzRpcServer() noexcept(false) {} - -void EzRpcServer::exportCap(kj::StringPtr name, Capability::Client cap) { - Impl::ExportedCap entry(kj::heapString(name), cap); - impl->exportMap[entry.name] = kj::mv(entry); -} - -kj::Promise EzRpcServer::getPort() { - return impl->portPromise.addBranch(); -} - -kj::WaitScope& EzRpcServer::getWaitScope() { - return impl->context->getWaitScope(); -} - -kj::AsyncIoProvider& EzRpcServer::getIoProvider() { - return impl->context->getIoProvider(); -} - -kj::LowLevelAsyncIoProvider& EzRpcServer::getLowLevelIoProvider() { - return impl->context->getLowLevelIoProvider(); -} - -} // namespace capnp diff --git a/c++/src/capnp/ez-rpc.h b/c++/src/capnp/ez-rpc.h deleted file mode 100644 index ef2649239a..0000000000 --- a/c++/src/capnp/ez-rpc.h +++ /dev/null @@ -1,251 +0,0 @@ -// Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors -// Licensed under the MIT License: -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -// THE SOFTWARE. - -#pragma once - -#include "rpc.h" -#include "message.h" - -CAPNP_BEGIN_HEADER - -struct sockaddr; - -namespace kj { class AsyncIoProvider; class LowLevelAsyncIoProvider; } - -namespace capnp { - -class EzRpcContext; - -class EzRpcClient { - // Super-simple interface for setting up a Cap'n Proto RPC client. Example: - // - // # Cap'n Proto schema - // interface Adder { - // add @0 (left :Int32, right :Int32) -> (value :Int32); - // } - // - // // C++ client - // int main() { - // capnp::EzRpcClient client("localhost:3456"); - // Adder::Client adder = client.getMain(); - // auto request = adder.addRequest(); - // request.setLeft(12); - // request.setRight(34); - // auto response = request.send().wait(client.getWaitScope()); - // assert(response.getValue() == 46); - // return 0; - // } - // - // // C++ server - // class AdderImpl final: public Adder::Server { - // public: - // kj::Promise add(AddContext context) override { - // auto params = context.getParams(); - // context.getResults().setValue(params.getLeft() + params.getRight()); - // return kj::READY_NOW; - // } - // }; - // - // int main() { - // capnp::EzRpcServer server(kj::heap(), "*:3456"); - // kj::NEVER_DONE.wait(server.getWaitScope()); - // } - // - // This interface is easy, but it hides a lot of useful features available from the lower-level - // classes: - // - The server can only export a small set of public, singleton capabilities under well-known - // string names. This is fine for transient services where no state needs to be kept between - // connections, but hides the power of Cap'n Proto when it comes to long-lived resources. - // - EzRpcClient/EzRpcServer automatically set up a `kj::EventLoop` and make it current for the - // thread. Only one `kj::EventLoop` can exist per thread, so you cannot use these interfaces - // if you wish to set up your own event loop. (However, you can safely create multiple - // EzRpcClient / EzRpcServer objects in a single thread; they will make sure to make no more - // than one EventLoop.) - // - These classes only support simple two-party connections, not multilateral VatNetworks. - // - These classes only support communication over a raw, unencrypted socket. If you want to - // build on an abstract stream (perhaps one which supports encryption), you must use the - // lower-level interfaces. - // - // Some of these restrictions will probably be lifted in future versions, but some things will - // always require using the low-level interfaces directly. If you are interested in working - // at a lower level, start by looking at these interfaces: - // - `kj::setupAsyncIo()` in `kj/async-io.h`. - // - `RpcSystem` in `capnp/rpc.h`. - // - `TwoPartyVatNetwork` in `capnp/rpc-twoparty.h`. - -public: - explicit EzRpcClient(kj::StringPtr serverAddress, uint defaultPort = 0, - ReaderOptions readerOpts = ReaderOptions()); - // Construct a new EzRpcClient and connect to the given address. The connection is formed in - // the background -- if it fails, calls to capabilities returned by importCap() will fail with an - // appropriate exception. - // - // `defaultPort` is the IP port number to use if `serverAddress` does not include it explicitly. - // If unspecified, the port is required in `serverAddress`. - // - // The address is parsed by `kj::Network` in `kj/async-io.h`. See that interface for more info - // on the address format, but basically it's what you'd expect. - // - // `readerOpts` is the ReaderOptions structure used to read each incoming message on the - // connection. Setting this may be necessary if you need to receive very large individual - // messages or messages. However, it is recommended that you instead think about how to change - // your protocol to send large data blobs in multiple small chunks -- this is much better for - // both security and performance. See `ReaderOptions` in `message.h` for more details. - - EzRpcClient(const struct sockaddr* serverAddress, uint addrSize, - ReaderOptions readerOpts = ReaderOptions()); - // Like the above constructor, but connects to an already-resolved socket address. Any address - // format supported by `kj::Network` in `kj/async-io.h` is accepted. - - explicit EzRpcClient(int socketFd, ReaderOptions readerOpts = ReaderOptions()); - // Create a client on top of an already-connected socket. - // `readerOpts` acts as in the first constructor. - - ~EzRpcClient() noexcept(false); - - template - typename Type::Client getMain(); - Capability::Client getMain(); - // Get the server's main (aka "bootstrap") interface. - - template - typename Type::Client importCap(kj::StringPtr name) CAPNP_DEPRECATED( - "Change your server to export a main interface, then use getMain() instead."); - Capability::Client importCap(kj::StringPtr name) CAPNP_DEPRECATED( - "Change your server to export a main interface, then use getMain() instead."); - // ** DEPRECATED ** - // - // Ask the sever for the capability with the given name. You may specify a type to automatically - // down-cast to that type. It is up to you to specify the correct expected type. - // - // Named interfaces are deprecated. The new preferred usage pattern is for the server to export - // a "main" interface which itself has methods for getting any other interfaces. - - kj::WaitScope& getWaitScope(); - // Get the `WaitScope` for the client's `EventLoop`, which allows you to synchronously wait on - // promises. - - kj::AsyncIoProvider& getIoProvider(); - // Get the underlying AsyncIoProvider set up by the RPC system. This is useful if you want - // to do some non-RPC I/O in asynchronous fashion. - - kj::LowLevelAsyncIoProvider& getLowLevelIoProvider(); - // Get the underlying LowLevelAsyncIoProvider set up by the RPC system. This is useful if you - // want to do some non-RPC I/O in asynchronous fashion. - -private: - struct Impl; - kj::Own impl; -}; - -class EzRpcServer { - // The server counterpart to `EzRpcClient`. See `EzRpcClient` for an example. - -public: - explicit EzRpcServer(Capability::Client mainInterface, kj::StringPtr bindAddress, - uint defaultPort = 0, ReaderOptions readerOpts = ReaderOptions()); - // Construct a new `EzRpcServer` that binds to the given address. An address of "*" means to - // bind to all local addresses. - // - // `defaultPort` is the IP port number to use if `serverAddress` does not include it explicitly. - // If unspecified, a port is chosen automatically, and you must call getPort() to find out what - // it is. - // - // The address is parsed by `kj::Network` in `kj/async-io.h`. See that interface for more info - // on the address format, but basically it's what you'd expect. - // - // The server might not begin listening immediately, especially if `bindAddress` needs to be - // resolved. If you need to wait until the server is definitely up, wait on the promise returned - // by `getPort()`. - // - // `readerOpts` is the ReaderOptions structure used to read each incoming message on the - // connection. Setting this may be necessary if you need to receive very large individual - // messages or messages. However, it is recommended that you instead think about how to change - // your protocol to send large data blobs in multiple small chunks -- this is much better for - // both security and performance. See `ReaderOptions` in `message.h` for more details. - - EzRpcServer(Capability::Client mainInterface, struct sockaddr* bindAddress, uint addrSize, - ReaderOptions readerOpts = ReaderOptions()); - // Like the above constructor, but binds to an already-resolved socket address. Any address - // format supported by `kj::Network` in `kj/async-io.h` is accepted. - - EzRpcServer(Capability::Client mainInterface, int socketFd, uint port, - ReaderOptions readerOpts = ReaderOptions()); - // Create a server on top of an already-listening socket (i.e. one on which accept() may be - // called). `port` is returned by `getPort()` -- it serves no other purpose. - // `readerOpts` acts as in the other two above constructors. - - explicit EzRpcServer(kj::StringPtr bindAddress, uint defaultPort = 0, - ReaderOptions readerOpts = ReaderOptions()) - CAPNP_DEPRECATED("Please specify a main interface for your server."); - EzRpcServer(struct sockaddr* bindAddress, uint addrSize, - ReaderOptions readerOpts = ReaderOptions()) - CAPNP_DEPRECATED("Please specify a main interface for your server."); - EzRpcServer(int socketFd, uint port, ReaderOptions readerOpts = ReaderOptions()) - CAPNP_DEPRECATED("Please specify a main interface for your server."); - - ~EzRpcServer() noexcept(false); - - void exportCap(kj::StringPtr name, Capability::Client cap); - // Export a capability publicly under the given name, so that clients can import it. - // - // Keep in mind that you can implicitly convert `kj::Own&&` to - // `Capability::Client`, so it's typical to pass something like - // `kj::heap()` as the second parameter. - - kj::Promise getPort(); - // Get the IP port number on which this server is listening. This promise won't resolve until - // the server is actually listening. If the address was not an IP address (e.g. it was a Unix - // domain socket) then getPort() resolves to zero. - - kj::WaitScope& getWaitScope(); - // Get the `WaitScope` for the client's `EventLoop`, which allows you to synchronously wait on - // promises. - - kj::AsyncIoProvider& getIoProvider(); - // Get the underlying AsyncIoProvider set up by the RPC system. This is useful if you want - // to do some non-RPC I/O in asynchronous fashion. - - kj::LowLevelAsyncIoProvider& getLowLevelIoProvider(); - // Get the underlying LowLevelAsyncIoProvider set up by the RPC system. This is useful if you - // want to do some non-RPC I/O in asynchronous fashion. - -private: - struct Impl; - kj::Own impl; -}; - -// ======================================================================================= -// inline implementation details - -template -inline typename Type::Client EzRpcClient::getMain() { - return getMain().castAs(); -} - -template -inline typename Type::Client EzRpcClient::importCap(kj::StringPtr name) { - return importCap(name).castAs(); -} - -} // namespace capnp - -CAPNP_END_HEADER diff --git a/c++/src/capnp/rpc-twoparty.h b/c++/src/capnp/rpc-twoparty.h index a2a1cec67c..53db974c91 100644 --- a/c++/src/capnp/rpc-twoparty.h +++ b/c++/src/capnp/rpc-twoparty.h @@ -47,9 +47,6 @@ class TwoPartyVatNetwork: public TwoPartyVatNetworkBase, private RpcFlowController::WindowGetter { // A `VatNetwork` that consists of exactly two parties communicating over an arbitrary byte // stream. This is used to implement the common case of a client/server network. - // - // See `ez-rpc.h` for a simple interface for setting up two-party clients and servers. - // Use `TwoPartyVatNetwork` only if you need the advanced features. public: TwoPartyVatNetwork(MessageStream& msgStream, diff --git a/c++/src/capnp/rpc.h b/c++/src/capnp/rpc.h index b153307f99..1777c75123 100644 --- a/c++/src/capnp/rpc.h +++ b/c++/src/capnp/rpc.h @@ -68,9 +68,6 @@ class RpcSystem: public _::RpcSystemBase { // // See `makeRpcServer()` and `makeRpcClient()` below for convenient syntax for setting up an // `RpcSystem` given a `VatNetwork`. - // - // See `ez-rpc.h` for an even simpler interface for setting up RPC in a typical two-party - // client/server scenario. public: template makeRpcServer( // MyMainInterface::Client bootstrap = makeMain(); // auto server = makeRpcServer(network, bootstrap); // kj::NEVER_DONE.wait(waitScope); // run forever -// -// See also ez-rpc.h, which has simpler instructions for the common case of a two-party -// client-server RPC connection. template @@ -221,9 +215,6 @@ RpcSystem makeRpcClient( // MyCapability::Client cap = client.restore(hostId, objId).castAs(); // auto response = cap.fooRequest().send().wait(waitScope); // handleMyResponse(response); -// -// See also ez-rpc.h, which has simpler instructions for the common case of a two-party -// client-server RPC connection. template class SturdyRefRestorer: public _::SturdyRefRestorerBase { @@ -372,8 +363,7 @@ class VatNetwork: public _::VatNetworkBase { // to manage object references and make method calls. // // The most common implementation of VatNetwork is TwoPartyVatNetwork (rpc-twoparty.h). Most - // simple client-server apps will want to use it. (You may even want to use the EZ RPC - // interfaces in `ez-rpc.h` and avoid all of this.) + // simple client-server apps will want to use it. // // TODO(someday): Provide a standard implementation for the public internet. From 61568ebf6f678740668c47855768cbce9905d2cd Mon Sep 17 00:00:00 2001 From: Kenton Varda Date: Wed, 27 Dec 2023 12:29:23 -0600 Subject: [PATCH 09/58] Cleanup: Delete `KJ_THREADLOCAL_PTR`. Plain old `thread_local` should be fine these days. --- c++/Makefile.am | 2 - c++/src/kj/BUILD.bazel | 2 - c++/src/kj/CMakeLists.txt | 2 - c++/src/kj/async-unix.c++ | 7 ++-- c++/src/kj/async.c++ | 5 +-- c++/src/kj/exception.c++ | 5 +-- c++/src/kj/threadlocal-test.c++ | 69 --------------------------------- c++/src/kj/threadlocal.h | 66 ------------------------------- 8 files changed, 7 insertions(+), 151 deletions(-) delete mode 100644 c++/src/kj/threadlocal-test.c++ delete mode 100644 c++/src/kj/threadlocal.h diff --git a/c++/Makefile.am b/c++/Makefile.am index ab928d59df..fc05b7a560 100644 --- a/c++/Makefile.am +++ b/c++/Makefile.am @@ -165,7 +165,6 @@ includekj_HEADERS = \ src/kj/mutex.h \ src/kj/source-location.h \ src/kj/thread.h \ - src/kj/threadlocal.h \ src/kj/filesystem.h \ src/kj/async-prelude.h \ src/kj/async.h \ @@ -584,7 +583,6 @@ capnp_test_SOURCES = \ src/kj/io-test.c++ \ src/kj/mutex-test.c++ \ src/kj/time-test.c++ \ - src/kj/threadlocal-test.c++ \ src/kj/filesystem-test.c++ \ src/kj/filesystem-disk-test.c++ \ src/kj/test-test.c++ \ diff --git a/c++/src/kj/BUILD.bazel b/c++/src/kj/BUILD.bazel index e492992c61..bed00c477d 100644 --- a/c++/src/kj/BUILD.bazel +++ b/c++/src/kj/BUILD.bazel @@ -61,7 +61,6 @@ cc_library( "table.h", "test.h", "thread.h", - "threadlocal.h", "time.h", "tuple.h", "units.h", @@ -164,7 +163,6 @@ cc_library( "string-tree-test.c++", "table-test.c++", "test-test.c++", - "threadlocal-test.c++", "thread-test.c++", "time-test.c++", "tuple-test.c++", diff --git a/c++/src/kj/CMakeLists.txt b/c++/src/kj/CMakeLists.txt index c84f550d16..bd72009254 100644 --- a/c++/src/kj/CMakeLists.txt +++ b/c++/src/kj/CMakeLists.txt @@ -62,7 +62,6 @@ set(kj_headers function.h mutex.h thread.h - threadlocal.h filesystem.h time.h main.h @@ -252,7 +251,6 @@ if(BUILD_TESTING) io-test.c++ mutex-test.c++ time-test.c++ - threadlocal-test.c++ test-test.c++ std/iostream-test.c++ ) diff --git a/c++/src/kj/async-unix.c++ b/c++/src/kj/async-unix.c++ index 6de12479ec..dfd25ee929 100644 --- a/c++/src/kj/async-unix.c++ +++ b/c++/src/kj/async-unix.c++ @@ -23,7 +23,6 @@ #include "async-unix.h" #include "debug.h" -#include "threadlocal.h" #include #include #include @@ -70,7 +69,7 @@ bool threadClaimedChildExits = false; namespace { -KJ_THREADLOCAL_PTR(UnixEventPort) threadEventPort = nullptr; +thread_local UnixEventPort* threadEventPort = nullptr; // This is set to the current UnixEventPort just before epoll_pwait(), then back to null after it // returns. @@ -100,7 +99,7 @@ void UnixEventPort::signalHandler(int, siginfo_t* siginfo, void*) noexcept { #elif KJ_USE_KQUEUE #if !KJ_HAS_SIGTIMEDWAIT -KJ_THREADLOCAL_PTR(siginfo_t) threadCapture = nullptr; +static thread_local siginfo_t* threadCapture = nullptr; #endif void UnixEventPort::signalHandler(int, siginfo_t* siginfo, void*) noexcept { @@ -152,7 +151,7 @@ struct SignalCapture { #endif }; -KJ_THREADLOCAL_PTR(SignalCapture) threadCapture = nullptr; +thread_local SignalCapture* threadCapture = nullptr; } // namespace diff --git a/c++/src/kj/async.c++ b/c++/src/kj/async.c++ index d4bfdd72ff..80cce9bc8d 100644 --- a/c++/src/kj/async.c++ +++ b/c++/src/kj/async.c++ @@ -42,7 +42,6 @@ #include "async.h" #include "debug.h" #include "vector.h" -#include "threadlocal.h" #include "mutex.h" #include "one-of.h" #include "function.h" @@ -116,7 +115,7 @@ namespace kj { namespace { -KJ_THREADLOCAL_PTR(DisallowAsyncDestructorsScope) disallowAsyncDestructorsScope = nullptr; +thread_local DisallowAsyncDestructorsScope* disallowAsyncDestructorsScope = nullptr; } // namespace @@ -161,7 +160,7 @@ AllowAsyncDestructorsScope::~AllowAsyncDestructorsScope() { namespace { -KJ_THREADLOCAL_PTR(EventLoop) threadLocalEventLoop = nullptr; +thread_local EventLoop* threadLocalEventLoop = nullptr; #define _kJ_ALREADY_READY reinterpret_cast< ::kj::_::Event*>(1) diff --git a/c++/src/kj/exception.c++ b/c++/src/kj/exception.c++ index be582f3d56..3a03f0be6b 100644 --- a/c++/src/kj/exception.c++ +++ b/c++/src/kj/exception.c++ @@ -37,7 +37,6 @@ #include "exception.h" #include "string.h" #include "debug.h" -#include "threadlocal.h" #include "miniposix.h" #include "function.h" #include "main.h" @@ -919,7 +918,7 @@ void Exception::addTraceHere() { namespace { -KJ_THREADLOCAL_PTR(ExceptionImpl) currentException = nullptr; +thread_local ExceptionImpl* currentException = nullptr; void validateExceptionPointer(const ExceptionImpl* e) noexcept { // Occasionally in production we are seeing `currentException` have the value 1. Try to figure @@ -1013,7 +1012,7 @@ kj::Exception getDestructionReason(void* traceSeparator, kj::Exception::Type def namespace { -KJ_THREADLOCAL_PTR(ExceptionCallback) threadLocalCallback = nullptr; +thread_local ExceptionCallback* threadLocalCallback = nullptr; } // namespace diff --git a/c++/src/kj/threadlocal-test.c++ b/c++/src/kj/threadlocal-test.c++ deleted file mode 100644 index 7d409912e3..0000000000 --- a/c++/src/kj/threadlocal-test.c++ +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors -// Licensed under the MIT License: -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -// THE SOFTWARE. - -#include "threadlocal.h" -#include "debug.h" -#include "thread.h" -#include - -namespace kj { -namespace { - -KJ_THREADLOCAL_PTR(uint) tls1 = nullptr; -KJ_THREADLOCAL_PTR(uint) tls2; - -TEST(ThreadLocal, Basic) { - // Verify that both started out null. - uint* p = tls1; - EXPECT_EQ(nullptr, p); - p = tls2; - EXPECT_EQ(nullptr, p); - - // Set tls1, then verify that only tls1 changed, not tls2. - uint i = 123; - tls1 = &i; - - p = tls1; - EXPECT_EQ(&i, p); - p = tls2; - EXPECT_EQ(nullptr, p); - - // Check that in another thread, tls1 starts null but can be changed. - uint j = 456; - bool threadDone = false; - Thread([&]() { - p = tls1; - EXPECT_EQ(nullptr, p); - tls1 = &j; - - p = tls1; - EXPECT_EQ(&j, p); - threadDone = true; - }); - EXPECT_TRUE(threadDone); - - // tls1 didn't change in this thread. - p = tls1; - EXPECT_EQ(&i, p); -} - -} // namespace -} // namespace kj diff --git a/c++/src/kj/threadlocal.h b/c++/src/kj/threadlocal.h deleted file mode 100644 index 613b96e788..0000000000 --- a/c++/src/kj/threadlocal.h +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright (c) 2014, Jason Choy -// Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors -// Licensed under the MIT License: -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -// THE SOFTWARE. - -#pragma once - -// This file declares a macro `KJ_THREADLOCAL_PTR` for declaring thread-local pointer-typed -// variables. Use like: -// KJ_THREADLOCAL_PTR(MyType) foo = nullptr; -// This is equivalent to: -// thread_local MyType* foo = nullptr; -// This can only be used at the global scope. -// -// AVOID USING THIS. Use of thread-locals is discouraged because they often have many of the same -// properties as singletons: http://www.object-oriented-security.org/lets-argue/singletons -// -// Also, thread-locals tend to be hostile to event-driven code, which can be particularly -// surprising when using fibers (all fibers in the same thread will share the same threadlocals, -// even though they do not share a stack). -// -// That said, thread-locals are sometimes needed for runtime logistics in the KJ framework. For -// example, the current exception callback and current EventLoop are stored as thread-local -// pointers. Since KJ only ever needs to store pointers, not values, we avoid the question of -// whether these values' destructors need to be run, and we avoid the need for heap allocation. - -#include "common.h" - -KJ_BEGIN_HEADER - -namespace kj { - -#if __GNUC__ - -#define KJ_THREADLOCAL_PTR(type) static __thread type* -// GCC's __thread is lighter-weight than thread_local and is good enough for our purposes. -// -// TODO(cleanup): The above comment was written many years ago. Is it still true? Shouldn't the -// compiler be smart enough to optimize a thread_local of POD type? - -#else - -#define KJ_THREADLOCAL_PTR(type) static thread_local type* - -#endif // KJ_USE_PTHREAD_TLS - -} // namespace kj - -KJ_END_HEADER From 4066806439698242f991ac67ab414479f3e97a58 Mon Sep 17 00:00:00 2001 From: Kenton Varda Date: Wed, 27 Dec 2023 13:25:39 -0600 Subject: [PATCH 10/58] Update samples for removal of ez-rpc. --- c++/samples/calculator-client.c++ | 22 ++++++++++++++++++---- c++/samples/calculator-server.c++ | 21 +++++++++++++++------ 2 files changed, 33 insertions(+), 10 deletions(-) diff --git a/c++/samples/calculator-client.c++ b/c++/samples/calculator-client.c++ index 5d8452921c..bdaa5260ec 100644 --- a/c++/samples/calculator-client.c++ +++ b/c++/samples/calculator-client.c++ @@ -20,7 +20,8 @@ // THE SOFTWARE. #include "calculator.capnp.h" -#include +#include +#include #include #include #include @@ -47,13 +48,26 @@ int main(int argc, const char* argv[]) { return 1; } - capnp::EzRpcClient client(argv[1]); - Calculator::Client calculator = client.getMain(); + // First we need to set up the KJ async event loop. This should happen one + // per thread that needs to perform RPC. + auto io = kj::setupAsyncIo(); // Keep an eye on `waitScope`. Whenever you see it used is a place where we // stop and wait for the server to respond. If a line of code does not use // `waitScope`, then it does not block! - auto& waitScope = client.getWaitScope(); + auto& waitScope = io.waitScope; + + // Using KJ APIs, let's parse our network address and connect to it. + kj::Network& network = io.provider->getNetwork(); + kj::Own addr = network.parseAddress(argv[1]).wait(waitScope); + kj::Own conn = addr->connect().wait(waitScope); + + // Now we can start the Cap'n Proto RPC system on this connection. + capnp::TwoPartyClient client(*conn); + + // The server exports a "bootstrap" capability implementing the + // `Calculator` interface. + Calculator::Client calculator = client.bootstrap().castAs(); { // Make a request that just evaluates the literal value 123. diff --git a/c++/samples/calculator-server.c++ b/c++/samples/calculator-server.c++ index c2593be3a9..9253aced53 100644 --- a/c++/samples/calculator-server.c++ +++ b/c++/samples/calculator-server.c++ @@ -20,8 +20,9 @@ // THE SOFTWARE. #include "calculator.capnp.h" +#include +#include #include -#include #include #include @@ -196,12 +197,17 @@ int main(int argc, const char* argv[]) { return 1; } - // Set up a server. - capnp::EzRpcServer server(kj::heap(), argv[1]); + // First we need to set up the KJ async event loop. This should happen one + // per thread that needs to perform RPC. + auto io = kj::setupAsyncIo(); + + // Using KJ APIs, let's parse our network address and listen on it. + kj::Network& network = io.provider->getNetwork(); + kj::Own addr = network.parseAddress(argv[1]).wait(io.waitScope); + kj::Own listener = addr->listen(); // Write the port number to stdout, in case it was chosen automatically. - auto& waitScope = server.getWaitScope(); - uint port = server.getPort().wait(waitScope); + uint port = listener->getPort(); if (port == 0) { // The address format "unix:/path/to/socket" opens a unix domain socket, // in which case the port will be zero. @@ -210,6 +216,9 @@ int main(int argc, const char* argv[]) { std::cout << "Listening on port " << port << "..." << std::endl; } + // Start the RPC server. + capnp::TwoPartyServer server(kj::heap()); + // Run forever, accepting connections and handling requests. - kj::NEVER_DONE.wait(waitScope); + server.listen(*listener).wait(io.waitScope); } From 2df3de32f12bc59f770a01aa90d58a2a494cc349 Mon Sep 17 00:00:00 2001 From: Kenton Varda Date: Tue, 26 Dec 2023 15:13:28 -0600 Subject: [PATCH 11/58] Fix InMemoryDirectory deadlock bug. The implementation of `transfer()` would deadlock if the source and destination directories were the same object, since it would attempt to operate on `fromDirectory` while holding the lock on the destination directory. Fixing this required a significant rewrite of this function. --- c++/src/kj/filesystem-test.c++ | 17 +++ c++/src/kj/filesystem.c++ | 259 +++++++++++++++++++++------------ 2 files changed, 180 insertions(+), 96 deletions(-) diff --git a/c++/src/kj/filesystem-test.c++ b/c++/src/kj/filesystem-test.c++ index d1c3bc377f..b1b3166de2 100644 --- a/c++/src/kj/filesystem-test.c++ +++ b/c++/src/kj/filesystem-test.c++ @@ -752,6 +752,23 @@ KJ_TEST("InMemoryDirectory move") { KJ_EXPECT(dst->openFile(Path({"link", "baz", "qux"}))->readAllText() == "bazqux"); } +KJ_TEST("InMemoryDirectory transfer from self") { + TestClock clock; + + auto dir = newInMemoryDirectory(clock); + + auto file = dir->openFile(Path({"foo"}), WriteMode::CREATE); + + dir->transfer(Path({"bar"}), WriteMode::CREATE, Path({"foo"}), TransferMode::MOVE); + + auto list = dir->listNames(); + KJ_EXPECT(list.size() == 1); + KJ_EXPECT(list[0] == "bar"); + + auto file2 = dir->openFile(Path({"bar"})); + KJ_EXPECT(file.get() == file2.get()); +} + KJ_TEST("InMemoryDirectory createTemporary") { TestClock clock; diff --git a/c++/src/kj/filesystem.c++ b/c++/src/kj/filesystem.c++ index d43d5dbf28..f6b954c424 100644 --- a/c++/src/kj/filesystem.c++ +++ b/c++/src/kj/filesystem.c++ @@ -984,6 +984,9 @@ class InMemoryDirectory final: public Directory, public AtomicRefcounted { public: InMemoryDirectory(const Clock& clock, const InMemoryFileFactory& fileFactory) : impl(clock, fileFactory) {} + InMemoryDirectory(const Clock& clock, const InMemoryFileFactory& fileFactory, + const Directory& copyFrom, bool copyFiles) + : impl(clock, fileFactory, copyFrom, copyFiles) {} Own cloneFsNode() const override { return atomicAddRef(*this); @@ -1275,31 +1278,104 @@ public: KJ_FAIL_REQUIRE("can't replace self") { return false; } } } else if (toPath.size() == 1) { - // tryTransferChild() needs to at least know the node type, so do an lstat. - KJ_IF_SOME(meta, fromDirectory.tryLstat(fromPath)) { - auto lock = impl.lockExclusive(); - KJ_IF_SOME(entry, lock->openEntry(toPath[0], toMode)) { - // Make sure if we just cerated a new entry, and we don't successfully transfer to it, we - // remove the entry before returning. - bool needRollback = entry.node == nullptr; - KJ_DEFER(if (needRollback) { lock->entries.erase(toPath[0]); }); - - if (lock->tryTransferChild(entry, meta.type, meta.lastModified, meta.size, - fromDirectory, fromPath, mode)) { - lock->modified(); - needRollback = false; - return true; - } else { - KJ_FAIL_REQUIRE("InMemoryDirectory can't link an inode of this type", fromPath) { - return false; - } - } - } else { + if (!has(toMode, WriteMode::MODIFY)) { + // Replacement is not allowed, so we'll have to check upfront if the target path exists. + // Unfortuantely we have to take a lock and then drop it immediately since we can't keep + // the lock held while accessing `fromDirectory`. + if (impl.lockShared()->tryGetEntry(toPath[0]) != kj::none) { return false; } + } + + OneOf newNode; + FsNode::Metadata meta; + KJ_IF_SOME(m, fromDirectory.tryLstat(fromPath)) { + meta = m; } else { return false; } + + switch (meta.type) { + case FsNode::Type::FILE: { + auto file = KJ_ASSERT_NONNULL( + fromDirectory.tryOpenFile(fromPath, WriteMode::MODIFY), + "source node deleted concurrently during transfer", fromPath); + + if (mode == TransferMode::COPY) { + auto copy = impl.getWithoutLock().newFile(); + copy->copy(0, *file, 0, meta.size); + file = kj::mv(copy); + } + + newNode = FileNode { kj::mv(file) }; + break; + } + case FsNode::Type::DIRECTORY: { + auto subdir = KJ_ASSERT_NONNULL( + fromDirectory.tryOpenSubdir(fromPath, WriteMode::MODIFY), + "source node deleted concurrently during transfer", fromPath); + + switch (mode) { + case TransferMode::COPY: + // Copying is straightforward: Make a deep copy of the entire directory tree, + // including file contents. + subdir = impl.getWithoutLock().copyDirectory(*subdir, /* copyFiles = */ true); + break; + + case TransferMode::LINK: + // To "link", we can safely just place `subdir` directly into our own tree. + break; + + case TransferMode::MOVE: + // Moving may be tricky: + // + // If `fromDirectory` is an `InMemoryDirectory`, then we know that removing the + // subdir just unlinks the object without modifying it, so we can safely just link it + // into our own tree. + // + // However, if `fromDirectory` is a disk directory, then remvoing the subdir will + // likely perform a recursive delete, thus leaving `subdir` pointing to an empty + // directory. If we link that into our tree, it's useless. So, instead, perform a + // deep copy of the directory tree upfront, into an InMemoryDirectory. However, file + // content need not be copied, since unliked files keep their contents until closed. + if (kj::dynamicDowncastIfAvailable(fromDirectory) == + kj::none) { + subdir = impl.getWithoutLock().copyDirectory(*subdir, /* copyFiles = */ false); + } + break; + } + + newNode = DirectoryNode { kj::mv(subdir) }; + break; + } + case FsNode::Type::SYMLINK: { + auto link = KJ_ASSERT_NONNULL(fromDirectory.tryReadlink(fromPath), + "source node deleted concurrently during transfer", fromPath); + + newNode = SymlinkNode {meta.lastModified, kj::mv(link)}; + break; + } + default: + KJ_FAIL_REQUIRE("InMemoryDirectory can't link an inode of this type", fromPath); + } + + if (mode == TransferMode::MOVE) { + KJ_ASSERT(fromDirectory.tryRemove(fromPath), "couldn't move node", fromPath); + } + + // Take the lock to insert the entry into our map. Remember that it's important we do not + // manipulate `fromDirectory` while the lock is held, since it could be the same directory. + { + auto lock = impl.lockExclusive(); + KJ_IF_SOME(targetEntry, lock->openEntry(toPath[0], toMode)) { + targetEntry.init(kj::mv(newNode));; + } else { + return false; + } + lock->modified(); + } + + return true; } else { // TODO(someday): Ideally we wouldn't create parent directories if fromPath doesn't exist. // This requires a different approach to the code here, though. @@ -1460,6 +1536,67 @@ private: Impl(const Clock& clock, const InMemoryFileFactory& fileFactory) : clock(clock), fileFactory(fileFactory), lastModified(clock.now()) {} + Impl(const Clock& clock, const InMemoryFileFactory& fileFactory, + const Directory& copyFrom, bool copyFiles) + : clock(clock), fileFactory(fileFactory), lastModified(clock.now()) { + // Implements copyDirectory() (see below). + for (auto& fromEntry: copyFrom.listEntries()) { + kj::Path filename({kj::mv(fromEntry.name)}); + OneOf newNode; + switch (fromEntry.type) { + case FsNode::Type::FILE: { + KJ_IF_SOME(file, copyFrom.tryOpenFile(filename, WriteMode::MODIFY)) { + if (copyFiles) { + auto copy = newFile(); + copy->copy(0, *file, 0, kj::maxValue); + file = kj::mv(copy); + } + + newNode = FileNode { kj::mv(file) }; + break; + } else { + continue; + } + } + + case FsNode::Type::DIRECTORY: { + KJ_IF_SOME(subdir, copyFrom.tryOpenSubdir(filename, WriteMode::MODIFY)) { + subdir = copyDirectory(*subdir, copyFiles); + newNode = DirectoryNode { kj::mv(subdir) }; + break; + } else { + continue; + } + } + + case FsNode::Type::SYMLINK: { + KJ_IF_SOME(link, copyFrom.tryReadlink(filename)) { + KJ_IF_SOME(metadata, copyFrom.tryLstat(filename)) { + newNode = SymlinkNode { metadata.lastModified, kj::mv(link) }; + break; + } else { + continue; + } + } else { + continue; + } + } + + default: + KJ_LOG(ERROR, "couldn't copy node of type not supported by InMemoryDirectory", + filename); + continue; + } + + KJ_ASSERT(newNode != nullptr); + + EntryImpl entry(kj::mv(filename)[0]); + StringPtr nameRef = entry.name; + entry.init(kj::mv(newNode)); + KJ_ASSERT(entries.insert(std::make_pair(nameRef, kj::mv(entry))).second); + } + } + Own newFile() const { // Consturct a new empty file. Note: This function is expected to work without the lock held. return fileFactory.create(clock); @@ -1470,6 +1607,12 @@ private: return newInMemoryDirectory(clock, fileFactory); } + Own copyDirectory(const Directory& other, bool copyFiles) const { + // Creates an in-memory deep copy of the given directory object. If `copyFiles` is true, then + // file contents are copied too, otherwise they are just linked. + return kj::atomicRefcounted(clock, fileFactory, other, copyFiles); + } + Maybe openEntry(kj::StringPtr name, WriteMode mode) { // TODO(perf): We could avoid a copy if the entry exists, at the expense of a double-lookup // if it doesn't. Maybe a better map implementation will solve everything? @@ -1517,82 +1660,6 @@ private: void modified() { lastModified = clock.now(); } - - bool tryTransferChild(EntryImpl& entry, const FsNode::Type type, kj::Maybe lastModified, - kj::Maybe size, const Directory& fromDirectory, - PathPtr fromPath, TransferMode mode) { - switch (type) { - case FsNode::Type::FILE: - KJ_IF_SOME(file, fromDirectory.tryOpenFile(fromPath, WriteMode::MODIFY)) { - if (mode == TransferMode::COPY) { - auto copy = newFile(); - copy->copy(0, *file, 0, size.orDefault(kj::maxValue)); - entry.set(kj::mv(copy)); - } else { - if (mode == TransferMode::MOVE) { - KJ_ASSERT(fromDirectory.tryRemove(fromPath), "couldn't move node", fromPath) { - return false; - } - } - entry.set(kj::mv(file)); - } - return true; - } else { - KJ_FAIL_ASSERT("source node deleted concurrently during transfer", fromPath) { - return false; - } - } - case FsNode::Type::DIRECTORY: - KJ_IF_SOME(subdir, fromDirectory.tryOpenSubdir(fromPath, WriteMode::MODIFY)) { - if (mode == TransferMode::COPY) { - auto copy = atomicRefcounted(clock, fileFactory); - auto& cpim = copy->impl.getWithoutLock(); // safe because just-created - for (auto& subEntry: subdir->listEntries()) { - EntryImpl newEntry(kj::mv(subEntry.name)); - Path filename(newEntry.name); - if (!cpim.tryTransferChild(newEntry, subEntry.type, kj::none, kj::none, *subdir, - filename, TransferMode::COPY)) { - KJ_LOG(ERROR, "couldn't copy node of type not supported by InMemoryDirectory", - filename); - } else { - StringPtr nameRef = newEntry.name; - cpim.entries.insert(std::make_pair(nameRef, kj::mv(newEntry))); - } - } - entry.set(kj::mv(copy)); - } else { - if (mode == TransferMode::MOVE) { - KJ_ASSERT(fromDirectory.tryRemove(fromPath), "couldn't move node", fromPath) { - return false; - } - } - entry.set(kj::mv(subdir)); - } - return true; - } else { - KJ_FAIL_ASSERT("source node deleted concurrently during transfer", fromPath) { - return false; - } - } - case FsNode::Type::SYMLINK: - KJ_IF_SOME(content, fromDirectory.tryReadlink(fromPath)) { - // Since symlinks are immutable, we can implement LINK the same as COPY. - entry.init(SymlinkNode { lastModified.orDefault(clock.now()), kj::mv(content) }); - if (mode == TransferMode::MOVE) { - KJ_ASSERT(fromDirectory.tryRemove(fromPath), "couldn't move node", fromPath) { - return false; - } - } - return true; - } else { - KJ_FAIL_ASSERT("source node deleted concurrently during transfer", fromPath) { - return false; - } - } - default: - return false; - } - } }; kj::MutexGuarded impl; From 146c8ceca615abb7f6ff3aaa074954758c2c4dc7 Mon Sep 17 00:00:00 2001 From: Kenton Varda Date: Wed, 27 Dec 2023 21:20:06 -0600 Subject: [PATCH 12/58] Cleanup: LLAIOP should not own EventLoop and friends. I'm not sure why I ever designed it this way, but LowLevelAsyncIoProviderImpl has always owned the UnixEventPort, EventLoop, and WaitScope. This means you actually couldn't use the main async I/O implementation while also controlling allocation of these objects, which is silly. --- c++/src/kj/async-io-unix.c++ | 46 ++++++++++++++++++++++++----------- c++/src/kj/async-io-win32.c++ | 45 +++++++++++++++++++++++----------- c++/src/kj/async-io.h | 8 ++++++ 3 files changed, 71 insertions(+), 28 deletions(-) diff --git a/c++/src/kj/async-io-unix.c++ b/c++/src/kj/async-io-unix.c++ index 91a85546ec..dd01a4189e 100644 --- a/c++/src/kj/async-io-unix.c++ +++ b/c++/src/kj/async-io-unix.c++ @@ -1474,10 +1474,8 @@ public: class LowLevelAsyncIoProviderImpl final: public LowLevelAsyncIoProvider { public: - LowLevelAsyncIoProviderImpl() - : eventPort(), eventLoop(eventPort), waitScope(eventLoop) {} - - inline WaitScope& getWaitScope() { return waitScope; } + LowLevelAsyncIoProviderImpl(UnixEventPort& eventPort) + : eventPort(eventPort) {} Own wrapInputFd(int fd, uint flags = 0) override { return heap(eventPort, fd, flags, UnixEventPort::FdObserver::OBSERVE_READ); @@ -1539,12 +1537,8 @@ public: Timer& getTimer() override { return eventPort.getTimer(); } - UnixEventPort& getEventPort() { return eventPort; } - private: - UnixEventPort eventPort; - EventLoop eventLoop; - WaitScope waitScope; + UnixEventPort& eventPort; }; // ======================================================================================= @@ -2031,10 +2025,13 @@ public: auto pipe = lowLevel.wrapSocketFd(fds[0], NEW_FD_FLAGS); auto thread = heap([threadFd,startFunc=kj::mv(startFunc)]() mutable { - LowLevelAsyncIoProviderImpl lowLevel; + UnixEventPort eventPort; + EventLoop eventLoop(eventPort); + WaitScope waitScope(eventLoop); + LowLevelAsyncIoProviderImpl lowLevel(eventPort); auto stream = lowLevel.wrapSocketFd(threadFd, NEW_FD_FLAGS); AsyncIoProviderImpl ioProvider(lowLevel); - startFunc(ioProvider, *stream, lowLevel.getWaitScope()); + startFunc(ioProvider, *stream, waitScope); }); return { kj::mv(thread), kj::mv(pipe) }; @@ -2053,11 +2050,32 @@ Own newAsyncIoProvider(LowLevelAsyncIoProvider& lowLevel) { return kj::heap(lowLevel); } +Own newLowLevelAsyncIoProvider(UnixEventPort& eventPort) { + return kj::heap(eventPort); +} + AsyncIoContext setupAsyncIo() { - auto lowLevel = heap(); + struct BasicContext { + UnixEventPort eventPort; + EventLoop eventLoop; + WaitScope waitScope; + + BasicContext(): eventLoop(eventPort), waitScope(eventLoop) {} + }; + + auto basicContext = heap(); + auto lowLevel = heap(basicContext->eventPort); auto ioProvider = kj::heap(*lowLevel); - auto& waitScope = lowLevel->getWaitScope(); - auto& eventPort = lowLevel->getEventPort(); + auto& waitScope = basicContext->waitScope; + auto& eventPort = basicContext->eventPort; + + // Historically, `LowLevelAsyncIoProviderImpl` contained the stuff that `BasicContext` now + // contains. However, this made it impossible to create more elaborate EventLoop arrangements + // while still using the default LLAIOP implementation. For backwards-compatibility, + // `setupAsyncIo()` still attaches this context to the LLAIOP, but it's now possible to construct + // these objects directly and LLAIOP on top. + lowLevel = lowLevel.attach(kj::mv(basicContext)); + return { kj::mv(lowLevel), kj::mv(ioProvider), waitScope, eventPort }; } diff --git a/c++/src/kj/async-io-win32.c++ b/c++/src/kj/async-io-win32.c++ index 87a70aea5b..9333670885 100644 --- a/c++/src/kj/async-io-win32.c++ +++ b/c++/src/kj/async-io-win32.c++ @@ -919,10 +919,7 @@ public: class LowLevelAsyncIoProviderImpl final: public LowLevelAsyncIoProvider { public: - LowLevelAsyncIoProviderImpl() - : eventLoop(eventPort), waitScope(eventLoop) {} - - inline WaitScope& getWaitScope() { return waitScope; } + LowLevelAsyncIoProviderImpl(Win32EventPort& eventPort): eventPort(eventPort) {} Own wrapInputFd(SOCKET fd, uint flags = 0) override { return heap(eventPort, fd, flags); @@ -952,12 +949,8 @@ public: Timer& getTimer() override { return eventPort.getTimer(); } - Win32EventPort& getEventPort() { return eventPort; } - private: - Win32IocpEventPort eventPort; - EventLoop eventLoop; - WaitScope waitScope; + Win32EventPort& eventPort; }; // ======================================================================================= @@ -1153,10 +1146,13 @@ public: auto pipe = lowLevel.wrapSocketFd(fds[0], NEW_FD_FLAGS); auto thread = heap([threadFd,startFunc=kj::mv(startFunc)]() mutable { - LowLevelAsyncIoProviderImpl lowLevel; + Win32IocpEventPort eventPort; + EventLoop eventLoop(eventPort); + WaitScope waitScope(eventLoop); + LowLevelAsyncIoProviderImpl lowLevel(eventPort); auto stream = lowLevel.wrapSocketFd(threadFd, NEW_FD_FLAGS); AsyncIoProviderImpl ioProvider(lowLevel); - startFunc(ioProvider, *stream, lowLevel.getWaitScope()); + startFunc(ioProvider, *stream, waitScope); }); return { kj::mv(thread), kj::mv(pipe) }; @@ -1175,13 +1171,34 @@ Own newAsyncIoProvider(LowLevelAsyncIoProvider& lowLevel) { return kj::heap(lowLevel); } +Own newLowLevelAsyncIoProvider(Win32EventPort& eventPort) { + return kj::heap(eventPort); +} + AsyncIoContext setupAsyncIo() { _::initWinsockOnce(); - auto lowLevel = heap(); + struct BasicContext { + Win32IocpEventPort eventPort; + EventLoop eventLoop; + WaitScope waitScope; + + BasicContext(): eventLoop(eventPort), waitScope(eventLoop) {} + }; + + auto basicContext = heap(); + auto lowLevel = heap(basicContext->eventPort); auto ioProvider = kj::heap(*lowLevel); - auto& waitScope = lowLevel->getWaitScope(); - auto& eventPort = lowLevel->getEventPort(); + auto& waitScope = basicContext->waitScope; + auto& eventPort = basicContext->eventPort; + + // Historically, `LowLevelAsyncIoProviderImpl` contained the stuff that `BasicContext` now + // contains. However, this made it impossible to create more elaborate EventLoop arrangements + // while still using the default LLAIOP implementation. For backwards-compatibility, + // `setupAsyncIo()` still attaches this context to the LLAIOP, but it's now possible to construct + // these objects directly and LLAIOP on top. + lowLevel = lowLevel.attach(kj::mv(basicContext)); + return { kj::mv(lowLevel), kj::mv(ioProvider), waitScope, eventPort }; } diff --git a/c++/src/kj/async-io.h b/c++/src/kj/async-io.h index 411937668a..ceb412da02 100644 --- a/c++/src/kj/async-io.h +++ b/c++/src/kj/async-io.h @@ -903,6 +903,14 @@ class LowLevelAsyncIoProvider { Own newAsyncIoProvider(LowLevelAsyncIoProvider& lowLevel); // Make a new AsyncIoProvider wrapping a `LowLevelAsyncIoProvider`. +#if _WIN32 +Own newLowLevelAsyncIoProvider(Win32EventPort& eventPort); +// Make a new `LowLevelAsyncIoProvider` backed by a `Win32EventPort`. +#else +Own newLowLevelAsyncIoProvider(UnixEventPort& eventPort); +// Make a new `LowLevelAsyncIoProvider` backed by a `UnixEventPort`. +#endif + struct AsyncIoContext { Own lowLevelProvider; Own provider; From 2614e25790acf4d149c43d0c31297febd971a7e2 Mon Sep 17 00:00:00 2001 From: Kenton Varda Date: Thu, 28 Dec 2023 15:14:14 -0600 Subject: [PATCH 13/58] Extend UnixEventPort on Linux to allow external polling. This makes it possible to use UnixEventPort inside a thread running some other event loop, by arranging for that other event loop to listen for the UnixEventPort's epoll FD becoming ready. Whenever it does, the application must pump the KJ event loop. This also could allow multiple event loops to run in the same thread, or even across some thread pool, as long as only one event loop is active in a thread a time, and each event loop is active in no more than one thread at a time. A scheduler could schedule event loops to threads when their epoll FD becomes ready. --- c++/src/kj/async-unix-test.c++ | 127 +++++++++++++++++++++++++++++++++ c++/src/kj/async-unix.c++ | 107 +++++++++++++++++++++++++++ c++/src/kj/async-unix.h | 70 +++++++++++++++++- c++/src/kj/timer.c++ | 51 ++++++++++--- c++/src/kj/timer.h | 22 +++++- 5 files changed, 364 insertions(+), 13 deletions(-) diff --git a/c++/src/kj/async-unix-test.c++ b/c++/src/kj/async-unix-test.c++ index 448bc24fb4..08565ca0ba 100644 --- a/c++/src/kj/async-unix-test.c++ +++ b/c++/src/kj/async-unix-test.c++ @@ -39,6 +39,7 @@ #include #include #include "mutex.h" +#include #if KJ_USE_EPOLL #include @@ -1119,6 +1120,132 @@ KJ_TEST("UnixEventPort thread-specific signals") { } #endif +#if KJ_USE_EPOLL +KJ_TEST("UnixEventPoll::getFd() for external waiting") { + kj::UnixEventPort port; + kj::EventLoop loop(port); + kj::WaitScope ws(loop); + + auto portIsReady = [&port](int timout = 0) { + struct pollfd pfd; + memset(&pfd, 0, sizeof(pfd)); + pfd.events = POLLIN; + pfd.fd = port.getPollableFd(); + + int n; + KJ_SYSCALL(n = poll(&pfd, 1, timout)); + return n > 0; + }; + + // Test wakeup on observed FD. + { + int pair[2]; + KJ_SYSCALL(pipe(pair)); + kj::AutoCloseFd in(pair[0]); + kj::AutoCloseFd out(pair[1]); + + kj::UnixEventPort::FdObserver observer(port, in, kj::UnixEventPort::FdObserver::OBSERVE_READ); + auto promise = observer.whenBecomesReadable(); + + KJ_EXPECT(!promise.poll(ws)); + ws.poll(); + port.preparePollableFdForSleep(); + + KJ_EXPECT(!portIsReady()); + + KJ_SYSCALL(write(out, "a", 1)); + + KJ_EXPECT(portIsReady()); + + KJ_ASSERT(promise.poll(ws)); + promise.wait(ws); + } + + // Test wakeup due to queuing work to the event loop in-process. + { + ws.poll(); + port.preparePollableFdForSleep(); + + KJ_EXPECT(!portIsReady()); + + auto promise = kj::evalLater([]() {}).eagerlyEvaluate(nullptr); + + KJ_EXPECT(portIsReady()); + KJ_ASSERT(promise.poll(ws)); + promise.wait(ws); + } + + // Test wakeup on timeout. + { + auto promise = port.getTimer().afterDelay(50 * kj::MILLISECONDS); + + KJ_EXPECT(!promise.poll(ws)); + ws.poll(); + port.preparePollableFdForSleep(); + + KJ_EXPECT(!portIsReady()); + + usleep(50'000); + + KJ_EXPECT(portIsReady()); + + KJ_ASSERT(promise.poll(ws)); + promise.wait(ws); + } + + // Test wakeup on time in past. This verifies timerfd_settime() won't just silently fail if the + // time is already past. + { + ws.poll(); + + // Schedule time event in the past. + auto promise = port.getTimer().atTime(kj::origin() + 1 * SECONDS); + + // As of this writing, atTime() doesn't do any special handling of times in the past, e.g. to + // immediately resolve the promise. It goes ahead and schedules them like any other I/O. So + // scheduling such a promise will not immediately schedule work on the event loop, and + // preparePollableFdForSleep() will in fact go and timerfd_settime() to a time in the past. (If + // this changes, we'll need to structure this test differently I guess.) + KJ_EXPECT(!loop.isRunnable()); + + port.preparePollableFdForSleep(); + + // Uhhhh... Apparently when timerfd_settime() sets a time in the past, the timerfd does NOT + // immediately become readable. The kernel still needs to process the timer in the background + // before it raises the event. So we will need to give it some time... we give it 10ms here. + KJ_EXPECT(portIsReady(10)); + + KJ_ASSERT(promise.poll(ws)); + promise.wait(ws); + } + + // Test wakeup when a timer event is created during sleep. + { + ws.poll(); + auto startTime = port.getTimer().now(); + port.preparePollableFdForSleep(); + + KJ_EXPECT(!portIsReady()); + + // When sleeping, passage of real time updates `timer.now()`. + usleep(50'000); + KJ_EXPECT(port.getTimer().now() - startTime >= 50 * MILLISECONDS); + + // We can set a timer now, and the epoll FD will wake up when it expires, even though no timer + // was set when `preparePollableFdForSleep()` was called. + auto promise = port.getTimer().afterDelay(50 * MILLISECONDS); + + // It won't expire too early: the delay was added to the real time, not the last time the + // timer was advanced to. + KJ_EXPECT(!portIsReady(10)); + KJ_EXPECT(portIsReady(40)); + + KJ_ASSERT(promise.poll(ws)); + promise.wait(ws); + } +} +#endif + } // namespace } // namespace kj diff --git a/c++/src/kj/async-unix.c++ b/c++/src/kj/async-unix.c++ index 6de12479ec..77b437780d 100644 --- a/c++/src/kj/async-unix.c++ +++ b/c++/src/kj/async-unix.c++ @@ -36,6 +36,7 @@ #if KJ_USE_EPOLL #include #include +#include #elif KJ_USE_KQUEUE #include #include @@ -542,6 +543,8 @@ void UnixEventPort::wake() const { } bool UnixEventPort::wait() { + sleeping = false; + #ifdef KJ_DEBUG // In debug mode, verify the current signal mask matches the original. { @@ -647,6 +650,19 @@ bool UnixEventPort::processEpollEvents(struct epoll_event events[], int n) { // We were woken. Need to return true. woken = true; + } else if (events[i].data.u64 == 1) { + // timerfd fired. We need to clear it by reading it. + int tfd = KJ_ASSERT_NONNULL(timerFd).get(); + char buffer[16]; + ssize_t n; + KJ_NONBLOCKING_SYSCALL(n = read(tfd, buffer, sizeof(buffer))); + KJ_ASSERT(n == 8 || n < 0); + + timerfdIsArmed = false; + + // The purpose of this event is just to wake up the event loop when needed. We'll check the + // timer queue separately, so we don't need to do anything special it response to this event + // here. } else { FdObserver* observer = reinterpret_cast(events[i].data.ptr); observer->fire(events[i].events); @@ -663,6 +679,8 @@ bool UnixEventPort::poll() { // pending signals. Therefore, we need a completely different approach to poll for signals. We // might as well use regular epoll_wait() in this case, too, to save the kernel some effort. + sleeping = false; + if (signalHead != nullptr || childSet != kj::none) { // Use sigtimedwait() to poll for signals. @@ -712,6 +730,85 @@ bool UnixEventPort::poll() { return processEpollEvents(events, n); } +int UnixEventPort::getPollableFd() { + return epollFd.get(); +} + +void UnixEventPort::preparePollableFdForSleep() { + KJ_ASSERT(signalHead == nullptr, + "preparePollableFdForSleep() cannot be used when waiting for signals"); + + if (runnable) { + // There is still immediate work in the queue, so force the epoll to be ready immediately. (See + // comments in setRunnable() regarding using wake() for this.) + wake(); + } else { + updateNextTimerEvent(timerImpl.nextEvent()); + + // Flag that we're sleeping, so setRunnable() knows to use wake() if needed. + sleeping = true; + + // Tell the timer we're sleeping, so that it notifies us if anyone tries to create a new time + // event. + timerImpl.setSleeping(*this); + } +} + +void UnixEventPort::setRunnable(bool runnable) { + this->runnable = runnable; + if (runnable && sleeping) { + // A meta event loop is waiting for the epoll to become readable, and an event was queued + // directly to the event loop in the meantime (e.g. due to a promise fulfiller being + // fulfilled). So, we need to cause the epoll to signal readability, to wake the event loop. We + // can use the cross-thread wake mechanism for this. It'll cause the event loop to spuriously + // check for cross-thread events, incurring a mutex lock, but that's not a big deal, and that's + // much better than creating a whole second event FD for this purpose. + wake(); + sleeping = false; + } +} + +void UnixEventPort::updateNextTimerEvent(kj::Maybe time) { + if (time == kj::none && !timerfdIsArmed) { + // No change needed. + return; + } + + // Create the timerfd if needed. + int tfd; + KJ_IF_SOME(f, timerFd) { + tfd = f; + } else { + KJ_SYSCALL(tfd = timerfd_create(CLOCK_MONOTONIC, TFD_CLOEXEC | TFD_NONBLOCK)); + timerFd = kj::AutoCloseFd(tfd); + + struct epoll_event event; + memset(&event, 0, sizeof(event)); + event.events = EPOLLIN; + event.data.u64 = 1; + KJ_SYSCALL(epoll_ctl(epollFd, EPOLL_CTL_ADD, tfd, &event)); + } + + // Update timerfd's expiration time. + struct itimerspec ts; + memset(&ts, 0, sizeof(ts)); + KJ_IF_SOME(t, time) { + auto t2 = t - origin(); + ts.it_value.tv_sec = t2 / SECONDS; + ts.it_value.tv_nsec = (t2 % SECONDS) / NANOSECONDS; + timerfdIsArmed = true; + } else { + // setting the time to zero will disarm it + timerfdIsArmed = false; + } + + KJ_SYSCALL(timerfd_settime(tfd, TFD_TIMER_ABSTIME, &ts, nullptr)); +} + +kj::TimePoint UnixEventPort::getTimeWhileSleeping() { + return clock.now(); +} + #elif KJ_USE_KQUEUE // ======================================================================================= // kqueue FdObserver implementation @@ -1583,6 +1680,16 @@ void UnixEventPort::wake() const { #endif // KJ_USE_EPOLL, else KJ_USE_KQUEUE, else +#if !KJ_USE_EPOLL +void UnixEventPort::updateNextTimerEvent(kj::Maybe time) { + KJ_UNIMPLEMENTED("SleepHooks not used on this platform, this shouldn't be called"); +} + +kj::TimePoint UnixEventPort::getTimeWhileSleeping() { + KJ_UNIMPLEMENTED("SleepHooks not used on this platform, this shouldn't be called"); +} +#endif + } // namespace kj #endif // !_WIN32 diff --git a/c++/src/kj/async-unix.h b/c++/src/kj/async-unix.h index 665305ea70..7176a7eefc 100644 --- a/c++/src/kj/async-unix.h +++ b/c++/src/kj/async-unix.h @@ -64,7 +64,7 @@ struct timespec; namespace kj { -class UnixEventPort: public EventPort { +class UnixEventPort: public EventPort, private TimerImpl::SleepHooks { // An EventPort implementation which can wait for events on file descriptors as well as signals. // This API only makes sense on Unix. // @@ -178,11 +178,70 @@ class UnixEventPort: public EventPort { // This method may capture the `SIGCHLD` signal. You must not use `captureSignal(SIGCHLD)` nor // `onSignal(SIGCHLD)` in your own code if you use `captureChildExit()`. +#if KJ_USE_EPOLL + int getPollableFd(); + // Get a file descriptor which represents the EventPort's backing OS event queue, and becomes + // readable when there are events to process. This may be an epoll FD or a kqueue FD depending on + // the OS. + // + // You MUST use preparePollableFdForSleep() before waiting on this FD. + // + // The caller should not perform operations on this FD. It should only be used for polling in + // some other event loop. This can be useful for allowing UnixEventPort to operate embedded in + // an application that uses some other event loop as its main loop. Whenever this FD becomes + // readable, call waitScope.poll() to handle all available events, then + // `preparePollableFdForSleep()`. + // + // It's also possible to use this to drive multiple event loops from the same thread, or even + // multiple threads in an m:n mapping. When an event loop becomes ready, create a temporary + // WaitScope on the thread where you want to run it, pump the loop, and then destroy the + // WaitScope. A WaitScope binds a loop to a thread while it exists, but an event loop is allowed + // to change threads and a thread is allowed to change event loops as long as no WaitScope is + // currently binding them. + // + // TODO(someday): Currently this is only implemented for epoll, NOT for kqueue. But in principle + // it should be possible to implement for kqueue as well. + + void preparePollableFdForSleep(); + // If you plan to monitor the FD return by getFd() for notifications that this queue is ready, + // you must call preparePollableFdForSleep() after each run of this port's event loop in order to + // ensure that all event types will in fact wake up the queue. + // + // This call is needed in particular to arrange for timer events. Normally, timer events are not + // implemented via the epoll/kqueue at all, but instead the call to epoll_wait() / kevent() is + // given a timeout that causes it to return early when the next timer event is scheduled. This + // doesn't work when the wait is being performed on an external event loop, so instead the + // implementation must arrange to use timerfd (for epoll) of EVFILT_TIMER (for kqueue) so that + // the queue FD actually becomes ready when the next timer event is ready to run. This requires + // some extra syscalls, unfortunately. + // + // A second reason this call is needed is to wake up the event queue if any work is queued + // directly to the event loop via function calls rather than OS events. For example, if anyone + // calls `fulfill()` on a `PromiseFulfiller`, thus queuing work to this event loop, it needs + // to wake up. The EventPort achieves this by implicitly calling `wake()` if any work is queued + // while sleeping -- similar to a cross-thread wakup. NOTE: Queuing work like this while the + // event loop is asleep is only safe if the EventLoop is still bound to the thread by the + // existence of a WaitScope. If you are destroying the WaitScope while sleeping, then you cannot + // manipulate promises attached to this loop at all while it is asleep. (Cross-thread events, + // e.g. queued via newPromiseAndCrossThreadFulfiller(), or via kj::Executor, are always safe. + // These will cause the queue fd to become ready so that the loop can wake up and respond to the + // events.) + // + // This method will throw an exception if the port is currently waiting on any signals via + // onSignal() or onChildExit(). Although it might theoretically be possible to support signals in + // this mode, deep flaws in the relevant APIs across multiple OSs make it likely not worth + // attempting. +#endif + // implements EventPort ------------------------------------------------------ bool wait() override; bool poll() override; void wake() const override; +#if KJ_USE_EPOLL + void setRunnable(bool runnable) override; +#endif + private: class SignalPromiseAdapter; class ChildExitPromiseAdapter; @@ -203,6 +262,11 @@ class UnixEventPort: public EventPort { sigset_t originalMask; AutoCloseFd epollFd; AutoCloseFd eventFd; // Used for cross-thread wakeups. + kj::Maybe timerFd; // Used if preparePollableFdForSleep() is ever called. + + bool sleeping = false; // Was preparePollableFdForSleep() called? + bool runnable = false; // Last value passed to setRunnable(). + bool timerfdIsArmed = false; bool processEpollEvents(struct epoll_event events[], int n); #elif KJ_USE_KQUEUE @@ -234,6 +298,10 @@ class UnixEventPort: public EventPort { static void registerReservedSignal(); #endif static void ignoreSigpipe(); + + // Implements TimerImpl::SleepHooks. + void updateNextTimerEvent(kj::Maybe time) override; + kj::TimePoint getTimeWhileSleeping() override; }; class UnixEventPort::FdObserver: private AsyncObject { diff --git a/c++/src/kj/timer.c++ b/c++/src/kj/timer.c++ index 5aec7a9a6f..9ce5a3b93f 100644 --- a/c++/src/kj/timer.c++ +++ b/c++/src/kj/timer.c++ @@ -40,28 +40,48 @@ struct TimerImpl::Impl { class TimerImpl::TimerPromiseAdapter { public: - TimerPromiseAdapter(PromiseFulfiller& fulfiller, TimerImpl::Impl& impl, TimePoint time) - : time(time), fulfiller(fulfiller), impl(impl) { - pos = impl.timers.insert(this); + TimerPromiseAdapter(PromiseFulfiller& fulfiller, TimerImpl& parent, TimePoint time) + : time(time), fulfiller(fulfiller), parent(parent) { + pos = parent.impl->timers.insert(this); + + KJ_IF_SOME(h, parent.sleepHooks) { + if (pos == parent.impl->timers.begin()) { + h.updateNextTimerEvent(time); + } + } } ~TimerPromiseAdapter() { - if (pos != impl.timers.end()) { - impl.timers.erase(pos); + if (pos != parent.impl->timers.end()) { + KJ_IF_SOME(h, parent.sleepHooks) { + bool isFirst = pos == parent.impl->timers.begin(); + + parent.impl->timers.erase(pos); + + if (isFirst) { + if (parent.impl->timers.empty()) { + h.updateNextTimerEvent(kj::none); + } else { + h.updateNextTimerEvent((*parent.impl->timers.begin())->time); + } + } + } else { + parent.impl->timers.erase(pos); + } } } void fulfill() { fulfiller.fulfill(); - impl.timers.erase(pos); - pos = impl.timers.end(); + parent.impl->timers.erase(pos); + pos = parent.impl->timers.end(); } const TimePoint time; private: PromiseFulfiller& fulfiller; - TimerImpl::Impl& impl; + TimerImpl& parent; Impl::Timers::const_iterator pos; }; @@ -70,12 +90,21 @@ inline bool TimerImpl::Impl::TimerBefore::operator()( return lhs->time < rhs->time; } +TimePoint TimerImpl::now() const { + KJ_IF_SOME(h, sleepHooks) { + return h.getTimeWhileSleeping(); + } else { + return time; + } +} + Promise TimerImpl::atTime(TimePoint time) { - return newAdaptedPromise(*impl, time); + auto result = newAdaptedPromise(*this, time); + return result; } Promise TimerImpl::afterDelay(Duration delay) { - return newAdaptedPromise(*impl, time + delay); + return newAdaptedPromise(*this, now() + delay); } TimerImpl::TimerImpl(TimePoint startTime) @@ -110,6 +139,8 @@ Maybe TimerImpl::timeoutToNextEvent(TimePoint start, Duration unit, ui } void TimerImpl::advanceTo(TimePoint newTime) { + sleepHooks = nullptr; + // On Macs running an Intel processor, it has been observed that clock_gettime // may return non monotonic time, even when CLOCK_MONOTONIC is used. // This workaround is to avoid the assert triggering on these machines. diff --git a/c++/src/kj/timer.h b/c++/src/kj/timer.h index eb9443c23b..89f04df275 100644 --- a/c++/src/kj/timer.h +++ b/c++/src/kj/timer.h @@ -108,6 +108,25 @@ class TimerImpl final: public Timer { void advanceTo(TimePoint newTime); // Set the time to `time` and fire any at() events that have been passed. + class SleepHooks { + public: + virtual void updateNextTimerEvent(kj::Maybe time) = 0; + // Called whenever the value returned by `nextEvent()` changes. + + virtual kj::TimePoint getTimeWhileSleeping() = 0; + // Get the current time. While sleeping, we can't lock time in place and advance it on each + // poll of the event queue, because aribtrary time might have passed outside the control of + // the KJ event loop. + }; + + void setSleeping(SleepHooks& hooks) { sleepHooks = hooks; } + // Hooks needed by UnixEventPort::preparePollableFdForSleep(). When the loop is sleeping, we + // would like for the application to be able to invoke the kj::Timer and for it to basically work + // correctly. This requires that we make some callbacks to the UnixEventPort to keep things + // consistent, since we can't assume the UnixEventPort will be actively polling the TimerImpl. + // + // The sleep hooks are automatically cleared when advanceTo() is next called. + // implements Timer ---------------------------------------------------------- TimePoint now() const override; Promise atTime(TimePoint time) override; @@ -118,6 +137,7 @@ class TimerImpl final: public Timer { class TimerPromiseAdapter; TimePoint time; Own impl; + kj::Maybe sleepHooks; }; // ======================================================================================= @@ -137,8 +157,6 @@ Promise Timer::timeoutAfter(Duration delay, Promise&& promise) { })); } -inline TimePoint TimerImpl::now() const { return time; } - } // namespace kj KJ_END_HEADER From b71827d53c6bad80155b2ba691af4b35e1729c95 Mon Sep 17 00:00:00 2001 From: Kenton Varda Date: Thu, 28 Dec 2023 15:35:21 -0600 Subject: [PATCH 14/58] Add test demonstrating EventLoops changing threads. Also, threads changing EventLoops. --- c++/src/kj/async-unix-test.c++ | 56 ++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/c++/src/kj/async-unix-test.c++ b/c++/src/kj/async-unix-test.c++ index 08565ca0ba..bdb5a6d2b2 100644 --- a/c++/src/kj/async-unix-test.c++ +++ b/c++/src/kj/async-unix-test.c++ @@ -1244,6 +1244,62 @@ KJ_TEST("UnixEventPoll::getFd() for external waiting") { promise.wait(ws); } } + +KJ_TEST("m:n threads:EventLoops") { + // This test shows that it's possible for an EventLoop to switch threads, and for a thread to + // switch event loops. + + UnixEventPort port1; + EventLoop loop1(port1); + + UnixEventPort port2; + EventLoop loop2(port2); + + kj::TimePoint startTime = kj::origin(); + kj::Promise promise1 = nullptr; + PromiseCrossThreadFulfillerPair xpaf { nullptr, {} }; + const Executor* executor; + + { + WaitScope ws1(loop1); + ws1.poll(); + startTime = port1.getTimer().now(); + promise1 = port1.getTimer().afterDelay(10 * kj::MILLISECONDS); + xpaf = kj::newPromiseAndCrossThreadFulfiller(); + executor = &getCurrentThreadExecutor(); + } + + static thread_local uint threadId = 0; + + threadId = 1; + bool executorDone = false; + + kj::Thread thread([&]() noexcept { + threadId = 2; + + WaitScope ws1(loop1); + promise1.wait(ws1); + KJ_EXPECT(port1.getTimer().now() - startTime >= 10 * kj::MILLISECONDS); + KJ_EXPECT(executorDone); + + xpaf.promise.wait(ws1); + }); + + [&]() noexcept { + WaitScope ws2(loop2); + + // The `executor` we captured earlier is tied to loop1, which has changed threads, so code we + // schedule on it will run there. + uint remoteThreadId = executor->executeAsync([&]() { + return threadId; + }).wait(ws2); + executorDone = true; + KJ_EXPECT(remoteThreadId == 2); + KJ_EXPECT(threadId == 1); + + xpaf.fulfiller->fulfill(); + }(); +} #endif } // namespace From 5e4bbd9800c221716d40d5d3b7c537c5c0e564b9 Mon Sep 17 00:00:00 2001 From: Justin Mazzola Paluska Date: Thu, 4 Jan 2024 08:59:19 -0500 Subject: [PATCH 15/58] Fix typos in filesystem files I found some of these while reading https://github.com/capnproto/capnproto/pull/1889/files and ispell found the rest. --- c++/src/kj/filesystem.c++ | 12 ++++++------ c++/src/kj/filesystem.h | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/c++/src/kj/filesystem.c++ b/c++/src/kj/filesystem.c++ index f6b954c424..c66187c12c 100644 --- a/c++/src/kj/filesystem.c++ +++ b/c++/src/kj/filesystem.c++ @@ -275,7 +275,7 @@ String PathPtr::toWin32StringImpl(bool absolute, bool forApi) const { // False alarm: this is the drive letter. } else { KJ_FAIL_REQUIRE( - "colons are prohibited in win32 paths to avoid triggering alterante data streams", + "colons are prohibited in win32 paths to avoid triggering alternate data streams", result) { // Recover by using a different character which we know Win32 syscalls will reject. result[i] = '|'; @@ -1280,7 +1280,7 @@ public: } else if (toPath.size() == 1) { if (!has(toMode, WriteMode::MODIFY)) { // Replacement is not allowed, so we'll have to check upfront if the target path exists. - // Unfortuantely we have to take a lock and then drop it immediately since we can't keep + // Unfortunately we have to take a lock and then drop it immediately since we can't keep // the lock held while accessing `fromDirectory`. if (impl.lockShared()->tryGetEntry(toPath[0]) != kj::none) { return false; @@ -1333,11 +1333,11 @@ public: // subdir just unlinks the object without modifying it, so we can safely just link it // into our own tree. // - // However, if `fromDirectory` is a disk directory, then remvoing the subdir will + // However, if `fromDirectory` is a disk directory, then removing the subdir will // likely perform a recursive delete, thus leaving `subdir` pointing to an empty // directory. If we link that into our tree, it's useless. So, instead, perform a // deep copy of the directory tree upfront, into an InMemoryDirectory. However, file - // content need not be copied, since unliked files keep their contents until closed. + // content need not be copied, since unlinked files keep their contents until closed. if (kj::dynamicDowncastIfAvailable(fromDirectory) == kj::none) { subdir = impl.getWithoutLock().copyDirectory(*subdir, /* copyFiles = */ false); @@ -1598,11 +1598,11 @@ private: } Own newFile() const { - // Consturct a new empty file. Note: This function is expected to work without the lock held. + // Construct a new empty file. Note: This function is expected to work without the lock held. return fileFactory.create(clock); } Own newDirectory() const { - // Consturct a new empty directory. Note: This function is expected to work without the lock + // Construct a new empty directory. Note: This function is expected to work without the lock // held. return newInMemoryDirectory(clock, fileFactory); } diff --git a/c++/src/kj/filesystem.h b/c++/src/kj/filesystem.h index 1c5ae28389..28c0ae1908 100644 --- a/c++/src/kj/filesystem.h +++ b/c++/src/kj/filesystem.h @@ -796,7 +796,7 @@ class Directory: public ReadableDirectory { virtual Own> replaceFile(PathPtr path, WriteMode mode) const = 0; // Construct a file which, when ready, will be atomically moved to `path`, replacing whatever - // is there already. See `Replacer` for detalis. + // is there already. See `Replacer` for details. // // The `CREATE` and `MODIFY` bits of `mode` are not enforced until commit time, hence // `replaceFile()` has no "try" variant. @@ -819,7 +819,7 @@ class Directory: public ReadableDirectory { virtual Own> replaceSubdir(PathPtr path, WriteMode mode) const = 0; // Construct a directory which, when ready, will be atomically moved to `path`, replacing - // whatever is there already. See `Replacer` for detalis. + // whatever is there already. See `Replacer` for details. // // The `CREATE` and `MODIFY` bits of `mode` are not enforced until commit time, hence // `replaceSubdir()` has no "try" variant. From 33c0d2c72b79b40ee33be4a53c6032c78583ae9d Mon Sep 17 00:00:00 2001 From: Justin Mazzola Paluska Date: Thu, 4 Jan 2024 08:59:56 -0500 Subject: [PATCH 16/58] Fix weird spacing in InMemoryDirectory constructor declaration I noticed this while reading https://github.com/capnproto/capnproto/pull/1889/files. --- c++/src/kj/filesystem.c++ | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/c++/src/kj/filesystem.c++ b/c++/src/kj/filesystem.c++ index c66187c12c..367895e16d 100644 --- a/c++/src/kj/filesystem.c++ +++ b/c++/src/kj/filesystem.c++ @@ -985,7 +985,7 @@ public: InMemoryDirectory(const Clock& clock, const InMemoryFileFactory& fileFactory) : impl(clock, fileFactory) {} InMemoryDirectory(const Clock& clock, const InMemoryFileFactory& fileFactory, - const Directory& copyFrom, bool copyFiles) + const Directory& copyFrom, bool copyFiles) : impl(clock, fileFactory, copyFrom, copyFiles) {} Own cloneFsNode() const override { From 6d8f6dd513a3b58c4fbfb065dc1925c9d7c9b9eb Mon Sep 17 00:00:00 2001 From: Kenton Varda Date: Thu, 11 Jan 2024 11:28:44 -0600 Subject: [PATCH 17/58] Fix typos Co-authored-by: Harris Hancock --- c++/src/kj/async-unix-test.c++ | 2 +- c++/src/kj/async-unix.c++ | 2 +- c++/src/kj/async-unix.h | 2 +- c++/src/kj/timer.h | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/c++/src/kj/async-unix-test.c++ b/c++/src/kj/async-unix-test.c++ index bdb5a6d2b2..bb0cc26dda 100644 --- a/c++/src/kj/async-unix-test.c++ +++ b/c++/src/kj/async-unix-test.c++ @@ -1121,7 +1121,7 @@ KJ_TEST("UnixEventPort thread-specific signals") { #endif #if KJ_USE_EPOLL -KJ_TEST("UnixEventPoll::getFd() for external waiting") { +KJ_TEST("UnixEventPoll::getPollableFd() for external waiting") { kj::UnixEventPort port; kj::EventLoop loop(port); kj::WaitScope ws(loop); diff --git a/c++/src/kj/async-unix.c++ b/c++/src/kj/async-unix.c++ index 77b437780d..05553004dc 100644 --- a/c++/src/kj/async-unix.c++ +++ b/c++/src/kj/async-unix.c++ @@ -661,7 +661,7 @@ bool UnixEventPort::processEpollEvents(struct epoll_event events[], int n) { timerfdIsArmed = false; // The purpose of this event is just to wake up the event loop when needed. We'll check the - // timer queue separately, so we don't need to do anything special it response to this event + // timer queue separately, so we don't need to do anything special in response to this event // here. } else { FdObserver* observer = reinterpret_cast(events[i].data.ptr); diff --git a/c++/src/kj/async-unix.h b/c++/src/kj/async-unix.h index 7176a7eefc..f0fbe51202 100644 --- a/c++/src/kj/async-unix.h +++ b/c++/src/kj/async-unix.h @@ -203,7 +203,7 @@ class UnixEventPort: public EventPort, private TimerImpl::SleepHooks { // it should be possible to implement for kqueue as well. void preparePollableFdForSleep(); - // If you plan to monitor the FD return by getFd() for notifications that this queue is ready, + // If you plan to monitor the FD return by getPollableFd() for notifications that this queue is ready, // you must call preparePollableFdForSleep() after each run of this port's event loop in order to // ensure that all event types will in fact wake up the queue. // diff --git a/c++/src/kj/timer.h b/c++/src/kj/timer.h index 89f04df275..2ef0c8fc42 100644 --- a/c++/src/kj/timer.h +++ b/c++/src/kj/timer.h @@ -115,7 +115,7 @@ class TimerImpl final: public Timer { virtual kj::TimePoint getTimeWhileSleeping() = 0; // Get the current time. While sleeping, we can't lock time in place and advance it on each - // poll of the event queue, because aribtrary time might have passed outside the control of + // poll of the event queue, because arbitrary time might have passed outside the control of // the KJ event loop. }; From 8c346817ceded8071982e29bafe64d3736266228 Mon Sep 17 00:00:00 2001 From: Samuel Merritt Date: Wed, 20 Dec 2023 15:50:03 -0800 Subject: [PATCH 18/58] Refactor some WebSocket internals to be a coroutine. --- c++/src/kj/compat/http.c++ | 41 +++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/c++/src/kj/compat/http.c++ b/c++/src/kj/compat/http.c++ index 29970d6af5..3b0ac1983f 100644 --- a/c++/src/kj/compat/http.c++ +++ b/c++/src/kj/compat/http.c++ @@ -3435,14 +3435,18 @@ private: currentlySending = true; - KJ_IF_SOME(p, sendingPong) { - // We recently sent a pong, make sure it's finished before proceeding. - auto promise = p.then([this, opcode, message]() { - currentlySending = false; - return sendImpl(opcode, message); - }); - sendingPong = kj::none; - return promise; + while (sendingPong != kj::none) { + KJ_IF_SOME(p, sendingPong) { + // Re-check in case of disconnect on a previous loop iteration. + KJ_REQUIRE(!disconnected, "WebSocket can't send after disconnect()"); + + // We recently sent a control message; make sure it's finished before proceeding. + // + // + auto localPromise = kj::mv(p); + sendingPong = kj::none; + co_await localPromise; + } } // We don't stop the application from sending further messages after close() -- this is the @@ -3491,21 +3495,16 @@ private: sendParts[0] = sendHeader.compose(true, useCompression, opcode, message.size(), mask); sendParts[1] = message; KJ_ASSERT(!sendHeader.hasRsv2or3(), "RSV bits 2 and 3 must be 0, as we do not currently " - "support an extension that would set these bits"); + "support an extension that would set these bits"); - auto promise = stream->write(sendParts).attach(kj::mv(compressedMessage)); - if (!mask.isZero()) { - promise = promise.attach(kj::mv(ownMessage)); - } - return promise.then([this, size = sendParts[0].size() + sendParts[1].size()]() { - currentlySending = false; + co_await stream->write(sendParts); + currentlySending = false; - // Send queued pong if needed. - if (queuedPong != kj::none) { - setUpSendingPong(); - } - sentBytes += size; - }); + // Send queued pong if needed. + if (queuedPong != kj::none) { + setUpSendingPong(); + }; + sentBytes += sendParts[0].size() + sendParts[1].size();; } void queuePong(kj::Array payload) { From 566791e6aa40c40a997f7df61be343cb63c0640c Mon Sep 17 00:00:00 2001 From: Samuel Merritt Date: Wed, 20 Dec 2023 17:42:00 -0800 Subject: [PATCH 19/58] WebSocket: Refactor sending of pongs to allow for other control messages. Renamed some things from "pong" to "control message"; the things that are actually pong-specific remain. --- c++/src/kj/compat/http.c++ | 102 +++++++++++++++++++++++-------------- 1 file changed, 64 insertions(+), 38 deletions(-) diff --git a/c++/src/kj/compat/http.c++ b/c++/src/kj/compat/http.c++ index 3b0ac1983f..123ff54d0c 100644 --- a/c++/src/kj/compat/http.c++ +++ b/c++/src/kj/compat/http.c++ @@ -2593,7 +2593,7 @@ public: : stream(kj::mv(stream)), maskKeyGenerator(maskKeyGenerator), compressionConfig(kj::mv(compressionConfigParam)), errorHandler(errorHandler.orDefault(*this)), - sendingPong(kj::mv(waitBeforeSend)), + sendingControlMessage(kj::mv(waitBeforeSend)), recvBuffer(kj::mv(buffer)), recvData(leftover) { #if KJ_HAS_ZLIB KJ_IF_SOME(config, compressionConfig) { @@ -2634,14 +2634,14 @@ public: kj::Promise disconnect() override { KJ_REQUIRE(!currentlySending, "another message send is already in progress"); - KJ_IF_SOME(p, sendingPong) { - // We recently sent a pong, make sure it's finished before proceeding. + KJ_IF_SOME(p, sendingControlMessage) { + // We recently sent a control message; make sure it's finished before proceeding. currentlySending = true; auto promise = p.then([this]() { currentlySending = false; return disconnect(); }); - sendingPong = kj::none; + sendingControlMessage = kj::none; return promise; } @@ -2652,8 +2652,8 @@ public: } void abort() override { - queuedPong = kj::none; - sendingPong = kj::none; + queuedControlMessage = kj::none; + sendingControlMessage = kj::none; disconnected = true; stream->abortRead(); stream->shutdownWrite(); @@ -3383,6 +3383,7 @@ private: static constexpr byte OPCODE_PONG = 10; static constexpr byte OPCODE_FIRST_CONTROL = 8; + static constexpr byte OPCODE_MAX = 15; // --------------------------------------------------------------------------- @@ -3401,18 +3402,32 @@ private: Header sendHeader; kj::ArrayPtr sendParts[2]; - kj::Maybe> queuedPong; - // queuedPong holds the body of the next pong to write, cleared when the pong is written. If a - // more recent ping arrives before the pong is actually written, we can update this value to - // instead respond to the more recent ping. + struct ControlMessage { + byte opcode; + kj::Array payload; + + ControlMessage(byte opcode, kj::Array payload) : opcode(opcode), payload(kj::mv(payload)) { + KJ_REQUIRE(opcode <= OPCODE_MAX); + } + }; - kj::Maybe> sendingPong; - // If a Pong is being sent asynchronously in response to a Ping, this is a promise for the - // completion of that send. + kj::Maybe queuedControlMessage; + // queuedControlMessage holds the body of the next control message to write; it is cleared when the message is + // written. + // + // It may be overwritten; for example, if a more recent ping arrives before the pong is actually written, we can + // update this value to instead respond to the more recent ping. If a bad frame shows up, we can overwrite any + // queued pong with a Close message. + // + // Currently, this holds either a Close or a Pong. + + kj::Maybe> sendingControlMessage; + // If a control message is being sent asynchronously (e.g., a Pong in response to a Ping), this is a + // promise for the completion of that send. // // Additionally, this member is used if we need to block our first send on WebSocket startup, // e.g. because we need to wait for HTTP handshake writes to flush before we can start sending - // WebSocket data. `sendingPong` was overloaded for this use case because the logic is the same. + // WebSocket data. `sendingControlMessage` was overloaded for this use case because the logic is the same. // Perhaps it should be renamed to `blockSend` or `writeQueue`. uint fragmentOpcode = 0; @@ -3435,8 +3450,8 @@ private: currentlySending = true; - while (sendingPong != kj::none) { - KJ_IF_SOME(p, sendingPong) { + while (sendingControlMessage != kj::none) { + KJ_IF_SOME(p, sendingControlMessage) { // Re-check in case of disconnect on a previous loop iteration. KJ_REQUIRE(!disconnected, "WebSocket can't send after disconnect()"); @@ -3444,7 +3459,7 @@ private: // // auto localPromise = kj::mv(p); - sendingPong = kj::none; + sendingControlMessage = kj::none; co_await localPromise; } } @@ -3500,53 +3515,64 @@ private: co_await stream->write(sendParts); currentlySending = false; - // Send queued pong if needed. - if (queuedPong != kj::none) { - setUpSendingPong(); + // Send queued control message if needed. + if (queuedControlMessage != kj::none) { + setUpSendingControlMessage(); }; sentBytes += sendParts[0].size() + sendParts[1].size();; } void queuePong(kj::Array payload) { - bool alreadyWaitingForPongWrite = (queuedPong != kj::none); + bool alreadyWaitingForPongWrite = false; + + KJ_IF_SOME(controlMessage, queuedControlMessage) { + if (controlMessage.opcode == OPCODE_CLOSE) { + // We're currently sending a Close message, which we only do (at least via queuedControlMessage) when we're + // closing the connection due to error. There's no point queueing a Pong that'll never be sent. + return; + } else { + alreadyWaitingForPongWrite = true; + } + } // Note: According to spec, if the server receives a second ping before responding to the // previous one, it can opt to respond only to the last ping. So we don't have to check if - // queuedPong is already non-null. - queuedPong = kj::mv(payload); + // queuedControlMessage is already non-null. + queuedControlMessage = ControlMessage(OPCODE_PONG, kj::mv(payload)); if (currentlySending) { // There is a message-send in progress, so we cannot write to the stream now. We will set - // up the pong write at the end of the message-send. + // up the control message write at the end of the message-send. return; } if (alreadyWaitingForPongWrite) { // We were already waiting for a pong to be written; don't need to queue another write. return; } - setUpSendingPong(); + setUpSendingControlMessage(); } - void setUpSendingPong() { - KJ_IF_SOME(promise, sendingPong) { - sendingPong = promise.then([this]() mutable { - return writeQueuedPong(); + void setUpSendingControlMessage() { + KJ_IF_SOME(promise, sendingControlMessage) { + sendingControlMessage = promise.then([this]() mutable { + return writeQueuedControlMessage(); }); } else { - sendingPong = writeQueuedPong(); + sendingControlMessage = writeQueuedControlMessage(); } } - kj::Promise writeQueuedPong() { - KJ_IF_SOME(q, queuedPong) { - kj::Array payload = kj::mv(q); - queuedPong = kj::none; + kj::Promise writeQueuedControlMessage() { + KJ_IF_SOME(q, queuedControlMessage) { + byte opcode = q.opcode; + kj::Array payload = kj::mv(q.payload); + queuedControlMessage = kj::none; if (hasSentClose || disconnected) { return kj::READY_NOW; } - sendParts[0] = sendHeader.compose(true, false, OPCODE_PONG, + sendParts[0] = sendHeader.compose(true, false, opcode, payload.size(), Mask(maskKeyGenerator)); sendParts[1] = payload; return stream->write(sendParts).attach(kj::mv(payload)); @@ -3556,12 +3582,12 @@ private: } kj::Promise optimizedPumpTo(WebSocketImpl& other) { - KJ_IF_SOME(p, other.sendingPong) { - // We recently sent a pong, make sure it's finished before proceeding. + KJ_IF_SOME(p, other.sendingControlMessage) { + // We recently sent a control message; make sure it's finished before proceeding. auto promise = p.then([this, &other]() { return optimizedPumpTo(other); }); - other.sendingPong = kj::none; + other.sendingControlMessage = kj::none; return promise; } From c0f31ae84e3759605d05c8b7873194a728d82b9a Mon Sep 17 00:00:00 2001 From: James M Snell Date: Wed, 3 Jan 2024 12:59:01 -0800 Subject: [PATCH 20/58] Move GlobFilter into a separate header/unit Addresses a long standing todo in workerd where we duplicate GlobFilter there. --- c++/Makefile.am | 2 + c++/src/kj/BUILD.bazel | 2 + c++/src/kj/CMakeLists.txt | 2 + c++/src/kj/glob-filter.c++ | 109 +++++++++++++++++++++++++++++++++++++ c++/src/kj/glob-filter.h | 45 +++++++++++++++ c++/src/kj/test-test.c++ | 1 + c++/src/kj/test.c++ | 91 +------------------------------ c++/src/kj/test.h | 19 +------ 8 files changed, 163 insertions(+), 108 deletions(-) create mode 100644 c++/src/kj/glob-filter.c++ create mode 100644 c++/src/kj/glob-filter.h diff --git a/c++/Makefile.am b/c++/Makefile.am index fc05b7a560..c2296798f7 100644 --- a/c++/Makefile.am +++ b/c++/Makefile.am @@ -151,6 +151,7 @@ includekj_HEADERS = \ src/kj/vector.h \ src/kj/string.h \ src/kj/string-tree.h \ + src/kj/glob-filter.h \ src/kj/hash.h \ src/kj/table.h \ src/kj/map.h \ @@ -274,6 +275,7 @@ libkj_la_SOURCES= \ src/kj/string.c++ \ src/kj/string-tree.c++ \ src/kj/source-location.c++ \ + src/kj/glob-filter.c++ src/kj/hash.c++ \ src/kj/table.c++ \ src/kj/encoding.c++ \ diff --git a/c++/src/kj/BUILD.bazel b/c++/src/kj/BUILD.bazel index b216f6144f..4566cfcbf0 100644 --- a/c++/src/kj/BUILD.bazel +++ b/c++/src/kj/BUILD.bazel @@ -13,6 +13,7 @@ cc_library( "encoding.c++", "exception.c++", "filesystem.c++", + "glob-filter.c++", "hash.c++", "io.c++", "list.c++", @@ -42,6 +43,7 @@ cc_library( "exception.h", "filesystem.h", "function.h", + "glob-filter.h", "hash.h", "io.h", "list.h", diff --git a/c++/src/kj/CMakeLists.txt b/c++/src/kj/CMakeLists.txt index bd72009254..fc20ae7159 100644 --- a/c++/src/kj/CMakeLists.txt +++ b/c++/src/kj/CMakeLists.txt @@ -13,6 +13,7 @@ set(kj_sources_lite mutex.c++ string.c++ source-location.c++ + glob-filter.c++ hash.c++ table.c++ thread.c++ @@ -49,6 +50,7 @@ set(kj_headers string.h string-tree.h source-location.h + glob-filter.h hash.h table.h map.h diff --git a/c++/src/kj/glob-filter.c++ b/c++/src/kj/glob-filter.c++ new file mode 100644 index 0000000000..d8ec620d0c --- /dev/null +++ b/c++/src/kj/glob-filter.c++ @@ -0,0 +1,109 @@ +// Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#include "glob-filter.h" + +namespace kj { + +GlobFilter::GlobFilter(const char* pattern): pattern(heapString(pattern)) {} +GlobFilter::GlobFilter(ArrayPtr pattern): pattern(heapString(pattern)) {} + +bool GlobFilter::matches(StringPtr name) { + // Get out your computer science books. We're implementing a non-deterministic finite automaton. + // + // Our NDFA has one "state" corresponding to each character in the pattern. + // + // As you may recall, an NDFA can be transformed into a DFA where every state in the DFA + // represents some combination of states in the NDFA. Therefore, we actually have to store a + // list of states here. (Actually, what we really want is a set of states, but because our + // patterns are mostly non-cyclic a list of states should work fine and be a bit more efficient.) + + // Our state list starts out pointing only at the start of the pattern. + states.resize(0); + states.add(0); + + Vector scratch; + + // Iterate through each character in the name. + for (char c: name) { + // Pull the current set of states off to the side, so that we can populate `states` with the + // new set of states. + Vector oldStates = kj::mv(states); + states = kj::mv(scratch); + states.resize(0); + + // The pattern can omit a leading path. So if we're at a '/' then enter the state machine at + // the beginning on the next char. + if (c == '/' || c == '\\') { + states.add(0); + } + + // Process each state. + for (uint state: oldStates) { + applyState(c, state); + } + + // Store the previous state vector for reuse. + scratch = kj::mv(oldStates); + } + + // If any one state is at the end of the pattern (or at a wildcard just before the end of the + // pattern), we have a match. + for (uint state: states) { + while (state < pattern.size() && pattern[state] == '*') { + ++state; + } + if (state == pattern.size()) { + return true; + } + } + return false; +} + +void GlobFilter::applyState(char c, int state) { + if (state < pattern.size()) { + switch (pattern[state]) { + case '*': + // At a '*', we both re-add the current state and attempt to match the *next* state. + if (c != '/' && c != '\\') { // '*' doesn't match '/'. + states.add(state); + } + applyState(c, state + 1); + break; + + case '?': + // A '?' matches one character (never a '/'). + if (c != '/' && c != '\\') { + states.add(state + 1); + } + break; + + default: + // Any other character matches only itself. + if (c == pattern[state]) { + states.add(state + 1); + } + break; + } + } +} + +} // namespace kj diff --git a/c++/src/kj/glob-filter.h b/c++/src/kj/glob-filter.h new file mode 100644 index 0000000000..583db6f0c1 --- /dev/null +++ b/c++/src/kj/glob-filter.h @@ -0,0 +1,45 @@ +// Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#pragma once + +#include +#include + +namespace kj { + +class GlobFilter { + // Implements glob filters for the --filter flag. + +public: + explicit GlobFilter(const char* pattern); + explicit GlobFilter(ArrayPtr pattern); + + bool matches(StringPtr name); + +private: + String pattern; + Vector states; + + void applyState(char c, int state); +}; + +} // namespace kj diff --git a/c++/src/kj/test-test.c++ b/c++/src/kj/test-test.c++ index 379c77d45f..669e252c98 100644 --- a/c++/src/kj/test-test.c++ +++ b/c++/src/kj/test-test.c++ @@ -21,6 +21,7 @@ #include "common.h" #include "test.h" +#include "glob-filter.h" #include #include #include diff --git a/c++/src/kj/test.c++ b/c++/src/kj/test.c++ index cb07d5675b..8b977bc7fa 100644 --- a/c++/src/kj/test.c++ +++ b/c++/src/kj/test.c++ @@ -68,95 +68,6 @@ size_t TestCase::iterCount() { // ======================================================================================= -namespace _ { // private - -GlobFilter::GlobFilter(const char* pattern): pattern(heapString(pattern)) {} -GlobFilter::GlobFilter(ArrayPtr pattern): pattern(heapString(pattern)) {} - -bool GlobFilter::matches(StringPtr name) { - // Get out your computer science books. We're implementing a non-deterministic finite automaton. - // - // Our NDFA has one "state" corresponding to each character in the pattern. - // - // As you may recall, an NDFA can be transformed into a DFA where every state in the DFA - // represents some combination of states in the NDFA. Therefore, we actually have to store a - // list of states here. (Actually, what we really want is a set of states, but because our - // patterns are mostly non-cyclic a list of states should work fine and be a bit more efficient.) - - // Our state list starts out pointing only at the start of the pattern. - states.resize(0); - states.add(0); - - Vector scratch; - - // Iterate through each character in the name. - for (char c: name) { - // Pull the current set of states off to the side, so that we can populate `states` with the - // new set of states. - Vector oldStates = kj::mv(states); - states = kj::mv(scratch); - states.resize(0); - - // The pattern can omit a leading path. So if we're at a '/' then enter the state machine at - // the beginning on the next char. - if (c == '/' || c == '\\') { - states.add(0); - } - - // Process each state. - for (uint state: oldStates) { - applyState(c, state); - } - - // Store the previous state vector for reuse. - scratch = kj::mv(oldStates); - } - - // If any one state is at the end of the pattern (or at a wildcard just before the end of the - // pattern), we have a match. - for (uint state: states) { - while (state < pattern.size() && pattern[state] == '*') { - ++state; - } - if (state == pattern.size()) { - return true; - } - } - return false; -} - -void GlobFilter::applyState(char c, int state) { - if (state < pattern.size()) { - switch (pattern[state]) { - case '*': - // At a '*', we both re-add the current state and attempt to match the *next* state. - if (c != '/' && c != '\\') { // '*' doesn't match '/'. - states.add(state); - } - applyState(c, state + 1); - break; - - case '?': - // A '?' matches one character (never a '/'). - if (c != '/' && c != '\\') { - states.add(state + 1); - } - break; - - default: - // Any other character matches only itself. - if (c == pattern[state]) { - states.add(state + 1); - } - break; - } - } -} - -} // namespace _ (private) - -// ======================================================================================= - namespace { class TestExceptionCallback: public ExceptionCallback { @@ -255,7 +166,7 @@ public: } } - _::GlobFilter filter(filePattern); + kj::GlobFilter filter(filePattern); for (TestCase* testCase = testCasesHead; testCase != nullptr; testCase = testCase->next) { if (!testCase->matchedFilter && filter.matches(testCase->file) && diff --git a/c++/src/kj/test.h b/c++/src/kj/test.h index 1d845c1452..e40bfbc96b 100644 --- a/c++/src/kj/test.h +++ b/c++/src/kj/test.h @@ -22,6 +22,7 @@ #pragma once #include +#include #include #include #include // work-around macro conflict with `ERROR` @@ -186,24 +187,6 @@ class LogExpectation: public ExceptionCallback { UnwindDetector unwindDetector; }; -class GlobFilter { - // Implements glob filters for the --filter flag. - // - // Exposed in header only for testing. - -public: - explicit GlobFilter(const char* pattern); - explicit GlobFilter(ArrayPtr pattern); - - bool matches(StringPtr name); - -private: - String pattern; - Vector states; - - void applyState(char c, int state); -}; - } // namespace _ (private) } // namespace kj From e02fb4863b7b76c63911aade36aaffd423d66aed Mon Sep 17 00:00:00 2001 From: James M Snell Date: Fri, 12 Jan 2024 16:30:12 -0800 Subject: [PATCH 21/58] Add new KJ_DISALLOW_AS_COROUTINE_PARAM mechanism Allows a type to be explicitly marked in a way that prevent it from being passed into a kj promise coroutine as an arg. ``` class Foo { private: KJ_DISALLOW_AS_COROUTINE_PARAM; }; kj::Promise simpleCoroutine(Foo foo) { // Compile Error! co_return; } ``` --- c++/src/kj/async-inl.h | 4 ++++ c++/src/kj/common-test.c++ | 28 +++++++++++++++++++++++++ c++/src/kj/common.h | 43 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 75 insertions(+) diff --git a/c++/src/kj/async-inl.h b/c++/src/kj/async-inl.h index 8032439059..b8efd04674 100644 --- a/c++/src/kj/async-inl.h +++ b/c++/src/kj/async-inl.h @@ -2062,6 +2062,7 @@ template concept NoWaitScope = !isSameType, WaitScope>(); // Define a Concept to use in our `coroutine_traits` specialization to validate allowable coroutine // parameter types. +// TODO(cleanup): This can be removed by adding KJ_DISALLOW_AS_COROUTINE_PARAM to WaitScope. } // namespace kj::_ @@ -2073,6 +2074,9 @@ template struct coroutine_traits, Args...> { // `Args...` are the coroutine's parameter types. + static_assert((!kj::_::isDisallowedInCoroutine() && ...), + "Disallowed type in coroutine"); + static_assert((::kj::_::NoWaitScope && ...), "Coroutines are not allowed to accept `WaitScope` parameters."); // Coroutines should never have access to a WaitScope. If they could, coroutines could be both a diff --git a/c++/src/kj/common-test.c++ b/c++/src/kj/common-test.c++ index af5ad74caf..b138777fc2 100644 --- a/c++/src/kj/common-test.c++ +++ b/c++/src/kj/common-test.c++ @@ -940,5 +940,33 @@ KJ_TEST("ArrayPtr::as") { KJ_EXPECT(stdPtr.size() == 5); } +// Verifies the expected values of kj::isDisallowedInCoroutine + +struct DisallowedInCoroutineStruct { + KJ_DISALLOW_AS_COROUTINE_PARAM; +}; +class DisallowedInCoroutinePublic { +public: + KJ_DISALLOW_AS_COROUTINE_PARAM; +}; +class DisallowedInCoroutinePrivate { +private: + KJ_DISALLOW_AS_COROUTINE_PARAM; +}; +struct AllowedInCoroutine {}; + +static_assert(_::isDisallowedInCoroutine()); +static_assert(_::isDisallowedInCoroutine()); +static_assert(_::isDisallowedInCoroutine()); +static_assert(_::isDisallowedInCoroutine()); +static_assert(_::isDisallowedInCoroutine()); +static_assert(_::isDisallowedInCoroutine()); +static_assert(_::isDisallowedInCoroutine()); +static_assert(_::isDisallowedInCoroutine()); +static_assert(_::isDisallowedInCoroutine()); +static_assert(!_::isDisallowedInCoroutine()); +static_assert(!_::isDisallowedInCoroutine()); +static_assert(!_::isDisallowedInCoroutine()); + } // namespace } // namespace kj diff --git a/c++/src/kj/common.h b/c++/src/kj/common.h index ac4d906e9f..b768bdbf4c 100644 --- a/c++/src/kj/common.h +++ b/c++/src/kj/common.h @@ -2098,6 +2098,49 @@ _::Deferred defer(Func&& func) { #define KJ_DEFER(code) auto KJ_UNIQUE_NAME(_kjDefer) = ::kj::defer([&](){code;}) // Run the given code when the function exits, whether by return or exception. +// ======================================================================================= +// IsDisallowedInCoroutine + +namespace _ { + +template +struct IsDisallowedInCoroutine { + static constexpr bool value = false; +}; + +template +struct IsDisallowedInCoroutine::_kj_DissalowedInCoroutine> { + static constexpr bool value = true; +}; + +template +struct IsDisallowedInCoroutine { + static constexpr bool value = true; +}; + +template +constexpr bool isDisallowedInCoroutine() { + return IsDisallowedInCoroutine::value; +} + +} // namespace _ + +#define KJ_DISALLOW_AS_COROUTINE_PARAM \ + using _kj_DissalowedInCoroutine = void; \ + template friend constexpr bool kj::_::isDisallowedInCoroutine(); \ + template friend struct kj::_::IsDisallowedInCoroutine +// Place in the body of a class or struct to indicate that an instance of or reference to this +// type cannot be passed as the parameter to a KJ coroutine. This makes sense, for example, for +// mutex locks, or other types which should never be held across a co_await. +// +// (Types annotated with this likely also should not be used as local variables inside coroutines, +// but there is no way for us to enforce that.) +// +// struct Foo { +// KJ_DISALLOW_AS_COROUTINE_PARAM; +// } +// + } // namespace kj KJ_END_HEADER From 7e74c891bd9686e2a4f1425e0da4a9fab450d700 Mon Sep 17 00:00:00 2001 From: James M Snell Date: Sat, 13 Jan 2024 15:01:34 -0800 Subject: [PATCH 22/58] Move GlobFilter tests into glob-filter-test.c++ --- c++/Makefile.am | 3 +- c++/src/kj/BUILD.bazel | 1 + c++/src/kj/CMakeLists.txt | 1 + c++/src/kj/glob-filter-test.c++ | 84 +++++++++++++++++++++++++++++++++ c++/src/kj/test-test.c++ | 53 --------------------- 5 files changed, 88 insertions(+), 54 deletions(-) create mode 100644 c++/src/kj/glob-filter-test.c++ diff --git a/c++/Makefile.am b/c++/Makefile.am index c2296798f7..101365bd46 100644 --- a/c++/Makefile.am +++ b/c++/Makefile.am @@ -275,7 +275,7 @@ libkj_la_SOURCES= \ src/kj/string.c++ \ src/kj/string-tree.c++ \ src/kj/source-location.c++ \ - src/kj/glob-filter.c++ + src/kj/glob-filter.c++ \ src/kj/hash.c++ \ src/kj/table.c++ \ src/kj/encoding.c++ \ @@ -588,6 +588,7 @@ capnp_test_SOURCES = \ src/kj/filesystem-test.c++ \ src/kj/filesystem-disk-test.c++ \ src/kj/test-test.c++ \ + src/kj/glob-filter-test.c++ \ src/capnp/common-test.c++ \ src/capnp/blob-test.c++ \ src/capnp/endian-test.c++ \ diff --git a/c++/src/kj/BUILD.bazel b/c++/src/kj/BUILD.bazel index 4566cfcbf0..de89d41eb6 100644 --- a/c++/src/kj/BUILD.bazel +++ b/c++/src/kj/BUILD.bazel @@ -175,6 +175,7 @@ cc_library( "string-tree-test.c++", "table-test.c++", "test-test.c++", + "glob-filter-test.c++", "thread-test.c++", "time-test.c++", "tuple-test.c++", diff --git a/c++/src/kj/CMakeLists.txt b/c++/src/kj/CMakeLists.txt index fc20ae7159..9cc7a9cac0 100644 --- a/c++/src/kj/CMakeLists.txt +++ b/c++/src/kj/CMakeLists.txt @@ -254,6 +254,7 @@ if(BUILD_TESTING) mutex-test.c++ time-test.c++ test-test.c++ + glob-filter-test.c++ std/iostream-test.c++ ) # TODO: Link with librt on Solaris for sched_yield diff --git a/c++/src/kj/glob-filter-test.c++ b/c++/src/kj/glob-filter-test.c++ new file mode 100644 index 0000000000..f00218c1fa --- /dev/null +++ b/c++/src/kj/glob-filter-test.c++ @@ -0,0 +1,84 @@ +// Copyright (c) 2013-2014 Sandstorm Development Group, Inc. and contributors +// Licensed under the MIT License: +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#include "common.h" +#include "test.h" +#include "glob-filter.h" + +namespace kj { +namespace _ { +namespace { + +KJ_TEST("GlobFilter") { + { + GlobFilter filter("foo"); + + KJ_EXPECT(filter.matches("foo")); + KJ_EXPECT(!filter.matches("bar")); + KJ_EXPECT(!filter.matches("foob")); + KJ_EXPECT(!filter.matches("foobbb")); + KJ_EXPECT(!filter.matches("fobbbb")); + KJ_EXPECT(!filter.matches("bfoo")); + KJ_EXPECT(!filter.matches("bbbbbfoo")); + KJ_EXPECT(filter.matches("bbbbb/foo")); + KJ_EXPECT(filter.matches("bar/baz/foo")); + } + + { + GlobFilter filter("foo*"); + + KJ_EXPECT(filter.matches("foo")); + KJ_EXPECT(!filter.matches("bar")); + KJ_EXPECT(filter.matches("foob")); + KJ_EXPECT(filter.matches("foobbb")); + KJ_EXPECT(!filter.matches("fobbbb")); + KJ_EXPECT(!filter.matches("bfoo")); + KJ_EXPECT(!filter.matches("bbbbbfoo")); + KJ_EXPECT(filter.matches("bbbbb/foo")); + KJ_EXPECT(filter.matches("bar/baz/foo")); + } + + { + GlobFilter filter("foo*bar"); + + KJ_EXPECT(filter.matches("foobar")); + KJ_EXPECT(filter.matches("fooxbar")); + KJ_EXPECT(filter.matches("fooxxxbar")); + KJ_EXPECT(!filter.matches("foo/bar")); + KJ_EXPECT(filter.matches("blah/fooxxxbar")); + KJ_EXPECT(!filter.matches("blah/xxfooxxxbar")); + } + + { + GlobFilter filter("foo?bar"); + + KJ_EXPECT(!filter.matches("foobar")); + KJ_EXPECT(filter.matches("fooxbar")); + KJ_EXPECT(!filter.matches("fooxxxbar")); + KJ_EXPECT(!filter.matches("foo/bar")); + KJ_EXPECT(filter.matches("blah/fooxbar")); + KJ_EXPECT(!filter.matches("blah/xxfooxbar")); + } +} + +} // namespace +} // namespace _ +} // namespace kj diff --git a/c++/src/kj/test-test.c++ b/c++/src/kj/test-test.c++ index 669e252c98..1cfa9a31f0 100644 --- a/c++/src/kj/test-test.c++ +++ b/c++/src/kj/test-test.c++ @@ -21,7 +21,6 @@ #include "common.h" #include "test.h" -#include "glob-filter.h" #include #include #include @@ -34,58 +33,6 @@ namespace kj { namespace _ { namespace { -KJ_TEST("GlobFilter") { - { - GlobFilter filter("foo"); - - KJ_EXPECT(filter.matches("foo")); - KJ_EXPECT(!filter.matches("bar")); - KJ_EXPECT(!filter.matches("foob")); - KJ_EXPECT(!filter.matches("foobbb")); - KJ_EXPECT(!filter.matches("fobbbb")); - KJ_EXPECT(!filter.matches("bfoo")); - KJ_EXPECT(!filter.matches("bbbbbfoo")); - KJ_EXPECT(filter.matches("bbbbb/foo")); - KJ_EXPECT(filter.matches("bar/baz/foo")); - } - - { - GlobFilter filter("foo*"); - - KJ_EXPECT(filter.matches("foo")); - KJ_EXPECT(!filter.matches("bar")); - KJ_EXPECT(filter.matches("foob")); - KJ_EXPECT(filter.matches("foobbb")); - KJ_EXPECT(!filter.matches("fobbbb")); - KJ_EXPECT(!filter.matches("bfoo")); - KJ_EXPECT(!filter.matches("bbbbbfoo")); - KJ_EXPECT(filter.matches("bbbbb/foo")); - KJ_EXPECT(filter.matches("bar/baz/foo")); - } - - { - GlobFilter filter("foo*bar"); - - KJ_EXPECT(filter.matches("foobar")); - KJ_EXPECT(filter.matches("fooxbar")); - KJ_EXPECT(filter.matches("fooxxxbar")); - KJ_EXPECT(!filter.matches("foo/bar")); - KJ_EXPECT(filter.matches("blah/fooxxxbar")); - KJ_EXPECT(!filter.matches("blah/xxfooxxxbar")); - } - - { - GlobFilter filter("foo?bar"); - - KJ_EXPECT(!filter.matches("foobar")); - KJ_EXPECT(filter.matches("fooxbar")); - KJ_EXPECT(!filter.matches("fooxxxbar")); - KJ_EXPECT(!filter.matches("foo/bar")); - KJ_EXPECT(filter.matches("blah/fooxbar")); - KJ_EXPECT(!filter.matches("blah/xxfooxbar")); - } -} - KJ_TEST("expect exit from exit") { KJ_EXPECT_EXIT(42, _exit(42)); KJ_EXPECT_EXIT(kj::none, _exit(42)); From 09560a220793e0e15e92c2086a33588d8c2b4dd3 Mon Sep 17 00:00:00 2001 From: Samuel Merritt Date: Tue, 2 Jan 2024 16:05:12 -0800 Subject: [PATCH 23/58] WebSocket: reply with Close upon receipt of an unexpected continuation frame. Previously, WebSocketImpl::receive() would just throw an exception when this happened, which resulted in the destruction of the WebSocket and the closing of its underlying TCP connection. The WebSocket client didn't get any indication of what they did wrong. Now, the client gets a Close frame with error code 1002 (terminating due to protocol error) and a reason ("Unexpected continuation frame"), so they can learn what they did wrong.. WebSocketImpl::receive still throws, but it waits to do so until after the Close has been sent. There are a bunch of different protocol errors that deserve Close frames, but this commit only addresses one so that the changes to WebSocketImpl are easier to review. --- c++/src/kj/compat/http-test.c++ | 36 +++++++++++++- c++/src/kj/compat/http.c++ | 85 ++++++++++++++++++++++++--------- 2 files changed, 97 insertions(+), 24 deletions(-) diff --git a/c++/src/kj/compat/http-test.c++ b/c++/src/kj/compat/http-test.c++ index 148d035e1d..4301ac47a4 100644 --- a/c++/src/kj/compat/http-test.c++ +++ b/c++/src/kj/compat/http-test.c++ @@ -1907,6 +1907,33 @@ KJ_TEST("WebSocket unexpected RSV bits") { clientTask.wait(waitScope); } +void assertContainsWebSocketClose(kj::ArrayPtr data, uint16_t code, kj::Maybe messageSubstr) { + KJ_ASSERT(data.size() >= 2); // The smallest possible Close frame has size 2. + KJ_ASSERT(data.size() <= 127); // Maximum size for control frames. + KJ_ASSERT((data[0] & 0xf0) == 0x80); // Only the FIN flag is set. + KJ_ASSERT((data[0] & 0x0f) == 8); // OPCODE_CLOSE + + size_t payloadSize = data[1] & 0x7f; + + if (payloadSize == 0) { + // A Close frame with no body has no status code and no reason. + KJ_ASSERT(code == 1005); + KJ_ASSERT(messageSubstr == kj::none); + } else { + KJ_ASSERT(code != 1005); + } + auto payload = data.slice(2); + + KJ_ASSERT(payload.size() >= 2); // The first two bytes are the status code, so we better have at least two bytes. + uint16_t gotCode = (payload[0] << 8) | payload[1]; + KJ_ASSERT(gotCode == code); + + KJ_IF_SOME(needle, messageSubstr) { + auto reason = kj::str(payload.asChars().slice(2)); + KJ_ASSERT(reason.contains(needle), reason, needle); + } +} + KJ_TEST("WebSocket unexpected continuation frame") { KJ_HTTP_TEST_SETUP_IO; auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; @@ -1919,7 +1946,11 @@ KJ_TEST("WebSocket unexpected continuation frame") { 0x80, 0x06, 'h', 'e', 'l', 'l', 'o', ' ', // Continuation frame with no start frame, plus FIN }; - auto clientTask = client->write(DATA, sizeof(DATA)); + auto rawCloseMessage = kj::heapArray(129); + + auto clientTask = client->write(DATA, sizeof(DATA)).then([&]() { + return client->tryRead(rawCloseMessage.begin(), 2, rawCloseMessage.size()); + }); { bool gotException = false; @@ -1930,7 +1961,8 @@ KJ_TEST("WebSocket unexpected continuation frame") { KJ_ASSERT(errorCatcher.errors[0].statusCode == 1002); } - clientTask.wait(waitScope); + auto nread = clientTask.wait(waitScope); + assertContainsWebSocketClose(rawCloseMessage.slice(0, nread), 1002, "Unexpected continuation frame"_kjc); } KJ_TEST("WebSocket missing continuation frame") { diff --git a/c++/src/kj/compat/http.c++ b/c++/src/kj/compat/http.c++ index 123ff54d0c..412782d5e9 100644 --- a/c++/src/kj/compat/http.c++ +++ b/c++/src/kj/compat/http.c++ @@ -2615,18 +2615,7 @@ public: } kj::Promise close(uint16_t code, kj::StringPtr reason) override { - kj::Array payload; - if (code == 1005) { - KJ_REQUIRE(reason.size() == 0, "WebSocket close code 1005 cannot have a reason"); - - // code 1005 -- leave payload empty - } else { - payload = heapArray(reason.size() + 2); - payload[0] = code >> 8; - payload[1] = code; - memcpy(payload.begin() + 2, reason.begin(), reason.size()); - } - + kj::Array payload = serializeClose(code, reason); auto promise = sendImpl(OPCODE_CLOSE, payload); return promise.attach(kj::mv(payload)); } @@ -2664,6 +2653,10 @@ public: } kj::Promise receive(size_t maxSize) override { + KJ_IF_SOME(ex, receiveException) { + return kj::cp(ex); + } + size_t headerSize = Header::headerSize(recvData.begin(), recvData.size()); if (headerSize > recvData.size()) { @@ -2713,9 +2706,14 @@ public: bool isData = opcode < OPCODE_FIRST_CONTROL; if (opcode == OPCODE_CONTINUATION) { if (fragments.empty()) { - return errorHandler.handleWebSocketProtocolError({ - 1002, "Unexpected continuation frame" - }); + auto paf = newPromiseAndFulfiller(); + queueClose(1002, "Unexpected continuation frame", kj::mv(paf.fulfiller)); + + return paf.promise.then([this]() -> kj::Promise { + return errorHandler.handleWebSocketProtocolError({ + 1002, "Unexpected continuation frame" + }); + }); } opcode = fragmentOpcode; @@ -3405,12 +3403,20 @@ private: struct ControlMessage { byte opcode; kj::Array payload; + kj::Maybe>> fulfiller; - ControlMessage(byte opcode, kj::Array payload) : opcode(opcode), payload(kj::mv(payload)) { + ControlMessage( + byte opcodeParam, + kj::Array payloadParam, + kj::Maybe>> fulfillerParam) + : opcode(opcodeParam), payload(kj::mv(payloadParam)), fulfiller(kj::mv(fulfillerParam)) { KJ_REQUIRE(opcode <= OPCODE_MAX); } }; + kj::Maybe receiveException; + // If set, all future calls to receive() will throw this exception. + kj::Maybe queuedControlMessage; // queuedControlMessage holds the body of the next control message to write; it is cleared when the message is // written. @@ -3450,7 +3456,7 @@ private: currentlySending = true; - while (sendingControlMessage != kj::none) { + for (;;) { KJ_IF_SOME(p, sendingControlMessage) { // Re-check in case of disconnect on a previous loop iteration. KJ_REQUIRE(!disconnected, "WebSocket can't send after disconnect()"); @@ -3461,6 +3467,8 @@ private: auto localPromise = kj::mv(p); sendingControlMessage = kj::none; co_await localPromise; + } else { + break; } } @@ -3522,6 +3530,33 @@ private: sentBytes += sendParts[0].size() + sendParts[1].size();; } + void queueClose(uint16_t code, kj::StringPtr reason, kj::Own> fulfiller) { + bool alreadyWaiting = (queuedControlMessage != kj::none); + + // Overwrite any previously-queued message. If there is one, it's just a Pong, and this Close supersedes it. + auto payload = serializeClose(code, reason); + queuedControlMessage = ControlMessage(OPCODE_CLOSE, kj::mv(payload), kj::mv(fulfiller)); + + if (!alreadyWaiting) { + setUpSendingControlMessage(); + } + } + + kj::Array serializeClose(uint16_t code, kj::StringPtr reason) { + kj::Array payload; + if (code == 1005) { + KJ_REQUIRE(reason.size() == 0, "WebSocket close code 1005 cannot have a reason"); + + // code 1005 -- leave payload empty + } else { + payload = heapArray(reason.size() + 2); + payload[0] = code >> 8; + payload[1] = code; + memcpy(payload.begin() + 2, reason.begin(), reason.size()); + } + return kj::mv(payload); + } + void queuePong(kj::Array payload) { bool alreadyWaitingForPongWrite = false; @@ -3531,6 +3566,7 @@ private: // closing the connection due to error. There's no point queueing a Pong that'll never be sent. return; } else { + KJ_ASSERT(controlMessage.opcode == OPCODE_PONG); alreadyWaitingForPongWrite = true; } } @@ -3538,7 +3574,7 @@ private: // Note: According to spec, if the server receives a second ping before responding to the // previous one, it can opt to respond only to the last ping. So we don't have to check if // queuedControlMessage is already non-null. - queuedControlMessage = ControlMessage(OPCODE_PONG, kj::mv(payload)); + queuedControlMessage = ControlMessage(OPCODE_PONG, kj::mv(payload), kj::none); if (currentlySending) { // There is a message-send in progress, so we cannot write to the stream now. We will set @@ -3566,18 +3602,23 @@ private: KJ_IF_SOME(q, queuedControlMessage) { byte opcode = q.opcode; kj::Array payload = kj::mv(q.payload); + auto maybeFulfiller = kj::mv(q.fulfiller); queuedControlMessage = kj::none; if (hasSentClose || disconnected) { - return kj::READY_NOW; + KJ_IF_SOME(fulfiller, maybeFulfiller) { + fulfiller->fulfill(); + } + co_return; } sendParts[0] = sendHeader.compose(true, false, opcode, payload.size(), Mask(maskKeyGenerator)); sendParts[1] = payload; - return stream->write(sendParts).attach(kj::mv(payload)); - } else { - return kj::READY_NOW; + co_await stream->write(sendParts); + KJ_IF_SOME(fulfiller, maybeFulfiller) { + fulfiller->fulfill(); + } } } From 6b75731e1da06085706fb9dfb1693bd9404d3104 Mon Sep 17 00:00:00 2001 From: Samuel Merritt Date: Wed, 3 Jan 2024 15:07:16 -0800 Subject: [PATCH 24/58] WebSocket: reply with Close for various protocol errors. Protocol errors are things like a client setting RSV bits 2 or 3, sending a compressed message when no compression was negotiated, and so on. --- c++/src/kj/compat/http-test.c++ | 94 ++++++++++++++++++++------------- c++/src/kj/compat/http.c++ | 58 ++++++++++---------- 2 files changed, 85 insertions(+), 67 deletions(-) diff --git a/c++/src/kj/compat/http-test.c++ b/c++/src/kj/compat/http-test.c++ index 4301ac47a4..0888b3b884 100644 --- a/c++/src/kj/compat/http-test.c++ +++ b/c++/src/kj/compat/http-test.c++ @@ -1879,34 +1879,6 @@ public: } }; -KJ_TEST("WebSocket unexpected RSV bits") { - KJ_HTTP_TEST_SETUP_IO; - auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; - - WebSocketErrorCatcher errorCatcher; - auto client = kj::mv(pipe.ends[0]); - auto server = newWebSocket(kj::mv(pipe.ends[1]), kj::none, kj::none, errorCatcher); - - byte DATA[] = { - 0x01, 0x06, 'h', 'e', 'l', 'l', 'o', ' ', - - 0xF0, 0x05, 'w', 'o', 'r', 'l', 'd' // all RSV bits set, plus FIN - }; - - auto clientTask = client->write(DATA, sizeof(DATA)); - - { - bool gotException = false; - auto serverTask = server->receive().then([](auto&& m) {}, [&gotException](kj::Exception&& ex) { gotException = true; }); - serverTask.wait(waitScope); - KJ_ASSERT(gotException); - KJ_ASSERT(errorCatcher.errors.size() == 1); - KJ_ASSERT(errorCatcher.errors[0].statusCode == 1002); - } - - clientTask.wait(waitScope); -} - void assertContainsWebSocketClose(kj::ArrayPtr data, uint16_t code, kj::Maybe messageSubstr) { KJ_ASSERT(data.size() >= 2); // The smallest possible Close frame has size 2. KJ_ASSERT(data.size() <= 127); // Maximum size for control frames. @@ -1934,6 +1906,38 @@ void assertContainsWebSocketClose(kj::ArrayPtr data, uint16_t code, kj } } +KJ_TEST("WebSocket unexpected RSV bits") { + KJ_HTTP_TEST_SETUP_IO; + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + WebSocketErrorCatcher errorCatcher; + auto client = kj::mv(pipe.ends[0]); + auto server = newWebSocket(kj::mv(pipe.ends[1]), kj::none, kj::none, errorCatcher); + + byte DATA[] = { + 0x01, 0x06, 'h', 'e', 'l', 'l', 'o', ' ', + + 0xF0, 0x05, 'w', 'o', 'r', 'l', 'd' // all RSV bits set, plus FIN + }; + + auto rawCloseMessage = kj::heapArray(129); + auto clientTask = client->write(DATA, sizeof(DATA)).then([&]() { + return client->tryRead(rawCloseMessage.begin(), 2, rawCloseMessage.size()); + }); + + { + bool gotException = false; + auto serverTask = server->receive().then([](auto&& m) {}, [&gotException](kj::Exception&& ex) { gotException = true; }); + serverTask.wait(waitScope); + KJ_ASSERT(gotException); + KJ_ASSERT(errorCatcher.errors.size() == 1); + KJ_ASSERT(errorCatcher.errors[0].statusCode == 1002); + } + + auto nread = clientTask.wait(waitScope); + assertContainsWebSocketClose(rawCloseMessage.slice(0, nread), 1002, "RSV bits"_kjc); +} + KJ_TEST("WebSocket unexpected continuation frame") { KJ_HTTP_TEST_SETUP_IO; auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; @@ -1947,7 +1951,6 @@ KJ_TEST("WebSocket unexpected continuation frame") { }; auto rawCloseMessage = kj::heapArray(129); - auto clientTask = client->write(DATA, sizeof(DATA)).then([&]() { return client->tryRead(rawCloseMessage.begin(), 2, rawCloseMessage.size()); }); @@ -1978,7 +1981,10 @@ KJ_TEST("WebSocket missing continuation frame") { 0x01, 0x06, 'w', 'o', 'r', 'l', 'd', '!', // Another start frame }; - auto clientTask = client->write(DATA, sizeof(DATA)); + auto rawCloseMessage = kj::heapArray(129); + auto clientTask = client->write(DATA, sizeof(DATA)).then([&]() { + return client->tryRead(rawCloseMessage.begin(), 2, rawCloseMessage.size()); + }); { bool gotException = false; @@ -1988,7 +1994,8 @@ KJ_TEST("WebSocket missing continuation frame") { KJ_ASSERT(errorCatcher.errors.size() == 1); } - clientTask.wait(waitScope); + auto nread = clientTask.wait(waitScope); + assertContainsWebSocketClose(rawCloseMessage.slice(0, nread), 1002, "Missing continuation frame"_kjc); } KJ_TEST("WebSocket fragmented control frame") { @@ -2003,7 +2010,10 @@ KJ_TEST("WebSocket fragmented control frame") { 0x09, 0x04, 'd', 'a', 't', 'a' // Fragmented ping frame }; - auto clientTask = client->write(DATA, sizeof(DATA)); + auto rawCloseMessage = kj::heapArray(129); + auto clientTask = client->write(DATA, sizeof(DATA)).then([&]() { + return client->tryRead(rawCloseMessage.begin(), 2, rawCloseMessage.size()); + }); { bool gotException = false; @@ -2014,7 +2024,8 @@ KJ_TEST("WebSocket fragmented control frame") { KJ_ASSERT(errorCatcher.errors[0].statusCode == 1002); } - clientTask.wait(waitScope); + auto nread = clientTask.wait(waitScope); + assertContainsWebSocketClose(rawCloseMessage.slice(0, nread), 1002, "Received fragmented control frame"_kjc); } KJ_TEST("WebSocket unknown opcode") { @@ -2029,7 +2040,10 @@ KJ_TEST("WebSocket unknown opcode") { 0x85, 0x04, 'd', 'a', 't', 'a' // 5 is a reserved opcode }; - auto clientTask = client->write(DATA, sizeof(DATA)); + auto rawCloseMessage = kj::heapArray(129); + auto clientTask = client->write(DATA, sizeof(DATA)).then([&]() { + return client->tryRead(rawCloseMessage.begin(), 2, rawCloseMessage.size()); + }); { bool gotException = false; @@ -2040,7 +2054,8 @@ KJ_TEST("WebSocket unknown opcode") { KJ_ASSERT(errorCatcher.errors[0].statusCode == 1002); } - clientTask.wait(waitScope); + auto nread = clientTask.wait(waitScope); + assertContainsWebSocketClose(rawCloseMessage.slice(0, nread), 1002, "Unknown opcode 5"_kjc); } KJ_TEST("WebSocket unsolicited pong") { @@ -2419,6 +2434,7 @@ KJ_TEST("WebSocket maximum message size") { WebSocketErrorCatcher errorCatcher; FakeEntropySource maskGenerator; + auto* rawClient = pipe.ends[0].get(); auto client = newWebSocket(kj::mv(pipe.ends[0]), maskGenerator); auto server = newWebSocket(kj::mv(pipe.ends[1]), kj::none, kj::none, errorCatcher); @@ -2426,9 +2442,12 @@ KJ_TEST("WebSocket maximum message size") { auto biggestAllowedString = kj::strArray(kj::repeat(kj::StringPtr("A"), maxSize), ""); auto tooBigString = kj::strArray(kj::repeat(kj::StringPtr("B"), maxSize + 1), ""); + auto rawCloseMessage = kj::heapArray(129); auto clientTask = client->send(biggestAllowedString) .then([&]() { return client->send(tooBigString); }) - .then([&]() { return client->close(1234, "done"); }); + .then([&]() { + return rawClient->tryRead(rawCloseMessage.begin(), 2, rawCloseMessage.size()); + }); { auto message = server->receive(maxSize).wait(waitScope); @@ -2442,6 +2461,9 @@ KJ_TEST("WebSocket maximum message size") { KJ_ASSERT(errorCatcher.errors.size() == 1); KJ_ASSERT(errorCatcher.errors[0].statusCode == 1009); } + + auto nread = clientTask.wait(waitScope); + assertContainsWebSocketClose(rawCloseMessage.slice(0, nread), 1009, "too large"_kjc); } class TestWebSocketService final: public HttpService, private kj::TaskSet::ErrorHandler { diff --git a/c++/src/kj/compat/http.c++ b/c++/src/kj/compat/http.c++ index 412782d5e9..abd426f00c 100644 --- a/c++/src/kj/compat/http.c++ +++ b/c++/src/kj/compat/http.c++ @@ -2688,40 +2688,28 @@ public: auto& recvHeader = *reinterpret_cast(recvData.begin()); if (recvHeader.hasRsv2or3()) { - return errorHandler.handleWebSocketProtocolError({ - 1002, "Received frame had RSV bits 2 or 3 set", - }); + return sendCloseDueToError(1002, "Received frame had RSV bits 2 or 3 set"); } recvData = recvData.slice(headerSize, recvData.size()); size_t payloadLen = recvHeader.getPayloadLen(); if (payloadLen > maxSize) { - return errorHandler.handleWebSocketProtocolError({ - 1009, kj::str("Message is too large: ", payloadLen, " > ", maxSize) - }); + auto description = kj::str("Message is too large: ", payloadLen, " > ", maxSize); + return sendCloseDueToError(1009, description.asPtr()).attach(kj::mv(description)); } auto opcode = recvHeader.getOpcode(); bool isData = opcode < OPCODE_FIRST_CONTROL; if (opcode == OPCODE_CONTINUATION) { if (fragments.empty()) { - auto paf = newPromiseAndFulfiller(); - queueClose(1002, "Unexpected continuation frame", kj::mv(paf.fulfiller)); - - return paf.promise.then([this]() -> kj::Promise { - return errorHandler.handleWebSocketProtocolError({ - 1002, "Unexpected continuation frame" - }); - }); + return sendCloseDueToError(1002, "Unexpected continuation frame"); } opcode = fragmentOpcode; } else if (isData) { if (!fragments.empty()) { - return errorHandler.handleWebSocketProtocolError({ - 1002, "Missing continuation frame" - }); + return sendCloseDueToError(1002, "Missing continuation frame"); } } @@ -2769,9 +2757,7 @@ public: } else { // Fragmented message, and this isn't the final fragment. if (!isData) { - return errorHandler.handleWebSocketProtocolError({ - 1002, "Received fragmented control frame" - }); + return sendCloseDueToError(1002, "Received fragmented control frame"); } message = kj::heapArray(payloadLen); @@ -2802,11 +2788,10 @@ public: // Provide a reasonable error if a compressed frame is received without compression enabled. if (isCompressed && compressionConfig == kj::none) { - return errorHandler.handleWebSocketProtocolError({ - 1002, kj::str( - "Received a WebSocket frame whose compression bit was set, but the compression " - "extension was not negotiated for this connection.") - }); + return sendCloseDueToError( + 1002, + "Received a WebSocket frame whose compression bit was set, but the compression " + "extension was not negotiated for this connection."); } switch (opcode) { @@ -2879,9 +2864,10 @@ public: // Unsolicited pong. Ignore. return receive(maxSize); default: - return errorHandler.handleWebSocketProtocolError({ - 1002, kj::str("Unknown opcode ", opcode) - }); + { + auto description = kj::str("Unknown opcode ", opcode); + return sendCloseDueToError(1002, description.asPtr()).attach(kj::mv(description)); + } } }; @@ -3398,7 +3384,6 @@ private: bool disconnected = false; bool currentlySending = false; Header sendHeader; - kj::ArrayPtr sendParts[2]; struct ControlMessage { byte opcode; @@ -3462,8 +3447,6 @@ private: KJ_REQUIRE(!disconnected, "WebSocket can't send after disconnect()"); // We recently sent a control message; make sure it's finished before proceeding. - // - // auto localPromise = kj::mv(p); sendingControlMessage = kj::none; co_await localPromise; @@ -3515,6 +3498,7 @@ private: message = ownMessage; } + kj::ArrayPtr sendParts[2]; sendParts[0] = sendHeader.compose(true, useCompression, opcode, message.size(), mask); sendParts[1] = message; KJ_ASSERT(!sendHeader.hasRsv2or3(), "RSV bits 2 and 3 must be 0, as we do not currently " @@ -3557,6 +3541,17 @@ private: return kj::mv(payload); } + kj::Promise sendCloseDueToError(uint16_t code, kj::StringPtr reason){ + auto paf = newPromiseAndFulfiller(); + queueClose(code, reason, kj::mv(paf.fulfiller)); + + return paf.promise.then([this, code, reason]() -> kj::Promise { + return errorHandler.handleWebSocketProtocolError({ + code, reason + }); + }); + } + void queuePong(kj::Array payload) { bool alreadyWaitingForPongWrite = false; @@ -3612,6 +3607,7 @@ private: co_return; } + kj::ArrayPtr sendParts[2]; sendParts[0] = sendHeader.compose(true, false, opcode, payload.size(), Mask(maskKeyGenerator)); sendParts[1] = payload; From dacb2ea3d9fe34d7a26a84d3d31885ab01a63a55 Mon Sep 17 00:00:00 2001 From: Kenton Varda Date: Tue, 16 Jan 2024 17:19:38 -0600 Subject: [PATCH 25/58] Fix spurious `expected !currentlySending` exception. We see this in production, probably after a WebSocket send() has failed due to network errors. I tried to produce a test for this, but after a few hours of fiddling, decided I have more important things to do. I would like to rewrite http-over-capnp entirely... --- c++/src/capnp/compat/http-over-capnp.c++ | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/c++/src/capnp/compat/http-over-capnp.c++ b/c++/src/capnp/compat/http-over-capnp.c++ index ce5d702f27..a10840783e 100644 --- a/c++/src/capnp/compat/http-over-capnp.c++ +++ b/c++/src/capnp/compat/http-over-capnp.c++ @@ -116,7 +116,9 @@ public: shorteningPromise(kj::mv(shorteningPromise)) {} ~CapnpToKjWebSocketAdapter() noexcept(false) { - state->disconnectWebSocket(); + if (clean) { + state->disconnectWebSocket(); + } } kj::Maybe> shortenPath() override { @@ -128,20 +130,34 @@ public: } kj::Promise sendText(SendTextContext context) override { - return state->wrap([&]() { return webSocket.send(context.getParams().getText()); }); + KJ_ASSERT(clean); // should be guaranteed by streaming semantics + clean = false; + co_await state->wrap([&]() { return webSocket.send(context.getParams().getText()); }); + clean = true; } kj::Promise sendData(SendDataContext context) override { - return state->wrap([&]() { return webSocket.send(context.getParams().getData()); }); + KJ_ASSERT(clean); // should be guaranteed by streaming semantics + clean = false; + co_await state->wrap([&]() { return webSocket.send(context.getParams().getData()); }); + clean = true; } kj::Promise close(CloseContext context) override { + KJ_ASSERT(clean); // should be guaranteed by streaming semantics auto params = context.getParams(); - return state->wrap([&]() { return webSocket.close(params.getCode(), params.getReason()); }); + clean = false; + co_await state->wrap([&]() { return webSocket.close(params.getCode(), params.getReason()); }); + clean = true; } private: kj::Own state; kj::WebSocket& webSocket; kj::Promise shorteningPromise; + + bool clean = true; + // It's illegal to call another `send()` or `disconnect()` until the previous `send()` has + // completed successfully. We want to send `disconnect()` in the destructor but only if we can + // do so cleanly. }; class HttpOverCapnpFactory::KjToCapnpWebSocketAdapter final: public kj::WebSocket { From 40a6be7b023c2181d705a94f321ba7cc694d50c8 Mon Sep 17 00:00:00 2001 From: Kenton Varda Date: Tue, 16 Jan 2024 17:29:01 -0600 Subject: [PATCH 26/58] On WebSocket handshake errors, throw DISCONNECTED at server. This should cause us to stop logging spurious sentry issues for client errors. --- c++/src/kj/compat/http.c++ | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/c++/src/kj/compat/http.c++ b/c++/src/kj/compat/http.c++ index 63e19234ad..706dacab4b 100644 --- a/c++/src/kj/compat/http.c++ +++ b/c++/src/kj/compat/http.c++ @@ -7941,7 +7941,12 @@ private: } kj::Own sendWebSocketError(StringPtr errorMessage) { - kj::Exception exception = KJ_EXCEPTION(FAILED, + // The client committed a protocol error during a WebSocket handshake. We will send an error + // response back to them, and throw an exception from `acceptWebSocket()` to our app. We'll + // label this as a DISCONNECTED exception, as if the client had simply closed the connection + // rather than commiting a protocol error. This is intended to let the server know that it + // wasn't an error on the server's part. (This is a big of a hack...) + kj::Exception exception = KJ_EXCEPTION(DISCONNECTED, "received bad WebSocket handshake", errorMessage); webSocketError = sendError( HttpHeaders::ProtocolError { 400, "Bad Request", errorMessage, nullptr }); From 044300cd5f1cf10224290a9479b39fbdd4286446 Mon Sep 17 00:00:00 2001 From: Felix Hanau Date: Mon, 22 Jan 2024 18:58:59 -0500 Subject: [PATCH 27/58] Update to zlib 1.3.1, switch to fetching from GitHub The zlib.net host eagerly deletes older zlib releases as they become obsolete, this can cause unexpected build failures whenever there is a new zlib version. --- c++/WORKSPACE | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/c++/WORKSPACE b/c++/WORKSPACE index d3012da069..d58ffdbd14 100644 --- a/c++/WORKSPACE +++ b/c++/WORKSPACE @@ -55,7 +55,7 @@ cc_library( http_archive( name = "zlib", build_file_content = _zlib_build, - sha256 = "8a9ba2898e1d0d774eca6ba5b4627a11e5588ba85c8851336eb38de4683050a7", - strip_prefix = "zlib-1.3", - urls = ["https://zlib.net/zlib-1.3.tar.xz"], + sha256 = "38ef96b8dfe510d42707d9c781877914792541133e1870841463bfa73f883e32", + strip_prefix = "zlib-1.3.1", + urls = ["https://zlib.net/zlib-1.3.1.tar.xz"], ) From 7ac227cc35cb76b57b2235b77a22e033ecb5e27f Mon Sep 17 00:00:00 2001 From: Kenton Varda Date: Tue, 23 Jan 2024 19:44:18 -0600 Subject: [PATCH 28/58] Fully fix exception-masking bug in async pipe. This is essentially the same as #1859 but extended to the two other `teeException*` helpers which I for some reason failed to notice when creating the original fix. --- c++/src/kj/async-io.c++ | 34 +++++++++++++++++++++++++--------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/c++/src/kj/async-io.c++ b/c++/src/kj/async-io.c++ index 45140abfa1..8761f76936 100644 --- a/c++/src/kj/async-io.c++ +++ b/c++/src/kj/async-io.c++ @@ -401,19 +401,35 @@ private: } template - static auto teeExceptionVoid(F& fulfiller) { + static auto teeExceptionVoid(F& fulfiller, Canceler& canceler) { // Returns a functor that can be passed as the second parameter to .then() to propagate the // exception to a given fulfiller. The functor's return type is void. - return [&fulfiller](kj::Exception&& e) { + // + // All use cases of this helper below are also wrapped in `canceler.wrap()`, and fulfilling + // `fulfiller` may cause the canceler to be canceled. It's possible the canceler will be + // canceled before the exception even gets a chance to propagate out of the wrapped promise, + // which would have the effet of replacing the original exception with a non-useful + // "operation canceled" exception. To avoid this, we must release the canceler before + // fulfilling the fulfiller. + return [&fulfiller, &canceler](kj::Exception&& e) { + canceler.release(); fulfiller.reject(kj::cp(e)); kj::throwRecoverableException(kj::mv(e)); }; } template - static auto teeExceptionSize(F& fulfiller) { + static auto teeExceptionSize(F& fulfiller, Canceler& canceler) { // Returns a functor that can be passed as the second parameter to .then() to propagate the // exception to a given fulfiller. The functor's return type is size_t. - return [&fulfiller](kj::Exception&& e) -> size_t { + // + // All use cases of this helper below are also wrapped in `canceler.wrap()`, and fulfilling + // `fulfiller` may cause the canceler to be canceled. It's possible the canceler will be + // canceled before the exception even gets a chance to propagate out of the wrapped promise, + // which would have the effet of replacing the original exception with a non-useful + // "operation canceled" exception. To avoid this, we must release the canceler before + // fulfilling the fulfiller. + return [&fulfiller, &canceler](kj::Exception&& e) -> size_t { + canceler.release(); fulfiller.reject(kj::cp(e)); kj::throwRecoverableException(kj::mv(e)); return 0; @@ -576,7 +592,7 @@ private: writeBuffer = writeBuffer.slice(amount, writeBuffer.size()); // We pumped the full amount, so we're done pumping. return amount; - }, teeExceptionSize(fulfiller))); + }, teeExceptionSize(fulfiller, canceler))); } // First piece doesn't cover the whole pump. Figure out how many more pieces to add. @@ -630,7 +646,7 @@ private: morePieces = newMorePieces; canceler.release(); return amount; - }, teeExceptionSize(fulfiller))); + }, teeExceptionSize(fulfiller, canceler))); } } @@ -807,7 +823,7 @@ private: // Completed entire pumpTo amount. KJ_ASSERT(actual == amount2); return amount2; - }, teeExceptionSize(fulfiller))); + }, teeExceptionSize(fulfiller, canceler))); } void abortRead() override { @@ -1263,7 +1279,7 @@ private: canceler.release(); fulfiller.fulfill(kj::cp(amount)); pipe.endState(*this); - }, teeExceptionVoid(fulfiller))); + }, teeExceptionVoid(fulfiller, canceler))); } auto remainder = pieces.slice(i, pieces.size()); @@ -1292,7 +1308,7 @@ private: fulfiller.fulfill(kj::cp(amount)); pipe.endState(*this); } - }, teeExceptionVoid(fulfiller))); + }, teeExceptionVoid(fulfiller, canceler))); } Promise writeWithFds(ArrayPtr data, From a12c15a7cd0ae5121dca6216f529fb15fdee238c Mon Sep 17 00:00:00 2001 From: Mike Aizatsky Date: Fri, 26 Jan 2024 08:52:33 -0800 Subject: [PATCH 29/58] Revert "Revert "gracefully handling http 431 error"" (#1920) Relanding #1828 This reverts commit 4cb0acf98577d8c29921474fcc8a4eec9111314c. --- c++/src/kj/compat/http-test.c++ | 139 +++++++++++++++++++++----------- c++/src/kj/compat/http.c++ | 109 +++++++++++++++++++------ 2 files changed, 174 insertions(+), 74 deletions(-) diff --git a/c++/src/kj/compat/http-test.c++ b/c++/src/kj/compat/http-test.c++ index 428f694a81..4849fc237a 100644 --- a/c++/src/kj/compat/http-test.c++ +++ b/c++/src/kj/compat/http-test.c++ @@ -43,8 +43,8 @@ #else // Run the test using in-process two-way pipes. #define KJ_HTTP_TEST_SETUP_IO \ - kj::EventLoop eventLoop; \ - kj::WaitScope waitScope(eventLoop) + auto io = kj::setupAsyncIo(); \ + auto& waitScope KJ_UNUSED = io.waitScope #define KJ_HTTP_TEST_SETUP_LOOPBACK_LISTENER_AND_ADDR \ auto capPipe = newCapabilityPipe(); \ auto listener = kj::heap(*capPipe.ends[0]); \ @@ -4158,60 +4158,103 @@ KJ_TEST("HttpServer threw exception") { KJ_EXPECT(text.startsWith("HTTP/1.1 500 Internal Server Error"), text); } -KJ_TEST("HttpServer bad request") { - KJ_HTTP_TEST_SETUP_IO; - kj::TimerImpl timer(kj::origin()); - auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; - - HttpHeaderTable table; - BrokenHttpService service; - HttpServer server(timer, table, service); - - auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); - - static constexpr auto request = "GET / HTTP/1.1\r\nbad request\r\n\r\n"_kj; - auto writePromise = pipe.ends[1]->write(request.begin(), request.size()); - auto response = pipe.ends[1]->readAllText().wait(waitScope); - KJ_EXPECT(writePromise.poll(waitScope)); - writePromise.wait(waitScope); +KJ_TEST("HttpServer bad requests") { + struct TestCase { + kj::StringPtr request; + kj::StringPtr expectedResponse; + bool expectWriteError; + }; - static constexpr auto expectedResponse = - "HTTP/1.1 400 Bad Request\r\n" - "Connection: close\r\n" - "Content-Length: 53\r\n" - "Content-Type: text/plain\r\n" - "\r\n" - "ERROR: The headers sent by your client are not valid."_kj; + static auto hugeHeaderRequest = kj::str( + "GET /foo/bar HTTP/1.1\r\n", + "Host: ", kj::strArray(kj::repeat("0", 1024 * 1024), ""), "\r\n", + "\r\n"); - KJ_EXPECT(expectedResponse == response, expectedResponse, response); -} + static TestCase testCases[] { + { + // bad request + .request = "GET / HTTP/1.1\r\nbad request\r\n\r\n"_kj, + .expectedResponse = + "HTTP/1.1 400 Bad Request\r\n" + "Connection: close\r\n" + "Content-Length: 53\r\n" + "Content-Type: text/plain\r\n" + "\r\n" + "ERROR: The headers sent by your client are not valid."_kj + }, + { + // invalid method + .request = "bad request\r\n\r\n"_kj, + .expectedResponse = + "HTTP/1.1 501 Not Implemented\r\n" + "Connection: close\r\n" + "Content-Length: 35\r\n" + "Content-Type: text/plain\r\n" + "\r\n" + "ERROR: Unrecognized request method."_kj + }, + { + // broken service generates 5000 + .request = + "GET /foo/bar HTTP/1.1\r\n" + "Host: example.com\r\n" + "\r\n"_kj, + .expectedResponse = + "HTTP/1.1 500 Internal Server Error\r\n" + "Connection: close\r\n" + "Content-Length: 51\r\n" + "Content-Type: text/plain\r\n" + "\r\n" + "ERROR: The HttpService did not generate a response."_kj, + }, + { + // huge header shouldn't break the server + .request = hugeHeaderRequest, + .expectedResponse = + "HTTP/1.1 431 Request Header Fields Too Large\r\n" + "Connection: close\r\n" + "Content-Length: 24\r\n" + "Content-Type: text/plain\r\n" + "\r\n" + "ERROR: header too large."_kj, + .expectWriteError = true, + }, + }; -KJ_TEST("HttpServer invalid method") { KJ_HTTP_TEST_SETUP_IO; - kj::TimerImpl timer(kj::origin()); - auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + // we need a real timer to test http server grace behavior. + auto& timer = io.provider->getTimer(); - HttpHeaderTable table; - BrokenHttpService service; - HttpServer server(timer, table, service); + for (auto testCase : testCases) { + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; - auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); - - static constexpr auto request = "bad request\r\n\r\n"_kj; - auto writePromise = pipe.ends[1]->write(request.begin(), request.size()); - auto response = pipe.ends[1]->readAllText().wait(waitScope); - KJ_EXPECT(writePromise.poll(waitScope)); - writePromise.wait(waitScope); + HttpHeaderTable table; + BrokenHttpService service; + HttpServer server(timer, table, service, { + .canceledUploadGraceBytes = 1024 * 1024, + }); - static constexpr auto expectedResponse = - "HTTP/1.1 501 Not Implemented\r\n" - "Connection: close\r\n" - "Content-Length: 35\r\n" - "Content-Type: text/plain\r\n" - "\r\n" - "ERROR: Unrecognized request method."_kj; + auto listenTask = server.listenHttp(kj::mv(pipe.ends[0])); + + auto request = testCase.request; + auto writePromise = pipe.ends[1]->write(request.begin(), request.size()); + try { + auto response = pipe.ends[1]->readAllText().wait(waitScope); + auto expectedResponse = testCase.expectedResponse; + KJ_EXPECT(expectedResponse == response, expectedResponse, response); + } catch (...) { + auto ex = kj::getCaughtExceptionAsKj(); + KJ_FAIL_REQUIRE("not supposed to happen", ex); + } - KJ_EXPECT(expectedResponse == response, expectedResponse, response); + // write promise should have been resolved already + KJ_EXPECT(writePromise.poll(waitScope)); + try { + writePromise.wait(waitScope); + } catch (...) { + KJ_EXPECT(testCase.expectWriteError, "write error wasn't expected"); + } + } } // Ensure that HttpServerSettings can continue to be constexpr. diff --git a/c++/src/kj/compat/http.c++ b/c++/src/kj/compat/http.c++ index 706dacab4b..601ca0900e 100644 --- a/c++/src/kj/compat/http.c++ +++ b/c++/src/kj/compat/http.c++ @@ -1453,7 +1453,9 @@ public: } kj::Promise readMessage() override { - auto text = co_await readMessageHeaders(); + auto textOrError = co_await readMessageHeaders(); + KJ_REQUIRE(textOrError.is>(), "bad message"); + auto text = textOrError.get>(); headers.clear(); KJ_REQUIRE(headers.tryParse(text), "bad message"); auto body = getEntityBody(HttpInputStreamImpl::RESPONSE, HttpMethod::GET, 0, headers); @@ -1528,7 +1530,7 @@ public: return !lineBreakBeforeNextHeader && leftover == nullptr; } - kj::Promise> readMessageHeaders() { + kj::Promise, HttpHeaders::ProtocolError>> readMessageHeaders() { ++pendingMessageCount; auto paf = kj::newPromiseAndFulfiller(); @@ -1541,28 +1543,38 @@ public: co_return co_await readHeader(HeaderType::MESSAGE, 0, 0); } - kj::Promise readChunkHeader() { + kj::Promise> readChunkHeader() { KJ_REQUIRE(onMessageDone != kj::none); // We use the portion of the header after the end of message headers. - auto text = co_await readHeader(HeaderType::CHUNK, messageHeaderEnd, messageHeaderEnd); - KJ_REQUIRE(text.size() > 0) { break; } - - uint64_t value = 0; - for (char c: text) { - if ('0' <= c && c <= '9') { - value = value * 16 + (c - '0'); - } else if ('a' <= c && c <= 'f') { - value = value * 16 + (c - 'a' + 10); - } else if ('A' <= c && c <= 'F') { - value = value * 16 + (c - 'A' + 10); - } else { - KJ_FAIL_REQUIRE("invalid HTTP chunk size", text, text.asBytes()) { break; } + auto textOrError = co_await readHeader(HeaderType::CHUNK, messageHeaderEnd, messageHeaderEnd); + + KJ_SWITCH_ONEOF(textOrError) { + KJ_CASE_ONEOF(protocolError, HttpHeaders::ProtocolError) { + co_return protocolError; + } + KJ_CASE_ONEOF(text, kj::ArrayPtr) { + KJ_REQUIRE(text.size() > 0) { break; } + + uint64_t value = 0; + for (char c: text) { + if ('0' <= c && c <= '9') { + value = value * 16 + (c - '0'); + } else if ('a' <= c && c <= 'f') { + value = value * 16 + (c - 'a' + 10); + } else if ('A' <= c && c <= 'F') { + value = value * 16 + (c - 'A' + 10); + } else { + KJ_FAIL_REQUIRE("invalid HTTP chunk size", text, text.asBytes()) { break; } + co_return value; + } + } + co_return value; } } - co_return value; + KJ_UNREACHABLE; } inline kj::Promise readRequestHeaders() { @@ -1571,18 +1583,36 @@ public: co_return HttpHeaders::RequestConnectOrProtocolError(resuming); } - auto text = co_await readMessageHeaders(); - headers.clear(); - co_return headers.tryParseRequestOrConnect(text); + auto textOrError = co_await readMessageHeaders(); + KJ_SWITCH_ONEOF(textOrError) { + KJ_CASE_ONEOF(protocolError, HttpHeaders::ProtocolError) { + co_return protocolError; + } + KJ_CASE_ONEOF(text, kj::ArrayPtr) { + headers.clear(); + co_return headers.tryParseRequestOrConnect(text); + } + } + + KJ_UNREACHABLE; } inline kj::Promise readResponseHeaders() { // Note: readResponseHeaders() could be called multiple times concurrently when pipelining // requests. readMessageHeaders() will serialize these, but it's important not to mess with // state (like calling headers.clear()) before said serialization has taken place. - auto text = co_await readMessageHeaders(); - headers.clear(); - co_return headers.tryParseResponse(text); + auto headersOrError = co_await readMessageHeaders(); + KJ_SWITCH_ONEOF(headersOrError) { + KJ_CASE_ONEOF(protocolError, HttpHeaders::ProtocolError) { + co_return protocolError; + } + KJ_CASE_ONEOF(text, kj::ArrayPtr) { + headers.clear(); + co_return headers.tryParseResponse(text); + } + } + + KJ_UNREACHABLE; } inline const HttpHeaders& getHeaders() const { return headers; } @@ -1637,6 +1667,11 @@ public: return { headerBuffer.releaseAsBytes(), leftover.asBytes() }; } + kj::Promise discard(AsyncOutputStream &output, size_t maxBytes) { + // Used to read and discard the input during error handling. + return inner.pumpTo(output, maxBytes).ignoreResult(); + } + private: AsyncInputStream& inner; kj::Array headerBuffer; @@ -1685,7 +1720,7 @@ private: CHUNK }; - kj::Promise> readHeader( + kj::Promise, HttpHeaders::ProtocolError>> readHeader( HeaderType type, size_t bufferStart, size_t bufferEnd) { // Reads the HTTP message header or a chunk header (as in transfer-encoding chunked) and // returns the buffer slice containing it. @@ -1731,7 +1766,12 @@ private: // Can't grow because we'd invalidate the HTTP headers. kj::throwFatalException(KJ_EXCEPTION(FAILED, "invalid HTTP chunk size")); } - KJ_REQUIRE(headerBuffer.size() < MAX_BUFFER, "request headers too large"); + if (headerBuffer.size() >= MAX_BUFFER) { + co_return HttpHeaders::ProtocolError { + .statusCode = 431, + .statusMessage = "Request Header Fields Too Large", + .description = "header too large." }; + } auto newBuffer = kj::heapArray(headerBuffer.size() * 2); memcpy(newBuffer.begin(), headerBuffer.begin(), headerBuffer.size()); headerBuffer = kj::mv(newBuffer); @@ -2013,7 +2053,9 @@ public: co_return alreadyRead; } else if (chunkSize == 0) { // Read next chunk header. - auto nextChunkSize = co_await getInner().readChunkHeader(); + auto nextChunkSizeOrError = co_await getInner().readChunkHeader(); + KJ_REQUIRE(nextChunkSizeOrError.is(), "bad header"); + auto nextChunkSize = nextChunkSizeOrError.get(); if (nextChunkSize == 0) { doneReading(); } @@ -7532,6 +7574,21 @@ private: KJ_CASE_ONEOF(protocolError, HttpHeaders::ProtocolError) { // Bad request. + auto needClientGrace = protocolError.statusCode == 431; + if (needClientGrace) { + // We're going to reply with an error and close the connection. + // The client might not be able to read the error back. Read some data and wait + // a bit to give client a chance to finish writing. + + auto dummy = kj::heap(); + auto lengthGrace = kj::evalNow([&]() { + return httpInput.discard(*dummy, server.settings.canceledUploadGraceBytes); + }).catch_([](kj::Exception&& e) -> void { }) + .attach(kj::mv(dummy)); + auto timeGrace = server.timer.afterDelay(server.settings.canceledUploadGracePeriod); + co_await lengthGrace.exclusiveJoin(kj::mv(timeGrace)); + } + // sendError() uses Response::send(), which requires that we have a currentMethod, but we // never read one. GET seems like the correct choice here. currentMethod = HttpMethod::GET; From aaed938ba2eba4ceebfb0706fa5249c6f78f8e7c Mon Sep 17 00:00:00 2001 From: Kenton Varda Date: Fri, 26 Jan 2024 14:27:13 -0600 Subject: [PATCH 30/58] Fix whitespace: Replace recently-introduced tabs with spaces. --- c++/src/kj/compat/http-test.c++ | 2 +- c++/src/kj/compat/http.c++ | 58 ++++++++++++++++----------------- 2 files changed, 30 insertions(+), 30 deletions(-) diff --git a/c++/src/kj/compat/http-test.c++ b/c++/src/kj/compat/http-test.c++ index 4849fc237a..03e3647921 100644 --- a/c++/src/kj/compat/http-test.c++ +++ b/c++/src/kj/compat/http-test.c++ @@ -2446,7 +2446,7 @@ KJ_TEST("WebSocket maximum message size") { auto clientTask = client->send(biggestAllowedString) .then([&]() { return client->send(tooBigString); }) .then([&]() { - return rawClient->tryRead(rawCloseMessage.begin(), 2, rawCloseMessage.size()); + return rawClient->tryRead(rawCloseMessage.begin(), 2, rawCloseMessage.size()); }); { diff --git a/c++/src/kj/compat/http.c++ b/c++/src/kj/compat/http.c++ index 601ca0900e..6df3f69d09 100644 --- a/c++/src/kj/compat/http.c++ +++ b/c++/src/kj/compat/http.c++ @@ -2745,13 +2745,13 @@ public: bool isData = opcode < OPCODE_FIRST_CONTROL; if (opcode == OPCODE_CONTINUATION) { if (fragments.empty()) { - return sendCloseDueToError(1002, "Unexpected continuation frame"); + return sendCloseDueToError(1002, "Unexpected continuation frame"); } opcode = fragmentOpcode; } else if (isData) { if (!fragments.empty()) { - return sendCloseDueToError(1002, "Missing continuation frame"); + return sendCloseDueToError(1002, "Missing continuation frame"); } } @@ -2799,7 +2799,7 @@ public: } else { // Fragmented message, and this isn't the final fragment. if (!isData) { - return sendCloseDueToError(1002, "Received fragmented control frame"); + return sendCloseDueToError(1002, "Received fragmented control frame"); } message = kj::heapArray(payloadLen); @@ -2830,10 +2830,10 @@ public: // Provide a reasonable error if a compressed frame is received without compression enabled. if (isCompressed && compressionConfig == kj::none) { - return sendCloseDueToError( - 1002, - "Received a WebSocket frame whose compression bit was set, but the compression " - "extension was not negotiated for this connection."); + return sendCloseDueToError( + 1002, + "Received a WebSocket frame whose compression bit was set, but the compression " + "extension was not negotiated for this connection."); } switch (opcode) { @@ -2906,10 +2906,10 @@ public: // Unsolicited pong. Ignore. return receive(maxSize); default: - { - auto description = kj::str("Unknown opcode ", opcode); - return sendCloseDueToError(1002, description.asPtr()).attach(kj::mv(description)); - } + { + auto description = kj::str("Unknown opcode ", opcode); + return sendCloseDueToError(1002, description.asPtr()).attach(kj::mv(description)); + } } }; @@ -3433,9 +3433,9 @@ private: kj::Maybe>> fulfiller; ControlMessage( - byte opcodeParam, - kj::Array payloadParam, - kj::Maybe>> fulfillerParam) + byte opcodeParam, + kj::Array payloadParam, + kj::Maybe>> fulfillerParam) : opcode(opcodeParam), payload(kj::mv(payloadParam)), fulfiller(kj::mv(fulfillerParam)) { KJ_REQUIRE(opcode <= OPCODE_MAX); } @@ -3485,8 +3485,8 @@ private: for (;;) { KJ_IF_SOME(p, sendingControlMessage) { - // Re-check in case of disconnect on a previous loop iteration. - KJ_REQUIRE(!disconnected, "WebSocket can't send after disconnect()"); + // Re-check in case of disconnect on a previous loop iteration. + KJ_REQUIRE(!disconnected, "WebSocket can't send after disconnect()"); // We recently sent a control message; make sure it's finished before proceeding. auto localPromise = kj::mv(p); @@ -3544,7 +3544,7 @@ private: sendParts[0] = sendHeader.compose(true, useCompression, opcode, message.size(), mask); sendParts[1] = message; KJ_ASSERT(!sendHeader.hasRsv2or3(), "RSV bits 2 and 3 must be 0, as we do not currently " - "support an extension that would set these bits"); + "support an extension that would set these bits"); co_await stream->write(sendParts); currentlySending = false; @@ -3588,9 +3588,9 @@ private: queueClose(code, reason, kj::mv(paf.fulfiller)); return paf.promise.then([this, code, reason]() -> kj::Promise { - return errorHandler.handleWebSocketProtocolError({ - code, reason - }); + return errorHandler.handleWebSocketProtocolError({ + code, reason + }); }); } @@ -3599,12 +3599,12 @@ private: KJ_IF_SOME(controlMessage, queuedControlMessage) { if (controlMessage.opcode == OPCODE_CLOSE) { - // We're currently sending a Close message, which we only do (at least via queuedControlMessage) when we're - // closing the connection due to error. There's no point queueing a Pong that'll never be sent. - return; + // We're currently sending a Close message, which we only do (at least via queuedControlMessage) when we're + // closing the connection due to error. There's no point queueing a Pong that'll never be sent. + return; } else { - KJ_ASSERT(controlMessage.opcode == OPCODE_PONG); - alreadyWaitingForPongWrite = true; + KJ_ASSERT(controlMessage.opcode == OPCODE_PONG); + alreadyWaitingForPongWrite = true; } } @@ -3643,9 +3643,9 @@ private: queuedControlMessage = kj::none; if (hasSentClose || disconnected) { - KJ_IF_SOME(fulfiller, maybeFulfiller) { - fulfiller->fulfill(); - } + KJ_IF_SOME(fulfiller, maybeFulfiller) { + fulfiller->fulfill(); + } co_return; } @@ -3655,7 +3655,7 @@ private: sendParts[1] = payload; co_await stream->write(sendParts); KJ_IF_SOME(fulfiller, maybeFulfiller) { - fulfiller->fulfill(); + fulfiller->fulfill(); } } } From 33e7c103389eff6d04161f5c28c063a078e80bd6 Mon Sep 17 00:00:00 2001 From: Mike Aizatsky Date: Wed, 24 Jan 2024 11:33:59 -0800 Subject: [PATCH 31/58] kj::Rc reference counting smart pointer --- c++/src/kj/refcount-test.c++ | 95 ++++++++++++++++++++++++- c++/src/kj/refcount.h | 132 +++++++++++++++++++++++++++++++++++ 2 files changed, 226 insertions(+), 1 deletion(-) diff --git a/c++/src/kj/refcount-test.c++ b/c++/src/kj/refcount-test.c++ index 6666615567..65fdf48cd2 100644 --- a/c++/src/kj/refcount-test.c++ +++ b/c++/src/kj/refcount-test.c++ @@ -24,7 +24,7 @@ namespace kj { -struct SetTrueInDestructor: public Refcounted { +struct SetTrueInDestructor: public Refcounted, EnableAddRefToThis { SetTrueInDestructor(bool* ptr): ptr(ptr) {} ~SetTrueInDestructor() { *ptr = true; } @@ -57,6 +57,99 @@ TEST(Refcount, Basic) { #endif } +KJ_TEST("Rc") { + bool b = false; + + Rc ref1 = kj::rc(&b); + EXPECT_FALSE(ref1->isShared()); + EXPECT_TRUE(ref1 != nullptr); + EXPECT_FALSE(ref1 == nullptr); + + Rc ref2 = ref1.addRef(); + EXPECT_TRUE(ref1->isShared()); + EXPECT_TRUE(ref1 == ref2); + + { + Rc ref3 = ref2.addRef(); + EXPECT_TRUE(ref3->isShared()); + // ref3 is dropped + } + + EXPECT_FALSE(b); + + // start dropping references one by one + + EXPECT_TRUE(ref2->isShared()); + ref1 = nullptr; + EXPECT_TRUE(ref1 == nullptr); + EXPECT_FALSE(ref2->isShared()); + EXPECT_FALSE(b); + EXPECT_FALSE(ref1 == ref2); + + ref2 = nullptr; + EXPECT_TRUE(ref1 == ref2); + + // last reference dropped, SetTrueInDestructor destructor should execute + EXPECT_TRUE(b); +} + +KJ_TEST("Rc Own interop") { + bool b = false; + + Rc ref1 = kj::rc(&b); + + EXPECT_FALSE(b); + auto own = ref1.toOwn(); + EXPECT_TRUE(ref1 == nullptr); + EXPECT_TRUE(own.get() != nullptr); + + EXPECT_FALSE(b); + own = nullptr; + EXPECT_TRUE(b); +} + +struct Child: public SetTrueInDestructor { + Child(bool* ptr): SetTrueInDestructor(ptr) {} +}; + +KJ_TEST("Rc inheritance") { + bool b = false; + + auto child = kj::rc(&b); + + // up casting works automatically + kj::Rc parent = child.addRef(); + + auto down = parent.downcast(); + EXPECT_TRUE(parent == nullptr); + EXPECT_TRUE(down != nullptr); + + EXPECT_FALSE(b); + child = nullptr; + EXPECT_FALSE(b); + down = nullptr; + EXPECT_TRUE(b); +} + +KJ_TEST("EnableAddRefToThis") { + bool b = false; + + auto ref1 = kj::rc(&b); + EXPECT_FALSE(ref1->isShared()); + + auto ref2 = ref1.addRef(); + EXPECT_TRUE(ref2->isShared()); + EXPECT_TRUE(ref1->isShared()); + EXPECT_FALSE(b); + + ref1 = nullptr; + EXPECT_FALSE(ref2->isShared()); + EXPECT_FALSE(b); + + ref2 = nullptr; + EXPECT_TRUE(b); +} + struct SetTrueInDestructor2 { // Like above but doesn't inherit Refcounted. diff --git a/c++/src/kj/refcount.h b/c++/src/kj/refcount.h index 03b5234d8d..2735e2a877 100644 --- a/c++/src/kj/refcount.h +++ b/c++/src/kj/refcount.h @@ -38,6 +38,12 @@ namespace kj { // ======================================================================================= // Non-atomic (thread-unsafe) refcounting +template +class Rc; + +template +class EnableAddRefToThis; + class Refcounted: private Disposer { // Subclass this to create a class that contains a reference count. Then, use // `kj::refcounted()` to allocate a new refcounted pointer. @@ -78,9 +84,13 @@ class Refcounted: private Disposer { // "mutable" because disposeImpl() is const. Bleh. void disposeImpl(void* pointer) const override; + template static Own addRefInternal(T* object); + template + static Rc addRcRefInternal(T* object); + template friend Own addRef(T& object); template @@ -88,6 +98,15 @@ class Refcounted: private Disposer { template friend class RefcountedWrapper; + + template + friend Rc rc(Params&&... params); + + template + friend class Rc; + + template + friend class EnableAddRefToThis; }; template @@ -98,6 +117,14 @@ inline Own refcounted(Params&&... params) { return Refcounted::addRefInternal(new T(kj::fwd(params)...)); } +template +inline Rc rc(Params&&... params) { + // Allocate a new refcounted instance of T, passing `params` to its constructor. + // Returns smart pointer that can be used to manage references. + + return Refcounted::addRcRefInternal(new T(kj::fwd(params)...)); +} + template Own addRef(T& object) { // Return a new reference to `object`, which must subclass Refcounted and have been allocated @@ -115,6 +142,111 @@ Own Refcounted::addRefInternal(T* object) { return Own(object, *refcounted); } +template +Rc Refcounted::addRcRefInternal(T* object) { + Refcounted* refcounted = object; + ++refcounted->refcount; + return Rc(object); +} + +template +class Rc { + // Smart pointer for reference counted objects. + // + // There are only two ways to obtain new Rc instances: + // - use kj::rc(...) function to create new T. + // - use kj::Rc::addRef() and the existing Rc instance. + // - use EnableAddRefToThis to allow T instance to add new references to itself. + // + // Suggested usage patterns are: + // - return kj::Rc as value from factory functions: + // kj::Rc createMyService(); + // - pass kj::Rc as rvalue to functions that need to extend T's lifetime: + // void setMyService(kj::Rc&& service) + // - store kj::Rc as data member: + // struct MyComputation { kj::Rc service; }; + // - use toOwn to convert kj::Rc instance to kj::Own and use it + // without being concerned of reference counting behavior. + // To improve the transparency of the code, kj::Own shouldn't be used + // to call addRef() without kj::Rc. + +public: + KJ_DISALLOW_COPY(Rc); + Rc() { } + Rc(decltype(nullptr)) { } + inline Rc(Rc&& other) noexcept = default; + + template ()>> + inline Rc(Rc&& other) noexcept : own(kj::mv(other.own)) { } + + kj::Own toOwn() { + // Convert Rc to Own. + // Nullifies the original Rc. + return kj::mv(own); + } + + kj::Rc addRef() { + T* refcounted = own.get(); + if (refcounted != nullptr) { + return Refcounted::addRcRefInternal(refcounted); + } else { + return kj::Rc(); + } + } + + Rc& operator=(decltype(nullptr)) { + own = nullptr; + return *this; + } + + Rc& operator=(Rc&& other) = default; + + template + Rc downcast() { + return Rc(own.template downcast()); + } + + inline bool operator==(const Rc& other) const { return own.get() == other.own.get(); } + inline bool operator==(decltype(nullptr)) const { return own.get() == nullptr; } + inline bool operator!=(decltype(nullptr)) const { return own.get() != nullptr; } + + inline T* operator->() { return own.get(); } + inline const T* operator->() const { return own.get(); } + + // do not expose * to avoid dangling references + +private: + Rc(T* t) : own(t, *t) { } + Rc(Own&& t) : own(kj::mv(t)) { } + + Own own; + + friend class Refcounted; + + template + friend class Rc; + + template + friend class EnableAddRefToThis; +}; + +template +class EnableAddRefToThis { + // Exposes addRefToThis member function for objects to add + // references to themselves. + +protected: + kj::Rc addRefToThis() const { + const Self* self = static_cast(this); + return Refcounted::addRcRefInternal(self); + } + + kj::Rc addRefToThis() { + Self* self = static_cast(this); + return Refcounted::addRcRefInternal(self); + } +}; + template class RefcountedWrapper: public Refcounted { // Adds refcounting as a wrapper around an existing type, allowing you to construct references From 05fb0778987a58ba1649321eca3e1bb819736ae0 Mon Sep 17 00:00:00 2001 From: Mike Aizatsky Date: Wed, 24 Jan 2024 11:37:08 -0800 Subject: [PATCH 32/58] example use of kj::Rc in http.c++ --- c++/src/kj/compat/http.c++ | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/c++/src/kj/compat/http.c++ b/c++/src/kj/compat/http.c++ index 601ca0900e..5cdda22070 100644 --- a/c++/src/kj/compat/http.c++ +++ b/c++/src/kj/compat/http.c++ @@ -4376,7 +4376,7 @@ private: class WebSocketPipeEnd final: public WebSocket { public: - WebSocketPipeEnd(kj::Own in, kj::Own out) + WebSocketPipeEnd(kj::Rc&& in, kj::Rc&& out) : in(kj::mv(in)), out(kj::mv(out)) {} ~WebSocketPipeEnd() noexcept(false) { in->abort(); @@ -4417,17 +4417,17 @@ public: uint64_t receivedByteCount() override { return in->sentByteCount(); } private: - kj::Own in; - kj::Own out; + kj::Rc in; + kj::Rc out; }; } // namespace WebSocketPipe newWebSocketPipe() { - auto pipe1 = kj::refcounted(); - auto pipe2 = kj::refcounted(); + auto pipe1 = kj::rc(); + auto pipe2 = kj::rc(); - auto end1 = kj::heap(kj::addRef(*pipe1), kj::addRef(*pipe2)); + auto end1 = kj::heap(pipe1.addRef(), pipe2.addRef()); auto end2 = kj::heap(kj::mv(pipe2), kj::mv(pipe1)); return { { kj::mv(end1), kj::mv(end2) } }; From 6c572fbe3134d75232b9725676cc4a8be5ffcb9d Mon Sep 17 00:00:00 2001 From: Milan Miladinovic Date: Thu, 21 Dec 2023 12:35:14 -0500 Subject: [PATCH 33/58] Make getPreferredExtensions pure virtual func Prior to this commit, WebSocket::getPreferredExtensions() had a default implementation, which meant derived classes would not be required to implement the method themselves. Derived classes probably shouldn't refer to some default implementation, so we require it to be implemented going forward. Co-authored-by: Kenton Varda --- c++/src/capnp/compat/http-over-capnp.c++ | 6 ++ c++/src/kj/compat/http.c++ | 78 +++++++++++++++++++++++- c++/src/kj/compat/http.h | 2 +- 3 files changed, 83 insertions(+), 3 deletions(-) diff --git a/c++/src/capnp/compat/http-over-capnp.c++ b/c++/src/capnp/compat/http-over-capnp.c++ index a10840783e..fb2732dba5 100644 --- a/c++/src/capnp/compat/http-over-capnp.c++ +++ b/c++/src/capnp/compat/http-over-capnp.c++ @@ -250,6 +250,12 @@ public: uint64_t sentByteCount() override { return sentBytes; } uint64_t receivedByteCount() override { return KJ_ASSERT_NONNULL(in)->receivedByteCount(); } + kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { + // TODO(someday): Optimzed pump is tricky with HttpOverCapnp, we may want to revist + // this but for now we always return none (indicating no preference). + return kj::none; + }; + private: kj::Maybe> in; // One end of a WebSocketPipe, used only for receiving. kj::Maybe out; // Used only for sending. diff --git a/c++/src/kj/compat/http.c++ b/c++/src/kj/compat/http.c++ index 601ca0900e..2ba8d62188 100644 --- a/c++/src/kj/compat/http.c++ +++ b/c++/src/kj/compat/http.c++ @@ -3901,6 +3901,19 @@ public: return transferredBytes; } + kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { + KJ_UNREACHABLE; + }; + + kj::Maybe destinationPumpingTo; + kj::Maybe destinationPumpingFrom; + // Tracks the outstanding pumpTo() and tryPumpFrom() calls currently running on the + // WebSocketPipeEnd, which is the destination side of this WebSocketPipeImpl. This is used by + // the source end to implement getPreferredExtensions(). + // + // getPreferredExtensions() does not fit into the model used by all the other methods because it + // is not directional (not a read nor a write call). + private: kj::Maybe state; // Object-oriented state! If any method call is blocked waiting on activity from the other end, @@ -4019,6 +4032,10 @@ private: KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); } + kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { + KJ_UNREACHABLE; + }; + private: kj::PromiseFulfiller& fulfiller; WebSocketPipeImpl& pipe; @@ -4104,6 +4121,10 @@ private: KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); } + kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { + KJ_UNREACHABLE; + }; + private: kj::PromiseFulfiller& fulfiller; WebSocketPipeImpl& pipe; @@ -4188,6 +4209,10 @@ private: KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); } + kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { + KJ_UNREACHABLE; + }; + private: kj::PromiseFulfiller& fulfiller; WebSocketPipeImpl& pipe; @@ -4284,6 +4309,10 @@ private: KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); } + kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { + KJ_UNREACHABLE; + }; + private: kj::PromiseFulfiller& fulfiller; WebSocketPipeImpl& pipe; @@ -4330,6 +4359,9 @@ private: KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); } + kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { + KJ_UNREACHABLE; + }; }; class Aborted final: public WebSocket { @@ -4371,6 +4403,9 @@ private: uint64_t receivedByteCount() override { KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); } + kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { + KJ_UNREACHABLE; + }; }; }; @@ -4403,19 +4438,50 @@ public: return out->whenAborted(); } kj::Maybe> tryPumpFrom(WebSocket& other) override { - return out->tryPumpFrom(other); + KJ_REQUIRE(in->destinationPumpingFrom == kj::none, "can only call tryPumpFrom() once at a time"); + // By convention, we store the WebSocket reference on `in`. + in->destinationPumpingFrom = other; + auto deferredUnregister = kj::defer([this]() { in->destinationPumpingFrom = kj::none; }); + KJ_IF_SOME(p, out->tryPumpFrom(other)) { + return p.attach(kj::mv(deferredUnregister)); + } else { + return kj::none; + } } kj::Promise receive(size_t maxSize) override { return in->receive(maxSize); } kj::Promise pumpTo(WebSocket& other) override { - return in->pumpTo(other); + KJ_REQUIRE(in->destinationPumpingTo == kj::none, "can only call pumpTo() once at a time"); + // By convention, we store the WebSocket reference on `in`. + in->destinationPumpingTo = other; + auto deferredUnregister = kj::defer([this]() { in->destinationPumpingTo = kj::none; }); + return in->pumpTo(other).attach(kj::mv(deferredUnregister)); } uint64_t sentByteCount() override { return out->sentByteCount(); } uint64_t receivedByteCount() override { return in->sentByteCount(); } + kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { + // We want to forward this call to whatever WebSocket the other end of the pipe is pumping + // to/from, if any. We'll check them in an arbitrary order and take the first one we see. + // But really, the hope is that both destinationPumpingTo and destinationPumpingFrom are in fact + // the same object. If they aren't the same, then it's not really clear whose extensions we + // should prefer; the choice here is arbitrary. + KJ_IF_SOME(ws, out->destinationPumpingTo) { + KJ_IF_SOME(result, ws.getPreferredExtensions(ctx)) { + return kj::mv(result); + } + } + KJ_IF_SOME(ws, out->destinationPumpingFrom) { + KJ_IF_SOME(result, ws.getPreferredExtensions(ctx)) { + return kj::mv(result); + } + } + return kj::none; + }; + private: kj::Own in; kj::Own out; @@ -6931,6 +6997,10 @@ private: uint64_t sentByteCount() override { return inner->sentByteCount(); } uint64_t receivedByteCount() override { return inner->receivedByteCount(); } + kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { + return inner->getPreferredExtensions(ctx); + }; + private: kj::Own inner; kj::Maybe> completionTask; @@ -8039,6 +8109,10 @@ private: uint64_t sentByteCount() override { KJ_FAIL_ASSERT("received bad WebSocket handshake"); } uint64_t receivedByteCount() override { KJ_FAIL_ASSERT("received bad WebSocket handshake"); } + kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { + KJ_FAIL_ASSERT(kj::cp(exception)); + }; + private: kj::Exception exception; }; diff --git a/c++/src/kj/compat/http.h b/c++/src/kj/compat/http.h index d77552d65c..9e979001bc 100644 --- a/c++/src/kj/compat/http.h +++ b/c++/src/kj/compat/http.h @@ -689,7 +689,7 @@ class WebSocket { REQUEST, RESPONSE }; - virtual kj::Maybe getPreferredExtensions(ExtensionsContext ctx) { return kj::none; } + virtual kj::Maybe getPreferredExtensions(ExtensionsContext ctx) = 0; // If pumpTo() / tryPumpFrom() is able to be optimized only if the other WebSocket is using // certain extensions (e.g. compression settings), then this method returns what those extensions // are. For example, matching extensions between standard WebSockets allows pumping to be From ae261d9fde3dbedc8c0334ae2a342b18ce43857a Mon Sep 17 00:00:00 2001 From: Felix Hanau Date: Sun, 28 Jan 2024 20:17:47 -0500 Subject: [PATCH 34/58] Assorted build system cleanup - Remove unused C++ headers/CI dependencies/bazel imports - Mark additional classes as final - Mark the library target within cc_capnp_library as off-by-default. This allows us to avoid building the library archive when start_stop_lib is enabled. --- .github/workflows/quick-test.yml | 1 - c++/build/configure.bzl | 2 +- c++/src/capnp/cc_capnp_library.bzl | 3 +++ c++/src/capnp/compiler/capnp.c++ | 1 - c++/src/capnp/compiler/compiler.c++ | 2 +- c++/src/capnp/serialize-async.c++ | 2 +- 6 files changed, 6 insertions(+), 5 deletions(-) diff --git a/.github/workflows/quick-test.yml b/.github/workflows/quick-test.yml index 9269a7ca74..415ad157d3 100644 --- a/.github/workflows/quick-test.yml +++ b/.github/workflows/quick-test.yml @@ -42,7 +42,6 @@ jobs: libunwind: libunwind-12-dev steps: - uses: actions/checkout@v2 - - uses: bazelbuild/setup-bazelisk@v2 - name: install dependencies # I observed 404s for some packages and added an `apt-get update`. Then, I observed package # conflicts between LLVM 14 and 15, and added the line which removes LLVM 14, the default on diff --git a/c++/build/configure.bzl b/c++/build/configure.bzl index a392fa45b2..4891d61cd5 100644 --- a/c++/build/configure.bzl +++ b/c++/build/configure.bzl @@ -1,4 +1,4 @@ -load("@bazel_skylib//rules:common_settings.bzl", "bool_flag", "int_flag") +load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") def kj_configure(): """Generates set of flag, settings for kj configuration. diff --git a/c++/src/capnp/cc_capnp_library.bzl b/c++/src/capnp/cc_capnp_library.bzl index 9e4acd35b9..e6d9f5dada 100644 --- a/c++/src/capnp/cc_capnp_library.bzl +++ b/c++/src/capnp/cc_capnp_library.bzl @@ -87,6 +87,7 @@ def cc_capnp_library( data = [], deps = [], src_prefix = "", + tags = ["off-by-default"], visibility = None, target_compatible_with = None, **kwargs): @@ -122,6 +123,8 @@ def cc_capnp_library( srcs = srcs_cpp, hdrs = hdrs, deps = deps + ["@capnp-cpp//src/capnp:capnp_runtime"], + # Allows us to avoid building the library archive when using start_end_lib + tags = tags, visibility = visibility, target_compatible_with = target_compatible_with, **kwargs diff --git a/c++/src/capnp/compiler/capnp.c++ b/c++/src/capnp/compiler/capnp.c++ index 01e15df8c9..09553f64f2 100644 --- a/c++/src/capnp/compiler/capnp.c++ +++ b/c++/src/capnp/compiler/capnp.c++ @@ -39,7 +39,6 @@ #include #include #include "../message.h" -#include #include #include #include diff --git a/c++/src/capnp/compiler/compiler.c++ b/c++/src/capnp/compiler/compiler.c++ index 49c9b3a83f..c0744d1908 100644 --- a/c++/src/capnp/compiler/compiler.c++ +++ b/c++/src/capnp/compiler/compiler.c++ @@ -256,7 +256,7 @@ private: Node rootNode; }; -class Compiler::Impl: public SchemaLoader::LazyLoadCallback { +class Compiler::Impl final : public SchemaLoader::LazyLoadCallback { public: explicit Impl(AnnotationFlag annotationFlag); virtual ~Impl() noexcept(false); diff --git a/c++/src/capnp/serialize-async.c++ b/c++/src/capnp/serialize-async.c++ index 436b6c9afe..6fd88e393c 100644 --- a/c++/src/capnp/serialize-async.c++ +++ b/c++/src/capnp/serialize-async.c++ @@ -557,7 +557,7 @@ kj::Promise MessageStream::readMessage( // ======================================================================================= -class BufferedMessageStream::MessageReaderImpl: public FlatArrayMessageReader { +class BufferedMessageStream::MessageReaderImpl final : public FlatArrayMessageReader { public: MessageReaderImpl(BufferedMessageStream& parent, kj::ArrayPtr data, ReaderOptions options) From 9b83075e22e25e022196680d5c1d0be791375dd9 Mon Sep 17 00:00:00 2001 From: Samuel Merritt Date: Wed, 31 Jan 2024 16:06:21 -0800 Subject: [PATCH 35/58] Allow custom WebSocket error handlers for kj::HttpClient. WebSocketImpl already knows what to do with a custom error handler; this commit just adds plumbing. --- c++/src/kj/compat/http-test.c++ | 66 +++++++++++++++++++++++++++++++++ c++/src/kj/compat/http.c++ | 2 +- c++/src/kj/compat/http.h | 29 ++++++++------- 3 files changed, 83 insertions(+), 14 deletions(-) diff --git a/c++/src/kj/compat/http-test.c++ b/c++/src/kj/compat/http-test.c++ index 4849fc237a..4aebf54a5b 100644 --- a/c++/src/kj/compat/http-test.c++ +++ b/c++/src/kj/compat/http-test.c++ @@ -5099,6 +5099,72 @@ KJ_TEST("newHttpService from HttpClient WebSockets") { writeResponsesPromise.wait(waitScope); } +KJ_TEST("HttpClient WebSocket: client can have a custom WebSocket error handler") { + KJ_HTTP_TEST_SETUP_IO; + kj::TimerImpl timer(kj::origin()); + auto pipe = KJ_HTTP_TEST_CREATE_2PIPE; + + // These are WEBSOCKET_REQUEST_HANDSHAKE and WEBSOCKET_RESPONSE_HANDSHAKE but without the "My-Header" header. + // This test isn't about the HTTP handshake, so the headers are just noise. + const char wsRequestHandshake[] = + " HTTP/1.1\r\n" + "Connection: Upgrade\r\n" + "Upgrade: websocket\r\n" + "Sec-WebSocket-Key: DCI4TgwiOE4MIjhODCI4Tg==\r\n" + "Sec-WebSocket-Version: 13\r\n" + "\r\n"; + const char wsResponseHandshake[] = + "HTTP/1.1 101 Switching Protocols\r\n" + "Connection: Upgrade\r\n" + "Upgrade: websocket\r\n" + "Sec-WebSocket-Accept: pShtIFKT0s8RYZvnWY/CrjQD8CM=\r\n" + "\r\n"; + + const byte badFrame[] = { + 0xF0, 0x02, 'y', 'o' // all RSV bits set, plus FIN + }; + const byte closeFrame[] = { + 0x88, 0xa8, 0xC, 0x22, 0x38, 0x4e, 0x3, 0xea, // FIN, opcode=Close, code=1009 + 'R', 'e', 'c', 'e', 'i', 'v', 'e', 'd', ' ', + 'f', 'r', 'a', 'm', 'e', ' ', + 'h', 'a', 'd', ' ', + 'R', 'S', 'V', ' ', + 'b', 'i', 't', 's', ' ', + '2', ' ', + 'o', 'r', ' ', + '3', ' ', + 's', 'e', 't', + }; + + auto request = kj::str("GET /websocket", wsRequestHandshake); + auto serverPromise = expectRead(*pipe.ends[1], request) + .then([&]() { return writeA(*pipe.ends[1], asBytes(wsResponseHandshake)); }) + .then([&]() { return writeA(*pipe.ends[1], badFrame); }) + .then([&]() { return expectRead(*pipe.ends[1], closeFrame); }) + .eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); }); + + { + HttpHeaderTable table; + FakeEntropySource entropySource; + HttpClientSettings clientSettings; + WebSocketErrorCatcher errorCatcher; + clientSettings.entropySource = entropySource; + clientSettings.webSocketErrorHandler = errorCatcher; + + auto clientStream = kj::mv(pipe.ends[0]); + auto httpClient = newHttpClient(table, *clientStream, clientSettings); + auto wsClientPromise = httpClient->openWebSocket("/websocket", HttpHeaders(table)) + .then([&](kj::HttpClient::WebSocketResponse resp) { return kj::mv(resp.webSocketOrBody.get>()); }) + .then([](kj::Own webSocket) -> kj::Promise { return webSocket->receive().attach(kj::mv(webSocket)); }) + .eagerlyEvaluate([](kj::Exception e) -> kj::WebSocket::Message { return kj::str("irrelevant value"); }); + + wsClientPromise.wait(waitScope); + KJ_EXPECT(errorCatcher.errors.size() == 1); + } + + serverPromise.wait(waitScope); +} + KJ_TEST("newHttpService from HttpClient WebSockets disconnect") { KJ_HTTP_TEST_SETUP_IO; kj::TimerImpl timer(kj::origin()); diff --git a/c++/src/kj/compat/http.c++ b/c++/src/kj/compat/http.c++ index f680e387f7..7c92da7a5a 100644 --- a/c++/src/kj/compat/http.c++ +++ b/c++/src/kj/compat/http.c++ @@ -5578,7 +5578,7 @@ public: response.statusText, &httpInput.getHeaders(), upgradeToWebSocket(kj::mv(ownStream), httpInput, httpOutput, settings.entropySource, - kj::mv(compressionParameters)), + kj::mv(compressionParameters), settings.webSocketErrorHandler), }; } else { upgraded = false; diff --git a/c++/src/kj/compat/http.h b/c++/src/kj/compat/http.h index 9e979001bc..abbd969a28 100644 --- a/c++/src/kj/compat/http.h +++ b/c++/src/kj/compat/http.h @@ -1021,6 +1021,19 @@ class HttpClientErrorHandler { // little reason to override this. }; +class WebSocketErrorHandler { +public: + virtual kj::Exception handleWebSocketProtocolError(WebSocket::ProtocolError protocolError); + // Handles low-level protocol errors in received WebSocket data. + // + // This is called when the WebSocket peer sends us bad data *after* a successful WebSocket + // upgrade, e.g. a continuation frame without a preceding start frame, a frame with an unknown + // opcode, or similar. + // + // You would override this method in order to customize the exception. You cannot prevent the + // exception from being thrown. +}; + struct HttpClientSettings { kj::Duration idleTimeout = 5 * kj::SECONDS; // For clients which automatically create new connections, any connection idle for at least this @@ -1046,23 +1059,13 @@ struct HttpClientSettings { }; WebSocketCompressionMode webSocketCompressionMode = NO_COMPRESSION; + kj::Maybe webSocketErrorHandler = kj::none; + // Customize exceptions thrown on WebSocket protocol errors. + kj::Maybe tlsContext; // A reference to a TLS context that will be used when tlsStarter is invoked. }; -class WebSocketErrorHandler { -public: - virtual kj::Exception handleWebSocketProtocolError(WebSocket::ProtocolError protocolError); - // Handles low-level protocol errors in received WebSocket data. - // - // This is called when the WebSocket peer sends us bad data *after* a successful WebSocket - // upgrade, e.g. a continuation frame without a preceding start frame, a frame with an unknown - // opcode, or similar. - // - // You would override this method in order to customize the exception. You cannot prevent the - // exception from being thrown. -}; - kj::Own newHttpClient(kj::Timer& timer, const HttpHeaderTable& responseHeaderTable, kj::Network& network, kj::Maybe tlsNetwork, HttpClientSettings settings = HttpClientSettings()); From 20be408666207623fbfd5002628455c41d15e3b9 Mon Sep 17 00:00:00 2001 From: Milan Miladinovic Date: Tue, 6 Feb 2024 15:26:51 -0500 Subject: [PATCH 36/58] Revert "Make getPreferredExtensions pure virtual func" This reverts commit 6c572fbe3134d75232b9725676cc4a8be5ffcb9d. --- c++/src/capnp/compat/http-over-capnp.c++ | 6 -- c++/src/kj/compat/http.c++ | 78 +----------------------- c++/src/kj/compat/http.h | 2 +- 3 files changed, 3 insertions(+), 83 deletions(-) diff --git a/c++/src/capnp/compat/http-over-capnp.c++ b/c++/src/capnp/compat/http-over-capnp.c++ index fb2732dba5..a10840783e 100644 --- a/c++/src/capnp/compat/http-over-capnp.c++ +++ b/c++/src/capnp/compat/http-over-capnp.c++ @@ -250,12 +250,6 @@ public: uint64_t sentByteCount() override { return sentBytes; } uint64_t receivedByteCount() override { return KJ_ASSERT_NONNULL(in)->receivedByteCount(); } - kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { - // TODO(someday): Optimzed pump is tricky with HttpOverCapnp, we may want to revist - // this but for now we always return none (indicating no preference). - return kj::none; - }; - private: kj::Maybe> in; // One end of a WebSocketPipe, used only for receiving. kj::Maybe out; // Used only for sending. diff --git a/c++/src/kj/compat/http.c++ b/c++/src/kj/compat/http.c++ index f680e387f7..5cdda22070 100644 --- a/c++/src/kj/compat/http.c++ +++ b/c++/src/kj/compat/http.c++ @@ -3901,19 +3901,6 @@ public: return transferredBytes; } - kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { - KJ_UNREACHABLE; - }; - - kj::Maybe destinationPumpingTo; - kj::Maybe destinationPumpingFrom; - // Tracks the outstanding pumpTo() and tryPumpFrom() calls currently running on the - // WebSocketPipeEnd, which is the destination side of this WebSocketPipeImpl. This is used by - // the source end to implement getPreferredExtensions(). - // - // getPreferredExtensions() does not fit into the model used by all the other methods because it - // is not directional (not a read nor a write call). - private: kj::Maybe state; // Object-oriented state! If any method call is blocked waiting on activity from the other end, @@ -4032,10 +4019,6 @@ private: KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); } - kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { - KJ_UNREACHABLE; - }; - private: kj::PromiseFulfiller& fulfiller; WebSocketPipeImpl& pipe; @@ -4121,10 +4104,6 @@ private: KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); } - kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { - KJ_UNREACHABLE; - }; - private: kj::PromiseFulfiller& fulfiller; WebSocketPipeImpl& pipe; @@ -4209,10 +4188,6 @@ private: KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); } - kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { - KJ_UNREACHABLE; - }; - private: kj::PromiseFulfiller& fulfiller; WebSocketPipeImpl& pipe; @@ -4309,10 +4284,6 @@ private: KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); } - kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { - KJ_UNREACHABLE; - }; - private: kj::PromiseFulfiller& fulfiller; WebSocketPipeImpl& pipe; @@ -4359,9 +4330,6 @@ private: KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); } - kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { - KJ_UNREACHABLE; - }; }; class Aborted final: public WebSocket { @@ -4403,9 +4371,6 @@ private: uint64_t receivedByteCount() override { KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); } - kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { - KJ_UNREACHABLE; - }; }; }; @@ -4438,50 +4403,19 @@ public: return out->whenAborted(); } kj::Maybe> tryPumpFrom(WebSocket& other) override { - KJ_REQUIRE(in->destinationPumpingFrom == kj::none, "can only call tryPumpFrom() once at a time"); - // By convention, we store the WebSocket reference on `in`. - in->destinationPumpingFrom = other; - auto deferredUnregister = kj::defer([this]() { in->destinationPumpingFrom = kj::none; }); - KJ_IF_SOME(p, out->tryPumpFrom(other)) { - return p.attach(kj::mv(deferredUnregister)); - } else { - return kj::none; - } + return out->tryPumpFrom(other); } kj::Promise receive(size_t maxSize) override { return in->receive(maxSize); } kj::Promise pumpTo(WebSocket& other) override { - KJ_REQUIRE(in->destinationPumpingTo == kj::none, "can only call pumpTo() once at a time"); - // By convention, we store the WebSocket reference on `in`. - in->destinationPumpingTo = other; - auto deferredUnregister = kj::defer([this]() { in->destinationPumpingTo = kj::none; }); - return in->pumpTo(other).attach(kj::mv(deferredUnregister)); + return in->pumpTo(other); } uint64_t sentByteCount() override { return out->sentByteCount(); } uint64_t receivedByteCount() override { return in->sentByteCount(); } - kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { - // We want to forward this call to whatever WebSocket the other end of the pipe is pumping - // to/from, if any. We'll check them in an arbitrary order and take the first one we see. - // But really, the hope is that both destinationPumpingTo and destinationPumpingFrom are in fact - // the same object. If they aren't the same, then it's not really clear whose extensions we - // should prefer; the choice here is arbitrary. - KJ_IF_SOME(ws, out->destinationPumpingTo) { - KJ_IF_SOME(result, ws.getPreferredExtensions(ctx)) { - return kj::mv(result); - } - } - KJ_IF_SOME(ws, out->destinationPumpingFrom) { - KJ_IF_SOME(result, ws.getPreferredExtensions(ctx)) { - return kj::mv(result); - } - } - return kj::none; - }; - private: kj::Rc in; kj::Rc out; @@ -6997,10 +6931,6 @@ private: uint64_t sentByteCount() override { return inner->sentByteCount(); } uint64_t receivedByteCount() override { return inner->receivedByteCount(); } - kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { - return inner->getPreferredExtensions(ctx); - }; - private: kj::Own inner; kj::Maybe> completionTask; @@ -8109,10 +8039,6 @@ private: uint64_t sentByteCount() override { KJ_FAIL_ASSERT("received bad WebSocket handshake"); } uint64_t receivedByteCount() override { KJ_FAIL_ASSERT("received bad WebSocket handshake"); } - kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { - KJ_FAIL_ASSERT(kj::cp(exception)); - }; - private: kj::Exception exception; }; diff --git a/c++/src/kj/compat/http.h b/c++/src/kj/compat/http.h index 9e979001bc..d77552d65c 100644 --- a/c++/src/kj/compat/http.h +++ b/c++/src/kj/compat/http.h @@ -689,7 +689,7 @@ class WebSocket { REQUEST, RESPONSE }; - virtual kj::Maybe getPreferredExtensions(ExtensionsContext ctx) = 0; + virtual kj::Maybe getPreferredExtensions(ExtensionsContext ctx) { return kj::none; } // If pumpTo() / tryPumpFrom() is able to be optimized only if the other WebSocket is using // certain extensions (e.g. compression settings), then this method returns what those extensions // are. For example, matching extensions between standard WebSockets allows pumping to be From 4e5f9191497f680b6df5c25a24fea51482275099 Mon Sep 17 00:00:00 2001 From: Kenton Varda Date: Wed, 24 Jan 2024 13:28:12 -0600 Subject: [PATCH 37/58] Delete code: Remove "level 1" http-over-capnp implementation. Our production servers were all updated a year or so ago to use only level 2, so we can safely remove the code implemeting level 1, which will make further changes easier. --- c++/src/capnp/compat/BUILD.bazel | 10 -- .../capnp/compat/http-over-capnp-old-test.c++ | 2 - c++/src/capnp/compat/http-over-capnp-test.c++ | 11 +- c++/src/capnp/compat/http-over-capnp.c++ | 115 +++--------------- c++/src/capnp/compat/http-over-capnp.h | 14 ++- 5 files changed, 30 insertions(+), 122 deletions(-) delete mode 100644 c++/src/capnp/compat/http-over-capnp-old-test.c++ diff --git a/c++/src/capnp/compat/BUILD.bazel b/c++/src/capnp/compat/BUILD.bazel index 7cb67279ca..528c919689 100644 --- a/c++/src/capnp/compat/BUILD.bazel +++ b/c++/src/capnp/compat/BUILD.bazel @@ -142,13 +142,3 @@ cc_library( name = "http-over-capnp-test-as-header", hdrs = ["http-over-capnp-test.c++"], ) - -cc_test( - name = "http-over-capnp-old-test", - srcs = ["http-over-capnp-old-test.c++"], - deps = [ - ":http-over-capnp-test-as-header", - ":http-over-capnp", - "//src/capnp:capnp-test" - ], -) diff --git a/c++/src/capnp/compat/http-over-capnp-old-test.c++ b/c++/src/capnp/compat/http-over-capnp-old-test.c++ deleted file mode 100644 index 9a5aea9b13..0000000000 --- a/c++/src/capnp/compat/http-over-capnp-old-test.c++ +++ /dev/null @@ -1,2 +0,0 @@ -#define TEST_PEER_OPTIMIZATION_LEVEL HttpOverCapnpFactory::LEVEL_1 -#include "http-over-capnp-test.c++" diff --git a/c++/src/capnp/compat/http-over-capnp-test.c++ b/c++/src/capnp/compat/http-over-capnp-test.c++ index 5d1e000069..8886e7e5df 100644 --- a/c++/src/capnp/compat/http-over-capnp-test.c++ +++ b/c++/src/capnp/compat/http-over-capnp-test.c++ @@ -657,14 +657,13 @@ KJ_TEST("HttpService isn't destroyed while call outstanding") { KJ_EXPECT(!called); KJ_EXPECT(!destroyed); - auto req = service.startRequestRequest(); + auto req = service.requestRequest(); auto httpReq = req.initRequest(); httpReq.setMethod(capnp::HttpMethod::GET); httpReq.setUrl("/"); - auto serverContext = req.send().wait(waitScope).getContext(); + auto promise = req.send(); service = nullptr; - auto promise = serverContext.whenResolved(); KJ_EXPECT(!promise.poll(waitScope)); KJ_EXPECT(called); @@ -769,7 +768,7 @@ KJ_TEST("HTTP-over-Cap'n-Proto Connect with close") { ByteStreamFactory streamFactory; kj::HttpHeaderTable::Builder tableBuilder; - HttpOverCapnpFactory factory(streamFactory, tableBuilder); + HttpOverCapnpFactory factory(streamFactory, tableBuilder, TEST_PEER_OPTIMIZATION_LEVEL); kj::Own table = tableBuilder.build(); ConnectWriteCloseService service(*table); kj::HttpServer server(timer, *table, service); @@ -844,7 +843,7 @@ KJ_TEST("HTTP-over-Cap'n-Proto Connect Reject") { ByteStreamFactory streamFactory; kj::HttpHeaderTable::Builder tableBuilder; - HttpOverCapnpFactory factory(streamFactory, tableBuilder); + HttpOverCapnpFactory factory(streamFactory, tableBuilder, TEST_PEER_OPTIMIZATION_LEVEL); kj::Own table = tableBuilder.build(); ConnectRejectService service(*table); kj::HttpServer server(timer, *table, service); @@ -917,7 +916,7 @@ KJ_TEST("HTTP-over-Cap'n-Proto Connect with startTls") { ByteStreamFactory streamFactory; kj::HttpHeaderTable::Builder tableBuilder; - HttpOverCapnpFactory factory(streamFactory, tableBuilder); + HttpOverCapnpFactory factory(streamFactory, tableBuilder, TEST_PEER_OPTIMIZATION_LEVEL); kj::Own table = tableBuilder.build(); ConnectWriteRespService service(*table); kj::HttpServer server(timer, *table, service); diff --git a/c++/src/capnp/compat/http-over-capnp.c++ b/c++/src/capnp/compat/http-over-capnp.c++ index a10840783e..56aec1831e 100644 --- a/c++/src/capnp/compat/http-over-capnp.c++ +++ b/c++/src/capnp/compat/http-over-capnp.c++ @@ -398,17 +398,10 @@ public: KjToCapnpHttpServiceAdapter(HttpOverCapnpFactory& factory, capnp::HttpService::Client inner) : factory(factory), inner(kj::mv(inner)) {} - template - kj::Promise requestImpl( - Request rpcRequest, + kj::Promise request( kj::HttpMethod method, kj::StringPtr url, const kj::HttpHeaders& headers, - kj::AsyncInputStream& requestBody, kj::HttpService::Response& kjResponse, - AwaitCompletionFunc&& awaitCompletion) { - // Common implementation calling request() or startRequest(). awaitCompletion() waits for - // final completion in a method-specific way. - // - // TODO(cleanup): When we move to C++17 or newer we can use `if constexpr` instead of a - // callback. + kj::AsyncInputStream& requestBody, kj::HttpService::Response& kjResponse) override { + auto rpcRequest = inner.requestRequest(); auto metadata = rpcRequest.initRequest(); metadata.setMethod(static_cast(method)); @@ -467,7 +460,7 @@ public: // Wait for the server to indicate completion. Meanwhile, if the // promise is canceled from the client side, we propagate cancellation naturally, and we // also call state->cancel(). - return awaitCompletion(pipeline) + return pipeline.ignoreResult() // Once the server indicates it is done, then we can cancel pumping the request, because // obviously the server won't use it. We should not cancel pumping the response since there // could be data in-flight still. @@ -477,18 +470,6 @@ public: .attach(kj::mv(deferredCancel)); } - kj::Promise request( - kj::HttpMethod method, kj::StringPtr url, const kj::HttpHeaders& headers, - kj::AsyncInputStream& requestBody, kj::HttpService::Response& kjResponse) override { - if (factory.peerOptimizationLevel < LEVEL_2) { - return requestImpl(inner.startRequestRequest(), method, url, headers, requestBody, kjResponse, - [](auto& pipeline) { return pipeline.getContext().whenResolved(); }); - } else { - return requestImpl(inner.requestRequest(), method, url, headers, requestBody, kjResponse, - [](auto& pipeline) { return pipeline.ignoreResult(); }); - } - } - kj::Promise connect( kj::StringPtr host, const kj::HttpHeaders& headers, kj::AsyncIoStream& connection, ConnectResponse& tunnel, kj::HttpConnectSettings settings) override { @@ -596,14 +577,9 @@ public: // We can't really optimize tryPumpFrom() unless AsyncInputStream grows a skip() method. }; -class ResolvedServerRequestContext final: public capnp::HttpService::ServerRequestContext::Server { -public: - // Nothing! It's done. -}; - } // namespace -class HttpOverCapnpFactory::HttpServiceResponseImpl +class HttpOverCapnpFactory::HttpServiceResponseImpl final : public kj::HttpService::Response { public: HttpServiceResponseImpl(HttpOverCapnpFactory& factory, @@ -766,50 +742,12 @@ public: }; -class HttpOverCapnpFactory::ServerRequestContextImpl final - : public capnp::HttpService::ServerRequestContext::Server, - public HttpServiceResponseImpl { -public: - ServerRequestContextImpl(HttpOverCapnpFactory& factory, - HttpService::Client serviceCap, - kj::Own request, - capnp::HttpService::ClientRequestContext::Client clientContext, - kj::Own requestBodyIn, - kj::HttpService& kjService) - : HttpServiceResponseImpl(factory, *request, kj::mv(clientContext)), - request(kj::mv(request)), - serviceCap(kj::mv(serviceCap)), - // Note we attach `requestBodyIn` to `task` so that we will implicitly cancel reading - // the request body as soon as the service returns. This is important in particular when - // the request body is not fully consumed, in order to propagate cancellation. - task(kjService.request(method, url, headers, *requestBodyIn, *this) - .attach(kj::mv(requestBodyIn))) {} - - kj::Maybe> shortenPath() override { - return task.then([]() -> Capability::Client { - // If all went well, resolve to a settled capability. - // TODO(perf): Could save a message by resolving to a capability hosted by the client, or - // some special "null" capability that isn't an error but is still transmitted by value. - // Otherwise we need a Release message from client -> server just to drop this... - return kj::heap(); - }); - } - - KJ_DISALLOW_COPY_AND_MOVE(ServerRequestContextImpl); - -private: - kj::Own request; - HttpService::Client serviceCap; // ensures the inner kj::HttpService isn't destroyed - kj::Promise task; -}; - class HttpOverCapnpFactory::CapnpToKjHttpServiceAdapter final: public capnp::HttpService::Server { public: CapnpToKjHttpServiceAdapter(HttpOverCapnpFactory& factory, kj::Own inner) : factory(factory), inner(kj::mv(inner)) {} - template - kj::Promise requestImpl(CallContext context, Callback&& callback) { + kj::Promise request(RequestContext context) override { // Common implementation of request() and startRequest(). callback() performs the // method-specific stuff at the end. // @@ -834,14 +772,11 @@ public: auto pipe = kj::newOneWayPipe(expectedSize); auto requestBodyCap = factory.streamFactory.kjToCapnp(kj::mv(pipe.out)); - if (kj::isSameType()) { - // For request(), use context.setPipeline() to enable pipelined calls to the request body - // stream before this RPC completes. (We don't bother when using startRequest() because - // it returns immediately anyway, so this would just waste effort.) - PipelineBuilder pipeline; - pipeline.setRequestBody(kj::cp(requestBodyCap)); - context.setPipeline(pipeline.build()); - } + // For request(), use context.setPipeline() to enable pipelined calls to the request body + // stream before this RPC completes. + PipelineBuilder pipeline; + pipeline.setRequestBody(kj::cp(requestBodyCap)); + context.setPipeline(pipeline.build()); results.setRequestBody(kj::mv(requestBodyCap)); requestBody = kj::mv(pipe.in); @@ -849,31 +784,9 @@ public: requestBody = kj::heap(); } - return callback(results, metadata, params, requestBody); - } - - kj::Promise request(RequestContext context) override { - return requestImpl(kj::mv(context), - [&](auto& results, auto& metadata, auto& params, auto& requestBody) { - class FinalHttpServiceResponseImpl final: public HttpServiceResponseImpl { - public: - using HttpServiceResponseImpl::HttpServiceResponseImpl; - }; - auto impl = kj::heap(factory, metadata, params.getContext()); - auto promise = inner->request(impl->method, impl->url, impl->headers, *requestBody, *impl); - return promise.attach(kj::mv(requestBody), kj::mv(impl)); - }); - } - - kj::Promise startRequest(StartRequestContext context) override { - return requestImpl(kj::mv(context), - [&](auto& results, auto& metadata, auto& params, auto& requestBody) { - results.setContext(kj::heap( - factory, thisCap(), capnp::clone(metadata), params.getContext(), kj::mv(requestBody), - *inner)); - - return kj::READY_NOW; - }); + auto impl = kj::heap(factory, metadata, params.getContext()); + auto promise = inner->request(impl->method, impl->url, impl->headers, *requestBody, *impl); + return promise.attach(kj::mv(requestBody), kj::mv(impl)); } kj::Promise connect(ConnectContext context) override { diff --git a/c++/src/capnp/compat/http-over-capnp.h b/c++/src/capnp/compat/http-over-capnp.h index 6b16749118..1170723aa2 100644 --- a/c++/src/capnp/compat/http-over-capnp.h +++ b/c++/src/capnp/compat/http-over-capnp.h @@ -57,8 +57,11 @@ class HttpOverCapnpFactory { // will improve efficiency but breaks compatibility with older peers that don't implement newer // levels. - LEVEL_1, - // Use startRequest(), the original version of the protocol. + // There used to be a LEVEL_1, which used `startRequest()`, the original version of the + // protocol. Support for this level was removed in the v2 branch in order to simplify the code. + // If you have existing servers in the wild implementing this protocol that don't support + // LEVEL_2, then your clients will have to stick to Cap'n Proto 1.x until those servers are all + // updated. LEVEL_2 // Use request(). This is more efficient than startRequest() but won't work with old peers that @@ -66,7 +69,12 @@ class HttpOverCapnpFactory { }; HttpOverCapnpFactory(ByteStreamFactory& streamFactory, HeaderIdBundle headerIds, - OptimizationLevel peerOptimizationLevel = LEVEL_1); + OptimizationLevel peerOptimizationLevel); + // Note: `peerOptimizationLevel` use to be optional, but defaulted to LEVEL_1. However, any + // client still setting this to LEVEL_1 will be unable to talk to any server who is running new + // code where LEVEL_1 was removed. So if you hit a compile error because your code is not setting + // this option, you will need to roll back to an older version of Cap'n Proto for now, until you + // can update all code in production to pass LEVEL_2 here. kj::Own capnpToKj(capnp::HttpService::Client rpcService); capnp::HttpService::Client kjToCapnp(kj::Own service); From 8f3a65f16d676a0419914df558a7e1dae36bd04f Mon Sep 17 00:00:00 2001 From: Kenton Varda Date: Wed, 24 Jan 2024 13:48:32 -0600 Subject: [PATCH 38/58] Refactor: Use RevocableServer instead of assertNotCanceled() for ClientRequestContext. This is a cleaner approach, and is already used in connect() below. --- c++/src/capnp/compat/http-over-capnp.c++ | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/c++/src/capnp/compat/http-over-capnp.c++ b/c++/src/capnp/compat/http-over-capnp.c++ index 56aec1831e..51369d9aec 100644 --- a/c++/src/capnp/compat/http-over-capnp.c++ +++ b/c++/src/capnp/compat/http-over-capnp.c++ @@ -55,12 +55,6 @@ public: } } - void assertNotCanceled() { - if (tasks == kj::none) { - kj::throwFatalException(KJ_EXCEPTION(DISCONNECTED, "client canceled HTTP request")); - } - } - void addTask(kj::Promise task) { KJ_IF_SOME(t, tasks) { t.add(kj::mv(task)); @@ -267,14 +261,9 @@ public: kj::HttpService::Response& kjResponse) : factory(factory), state(kj::mv(state)), kjResponse(kjResponse) {} - ~ClientRequestContextImpl() noexcept(false) { - // Note this implicitly cancels the upstream pump task. - } - kj::Promise startResponse(StartResponseContext context) override { KJ_REQUIRE(!sent, "already called startResponse() or startWebSocket()"); sent = true; - state->assertNotCanceled(); auto params = context.getParams(); auto rpcResponse = params.getResponse(); @@ -305,7 +294,6 @@ public: kj::Promise startWebSocket(StartWebSocketContext context) override { KJ_REQUIRE(!sent, "already called startResponse() or startWebSocket()"); sent = true; - state->assertNotCanceled(); auto params = context.getParams(); @@ -340,7 +328,6 @@ private: bool sent = false; kj::HttpService::Response& kjResponse; - // Must check state->assertNotCanceled() before using this. }; class HttpOverCapnpFactory::ConnectClientRequestContextImpl final @@ -432,8 +419,11 @@ public: state->cancel(); }); - rpcRequest.setContext( - kj::heap(factory, kj::addRef(*state), kjResponse)); + auto context = kj::heap( + factory, kj::addRef(*state), kjResponse); + RevocableServer revocableContext(*context); + + rpcRequest.setContext(revocableContext.getClient()); auto pipeline = rpcRequest.send(); @@ -467,7 +457,7 @@ public: .attach(kj::mv(pumpRequestTask)) // finishTasks() will wait for the respones to complete. .then([state = kj::mv(state)]() mutable { return state->finishTasks(); }) - .attach(kj::mv(deferredCancel)); + .attach(kj::mv(deferredCancel), kj::mv(revocableContext), kj::mv(context)); } kj::Promise connect( From 324d0e736f52c201274b28eb24fa9102c95fa35e Mon Sep 17 00:00:00 2001 From: Kenton Varda Date: Wed, 24 Jan 2024 13:58:34 -0600 Subject: [PATCH 39/58] Refactor: Coroutinize KjToCapnpHttpServiceAdapter::request() and connect(). This lets us save a bunch of allocations! --- c++/src/capnp/compat/http-over-capnp.c++ | 41 ++++++++++++------------ 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/c++/src/capnp/compat/http-over-capnp.c++ b/c++/src/capnp/compat/http-over-capnp.c++ index 51369d9aec..0ecb5473e7 100644 --- a/c++/src/capnp/compat/http-over-capnp.c++ +++ b/c++/src/capnp/compat/http-over-capnp.c++ @@ -415,18 +415,18 @@ public: } auto state = kj::refcounted(); - auto deferredCancel = kj::defer([state = kj::addRef(*state)]() mutable { - state->cancel(); - }); + KJ_DEFER(state->cancel()); - auto context = kj::heap( - factory, kj::addRef(*state), kjResponse); - RevocableServer revocableContext(*context); + ClientRequestContextImpl context(factory, kj::addRef(*state), kjResponse); + RevocableServer revocableContext(context); rpcRequest.setContext(revocableContext.getClient()); auto pipeline = rpcRequest.send(); + // Make sure the request message isn't pinned into memory through the co_await below. + { auto drop = kj::mv(rpcRequest); } + // Pump upstream -- unless we don't expect a request body. kj::Maybe> pumpRequestTask; KJ_IF_SOME(rb, maybeRequestBody) { @@ -450,14 +450,15 @@ public: // Wait for the server to indicate completion. Meanwhile, if the // promise is canceled from the client side, we propagate cancellation naturally, and we // also call state->cancel(). - return pipeline.ignoreResult() - // Once the server indicates it is done, then we can cancel pumping the request, because - // obviously the server won't use it. We should not cancel pumping the response since there - // could be data in-flight still. - .attach(kj::mv(pumpRequestTask)) - // finishTasks() will wait for the respones to complete. - .then([state = kj::mv(state)]() mutable { return state->finishTasks(); }) - .attach(kj::mv(deferredCancel), kj::mv(revocableContext), kj::mv(context)); + co_await pipeline.ignoreResult(); + + // Once the server indicates it is done, then we can cancel pumping the request, because + // obviously the server won't use it. We should not cancel pumping the response since there + // could be data in-flight still. + { auto drop = kj::mv(pumpRequestTask); } + + // finishTasks() will wait for the respones to complete. + co_await state->finishTasks(); } kj::Promise connect( @@ -469,8 +470,8 @@ public: rpcRequest.setDown(factory.streamFactory.kjToCapnp(kj::mv(downPipe.out))); rpcRequest.initSettings().setUseTls(settings.useTls); - auto context = kj::heap(factory, tunnel); - RevocableServer revocableContext(*context); + ConnectClientRequestContextImpl context(factory, tunnel); + RevocableServer revocableContext(context); auto builder = capnp::Request< capnp::HttpService::ConnectParams, @@ -480,6 +481,9 @@ public: rpcRequest.setContext(revocableContext.getClient()); RemotePromise pipeline = rpcRequest.send(); + // Make sure the request message isn't pinned into memory through the co_await below. + { auto drop = kj::mv(rpcRequest); } + // We read from `downPipe` (the other side writes into it.) auto downPumpTask = downPipe.in->pumpTo(connection) .then([&connection, down = kj::mv(downPipe.in)](uint64_t) -> kj::Promise { @@ -511,10 +515,7 @@ public: return kj::NEVER_DONE; }); - return pipeline.ignoreResult() - .attach(kj::mv(downPumpTask), kj::mv(upPumpTask), kj::mv(revocableContext)) - // Separate attach to make sure `revocableContext` is destroyed before `context`. - .attach(kj::mv(context)); + co_await pipeline.ignoreResult(); } From 4a52e699201fbb9c070f291bca21bb2d2b468123 Mon Sep 17 00:00:00 2001 From: Kenton Varda Date: Wed, 24 Jan 2024 14:12:59 -0600 Subject: [PATCH 40/58] Refactor: Avoid using RequestState for request pump failures. This is a step towards getting rid of RequestState. --- c++/src/capnp/compat/http-over-capnp.c++ | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/c++/src/capnp/compat/http-over-capnp.c++ b/c++/src/capnp/compat/http-over-capnp.c++ index 0ecb5473e7..5f76c89e28 100644 --- a/c++/src/capnp/compat/http-over-capnp.c++ +++ b/c++/src/capnp/compat/http-over-capnp.c++ @@ -428,13 +428,14 @@ public: { auto drop = kj::mv(rpcRequest); } // Pump upstream -- unless we don't expect a request body. + kj::Maybe> pumpRequestFailedReason; kj::Maybe> pumpRequestTask; KJ_IF_SOME(rb, maybeRequestBody) { auto bodyOut = factory.streamFactory.capnpToKjExplicitEnd(pipeline.getRequestBody()); pumpRequestTask = rb.pumpTo(*bodyOut) .then([&bodyOut = *bodyOut](uint64_t) mutable { return bodyOut.end(); - }).eagerlyEvaluate([state = kj::addRef(*state), bodyOut = kj::mv(bodyOut)] + }).eagerlyEvaluate([&pumpRequestFailedReason, bodyOut = kj::mv(bodyOut)] (kj::Exception&& e) mutable { // A DISCONNECTED exception probably means the server decided not to read the whole request // before responding. In that case we simply want the pump to end, so that on this end it @@ -442,7 +443,7 @@ public: // exception in that case. For any other exception, we want to merge the exception with // the final result. if (e.getType() != kj::Exception::Type::DISCONNECTED) { - state->taskFailed(kj::mv(e)); + pumpRequestFailedReason = kj::heap(kj::mv(e)); } }); } @@ -457,6 +458,12 @@ public: // could be data in-flight still. { auto drop = kj::mv(pumpRequestTask); } + // If the request pump failed (for a non-disconnect reason) we'd better propagate that + // exception. + KJ_IF_SOME(e, pumpRequestFailedReason) { + kj::throwFatalException(kj::mv(*e)); + } + // finishTasks() will wait for the respones to complete. co_await state->finishTasks(); } From 0d2a2f253a673b95084739357a4dd5f882bf5a58 Mon Sep 17 00:00:00 2001 From: Kenton Varda Date: Wed, 24 Jan 2024 16:33:55 -0600 Subject: [PATCH 41/58] Refactor: Eliminate HttpOverCapnpFactory::RequestState::addTask(). We can't quite remove the TaskSet itself yet, since disconnectWebSocket() still uses it. --- c++/src/capnp/compat/http-over-capnp.c++ | 36 +++++++++++++----------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/c++/src/capnp/compat/http-over-capnp.c++ b/c++/src/capnp/compat/http-over-capnp.c++ index 5f76c89e28..7ad912ea90 100644 --- a/c++/src/capnp/compat/http-over-capnp.c++ +++ b/c++/src/capnp/compat/http-over-capnp.c++ @@ -55,14 +55,6 @@ public: } } - void addTask(kj::Promise task) { - KJ_IF_SOME(t, tasks) { - t.add(kj::mv(task)); - } else { - // Just drop the task. - } - } - kj::Promise finishTasks() { // This is merged into the final promise, so we don't need to worry about wrapping it for // cancellation. @@ -262,8 +254,7 @@ public: : factory(factory), state(kj::mv(state)), kjResponse(kjResponse) {} kj::Promise startResponse(StartResponseContext context) override { - KJ_REQUIRE(!sent, "already called startResponse() or startWebSocket()"); - sent = true; + KJ_REQUIRE(responsePumpTask == kj::none, "already called startResponse() or startWebSocket()"); auto params = context.getParams(); auto rpcResponse = params.getResponse(); @@ -284,16 +275,15 @@ public: if (hasBody) { auto pipe = kj::newOneWayPipe(); results.setBody(factory.streamFactory.kjToCapnp(kj::mv(pipe.out))); - state->addTask(pipe.in->pumpTo(*bodyStream) + responsePumpTask = pipe.in->pumpTo(*bodyStream) .ignoreResult() - .attach(kj::mv(bodyStream), kj::mv(pipe.in))); + .attach(kj::mv(bodyStream), kj::mv(pipe.in)); } return kj::READY_NOW; } kj::Promise startWebSocket(StartWebSocketContext context) override { - KJ_REQUIRE(!sent, "already called startResponse() or startWebSocket()"); - sent = true; + KJ_REQUIRE(responsePumpTask == kj::none, "already called startResponse() or startWebSocket()"); auto params = context.getParams(); @@ -305,7 +295,7 @@ public: auto upWrapper = kj::heap( nullptr, params.getUpSocket(), kj::mv(shorteningPaf.fulfiller)); - state->addTask(webSocket.pumpTo(*upWrapper).attach(kj::mv(upWrapper)) + responsePumpTask = webSocket.pumpTo(*upWrapper).attach(kj::mv(upWrapper)) .catch_([&webSocket=webSocket](kj::Exception&& e) -> kj::Promise { // The pump in the client -> server direction failed. The error may have originated from // either the client or the server. In case it came from the server, we want to call .abort() @@ -313,7 +303,7 @@ public: // .abort() probably is a noop. webSocket.abort(); return kj::mv(e); - })); + }); auto results = context.getResults(MessageSize { 16, 1 }); results.setDownSocket(kj::heap( @@ -322,10 +312,18 @@ public: return kj::READY_NOW; } + kj::Promise finishPump() { + KJ_IF_SOME(r, responsePumpTask) { + return kj::mv(r); + } else { + return kj::READY_NOW; + } + } + private: HttpOverCapnpFactory& factory; kj::Own state; - bool sent = false; + kj::Maybe> responsePumpTask; kj::HttpService::Response& kjResponse; }; @@ -464,7 +462,11 @@ public: kj::throwFatalException(kj::mv(*e)); } + // Finish pumping the response or WebSocket. (Probably it's already finished.) + co_await context.finishPump(); + // finishTasks() will wait for the respones to complete. + // TODO(now): Eliminate this. co_await state->finishTasks(); } From 7569830445eb8a453519b4dbb1ca0d31c0927b33 Mon Sep 17 00:00:00 2001 From: Kenton Varda Date: Wed, 24 Jan 2024 17:31:43 -0600 Subject: [PATCH 42/58] Refactor: Remove RequestState::disconnectWebSocket(). Instead, we change CapnpToKjWebSocketAdapter to pass off responsibility for the `disconnect()` call to its creator. This seems like a better design -- we can now actually explicitly wait for the stream to complete. Previously I think it was possible that we'd kill the websocket prematurely if it had a bunch of messages queued up still upon stream completion. --- c++/src/capnp/compat/http-over-capnp.c++ | 48 +++++++++++++++--------- 1 file changed, 31 insertions(+), 17 deletions(-) diff --git a/c++/src/capnp/compat/http-over-capnp.c++ b/c++/src/capnp/compat/http-over-capnp.c++ index 7ad912ea90..e78d98e69e 100644 --- a/c++/src/capnp/compat/http-over-capnp.c++ +++ b/c++/src/capnp/compat/http-over-capnp.c++ @@ -79,12 +79,6 @@ public: this->webSocket = kj::mv(webSocket); } - void disconnectWebSocket() { - KJ_IF_SOME(t, tasks) { - t.add(kj::evalNow([&]() { return KJ_ASSERT_NONNULL(webSocket)->disconnect(); })); - } - } - private: kj::Maybe error; kj::Maybe> webSocket; @@ -97,13 +91,21 @@ private: class HttpOverCapnpFactory::CapnpToKjWebSocketAdapter final: public capnp::WebSocket::Server { public: CapnpToKjWebSocketAdapter(kj::Own state, kj::WebSocket& webSocket, - kj::Promise shorteningPromise) + kj::Promise shorteningPromise, + kj::Own>> onEnd) : state(kj::mv(state)), webSocket(webSocket), - shorteningPromise(kj::mv(shorteningPromise)) {} + shorteningPromise(kj::mv(shorteningPromise)), + onEnd(kj::mv(onEnd)) {} + // `onEnd` is resolved if and when the stream (in this direction) ends cleanly. ~CapnpToKjWebSocketAdapter() noexcept(false) { if (clean) { - state->disconnectWebSocket(); + onEnd->fulfill(state->wrap([&]() { + return webSocket.disconnect(); + })); + } else { + // TODO(now): Capture actual exception? + onEnd->reject(KJ_EXCEPTION(FAILED, "WebSocket-over-capnp downstream failed")); } } @@ -139,6 +141,7 @@ private: kj::Own state; kj::WebSocket& webSocket; kj::Promise shorteningPromise; + kj::Own>> onEnd; bool clean = true; // It's illegal to call another `send()` or `disconnect()` until the previous `send()` has @@ -295,7 +298,7 @@ public: auto upWrapper = kj::heap( nullptr, params.getUpSocket(), kj::mv(shorteningPaf.fulfiller)); - responsePumpTask = webSocket.pumpTo(*upWrapper).attach(kj::mv(upWrapper)) + auto upPumpTask = webSocket.pumpTo(*upWrapper).attach(kj::mv(upWrapper)) .catch_([&webSocket=webSocket](kj::Exception&& e) -> kj::Promise { // The pump in the client -> server direction failed. The error may have originated from // either the client or the server. In case it came from the server, we want to call .abort() @@ -306,8 +309,16 @@ public: }); auto results = context.getResults(MessageSize { 16, 1 }); - results.setDownSocket(kj::heap( - kj::addRef(*state), webSocket, kj::mv(shorteningPaf.promise))); + auto downPaf = kj::newPromiseAndFulfiller>(); + auto downSocket = kj::heap( + kj::addRef(*state), webSocket, kj::mv(shorteningPaf.promise), kj::mv(downPaf.fulfiller)); + results.setDownSocket(kj::mv(downSocket)); + + // Note: This intentionally uses joinPromises and not joinPromisesFailFast, because + // finishPump() isn't called to wait on the final promise until we already expect both + // promises to be done anyway, so if we cancel one as a result of the other failing, it + // provides no benefit and possibly creates confusion. + responsePumpTask = kj::joinPromises(kj::arr(kj::mv(upPumpTask), kj::mv(downPaf.promise))); return kj::READY_NOW; } @@ -652,18 +663,21 @@ public: auto dummyState = kj::refcounted(); auto& pipeEnd0Ref = *pipe.ends[0]; dummyState->holdWebSocket(kj::mv(pipe.ends[0])); + auto upPumpPaf = kj::newPromiseAndFulfiller>(); req.setUpSocket(kj::heap( - kj::mv(dummyState), pipeEnd0Ref, kj::mv(shorteningPaf.promise))); + kj::mv(dummyState), pipeEnd0Ref, kj::mv(shorteningPaf.promise), + kj::mv(upPumpPaf.fulfiller))); auto pipeline = req.send(); auto result = kj::heap( kj::mv(pipe.ends[1]), pipeline.getDownSocket(), kj::mv(shorteningPaf.fulfiller)); - // Note we need eagerlyEvaluate() here to force proactively discarding the response object, - // since it holds a reference to `downSocket`. replyTask = pipeline.ignoreResult() - .eagerlyEvaluate([](kj::Exception&& e) { - KJ_LOG(INFO, "HTTP-over-RPC startWebSocketRequest() failed", e); + .then([upPumpTask = kj::mv(upPumpPaf.promise)]() mutable { + return kj::mv(upPumpTask); + }).eagerlyEvaluate([](kj::Exception&& e) { + // TODO(now): Maybe we should cancel the whole request and propogate this back to the caller? + KJ_LOG(INFO, "HTTP-over-RPC WebSocket pump failed on server side", e); }); return result; From 0952bbc90942b50451eb2d71e71bcc400060c46c Mon Sep 17 00:00:00 2001 From: Kenton Varda Date: Wed, 24 Jan 2024 17:44:35 -0600 Subject: [PATCH 43/58] Refactor: Elimitate RequestState::holdWebSocket(). --- c++/src/capnp/compat/http-over-capnp.c++ | 39 ++++++++++++++---------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/c++/src/capnp/compat/http-over-capnp.c++ b/c++/src/capnp/compat/http-over-capnp.c++ index e78d98e69e..d7902008d8 100644 --- a/c++/src/capnp/compat/http-over-capnp.c++ +++ b/c++/src/capnp/compat/http-over-capnp.c++ @@ -72,13 +72,6 @@ public: } } - void holdWebSocket(kj::Own webSocket) { - // Hold on to this WebSocket until cancellation. - KJ_REQUIRE(this->webSocket == nullptr); - KJ_REQUIRE(tasks != nullptr); - this->webSocket = kj::mv(webSocket); - } - private: kj::Maybe error; kj::Maybe> webSocket; @@ -96,6 +89,12 @@ public: : state(kj::mv(state)), webSocket(webSocket), shorteningPromise(kj::mv(shorteningPromise)), onEnd(kj::mv(onEnd)) {} + CapnpToKjWebSocketAdapter(kj::Own state, kj::Own webSocket, + kj::Promise shorteningPromise, + kj::Own>> onEnd) + : state(kj::mv(state)), webSocket(*webSocket), ownWebSocket(kj::mv(webSocket)), + shorteningPromise(kj::mv(shorteningPromise)), + onEnd(kj::mv(onEnd)) {} // `onEnd` is resolved if and when the stream (in this direction) ends cleanly. ~CapnpToKjWebSocketAdapter() noexcept(false) { @@ -140,6 +139,7 @@ public: private: kj::Own state; kj::WebSocket& webSocket; + kj::Own ownWebSocket; kj::Promise shorteningPromise; kj::Own>> onEnd; @@ -292,14 +292,12 @@ public: auto shorteningPaf = kj::newPromiseAndFulfiller>(); - auto ownWebSocket = kjResponse.acceptWebSocket(factory.headersToKj(params.getHeaders())); - auto& webSocket = *ownWebSocket; - state->holdWebSocket(kj::mv(ownWebSocket)); + auto webSocket = kjResponse.acceptWebSocket(factory.headersToKj(params.getHeaders())); auto upWrapper = kj::heap( nullptr, params.getUpSocket(), kj::mv(shorteningPaf.fulfiller)); - auto upPumpTask = webSocket.pumpTo(*upWrapper).attach(kj::mv(upWrapper)) - .catch_([&webSocket=webSocket](kj::Exception&& e) -> kj::Promise { + auto upPumpTask = webSocket->pumpTo(*upWrapper).attach(kj::mv(upWrapper)) + .catch_([&webSocket=*webSocket](kj::Exception&& e) -> kj::Promise { // The pump in the client -> server direction failed. The error may have originated from // either the client or the server. In case it came from the server, we want to call .abort() // to propagate the problem back to the client. If the error came from the client, then @@ -311,7 +309,7 @@ public: auto results = context.getResults(MessageSize { 16, 1 }); auto downPaf = kj::newPromiseAndFulfiller>(); auto downSocket = kj::heap( - kj::addRef(*state), webSocket, kj::mv(shorteningPaf.promise), kj::mv(downPaf.fulfiller)); + kj::addRef(*state), *webSocket, kj::mv(shorteningPaf.promise), kj::mv(downPaf.fulfiller)); results.setDownSocket(kj::mv(downSocket)); // Note: This intentionally uses joinPromises and not joinPromisesFailFast, because @@ -320,6 +318,16 @@ public: // provides no benefit and possibly creates confusion. responsePumpTask = kj::joinPromises(kj::arr(kj::mv(upPumpTask), kj::mv(downPaf.promise))); + // We need to hold onto this WebSocket until `CapnpToKjWebSocketAdapter` is canceled or + // destroyed. If `responsePumpTask`completes successfully, then `CapnpToKjWebSocketAdapter` + // has to have been destroyed, since `downPaf.promise` doesn't resolve until that point. But + // in the case of request cancellation, it is our own destructor that will call `cancel()` + // on the `CapnpToKjWebSocketAdapter`, so we should make sure the `webSocket` outlives that. + // + // (Additionally, the WebSocket must outlive `responsePumpTask` itself, even when it is + // canceled.) + ownWebSocket = kj::mv(webSocket); + return kj::READY_NOW; } @@ -334,6 +342,7 @@ public: private: HttpOverCapnpFactory& factory; kj::Own state; + kj::Maybe> ownWebSocket; kj::Maybe> responsePumpTask; kj::HttpService::Response& kjResponse; @@ -661,11 +670,9 @@ public: // a RequestState here so that we can reuse the implementation of CapnpToKjWebSocketAdapter // that needs this for the client side. auto dummyState = kj::refcounted(); - auto& pipeEnd0Ref = *pipe.ends[0]; - dummyState->holdWebSocket(kj::mv(pipe.ends[0])); auto upPumpPaf = kj::newPromiseAndFulfiller>(); req.setUpSocket(kj::heap( - kj::mv(dummyState), pipeEnd0Ref, kj::mv(shorteningPaf.promise), + kj::mv(dummyState), kj::mv(pipe.ends[0]), kj::mv(shorteningPaf.promise), kj::mv(upPumpPaf.fulfiller))); auto pipeline = req.send(); From de4100e82b15d762b0e33863868ddf072e73c71e Mon Sep 17 00:00:00 2001 From: Kenton Varda Date: Thu, 25 Jan 2024 09:46:48 -0600 Subject: [PATCH 44/58] Refactor: CapnpToKjWebSocketAdapter doesn't depend on RequestState. Instead, since this is the only remaining user of any of the cancellation functionality, we simply merge it into this class itself. I think the code ends up somewhat simpler, though I think it could get even moreso if it turns out `disconnect()` isn't really needed. I think this may be the case: WebSocket connections already have an explicit close message and are not intended to run with the socket itself in a half-open state. So maybe we could just remove `disconnect()` entirely and close the underlying socket in the destructor instead. But that's a change that extends outside http-over-capnp, so not doing it here. --- c++/src/capnp/compat/http-over-capnp.c++ | 158 +++++++++++++++++------ 1 file changed, 119 insertions(+), 39 deletions(-) diff --git a/c++/src/capnp/compat/http-over-capnp.c++ b/c++/src/capnp/compat/http-over-capnp.c++ index d7902008d8..e70fd5e15a 100644 --- a/c++/src/capnp/compat/http-over-capnp.c++ +++ b/c++/src/capnp/compat/http-over-capnp.c++ @@ -83,33 +83,87 @@ private: class HttpOverCapnpFactory::CapnpToKjWebSocketAdapter final: public capnp::WebSocket::Server { public: - CapnpToKjWebSocketAdapter(kj::Own state, kj::WebSocket& webSocket, + CapnpToKjWebSocketAdapter(kj::WebSocket& webSocket, kj::Promise shorteningPromise, - kj::Own>> onEnd) - : state(kj::mv(state)), webSocket(webSocket), + kj::Own>> onEnd, + kj::Maybe&> selfRef) + : webSocket(webSocket), shorteningPromise(kj::mv(shorteningPromise)), - onEnd(kj::mv(onEnd)) {} - CapnpToKjWebSocketAdapter(kj::Own state, kj::Own webSocket, + onEnd(kj::mv(onEnd)), selfRef(selfRef) { + KJ_IF_SOME(s, selfRef) { + s = *this; + } + } + CapnpToKjWebSocketAdapter(kj::Own webSocket, kj::Promise shorteningPromise, kj::Own>> onEnd) - : state(kj::mv(state)), webSocket(*webSocket), ownWebSocket(kj::mv(webSocket)), + : webSocket(*webSocket), ownWebSocket(kj::mv(webSocket)), shorteningPromise(kj::mv(shorteningPromise)), onEnd(kj::mv(onEnd)) {} // `onEnd` is resolved if and when the stream (in this direction) ends cleanly. + // + // `selfRef`, if given, will be initialized to point back to this object, and will be nulled + // out in the destructor. This is intended to allow the caller to arrange to call cancel() if + // the capability still exists when the underlying `webSocket` is about to go away. + // + // The second version of the constructor takes ownership of the underlying `webSocket`. In + // this case, a `selfRef` isn't needed since there's no need to call `cancel()`. ~CapnpToKjWebSocketAdapter() noexcept(false) { - if (clean) { - onEnd->fulfill(state->wrap([&]() { - return webSocket.disconnect(); + // The peer dropped the capability, which means the WebSocket stream has ended. We want to + // tanslate this to a `disconnect()` call on the `kj::WebSocket`, if it is still around. + + // Null out our self-ref, if any. + KJ_IF_SOME(s, selfRef) { + s = nullptr; + } + + // Arrange to call disconnect() and then notify the observer that the stream has finished. + // If an error has occurred, this will also propagate that. + // + // Note that we can't just use `wrap()` here because the canceler is about to be destroyed, + // which would immediately cancel the operation. Luckily, we don't actually need to wrap this + // promise in the canceler because we can assume that the observer listening on this fulfiller + // will cancel the promise themselves at the same time as calling cancel() on this object. + KJ_IF_SOME(e, error) { + onEnd->reject(kj::mv(*e)); + } else KJ_IF_SOME(ws, webSocket) { + onEnd->fulfill(kj::evalNow([&]() { + return ws.disconnect().attach(kj::mv(ownWebSocket)); })); } else { - // TODO(now): Capture actual exception? - onEnd->reject(KJ_EXCEPTION(FAILED, "WebSocket-over-capnp downstream failed")); + // cancel() was called -- we assume no one is waiting on the fulfiller + } + } + + void cancel() { + // Called when the overall HTTP request completes or is canceled while this capability still + // exists. Since we can't force the peer to drop the capability, we have to disable it. + // Further access to `webSocket` must be blocked since it is no longer valid. + // + // Arguably we could instead use capnp::RevocableServer to accomplish something similar. The + // problem is, we also do actually want to know when the peer drops this capability. With + // RevocableServer, we no longer get notification of that -- the destructor runs when we tell + // it to, rather than when the peer drops the cap. + // + // TODO(cleanup): Could RevocableServer be improved to allow us to notice the drop? + // Alternatively, maybe it's not really that important for us to call disconnect() + // proactively, considering: + // - The application can send a close message for explicit end. + // - A client->server disconnect will presumably cancel the whole request anyway. + // - A server->client disconnect will presumably be followed by the server returning from + // the request() RPC. + + selfRef = kj::none; + webSocket = kj::none; + { auto drop = kj::mv(ownWebSocket); } + if (!canceler.isEmpty()) { + canceler.cancel(KJ_EXCEPTION(DISCONNECTED, "request canceled")); } } kj::Maybe> shortenPath() override { - auto onAbort = webSocket.whenAborted() + auto onAbort = canceler.wrap(KJ_ASSERT_NONNULL(webSocket).whenAborted()) .then([]() -> kj::Promise { return KJ_EXCEPTION(DISCONNECTED, "WebSocket was aborted"); }); @@ -117,36 +171,60 @@ public: } kj::Promise sendText(SendTextContext context) override { - KJ_ASSERT(clean); // should be guaranteed by streaming semantics - clean = false; - co_await state->wrap([&]() { return webSocket.send(context.getParams().getText()); }); - clean = true; + return wrap([&](kj::WebSocket& ws) { return ws.send(context.getParams().getText()); }); } kj::Promise sendData(SendDataContext context) override { - KJ_ASSERT(clean); // should be guaranteed by streaming semantics - clean = false; - co_await state->wrap([&]() { return webSocket.send(context.getParams().getData()); }); - clean = true; + return wrap([&](kj::WebSocket& ws) { return ws.send(context.getParams().getData()); }); } kj::Promise close(CloseContext context) override { - KJ_ASSERT(clean); // should be guaranteed by streaming semantics auto params = context.getParams(); - clean = false; - co_await state->wrap([&]() { return webSocket.close(params.getCode(), params.getReason()); }); - clean = true; + return wrap([&](kj::WebSocket& ws) { return ws.close(params.getCode(), params.getReason()); }); } private: - kj::Own state; - kj::WebSocket& webSocket; + kj::Maybe webSocket; // becomes none when canceled kj::Own ownWebSocket; kj::Promise shorteningPromise; kj::Own>> onEnd; + kj::Maybe&> selfRef; + + kj::Canceler canceler; - bool clean = true; - // It's illegal to call another `send()` or `disconnect()` until the previous `send()` has - // completed successfully. We want to send `disconnect()` in the destructor but only if we can - // do so cleanly. + kj::Maybe> error; + + kj::WebSocket& getWebSocket() { + return KJ_REQUIRE_NONNULL(webSocket, "request canceled"); + } + + template + kj::Promise wrap(Func&& func) { + KJ_IF_SOME(e, error) { + kj::throwFatalException(kj::cp(*e)); + } + + // Detect cancellation (of the operation) and mark the object broken in this case. + bool done = false; + KJ_DEFER({ + if (!done && error == kj::none) { + error = kj::heap(KJ_EXCEPTION(FAILED, + "a write was canceled before completing, breaking the WebSocket")); + } + }); + + try { + KJ_IF_SOME(ws, webSocket) { + co_await canceler.wrap(func(ws)); + } else { + kj::throwFatalException(KJ_EXCEPTION(DISCONNECTED, "request canceled")); + } + } catch (...) { + auto e = kj::getCaughtExceptionAsKj(); + error = kj::heap(kj::cp(e)); + kj::throwFatalException(kj::mv(e)); + } + + done = true; + } }; class HttpOverCapnpFactory::KjToCapnpWebSocketAdapter final: public kj::WebSocket { @@ -256,6 +334,12 @@ public: kj::HttpService::Response& kjResponse) : factory(factory), state(kj::mv(state)), kjResponse(kjResponse) {} + ~ClientRequestContextImpl() noexcept(false) { + KJ_IF_SOME(ws, maybeWebSocket) { + ws.cancel(); + } + } + kj::Promise startResponse(StartResponseContext context) override { KJ_REQUIRE(responsePumpTask == kj::none, "already called startResponse() or startWebSocket()"); @@ -309,7 +393,7 @@ public: auto results = context.getResults(MessageSize { 16, 1 }); auto downPaf = kj::newPromiseAndFulfiller>(); auto downSocket = kj::heap( - kj::addRef(*state), *webSocket, kj::mv(shorteningPaf.promise), kj::mv(downPaf.fulfiller)); + *webSocket, kj::mv(shorteningPaf.promise), kj::mv(downPaf.fulfiller), maybeWebSocket); results.setDownSocket(kj::mv(downSocket)); // Note: This intentionally uses joinPromises and not joinPromisesFailFast, because @@ -344,6 +428,7 @@ private: kj::Own state; kj::Maybe> ownWebSocket; kj::Maybe> responsePumpTask; + kj::Maybe maybeWebSocket; kj::HttpService::Response& kjResponse; }; @@ -664,16 +749,11 @@ public: auto pipe = kj::newWebSocketPipe(); auto shorteningPaf = kj::newPromiseAndFulfiller>(); - // We don't need the RequestState mechanism on the server side because - // CapnpToKjWebSocketAdapter wraps a pipe end, and that pipe end can continue to exist beyond - // the lifetime of the request, because the other end will have been dropped. We only create - // a RequestState here so that we can reuse the implementation of CapnpToKjWebSocketAdapter - // that needs this for the client side. - auto dummyState = kj::refcounted(); + // Note that since CapnpToKjWebSocketAdapter takes ownership of the pipe end, we don't need + // to cancel it later. Dropping the other end of the pipe will have the same effect. auto upPumpPaf = kj::newPromiseAndFulfiller>(); req.setUpSocket(kj::heap( - kj::mv(dummyState), kj::mv(pipe.ends[0]), kj::mv(shorteningPaf.promise), - kj::mv(upPumpPaf.fulfiller))); + kj::mv(pipe.ends[0]), kj::mv(shorteningPaf.promise), kj::mv(upPumpPaf.fulfiller))); auto pipeline = req.send(); auto result = kj::heap( From bb50d1eeaadb7f57cac781db597b361cc6e07e46 Mon Sep 17 00:00:00 2001 From: Kenton Varda Date: Wed, 24 Jan 2024 17:47:22 -0600 Subject: [PATCH 45/58] Delete code: HttpOverCapnpFactory::RequestState is no longer needed. --- c++/src/capnp/compat/http-over-capnp.c++ | 66 ++---------------------- c++/src/capnp/compat/http-over-capnp.h | 2 - 2 files changed, 3 insertions(+), 65 deletions(-) diff --git a/c++/src/capnp/compat/http-over-capnp.c++ b/c++/src/capnp/compat/http-over-capnp.c++ index e70fd5e15a..f9c3c0cabc 100644 --- a/c++/src/capnp/compat/http-over-capnp.c++ +++ b/c++/src/capnp/compat/http-over-capnp.c++ @@ -29,56 +29,6 @@ namespace capnp { using kj::uint; using kj::byte; -class HttpOverCapnpFactory::RequestState final - : public kj::Refcounted, public kj::TaskSet::ErrorHandler { -public: - RequestState() { - tasks.emplace(*this); - } - - template - auto wrap(Func&& func) -> decltype(func()) { - if (tasks == kj::none) { - return KJ_EXCEPTION(DISCONNECTED, "client canceled HTTP request"); - } else { - return canceler.wrap(func()); - } - } - - void cancel() { - if (tasks != kj::none) { - if (!canceler.isEmpty()) { - canceler.cancel(KJ_EXCEPTION(DISCONNECTED, "request canceled")); - } - tasks = kj::none; - webSocket = kj::none; - } - } - - kj::Promise finishTasks() { - // This is merged into the final promise, so we don't need to worry about wrapping it for - // cancellation. - return KJ_REQUIRE_NONNULL(tasks).onEmpty() - .then([this]() { - KJ_IF_SOME(e, error) { - kj::throwRecoverableException(kj::mv(e)); - } - }); - } - - void taskFailed(kj::Exception&& exception) override { - if (error == kj::none) { - error = kj::mv(exception); - } - } - -private: - kj::Maybe error; - kj::Maybe> webSocket; - kj::Canceler canceler; - kj::Maybe tasks; -}; - // ======================================================================================= class HttpOverCapnpFactory::CapnpToKjWebSocketAdapter final: public capnp::WebSocket::Server { @@ -330,9 +280,8 @@ class HttpOverCapnpFactory::ClientRequestContextImpl final : public capnp::HttpService::ClientRequestContext::Server { public: ClientRequestContextImpl(HttpOverCapnpFactory& factory, - kj::Own state, kj::HttpService::Response& kjResponse) - : factory(factory), state(kj::mv(state)), kjResponse(kjResponse) {} + : factory(factory), kjResponse(kjResponse) {} ~ClientRequestContextImpl() noexcept(false) { KJ_IF_SOME(ws, maybeWebSocket) { @@ -425,7 +374,6 @@ public: private: HttpOverCapnpFactory& factory; - kj::Own state; kj::Maybe> ownWebSocket; kj::Maybe> responsePumpTask; kj::Maybe maybeWebSocket; @@ -517,10 +465,7 @@ public: maybeRequestBody = requestBody; } - auto state = kj::refcounted(); - KJ_DEFER(state->cancel()); - - ClientRequestContextImpl context(factory, kj::addRef(*state), kjResponse); + ClientRequestContextImpl context(factory, kjResponse); RevocableServer revocableContext(context); rpcRequest.setContext(revocableContext.getClient()); @@ -552,8 +497,7 @@ public: } // Wait for the server to indicate completion. Meanwhile, if the - // promise is canceled from the client side, we propagate cancellation naturally, and we - // also call state->cancel(). + // promise is canceled from the client side, we propagate cancellation naturally. co_await pipeline.ignoreResult(); // Once the server indicates it is done, then we can cancel pumping the request, because @@ -569,10 +513,6 @@ public: // Finish pumping the response or WebSocket. (Probably it's already finished.) co_await context.finishPump(); - - // finishTasks() will wait for the respones to complete. - // TODO(now): Eliminate this. - co_await state->finishTasks(); } kj::Promise connect( diff --git a/c++/src/capnp/compat/http-over-capnp.h b/c++/src/capnp/compat/http-over-capnp.h index 1170723aa2..6a428c1c65 100644 --- a/c++/src/capnp/compat/http-over-capnp.h +++ b/c++/src/capnp/compat/http-over-capnp.h @@ -88,8 +88,6 @@ class HttpOverCapnpFactory { kj::Array valueCapnpToKj; kj::HashMap valueKjToCapnp; - class RequestState; - class CapnpToKjWebSocketAdapter; class KjToCapnpWebSocketAdapter; From af54b4dfc676850d9cd7682d89b4756008725d51 Mon Sep 17 00:00:00 2001 From: Kenton Varda Date: Thu, 25 Jan 2024 10:21:21 -0600 Subject: [PATCH 46/58] Refactor: Coroutinize CapnpToKjHttpServiceAdapter. --- c++/src/capnp/compat/http-over-capnp.c++ | 26 ++++++++++++------------ 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/c++/src/capnp/compat/http-over-capnp.c++ b/c++/src/capnp/compat/http-over-capnp.c++ index f9c3c0cabc..1aca42be54 100644 --- a/c++/src/capnp/compat/http-over-capnp.c++ +++ b/c++/src/capnp/compat/http-over-capnp.c++ @@ -825,9 +825,8 @@ public: requestBody = kj::heap(); } - auto impl = kj::heap(factory, metadata, params.getContext()); - auto promise = inner->request(impl->method, impl->url, impl->headers, *requestBody, *impl); - return promise.attach(kj::mv(requestBody), kj::mv(impl)); + HttpServiceResponseImpl impl(factory, metadata, params.getContext()); + co_await inner->request(impl.method, impl.url, impl.headers, *requestBody, impl); } kj::Promise connect(ConnectContext context) override { @@ -884,19 +883,20 @@ public: return kj::NEVER_DONE; }); - PipelineBuilder pb; - auto eofWrapper = kj::heap(kj::mv(ref2)); - auto up = factory.streamFactory.kjToCapnp(kj::mv(eofWrapper), kj::mv(tlsStarter)); - pb.setUp(kj::cp(up)); + { + PipelineBuilder pb; + auto eofWrapper = kj::heap(kj::mv(ref2)); + auto up = factory.streamFactory.kjToCapnp(kj::mv(eofWrapper), kj::mv(tlsStarter)); + pb.setUp(kj::cp(up)); - context.setPipeline(pb.build()); - context.initResults(capnp::MessageSize { 4, 1 }).setUp(kj::mv(up)); + context.setPipeline(pb.build()); + context.initResults(capnp::MessageSize { 4, 1 }).setUp(kj::mv(up)); + } - auto response = kj::heap( - factory, context.getParams().getContext()); + { auto drop = kj::mv(refcounted); } - return inner->connect(host, headers, *pipe.ends[0], *response, settings).attach( - kj::mv(host), kj::mv(headers), kj::mv(response), kj::mv(pipe)) + HttpOverCapnpConnectResponseImpl response(factory, context.getParams().getContext()); + co_await inner->connect(host, headers, *pipe.ends[0], response, settings) .exclusiveJoin(kj::mv(pumpTask)); } From 7dd54e88e2aef5ee7dd576891ad745457d6512d5 Mon Sep 17 00:00:00 2001 From: Kenton Varda Date: Thu, 25 Jan 2024 10:31:14 -0600 Subject: [PATCH 47/58] Refactor: Join server-side reply task with overall request. If attempting to send the reply back to the client fails, we should cancel out the whole request and propagate the exception, rather than uselessly info-log it. --- c++/src/capnp/compat/http-over-capnp.c++ | 51 ++++++++++++++---------- 1 file changed, 30 insertions(+), 21 deletions(-) diff --git a/c++/src/capnp/compat/http-over-capnp.c++ b/c++/src/capnp/compat/http-over-capnp.c++ index 1aca42be54..1d0494d120 100644 --- a/c++/src/capnp/compat/http-over-capnp.c++ +++ b/c++/src/capnp/compat/http-over-capnp.c++ @@ -629,17 +629,21 @@ class HttpOverCapnpFactory::HttpServiceResponseImpl final public: HttpServiceResponseImpl(HttpOverCapnpFactory& factory, capnp::HttpRequest::Reader request, - capnp::HttpService::ClientRequestContext::Client clientContext) + capnp::HttpService::ClientRequestContext::Client clientContext, + kj::Own>> replyFulfiller) : factory(factory), method(validateMethod(request.getMethod())), url(request.getUrl()), headers(factory.headersToKj(request.getHeaders())), - clientContext(kj::mv(clientContext)) {} + clientContext(kj::mv(clientContext)), + replyFulfiller(kj::mv(replyFulfiller)) {} + // `replyFulfiller` is eventulaly fulfilled with the task that that sends the reply back to + // the client. kj::Own send( uint statusCode, kj::StringPtr statusText, const kj::HttpHeaders& headers, kj::Maybe expectedBodySize = kj::none) override { - KJ_REQUIRE(replyTask == nullptr, "already called send() or acceptWebSocket()"); + KJ_REQUIRE(replyFulfiller->isWaiting(), "already called send() or acceptWebSocket()"); auto req = clientContext.startResponseRequest(); @@ -659,26 +663,19 @@ public: hasBody = s > 0; } - auto logError = [hasBody](kj::Exception&& e) { - KJ_LOG(INFO, "HTTP-over-RPC startResponse() failed", hasBody, e); - }; if (hasBody) { auto pipeline = req.send(); auto result = factory.streamFactory.capnpToKj(pipeline.getBody()); - replyTask = pipeline.ignoreResult().eagerlyEvaluate(kj::mv(logError)); + replyFulfiller->fulfill(pipeline.ignoreResult()); return result; } else { - replyTask = req.send().ignoreResult().eagerlyEvaluate(kj::mv(logError)); + replyFulfiller->fulfill(req.send().ignoreResult()); return kj::heap(); } - - // We don't actually wait for replyTask anywhere, because we may be all done with this HTTP - // message before the client gets a chance to respond, and we don't want to force an extra - // network round trip. If the client fails this call that's the client's problem, really. } kj::Own acceptWebSocket(const kj::HttpHeaders& headers) override { - KJ_REQUIRE(replyTask == nullptr, "already called send() or acceptWebSocket()"); + KJ_REQUIRE(replyFulfiller->isWaiting(), "already called send() or acceptWebSocket()"); auto req = clientContext.startWebSocketRequest(); @@ -699,13 +696,12 @@ public: auto result = kj::heap( kj::mv(pipe.ends[1]), pipeline.getDownSocket(), kj::mv(shorteningPaf.fulfiller)); - replyTask = pipeline.ignoreResult() + replyFulfiller->fulfill(pipeline.ignoreResult() .then([upPumpTask = kj::mv(upPumpPaf.promise)]() mutable { + // We need to continue pumping the WebSocket in the client->server direction, so let's do + // that as part of the reply task. return kj::mv(upPumpTask); - }).eagerlyEvaluate([](kj::Exception&& e) { - // TODO(now): Maybe we should cancel the whole request and propogate this back to the caller? - KJ_LOG(INFO, "HTTP-over-RPC WebSocket pump failed on server side", e); - }); + })); return result; } @@ -715,7 +711,7 @@ public: kj::StringPtr url; kj::HttpHeaders headers; capnp::HttpService::ClientRequestContext::Client clientContext; - kj::Maybe> replyTask; + kj::Own>> replyFulfiller; static kj::HttpMethod validateMethod(capnp::HttpMethod method) { KJ_REQUIRE(method <= capnp::HttpMethod::UNSUBSCRIBE, "unknown method", method); @@ -825,8 +821,21 @@ public: requestBody = kj::heap(); } - HttpServiceResponseImpl impl(factory, metadata, params.getContext()); - co_await inner->request(impl.method, impl.url, impl.headers, *requestBody, impl); + auto replyPaf = kj::newPromiseAndFulfiller>(); + auto replyPromise = replyPaf.promise.then([]() -> kj::Promise { + // The reply may complete before the request promise does. We don't want to cancel the + // request if the reply completed successfully, so return NEVER_DONE here so that the + // exclusiveJoin() below becomes a no-op. + // + // On the other hand, if the reply throws an exception, we want to cancel the request and + // propagate that exception immediately! + return kj::NEVER_DONE; + }); + + HttpServiceResponseImpl impl( + factory, metadata, params.getContext(), kj::mv(replyPaf.fulfiller)); + co_await inner->request(impl.method, impl.url, impl.headers, *requestBody, impl) + .exclusiveJoin(kj::mv(replyPromise)); } kj::Promise connect(ConnectContext context) override { From ba1657471693663c5acbe94f346a94c973a43b2e Mon Sep 17 00:00:00 2001 From: Kenton Varda Date: Thu, 25 Jan 2024 10:47:00 -0600 Subject: [PATCH 48/58] Cleanup: Create a public `NullStream`, remove various private versions. --- c++/src/capnp/compat/http-over-capnp.c++ | 44 +---------------- c++/src/kj/async-io.c++ | 22 +++++++++ c++/src/kj/async-io.h | 17 +++++++ c++/src/kj/compat/http.c++ | 63 ++++++------------------ 4 files changed, 55 insertions(+), 91 deletions(-) diff --git a/c++/src/capnp/compat/http-over-capnp.c++ b/c++/src/capnp/compat/http-over-capnp.c++ index 1d0494d120..159c1a1983 100644 --- a/c++/src/capnp/compat/http-over-capnp.c++ +++ b/c++/src/capnp/compat/http-over-capnp.c++ @@ -584,46 +584,6 @@ kj::Own HttpOverCapnpFactory::capnpToKj(capnp::HttpService::Cli // ======================================================================================= -namespace { - -class NullInputStream final: public kj::AsyncInputStream { - // TODO(cleanup): This class has been replicated in a bunch of places now, make it public - // somewhere. - -public: - kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { - return kj::constPromise(); - } - - kj::Maybe tryGetLength() override { - return uint64_t(0); - } - - kj::Promise pumpTo(kj::AsyncOutputStream& output, uint64_t amount) override { - return kj::constPromise(); - } -}; - -class NullOutputStream final: public kj::AsyncOutputStream { - // TODO(cleanup): This class has been replicated in a bunch of places now, make it public - // somewhere. - -public: - kj::Promise write(const void* buffer, size_t size) override { - return kj::READY_NOW; - } - kj::Promise write(kj::ArrayPtr> pieces) override { - return kj::READY_NOW; - } - kj::Promise whenWriteDisconnected() override { - return kj::NEVER_DONE; - } - - // We can't really optimize tryPumpFrom() unless AsyncInputStream grows a skip() method. -}; - -} // namespace - class HttpOverCapnpFactory::HttpServiceResponseImpl final : public kj::HttpService::Response { public: @@ -670,7 +630,7 @@ public: return result; } else { replyFulfiller->fulfill(req.send().ignoreResult()); - return kj::heap(); + return kj::heap(); } } @@ -818,7 +778,7 @@ public: results.setRequestBody(kj::mv(requestBodyCap)); requestBody = kj::mv(pipe.in); } else { - requestBody = kj::heap(); + requestBody = kj::heap(); } auto replyPaf = kj::newPromiseAndFulfiller>(); diff --git a/c++/src/kj/async-io.c++ b/c++/src/kj/async-io.c++ index 8761f76936..b547c3b2a6 100644 --- a/c++/src/kj/async-io.c++ +++ b/c++/src/kj/async-io.c++ @@ -80,6 +80,28 @@ Maybe> AsyncInputStream::tryTee(uint64_t) { return kj::none; } +kj::Promise NullStream::tryRead(void* buffer, size_t minBytes, size_t maxBytes) { + return kj::constPromise(); +} +kj::Maybe NullStream::tryGetLength() { + return uint64_t(0); +} +kj::Promise NullStream::pumpTo(kj::AsyncOutputStream& output, uint64_t amount) { + return kj::constPromise(); +} + +kj::Promise NullStream::write(const void* buffer, size_t size) { + return kj::READY_NOW; +} +kj::Promise NullStream::write(kj::ArrayPtr> pieces) { + return kj::READY_NOW; +} +kj::Promise NullStream::whenWriteDisconnected() { + return kj::NEVER_DONE; +} + +void NullStream::shutdownWrite() {} + namespace { class AsyncPump { diff --git a/c++/src/kj/async-io.h b/c++/src/kj/async-io.h index ceb412da02..68fe673787 100644 --- a/c++/src/kj/async-io.h +++ b/c++/src/kj/async-io.h @@ -171,6 +171,23 @@ class AsyncIoStream: public AsyncInputStream, public AsyncOutputStream { // isn't wrapping a file descriptor. }; +class NullStream final: public AsyncIoStream { + // Convenience class that implements an I/O stream that ignores all writes and returns EOF for + // all reads. + // + // Hint: You can also use this class when you just need an input stream or an output stream. +public: + kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override; + kj::Maybe tryGetLength() override; + kj::Promise pumpTo(kj::AsyncOutputStream& output, uint64_t amount) override; + + kj::Promise write(const void* buffer, size_t size) override; + kj::Promise write(kj::ArrayPtr> pieces) override; + kj::Promise whenWriteDisconnected() override; + + void shutdownWrite() override; +}; + Promise unoptimizedPumpTo( AsyncInputStream& input, AsyncOutputStream& output, uint64_t amount, uint64_t completedSoFar = 0); diff --git a/c++/src/kj/compat/http.c++ b/c++/src/kj/compat/http.c++ index 06dec43f4a..39d253f1c6 100644 --- a/c++/src/kj/compat/http.c++ +++ b/c++/src/kj/compat/http.c++ @@ -5211,14 +5211,21 @@ kj::OneOf tryParseExtensionAgreement( "an invalid value.")); return kj::mv(e); } + } // namespace _ (private) + namespace { -class NullInputStream final: public kj::AsyncInputStream { + +class HeadResponseStream final: public kj::AsyncInputStream { + // An input stream which returns no data, but `tryGetLength()` returns a specified value. Used + // for HEAD responses, where the size is known but the body content is not sent. public: - NullInputStream(kj::Maybe expectedLength = size_t(0)) + HeadResponseStream(kj::Maybe expectedLength) : expectedLength(expectedLength) {} kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { + // TODO(someday): Maybe this should throw? We should not be trying to read the body of a + // HEAD response. return constPromise(); } @@ -5234,48 +5241,6 @@ private: kj::Maybe expectedLength; }; -class NullOutputStream final: public kj::AsyncOutputStream { -public: - Promise write(const void* buffer, size_t size) override { - return kj::READY_NOW; - } - Promise write(ArrayPtr> pieces) override { - return kj::READY_NOW; - } - Promise whenWriteDisconnected() override { - return kj::NEVER_DONE; - } - - // We can't really optimize tryPumpFrom() unless AsyncInputStream grows a skip() method. -}; - -class NullIoStream final: public kj::AsyncIoStream { -public: - void shutdownWrite() override {} - - Promise write(const void* buffer, size_t size) override { - return kj::READY_NOW; - } - Promise write(ArrayPtr> pieces) override { - return kj::READY_NOW; - } - Promise whenWriteDisconnected() override { - return kj::NEVER_DONE; - } - - kj::Promise tryRead(void* buffer, size_t minBytes, size_t maxBytes) override { - return constPromise(); - } - - kj::Maybe tryGetLength() override { - return kj::Maybe((uint64_t)0); - } - - kj::Promise pumpTo(AsyncOutputStream& output, uint64_t amount) override { - return constPromise(); - } -}; - class HttpClientImpl final: public HttpClient, private HttpClientErrorHandler { public: @@ -6683,7 +6648,7 @@ public: auto requestPaf = kj::newPromiseAndFulfiller>(); responder->setPromise(kj::mv(requestPaf.promise)); - auto in = kj::heap(); + auto in = kj::heap(); auto promise = service.request(HttpMethod::GET, urlCopy, *headersCopy, *in, *responder) .attach(kj::mv(in), kj::mv(urlCopy), kj::mv(headersCopy)); requestPaf.fulfiller->fulfill(kj::mv(promise)); @@ -6847,11 +6812,11 @@ private: headersCopy=kj::mv(headersCopy),expectedBodySize]() mutable { fulfiller->fulfill({ statusCode, statusTextCopy, headersCopy.get(), - kj::heap(expectedBodySize) + kj::heap(expectedBodySize) .attach(kj::mv(statusTextCopy), kj::mv(headersCopy)) }); }).eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); }); - return kj::heap(); + return kj::heap(); } else { auto pipe = newOneWayPipe(expectedBodySize); @@ -6998,11 +6963,11 @@ private: headersCopy=kj::mv(headersCopy),expectedBodySize]() mutable { fulfiller->fulfill({ statusCode, statusTextCopy, headersCopy.get(), - kj::Own(kj::heap(expectedBodySize) + kj::Own(kj::heap(expectedBodySize) .attach(kj::mv(statusTextCopy), kj::mv(headersCopy))) }); }).eagerlyEvaluate([](kj::Exception&& e) { KJ_LOG(ERROR, e); }); - return kj::heap(); + return kj::heap(); } else { auto pipe = newOneWayPipe(expectedBodySize); From b69240097ad19b256f7927af836adaf3bbedcb6c Mon Sep 17 00:00:00 2001 From: Mike Aizatsky Date: Wed, 7 Feb 2024 10:54:15 -0800 Subject: [PATCH 49/58] kj::Arc smart poiner for atomic reference counted ojbects. (#1925) --- c++/src/kj/refcount-test.c++ | 80 ++++++++++++++++++++++++- c++/src/kj/refcount.h | 112 +++++++++++++++++++++++++++++++++-- 2 files changed, 185 insertions(+), 7 deletions(-) diff --git a/c++/src/kj/refcount-test.c++ b/c++/src/kj/refcount-test.c++ index 65fdf48cd2..c2254989f8 100644 --- a/c++/src/kj/refcount-test.c++ +++ b/c++/src/kj/refcount-test.c++ @@ -28,6 +28,8 @@ struct SetTrueInDestructor: public Refcounted, EnableAddRefToThis newRef() { return addRefToThis(); } + bool* ptr; }; @@ -131,13 +133,13 @@ KJ_TEST("Rc inheritance") { EXPECT_TRUE(b); } -KJ_TEST("EnableAddRefToThis") { +KJ_TEST("Refcounted::EnableAddRefToThis") { bool b = false; auto ref1 = kj::rc(&b); EXPECT_FALSE(ref1->isShared()); - auto ref2 = ref1.addRef(); + auto ref2 = ref1->newRef(); EXPECT_TRUE(ref2->isShared()); EXPECT_TRUE(ref1->isShared()); EXPECT_FALSE(b); @@ -227,4 +229,78 @@ KJ_TEST("RefcountedWrapper") { } } + +struct AtomicSetTrueInDestructor: public AtomicRefcounted, + EnableAddRefToThis { + + AtomicSetTrueInDestructor(bool* ptr): ptr(ptr) {} + ~AtomicSetTrueInDestructor() { *ptr = true; } + + kj::Arc newRef() { return addRefToThis(); } + + bool* ptr; +}; + +KJ_TEST("Arc") { + bool b = false; + + kj::Arc ref1 = kj::arc(&b); + EXPECT_FALSE(ref1->isShared()); + EXPECT_TRUE(ref1 != nullptr); + EXPECT_FALSE(ref1 == nullptr); + + kj::Arc ref2 = ref1.addRef(); + + // can be always cast to Arc + kj::Arc ref3 = ref1.addRef(); + + // addRef works for const references too + kj::Arc ref4 = ref3.addRef(); + + ref1 = nullptr; + EXPECT_TRUE(ref1 == nullptr); + ref2 = nullptr; + EXPECT_TRUE(ref2 == nullptr); + ref3 = nullptr; + EXPECT_TRUE(ref3 == nullptr); + + EXPECT_FALSE(b); + ref4 = nullptr; + EXPECT_TRUE(b); +} + +KJ_TEST("AtomicRefcounted::EnableAddRefToThis") { + bool b = false; + + kj::Arc ref1 = kj::arc(&b); + EXPECT_FALSE(ref1->isShared()); + + kj::Arc ref2 = ref1->newRef(); + EXPECT_TRUE(ref2->isShared()); + EXPECT_TRUE(ref1->isShared()); + EXPECT_FALSE(b); + + ref1 = nullptr; + EXPECT_FALSE(ref2->isShared()); + EXPECT_FALSE(b); + + ref2 = nullptr; + EXPECT_TRUE(b); +} + +KJ_TEST("Arc Own interop") { + bool b = false; + + kj::Arc ref1 = kj::arc(&b); + + EXPECT_FALSE(b); + auto own = ref1.toOwn(); + EXPECT_TRUE(ref1 == nullptr); + EXPECT_TRUE(own.get() != nullptr); + + EXPECT_FALSE(b); + own = nullptr; + EXPECT_TRUE(b); +} + } // namespace kj diff --git a/c++/src/kj/refcount.h b/c++/src/kj/refcount.h index 2735e2a877..7b1502663c 100644 --- a/c++/src/kj/refcount.h +++ b/c++/src/kj/refcount.h @@ -144,6 +144,7 @@ Own Refcounted::addRefInternal(T* object) { template Rc Refcounted::addRcRefInternal(T* object) { + static_assert(kj::canConvert()); Refcounted* refcounted = object; ++refcounted->refcount; return Rc(object); @@ -153,7 +154,7 @@ template class Rc { // Smart pointer for reference counted objects. // - // There are only two ways to obtain new Rc instances: + // There are only three ways to obtain new Rc instances: // - use kj::rc(...) function to create new T. // - use kj::Rc::addRef() and the existing Rc instance. // - use EnableAddRefToThis to allow T instance to add new references to itself. @@ -234,16 +235,17 @@ template class EnableAddRefToThis { // Exposes addRefToThis member function for objects to add // references to themselves. + // Can be used both with Refcounted and AtomicRefcounted objects. protected: - kj::Rc addRefToThis() const { + auto addRefToThis() const { const Self* self = static_cast(this); - return Refcounted::addRcRefInternal(self); + return Self::addRcRefInternal(self); } - kj::Rc addRefToThis() { + auto addRefToThis() { Self* self = static_cast(this); - return Refcounted::addRcRefInternal(self); + return Self::addRcRefInternal(self); } }; @@ -313,6 +315,9 @@ Own>> refcountedWrapper(Own&& wrapped) { #endif #endif +template +class Arc; + class AtomicRefcounted: private kj::Disposer { public: AtomicRefcounted() = default; @@ -350,6 +355,19 @@ class AtomicRefcounted: private kj::Disposer { friend kj::Maybe> atomicAddRefWeak(const T& object); template friend kj::Own atomicRefcounted(Params&&... params); + + template + static kj::Arc addRcRefInternal(T* object); + template + static kj::Arc addRcRefInternal(const T* object); + + template + friend class Arc; + template + friend kj::Arc arc(Params&&... params); + + template + friend class EnableAddRefToThis; }; template @@ -357,6 +375,11 @@ inline kj::Own atomicRefcounted(Params&&... params) { return AtomicRefcounted::addRefInternal(new T(kj::fwd(params)...)); } +template +inline kj::Arc arc(Params&&... params) { + return AtomicRefcounted::addRcRefInternal(new T(kj::fwd(params)...)); +} + template kj::Own atomicAddRef(T& object) { KJ_IREQUIRE(object.AtomicRefcounted::refcount > 0, @@ -410,6 +433,85 @@ kj::Own AtomicRefcounted::addRefInternal(const T* object) { return kj::Own(object, *refcounted); } +template +kj::Arc AtomicRefcounted::addRcRefInternal(T* object) { + static_assert(kj::canConvert()); + return kj::Arc(addRefInternal(object)); +} + +template +kj::Arc AtomicRefcounted::addRcRefInternal(const T* object) { + static_assert(kj::canConvert()); + return kj::Arc(addRefInternal(object)); +} + +template +class Arc { + // Smart pointer for atomic reference counted objects. + // + // Usage is similar to kj::Rc. + +public: + KJ_DISALLOW_COPY(Arc); + Arc() { } + Arc(decltype(nullptr)) { } + inline Arc(Arc&& other) noexcept = default; + + template ()>> + inline Arc(Arc&& other) noexcept : own(kj::mv(other.own)) { } + + kj::Own toOwn() { + // Convert Arc to Own. + // Nullifies the original Arc. + return kj::mv(own); + } + + kj::Arc addRef() { + T* refcounted = own.get(); + if (refcounted != nullptr) { + return AtomicRefcounted::addRcRefInternal(refcounted); + } else { + return kj::Arc(); + } + } + + Arc& operator=(decltype(nullptr)) { + own = nullptr; + return *this; + } + + Arc& operator=(Arc&& other) = default; + + template + Arc downcast() { + return Arc(own.template downcast()); + } + + inline bool operator==(const Arc& other) const { return own.get() == other.own.get(); } + inline bool operator==(decltype(nullptr)) const { return own.get() == nullptr; } + inline bool operator!=(decltype(nullptr)) const { return own.get() != nullptr; } + + inline T* operator->() { return own.get(); } + inline const T* operator->() const { return own.get(); } + + // do not expose * to avoid dangling references + +private: + Arc(T* t) : own(t, *t) { } + Arc(Own&& t) : own(kj::mv(t)) { } + + Own own; + + friend class AtomicRefcounted; + + template + friend class Arc; + + template + friend class EnableAddRefToThis; +}; + + } // namespace kj KJ_END_HEADER From 0985926e3b7356adc5a0306cc6e9d3f7f48c2e32 Mon Sep 17 00:00:00 2001 From: Milan Miladinovic Date: Thu, 8 Feb 2024 11:11:46 -0500 Subject: [PATCH 50/58] Revert "Revert "Make getPreferredExtensions pure virtual func"" This reverts commit 20be408666207623fbfd5002628455c41d15e3b9. --- c++/src/capnp/compat/http-over-capnp.c++ | 6 ++ c++/src/kj/compat/http.c++ | 78 +++++++++++++++++++++++- c++/src/kj/compat/http.h | 2 +- 3 files changed, 83 insertions(+), 3 deletions(-) diff --git a/c++/src/capnp/compat/http-over-capnp.c++ b/c++/src/capnp/compat/http-over-capnp.c++ index 159c1a1983..139e6eedd6 100644 --- a/c++/src/capnp/compat/http-over-capnp.c++ +++ b/c++/src/capnp/compat/http-over-capnp.c++ @@ -267,6 +267,12 @@ public: uint64_t sentByteCount() override { return sentBytes; } uint64_t receivedByteCount() override { return KJ_ASSERT_NONNULL(in)->receivedByteCount(); } + kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { + // TODO(someday): Optimzed pump is tricky with HttpOverCapnp, we may want to revist + // this but for now we always return none (indicating no preference). + return kj::none; + }; + private: kj::Maybe> in; // One end of a WebSocketPipe, used only for receiving. kj::Maybe out; // Used only for sending. diff --git a/c++/src/kj/compat/http.c++ b/c++/src/kj/compat/http.c++ index 39d253f1c6..447e15e3ac 100644 --- a/c++/src/kj/compat/http.c++ +++ b/c++/src/kj/compat/http.c++ @@ -3901,6 +3901,19 @@ public: return transferredBytes; } + kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { + KJ_UNREACHABLE; + }; + + kj::Maybe destinationPumpingTo; + kj::Maybe destinationPumpingFrom; + // Tracks the outstanding pumpTo() and tryPumpFrom() calls currently running on the + // WebSocketPipeEnd, which is the destination side of this WebSocketPipeImpl. This is used by + // the source end to implement getPreferredExtensions(). + // + // getPreferredExtensions() does not fit into the model used by all the other methods because it + // is not directional (not a read nor a write call). + private: kj::Maybe state; // Object-oriented state! If any method call is blocked waiting on activity from the other end, @@ -4019,6 +4032,10 @@ private: KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); } + kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { + KJ_UNREACHABLE; + }; + private: kj::PromiseFulfiller& fulfiller; WebSocketPipeImpl& pipe; @@ -4104,6 +4121,10 @@ private: KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); } + kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { + KJ_UNREACHABLE; + }; + private: kj::PromiseFulfiller& fulfiller; WebSocketPipeImpl& pipe; @@ -4188,6 +4209,10 @@ private: KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); } + kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { + KJ_UNREACHABLE; + }; + private: kj::PromiseFulfiller& fulfiller; WebSocketPipeImpl& pipe; @@ -4284,6 +4309,10 @@ private: KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); } + kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { + KJ_UNREACHABLE; + }; + private: kj::PromiseFulfiller& fulfiller; WebSocketPipeImpl& pipe; @@ -4330,6 +4359,9 @@ private: KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); } + kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { + KJ_UNREACHABLE; + }; }; class Aborted final: public WebSocket { @@ -4371,6 +4403,9 @@ private: uint64_t receivedByteCount() override { KJ_FAIL_ASSERT("Bytes are not counted for the individual states of WebSocketPipeImpl."); } + kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { + KJ_UNREACHABLE; + }; }; }; @@ -4403,19 +4438,50 @@ public: return out->whenAborted(); } kj::Maybe> tryPumpFrom(WebSocket& other) override { - return out->tryPumpFrom(other); + KJ_REQUIRE(in->destinationPumpingFrom == kj::none, "can only call tryPumpFrom() once at a time"); + // By convention, we store the WebSocket reference on `in`. + in->destinationPumpingFrom = other; + auto deferredUnregister = kj::defer([this]() { in->destinationPumpingFrom = kj::none; }); + KJ_IF_SOME(p, out->tryPumpFrom(other)) { + return p.attach(kj::mv(deferredUnregister)); + } else { + return kj::none; + } } kj::Promise receive(size_t maxSize) override { return in->receive(maxSize); } kj::Promise pumpTo(WebSocket& other) override { - return in->pumpTo(other); + KJ_REQUIRE(in->destinationPumpingTo == kj::none, "can only call pumpTo() once at a time"); + // By convention, we store the WebSocket reference on `in`. + in->destinationPumpingTo = other; + auto deferredUnregister = kj::defer([this]() { in->destinationPumpingTo = kj::none; }); + return in->pumpTo(other).attach(kj::mv(deferredUnregister)); } uint64_t sentByteCount() override { return out->sentByteCount(); } uint64_t receivedByteCount() override { return in->sentByteCount(); } + kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { + // We want to forward this call to whatever WebSocket the other end of the pipe is pumping + // to/from, if any. We'll check them in an arbitrary order and take the first one we see. + // But really, the hope is that both destinationPumpingTo and destinationPumpingFrom are in fact + // the same object. If they aren't the same, then it's not really clear whose extensions we + // should prefer; the choice here is arbitrary. + KJ_IF_SOME(ws, out->destinationPumpingTo) { + KJ_IF_SOME(result, ws.getPreferredExtensions(ctx)) { + return kj::mv(result); + } + } + KJ_IF_SOME(ws, out->destinationPumpingFrom) { + KJ_IF_SOME(result, ws.getPreferredExtensions(ctx)) { + return kj::mv(result); + } + } + return kj::none; + }; + private: kj::Rc in; kj::Rc out; @@ -6896,6 +6962,10 @@ private: uint64_t sentByteCount() override { return inner->sentByteCount(); } uint64_t receivedByteCount() override { return inner->receivedByteCount(); } + kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { + return inner->getPreferredExtensions(ctx); + }; + private: kj::Own inner; kj::Maybe> completionTask; @@ -8004,6 +8074,10 @@ private: uint64_t sentByteCount() override { KJ_FAIL_ASSERT("received bad WebSocket handshake"); } uint64_t receivedByteCount() override { KJ_FAIL_ASSERT("received bad WebSocket handshake"); } + kj::Maybe getPreferredExtensions(ExtensionsContext ctx) override { + KJ_FAIL_ASSERT(kj::cp(exception)); + }; + private: kj::Exception exception; }; diff --git a/c++/src/kj/compat/http.h b/c++/src/kj/compat/http.h index 028cc3ed23..abbd969a28 100644 --- a/c++/src/kj/compat/http.h +++ b/c++/src/kj/compat/http.h @@ -689,7 +689,7 @@ class WebSocket { REQUEST, RESPONSE }; - virtual kj::Maybe getPreferredExtensions(ExtensionsContext ctx) { return kj::none; } + virtual kj::Maybe getPreferredExtensions(ExtensionsContext ctx) = 0; // If pumpTo() / tryPumpFrom() is able to be optimized only if the other WebSocket is using // certain extensions (e.g. compression settings), then this method returns what those extensions // are. For example, matching extensions between standard WebSockets allows pumping to be From 41e3c5d933c9d756cde4319f8d54e0e2fb1c109a Mon Sep 17 00:00:00 2001 From: Vaci Date: Wed, 28 Feb 2024 18:58:11 +0000 Subject: [PATCH 51/58] check for empty AggregateConnectionReceiver --- c++/src/kj/async-io.c++ | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/c++/src/kj/async-io.c++ b/c++/src/kj/async-io.c++ index b547c3b2a6..fd741bfda1 100644 --- a/c++/src/kj/async-io.c++ +++ b/c++/src/kj/async-io.c++ @@ -2893,9 +2893,10 @@ public: } uint getPort() override { - return receivers[0]->getPort(); + return receivers.size() > 0 ? receivers[0]->getPort() : 0u; } void getsockopt(int level, int option, void* value, uint* length) override { + KJ_REQUIRE(receivers.size() > 0); return receivers[0]->getsockopt(level, option, value, length); } void setsockopt(int level, int option, const void* value, uint length) override { @@ -2905,6 +2906,7 @@ public: } } void getsockname(struct sockaddr* addr, uint* length) override { + KJ_REQUIRE(receivers.size() > 0); return receivers[0]->getsockname(addr, length); } From e49c1eb82b1e39f216f8313f97cf35d2c34b9f89 Mon Sep 17 00:00:00 2001 From: Vaci Date: Fri, 1 Mar 2024 16:27:08 +0000 Subject: [PATCH 52/58] test that an empty AggregateConnectionReceiver fails gracefully --- c++/src/kj/async-io-test.c++ | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/c++/src/kj/async-io-test.c++ b/c++/src/kj/async-io-test.c++ index 3b567fe1bc..948d2de276 100644 --- a/c++/src/kj/async-io-test.c++ +++ b/c++/src/kj/async-io-test.c++ @@ -3080,6 +3080,22 @@ KJ_TEST("AggregateConnectionReceiver") { acceptPromise3.wait(ws); } +KJ_TEST("AggregateConnectionReceiver empty") { + auto aggregate = newAggregateConnectionReceiver({}); + KJ_EXPECT(aggregate->getPort() == 0); + + int value; + uint length = sizeof(value); + + KJ_IF_SOME(exception, kj::runCatchingExceptions([&]() { + aggregate->getsockopt(0, 0, &value, &length); + })) { + (void)exception; + } else { + KJ_FAIL_EXPECT("Expected an exception"); + } +} + // ======================================================================================= // Tests for optimized pumpTo() between OS handles. Note that this is only even optimized on // some OSes (only Linux as of this writing), but the behavior should still be the same on all From d06dda888785d37d1bcd4e08a0a44334c76ff686 Mon Sep 17 00:00:00 2001 From: Jianyong Chen Date: Mon, 4 Mar 2024 16:59:48 +0800 Subject: [PATCH 53/58] kj/table: fix the initialization of BTreeImpl::MaybeUnit with uint. This constructor is currently unused, so the current issue has no immediate impact. Signed-off-by: Jianyong Chen --- c++/src/kj/table.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/c++/src/kj/table.h b/c++/src/kj/table.h index 6f670a273b..2a4a83c76c 100644 --- a/c++/src/kj/table.h +++ b/c++/src/kj/table.h @@ -1164,7 +1164,7 @@ class BTreeImpl::MaybeUint { // A nullable uint, using the value zero to mean null and shifting all other values up by 1. public: MaybeUint() = default; - inline MaybeUint(uint i): i(i - 1) {} + inline MaybeUint(uint i): i(i + 1) {} inline MaybeUint(decltype(nullptr)): i(0) {} inline bool operator==(decltype(nullptr)) const { return i == 0; } From 06db80191ab6d09465654a23ea63ed2e3aed4020 Mon Sep 17 00:00:00 2001 From: Kenton Varda Date: Mon, 4 Mar 2024 09:52:51 -0600 Subject: [PATCH 54/58] Update c++/src/kj/async-io-test.c++ --- c++/src/kj/async-io-test.c++ | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/c++/src/kj/async-io-test.c++ b/c++/src/kj/async-io-test.c++ index 948d2de276..eade6c5391 100644 --- a/c++/src/kj/async-io-test.c++ +++ b/c++/src/kj/async-io-test.c++ @@ -3087,13 +3087,7 @@ KJ_TEST("AggregateConnectionReceiver empty") { int value; uint length = sizeof(value); - KJ_IF_SOME(exception, kj::runCatchingExceptions([&]() { - aggregate->getsockopt(0, 0, &value, &length); - })) { - (void)exception; - } else { - KJ_FAIL_EXPECT("Expected an exception"); - } + KJ_EXPECT_THROW_MESSAGE("receivers.size() > 0", aggregate->getsockopt(0, 0, &value, &length)); } // ======================================================================================= From 5f1d11b91a0f7e78cea2711619f2e08d3694e71a Mon Sep 17 00:00:00 2001 From: Joe Lee Date: Mon, 4 Mar 2024 10:29:49 -0800 Subject: [PATCH 55/58] capnp-rpc: retain RpcConnectionState reference in error handling lambda Fixes potential read-after-free that can happen if the connection's shutdown() method throws an exception. --- c++/src/capnp/rpc-test.c++ | 48 ++++++++++++++++++++++++++++++++++++++ c++/src/capnp/rpc.c++ | 5 ++-- 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/c++/src/capnp/rpc-test.c++ b/c++/src/capnp/rpc-test.c++ index b8a6dccc11..06963a13f1 100644 --- a/c++/src/capnp/rpc-test.c++ +++ b/c++/src/capnp/rpc-test.c++ @@ -342,6 +342,9 @@ public: } } kj::Promise shutdown() override { + KJ_IF_SOME(e, network.shutdownExceptionToThrow) { + return kj::cp(e); + } KJ_IF_SOME(p, partner) { auto paf = kj::newPromiseAndFulfiller(); p.fulfillOnEnd = kj::mv(paf.fulfiller); @@ -410,11 +413,16 @@ public: } } + void setShutdownExceptionToThrow(kj::Exception&& e) { + shutdownExceptionToThrow = kj::mv(e); + } + private: TestNetwork& network; kj::StringPtr self; uint sent = 0; uint received = 0; + kj::Maybe shutdownExceptionToThrow = kj::none; std::map> connections; std::queue>>> fulfillerQueue; @@ -1303,6 +1311,46 @@ TEST(Rpc, Abort) { EXPECT_TRUE(conn->receiveIncomingMessage().wait(context.waitScope) == nullptr); } +KJ_TEST("handles exceptions thrown during disconnect") { + // This is similar to the earlier "abort" test, but throws an exception on + // connection shutdown, to exercise the RpcConnectionState error handler. + + TestContext context; + + MallocMessageBuilder refMessage(128); + auto hostId = refMessage.initRoot(); + hostId.setHost("server"); + + context.serverNetwork.setShutdownExceptionToThrow( + KJ_EXCEPTION(FAILED, "a_disconnect_exception")); + + auto conn = KJ_ASSERT_NONNULL(context.clientNetwork.connect(hostId)); + + { + // Send an invalid message (Return to non-existent question). + auto msg = conn->newOutgoingMessage(128); + auto body = msg->getBody().initAs().initReturn(); + body.setAnswerId(1234); + body.setCanceled(); + msg->send(); + } + + { + // The internal exception handler of RpcSystemBase logs exceptions thrown + // during disconnect, which the test framework will flag as a failure if we + // don't explicitly tell it to expect the logged output. + KJ_EXPECT_LOG(ERROR, "a_disconnect_exception"); + + // Force outstanding promises to completion. The server should detect the + // invalid message and disconnect, which should cause the connection's + // disconnect() to throw an exception that will then be handled by a + // RpcConnectionState handler. Since other state instances were freed prior + // to the handler invocation, this caused failures in earlier versions of + // the code when run under asan. + kj::Promise(kj::NEVER_DONE).poll(context.waitScope); + } +} + KJ_TEST("loopback bootstrap()") { int callCount = 0; test::TestInterface::Client bootstrap = kj::heap(callCount); diff --git a/c++/src/capnp/rpc.c++ b/c++/src/capnp/rpc.c++ index 84710fc7e8..b15d23fb84 100644 --- a/c++/src/capnp/rpc.c++ +++ b/c++/src/capnp/rpc.c++ @@ -503,7 +503,8 @@ public: auto shutdownPromise = dyingConnection->shutdown() .attach(kj::mv(dyingConnection)) .then([]() -> kj::Promise { return kj::READY_NOW; }, - [this, origException = kj::mv(exception)](kj::Exception&& shutdownException) -> kj::Promise { + [self = kj::addRef(*this), origException = kj::mv(exception)]( + kj::Exception&& shutdownException) -> kj::Promise { // Don't report disconnects as an error. if (shutdownException.getType() == kj::Exception::Type::DISCONNECTED) { return kj::READY_NOW; @@ -516,7 +517,7 @@ public: } // We are shutting down after receive error, ignore shutdown exception since underlying // transport is probably broken. - if (receiveIncomingMessageError) { + if (self->receiveIncomingMessageError) { return kj::READY_NOW; } return kj::mv(shutdownException); From c0b4726bb720ed7cfb81cc29347c194b7a2128e4 Mon Sep 17 00:00:00 2001 From: Jonas Vautherin Date: Thu, 7 Mar 2024 00:46:39 +0100 Subject: [PATCH 56/58] nitpick: fix a few typos in the comments --- c++/src/capnp/arena.h | 2 +- c++/src/capnp/capability.h | 2 +- c++/src/capnp/rpc.c++ | 14 +++++++------- c++/src/kj/async.h | 4 ++-- c++/src/kj/common.h | 4 ++-- c++/src/kj/exception.c++ | 2 +- 6 files changed, 14 insertions(+), 14 deletions(-) diff --git a/c++/src/capnp/arena.h b/c++/src/capnp/arena.h index aeaff8448d..7308912b80 100644 --- a/c++/src/capnp/arena.h +++ b/c++/src/capnp/arena.h @@ -330,7 +330,7 @@ class BuilderArena final: public Arena { SegmentBuilder* addExternalSegment(kj::ArrayPtr content); // Add a new segment to the arena which points to some existing memory region. The segment is - // assumed to be completley full; the arena will never allocate from it. In fact, the segment + // assumed to be completely full; the arena will never allocate from it. In fact, the segment // is considered read-only. Any attempt to get a Builder pointing into this segment will throw // an exception. Readers are allowed, however. // diff --git a/c++/src/capnp/capability.h b/c++/src/capnp/capability.h index b74c40c02b..7219845604 100644 --- a/c++/src/capnp/capability.h +++ b/c++/src/capnp/capability.h @@ -133,7 +133,7 @@ class Request: public Params::Builder { // to complete (and possibly other things, if that RPC itself returned a promise capability), // but when using `sendPipelineOnly()`, `whenResolved()` may complete immediately, or never, or // at an arbitrary time. Do not rely on it. - // - Normal path shortening may not work with these capabilities. For exmaple, if the caller + // - Normal path shortening may not work with these capabilities. For example, if the caller // forwards a pipelined capability back to the callee's vat, calls made by the callee to that // capability may continue to proxy through the caller. Conversely, if the callee ends up // returning a capability that points back to the caller's vat, calls on the pipelined diff --git a/c++/src/capnp/rpc.c++ b/c++/src/capnp/rpc.c++ index b15d23fb84..02775a31e3 100644 --- a/c++/src/capnp/rpc.c++ +++ b/c++/src/capnp/rpc.c++ @@ -714,7 +714,7 @@ private: bool gotReturnForHighQuestionId = false; // Becomes true if we ever get a `Return` message for a high question ID (with top bit set), // which we use in cases where we've hinted to the peer that we don't want a `Return`. If the - // peer sends us one anyway then it seemingly doesn't not implement our hints. We need to stop + // peer sends us one anyway then it seemingly does not implement our hints. We need to stop // using the hints in this case before the high question ID space wraps around since otherwise // we might reuse an ID that the peer thinks is still in use. @@ -1933,7 +1933,7 @@ private: question.paramExports = kj::mv(exports); question.isTailCall = isTailCall; - // Make the QuentionRef and result promise. + // Make the QuestionRef and result promise. SendInternalResult result; auto paf = kj::newPromiseAndFulfiller>>(); result.questionRef = kj::refcounted( @@ -2024,7 +2024,7 @@ private: question.paramExports = kj::mv(exports); question.isTailCall = false; - // Make the QuentionRef and result promise. + // Make the QuestionRef and result promise. auto questionRef = kj::refcounted(*connectionState, questionId, kj::none); question.selfRef = *questionRef; @@ -2249,7 +2249,7 @@ private: struct Resolution { kj::Own returnedCap; - // The capabiilty that appeared in the response message in this slot. + // The capability that appeared in the response message in this slot. kj::Own unwrapped; // Exactly what `getInnermostClient(returnedCap)` produced at the time that the return @@ -2463,7 +2463,7 @@ private: if (!responseImpl.hasCapabilities()) { returnMessage.setNoFinishNeeded(true); - // Tell ourselves that a finsih was already received, so that `cleanupAnswerTable()` + // Tell ourselves that a finish was already received, so that `cleanupAnswerTable()` // removes the answer table entry. receivedFinish = true; @@ -2696,7 +2696,7 @@ private: // Cancellation state ---------------------------------- bool receivedFinish = false; - // True if a `Finish` message has been recevied OR we sent a `Return` with `noFinishNedeed`. + // True if a `Finish` message has been received OR we sent a `Return` with `noFinishNeeded`. // In either case, it is our responsibility to clean up the answer table. kj::UnwindDetector unwindDetector; @@ -3451,7 +3451,7 @@ private: // Carol returns a capability from this call that points all the way back though Bob to // Alice. When this return capability passes through Bob, Bob will resolve the previous // promise-pipeline capability to it. However, Bob has to send a Disembargo to Carol before - // completing this resolution. In the meantime, though, Bob returns the final repsonse to + // completing this resolution. In the meantime, though, Bob returns the final response to // Alice. Alice then *also* sends a Disembargo to Bob. The Alice -> Bob Disembargo might // arrive at Bob before the Bob -> Carol Disembargo has resolved, in which case the // Disembargo is delivered to a promise capability. diff --git a/c++/src/kj/async.h b/c++/src/kj/async.h index f271ea5b21..f305869477 100644 --- a/c++/src/kj/async.h +++ b/c++/src/kj/async.h @@ -125,7 +125,7 @@ class [[nodiscard]] Promise: protected _::PromiseBase { // `kj::READY_NOW` to an already-fulfilled Promise. You may also implicitly convert a // `kj::Exception` to an already-broken promise of any type. // - // Promises are linear types -- they are moveable but not copyable. If a Promise is destroyed + // Promises are linear types -- they are movable but not copyable. If a Promise is destroyed // or goes out of scope (without being moved elsewhere), any ongoing asynchronous operations // meant to fulfill the promise will be canceled if possible. All methods of `Promise` (unless // otherwise noted) actually consume the promise in the sense of move semantics. (Arguably they @@ -476,7 +476,7 @@ PromiseForResult retryOnDisconnect(Func&& func) KJ_WARN_UNUSED_RESUL template PromiseForResult startFiber( size_t stackSize, Func&& func, SourceLocation location = {}) KJ_WARN_UNUSED_RESULT; -// Executes `func()` in a fiber, returning a promise for the eventual reseult. `func()` will be +// Executes `func()` in a fiber, returning a promise for the eventual result. `func()` will be // passed a `WaitScope&` as its parameter, allowing it to call `.wait()` on promises. Thus, `func()` // can be written in a synchronous, blocking style, instead of using `.then()`. This is often much // easier to write and read, and may even be significantly faster if it allows the use of stack diff --git a/c++/src/kj/common.h b/c++/src/kj/common.h index b768bdbf4c..67718d5951 100644 --- a/c++/src/kj/common.h +++ b/c++/src/kj/common.h @@ -290,7 +290,7 @@ typedef unsigned char byte; #define KJ_DEPRECATED(reason) \ __attribute__((deprecated)) #define KJ_UNAVAILABLE(reason) = delete -// If the `unavailable` attribute is not supproted, just mark the method deleted, which at least +// If the `unavailable` attribute is not supported, just mark the method deleted, which at least // makes it a compile-time error to try to call it. Note that on Clang, marking a method deleted // *and* unavailable unfortunately defeats the purpose of the unavailable annotation, as the // generic "deleted" error is reported instead. @@ -1771,7 +1771,7 @@ class ArrayPtr: public DisallowConstCopyIfNotConst { // ArrayPtr ptr = { 1, 2, 3 }; // foo(ptr[1]); // undefined behavior! // Any KJ programmer should be able to recognize that this is UB, because an ArrayPtr does not own -// its content. That's not what this constructor is for, tohugh. This constructor is meant to allow +// its content. That's not what this constructor is for, though. This constructor is meant to allow // code like this: // int foo(ArrayPtr p); // // ... later ... diff --git a/c++/src/kj/exception.c++ b/c++/src/kj/exception.c++ index 3a03f0be6b..2443c284df 100644 --- a/c++/src/kj/exception.c++ +++ b/c++/src/kj/exception.c++ @@ -255,7 +255,7 @@ ArrayPtr getStackTrace(ArrayPtr space, uint ignoreCount) { #if (__GNUC__ && !_WIN32) || __clang__ // Allow dependents to override the implementation of stack symbolication by making it a weak -// symbol. We prefer weak symbols over some sort of callback registration mechanism becasue this +// symbol. We prefer weak symbols over some sort of callback registration mechanism because this // allows an alternate symbolication library to be easily linked into tests without changing the // code of the test. __attribute__((weak)) From 4e6b53f2aeaa7434675e919f2b64afd38c5ef896 Mon Sep 17 00:00:00 2001 From: Kenton Varda Date: Mon, 11 Mar 2024 13:11:58 -0500 Subject: [PATCH 57/58] In ~ConcurrencyLimitingHttpClient(), crash more eagerly. This error log is necessarily followed by a use-after-free later on that can be nasty. We should crash eagerly instead. --- c++/src/kj/compat/http.c++ | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/c++/src/kj/compat/http.c++ b/c++/src/kj/compat/http.c++ index 447e15e3ac..3e655ed11b 100644 --- a/c++/src/kj/compat/http.c++ +++ b/c++/src/kj/compat/http.c++ @@ -6452,14 +6452,11 @@ public: maxConcurrentRequests(maxConcurrentRequests), countChangedCallback(kj::mv(countChangedCallback)) {} - ~ConcurrencyLimitingHttpClient() noexcept(false) { - if (concurrentRequests > 0) { - static bool logOnce KJ_UNUSED = ([&] { - KJ_LOG(ERROR, "ConcurrencyLimitingHttpClient getting destroyed when concurrent requests " - "are still active", concurrentRequests); - return true; - })(); - } + ~ConcurrencyLimitingHttpClient() noexcept { + // Crash in this case because otherwise we'll have UAF later on. + KJ_ASSERT(concurrentRequests == 0, + "ConcurrencyLimitingHttpClient getting destroyed when concurrent requests " + "are still active"); } Request request(HttpMethod method, kj::StringPtr url, const HttpHeaders& headers, From 24c7498612bffcbb8a11965f3c98126f57b9455c Mon Sep 17 00:00:00 2001 From: Nick Hynes Date: Thu, 29 Feb 2024 13:13:21 +0800 Subject: [PATCH 58/58] Add AF_VSOCK support to async io --- c++/src/kj/async-io-internal.h | 1 + c++/src/kj/async-io-test.c++ | 50 ++++++++++++++++++++++++++++++++++ c++/src/kj/async-io-unix.c++ | 36 ++++++++++++++++++++++++ c++/src/kj/async-io.c++ | 32 ++++++++++++++++------ 4 files changed, 111 insertions(+), 8 deletions(-) diff --git a/c++/src/kj/async-io-internal.h b/c++/src/kj/async-io-internal.h index d030ad9577..bc0df7c30c 100644 --- a/c++/src/kj/async-io-internal.h +++ b/c++/src/kj/async-io-internal.h @@ -58,6 +58,7 @@ class NetworkFilter: public LowLevelAsyncIoProvider::NetworkFilter { Vector denyCidrs; bool allowUnix; bool allowAbstractUnix; + bool allowVsock; bool allowPublic = false; bool allowNetwork = false; diff --git a/c++/src/kj/async-io-test.c++ b/c++/src/kj/async-io-test.c++ index eade6c5391..0107705cfe 100644 --- a/c++/src/kj/async-io-test.c++ +++ b/c++/src/kj/async-io-test.c++ @@ -367,6 +367,49 @@ TEST(AsyncIo, AncillaryMessageHandler) { } #endif +#if __linux__ +TEST(AsyncIo, VmSocket) { + auto ioContext = setupAsyncIo(); + auto& network = ioContext.provider->getNetwork(); + + Own listener; + Own server; + Own client; + + char receiveBuffer[4]; + + int port = ((getpid() >> 10) % (0xffff - 1024 + 1)) + 1024; // wrap to dynamic port range + auto ready = newPromiseAndFulfiller(); + + ready.promise.then([&]() { + return network.parseAddress(kj::str("vsock:2:", port)); + }).then([&](Own&& addr) { + auto promise = addr->connect(); + return promise.then([&,addr=kj::mv(addr)](auto result) mutable { + client = kj::mv(result); + return client->write("foo", 3); + }); + }).detach([](kj::Exception&& exception) { + KJ_FAIL_EXPECT(exception); + }); + + kj::String result = network.parseAddress(kj::str("vsock:2:", port)) + .then([&](Own&& result) { + listener = result->listen(); + ready.fulfiller->fulfill(); + return listener->accept(); + }).then([&](auto result) { + server = kj::mv(result); + return server->tryRead(receiveBuffer, 3, 4); + }).then([&](size_t n) { + EXPECT_EQ(3u, n); + return heapString(receiveBuffer, n); + }).wait(ioContext.waitScope); + + EXPECT_EQ("foo", result); +} +#endif + String tryParse(WaitScope& waitScope, Network& network, StringPtr text, uint portHint = 0) { return network.parseAddress(text, portHint).wait(waitScope)->toString(); } @@ -409,6 +452,10 @@ TEST(AsyncIo, AddressParsing) { EXPECT_EQ("unix-abstract:foo/bar/baz", tryParse(w, network, "unix-abstract:foo/bar/baz")); #endif +#if __linux__ + EXPECT_EQ("vsock:4294967295:1234", tryParse(w, network, "vsock:-1:1234")); +#endif + // We can parse services by name... // // For some reason, Android and some various Linux distros do not support service names. @@ -1265,6 +1312,9 @@ KJ_TEST("Network::restrictPeers()") { #if !_WIN32 KJ_EXPECT_THROW_MESSAGE("restrictPeers", tryParse(w, *restrictedNetwork, "unix:/foo")); #endif +#if __linux__ + KJ_EXPECT_THROW_MESSAGE("restrictPeers", tryParse(w, *restrictedNetwork, "vsock:-1:80")); +#endif auto addr = restrictedNetwork->parseAddress("127.0.0.1").wait(w); diff --git a/c++/src/kj/async-io-unix.c++ b/c++/src/kj/async-io-unix.c++ index dd01a4189e..27e3dbfd9b 100644 --- a/c++/src/kj/async-io-unix.c++ +++ b/c++/src/kj/async-io-unix.c++ @@ -60,6 +60,7 @@ #if __linux__ #include +#include #endif #if !defined(SO_PEERCRED) && defined(LOCAL_PEERCRED) @@ -985,6 +986,11 @@ public: return str("unix:", path); } } +#if __linux__ + case AF_VSOCK: { + return str("vsock:", addr.vsock.svm_cid, ":", addr.vsock.svm_port); + } +#endif default: return str("(unknown address family ", addr.generic.sa_family, ")"); } @@ -1043,6 +1049,33 @@ public: return array.finish(); } +#if __linux__ + if (str.startsWith("vsock:")) { + StringPtr path = str.slice(strlen("vsock:")); + + char* endptr; + unsigned int cid = strtoul(path.cStr(), &endptr, 0); + KJ_REQUIRE(*endptr == ':', "missing vsock port"); + unsigned int port = strtoul(endptr + 1, &endptr, 0); + KJ_REQUIRE(*endptr == '\0', "invalid vsock addr"); + + memset(&result.addr.vsock, 0, sizeof(result.addr.vsock)); + result.addr.vsock.svm_family = AF_VSOCK; + result.addr.vsock.svm_cid = cid; + result.addr.vsock.svm_port = port; + result.addrlen = sizeof(struct sockaddr_vm); + + if (!result.parseAllowedBy(filter)) { + KJ_FAIL_REQUIRE("VM sockets blocked by restrictPeers()"); + return Array(); + } + + auto array = kj::heapArrayBuilder(1); + array.add(result); + return array.finish(); + } +#endif + // Try to separate the address and port. ArrayPtr addrPart; Maybe portPart; @@ -1195,6 +1228,9 @@ private: struct sockaddr_in inet4; struct sockaddr_in6 inet6; struct sockaddr_un unixDomain; +#if __linux__ + struct sockaddr_vm vsock; +#endif struct sockaddr_storage storage; } addr; diff --git a/c++/src/kj/async-io.c++ b/c++/src/kj/async-io.c++ index fd741bfda1..c8e36787e6 100644 --- a/c++/src/kj/async-io.c++ +++ b/c++/src/kj/async-io.c++ @@ -3082,7 +3082,7 @@ bool matchesAny(ArrayPtr cidrs, const struct sockaddr* addr) { } NetworkFilter::NetworkFilter() - : allowUnix(true), allowAbstractUnix(true) { + : allowUnix(true), allowAbstractUnix(true), allowVsock(true) { allowCidrs.add(CidrRange::inet4({0,0,0,0}, 0)); allowCidrs.add(CidrRange::inet6({}, {}, 0)); denyCidrs.addAll(reservedCidrs()); @@ -3090,7 +3090,7 @@ NetworkFilter::NetworkFilter() NetworkFilter::NetworkFilter(ArrayPtr allow, ArrayPtr deny, NetworkFilter& next) - : allowUnix(false), allowAbstractUnix(false), next(next) { + : allowUnix(false), allowAbstractUnix(false), allowVsock(false), next(next) { for (auto rule: allow) { if (rule == "local") { allowCidrs.addAll(localCidrs()); @@ -3107,6 +3107,8 @@ NetworkFilter::NetworkFilter(ArrayPtr allow, ArrayPtr allow, ArrayPtrsa_family == AF_VSOCK) return allowVsock; +#endif + bool allowed = false; uint allowSpecificity = 0; @@ -3197,15 +3205,23 @@ bool NetworkFilter::shouldAllowParse(const struct sockaddr* addr, uint addrlen) } } else { #endif - if ((addr->sa_family == AF_INET || addr->sa_family == AF_INET6) && - (allowPublic || allowNetwork)) { - matched = true; - } - for (auto& cidr: allowCidrs) { - if (cidr.matchesFamily(addr->sa_family)) { +#if __linux__ + if (addr->sa_family == AF_VSOCK) { + if (allowVsock) matched = true; + } else { +#endif + if ((addr->sa_family == AF_INET || addr->sa_family == AF_INET6) && + (allowPublic || allowNetwork)) { matched = true; } + for (auto& cidr: allowCidrs) { + if (cidr.matchesFamily(addr->sa_family)) { + matched = true; + } + } +#if __linux__ } +#endif #if !_WIN32 } #endif