From 3f4e48642392246ae57de0a13963003b8a7df057 Mon Sep 17 00:00:00 2001 From: Justin King Date: Fri, 21 Mar 2025 12:29:16 -0700 Subject: [PATCH] Implement optimized `string` and `bytes` concatenation PiperOrigin-RevId: 739263703 --- common/internal/byte_string.cc | 64 ++++++++++++++++++++++++++++ common/internal/byte_string.h | 7 +++ common/values/bytes_value.cc | 7 +++ common/values/bytes_value.h | 4 ++ common/values/string_value.cc | 18 ++------ runtime/standard/string_functions.cc | 10 +---- 6 files changed, 88 insertions(+), 22 deletions(-) diff --git a/common/internal/byte_string.cc b/common/internal/byte_string.cc index 0e2c19a65..416cb8621 100644 --- a/common/internal/byte_string.cc +++ b/common/internal/byte_string.cc @@ -56,6 +56,51 @@ T ConsumeAndDestroy(T& object) { } // namespace +ByteString ByteString::Concat(const ByteString& lhs, const ByteString& rhs, + absl::Nonnull arena) { + ABSL_DCHECK(arena != nullptr); + + if (lhs.empty()) { + return rhs; + } + if (rhs.empty()) { + return lhs; + } + + if (lhs.GetKind() == ByteStringKind::kLarge || + rhs.GetKind() == ByteStringKind::kLarge) { + // If either the left or right are absl::Cord, use absl::Cord. + absl::Cord result; + result.Append(lhs.ToCord()); + result.Append(rhs.ToCord()); + return ByteString(std::move(result)); + } + + const size_t lhs_size = lhs.size(); + const size_t rhs_size = rhs.size(); + const size_t result_size = lhs_size + rhs_size; + ByteString result; + if (result_size <= kSmallByteStringCapacity) { + // If the resulting string fits in inline storage, do it. + result.rep_.small.size = result_size; + result.rep_.small.arena = arena; + lhs.CopyToArray(result.rep_.small.data); + rhs.CopyToArray(result.rep_.small.data + lhs_size); + } else { + // Otherwise allocate on the arena. + char* result_data = + reinterpret_cast(arena->AllocateAligned(result_size)); + lhs.CopyToArray(result_data); + rhs.CopyToArray(result_data + lhs_size); + result.rep_.medium.data = result_data; + result.rep_.medium.size = result_size; + result.rep_.medium.owner = + reinterpret_cast(arena) | kMetadataOwnerArenaBit; + result.rep_.medium.kind = ByteStringKind::kMedium; + } + return result; +} + ByteString::ByteString(Allocator<> allocator, absl::string_view string) { ABSL_DCHECK_LE(string.size(), max_size()); auto* arena = allocator.arena(); @@ -249,6 +294,25 @@ void ByteString::RemoveSuffix(size_t n) { } } +void ByteString::CopyToArray(absl::Nonnull out) const { + ABSL_DCHECK(out != nullptr); + + switch (GetKind()) { + case ByteStringKind::kSmall: { + absl::string_view small = GetSmall(); + std::memcpy(out, small.data(), small.size()); + } break; + case ByteStringKind::kMedium: { + absl::string_view medium = GetMedium(); + std::memcpy(out, medium.data(), medium.size()); + } break; + case ByteStringKind::kLarge: { + const absl::Cord& large = GetLarge(); + (CopyCordToArray)(large, out); + } break; + } +} + std::string ByteString::ToString() const { switch (GetKind()) { case ByteStringKind::kSmall: diff --git a/common/internal/byte_string.h b/common/internal/byte_string.h index f2c11589e..e2f38a5c4 100644 --- a/common/internal/byte_string.h +++ b/common/internal/byte_string.h @@ -43,6 +43,7 @@ namespace cel { class BytesValueInputStream; class BytesValueOutputStream; +class StringValue; namespace common_internal { @@ -171,6 +172,9 @@ absl::string_view LegacyByteString(const ByteString& string, bool stable, class CEL_COMMON_INTERNAL_BYTE_STRING_TRIVIAL_ABI [[nodiscard]] ByteString final { public: + static ByteString Concat(const ByteString& lhs, const ByteString& rhs, + absl::Nonnull arena); + ByteString() : ByteString(NewDeleteAllocator()) {} explicit ByteString(absl::Nullable string) @@ -333,6 +337,7 @@ ByteString final { friend struct ByteStringTestFriend; friend class cel::BytesValueInputStream; friend class cel::BytesValueOutputStream; + friend class cel::StringValue; friend absl::string_view LegacyByteString( const ByteString& string, bool stable, absl::Nonnull arena); @@ -475,6 +480,8 @@ ByteString final { static void DestroyLarge(LargeByteStringRep& rep) { GetLarge(rep).~Cord(); } + void CopyToArray(absl::Nonnull out) const; + ByteStringRep rep_; }; diff --git a/common/values/bytes_value.cc b/common/values/bytes_value.cc index cac57b320..4a8123e5e 100644 --- a/common/values/bytes_value.cc +++ b/common/values/bytes_value.cc @@ -23,6 +23,7 @@ #include "absl/strings/cord.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "common/internal/byte_string.h" #include "common/value.h" #include "internal/status_macros.h" #include "internal/strings.h" @@ -53,6 +54,12 @@ std::string BytesDebugString(const Bytes& value) { } // namespace +BytesValue BytesValue::Concat(const BytesValue& lhs, const BytesValue& rhs, + absl::Nonnull arena) { + return BytesValue( + common_internal::ByteString::Concat(lhs.value_, rhs.value_, arena)); +} + std::string BytesValue::DebugString() const { return BytesDebugString(*this); } absl::Status BytesValue::SerializeTo( diff --git a/common/values/bytes_value.h b/common/values/bytes_value.h index 56c2130e8..98652755d 100644 --- a/common/values/bytes_value.h +++ b/common/values/bytes_value.h @@ -81,6 +81,10 @@ class BytesValue final : private common_internal::ValueMixin { absl::Nullable arena ABSL_ATTRIBUTE_LIFETIME_BOUND) = delete; + static BytesValue Concat(const BytesValue& lhs, const BytesValue& rhs, + absl::Nonnull arena + ABSL_ATTRIBUTE_LIFETIME_BOUND); + ABSL_DEPRECATED("Use From") explicit BytesValue(absl::Nullable value) : value_(value) {} diff --git a/common/values/string_value.cc b/common/values/string_value.cc index 8fb4f4a1d..72fda5114 100644 --- a/common/values/string_value.cc +++ b/common/values/string_value.cc @@ -13,8 +13,8 @@ // limitations under the License. #include +#include #include -#include #include "google/protobuf/wrappers.pb.h" #include "absl/base/nullability.h" @@ -25,6 +25,7 @@ #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "common/internal/byte_string.h" #include "common/value.h" #include "internal/status_macros.h" #include "internal/strings.h" @@ -58,19 +59,8 @@ std::string StringDebugString(const Bytes& value) { StringValue StringValue::Concat(const StringValue& lhs, const StringValue& rhs, absl::Nonnull arena) { - ABSL_DCHECK(arena != nullptr); - - if (lhs.IsEmpty()) { - return rhs; - } - if (rhs.IsEmpty()) { - return lhs; - } - - absl::Cord result; - result.Append(lhs.ToCord()); - result.Append(rhs.ToCord()); - return StringValue(std::move(result)); + return StringValue( + common_internal::ByteString::Concat(lhs.value_, rhs.value_, arena)); } std::string StringValue::DebugString() const { diff --git a/runtime/standard/string_functions.cc b/runtime/standard/string_functions.cc index e6b60c618..d14e7674c 100644 --- a/runtime/standard/string_functions.cc +++ b/runtime/standard/string_functions.cc @@ -40,10 +40,7 @@ absl::StatusOr ConcatString( absl::Nonnull, absl::Nonnull, absl::Nonnull arena) { - // TODO: use StringValue::Concat when remaining interop usages - // removed. Modern concat implementation forces additional copies when - // converting to legacy string values. - return StringValue(arena, absl::StrCat(value1.ToString(), value2.ToString())); + return StringValue::Concat(value1, value2, arena); } // Concatenation for bytes type. @@ -52,10 +49,7 @@ absl::StatusOr ConcatBytes( absl::Nonnull, absl::Nonnull, absl::Nonnull arena) { - // TODO: use BytesValue::Concat when remaining interop usages - // removed. Modern concat implementation forces additional copies when - // converting to legacy string values. - return BytesValue(arena, absl::StrCat(value1.ToString(), value2.ToString())); + return BytesValue::Concat(value1, value2, arena); } bool StringContains(const StringValue& value, const StringValue& substr) {