From adb2c4b415c35cda3e4394384ec579ea8dd8ec92 Mon Sep 17 00:00:00 2001 From: Protobuf Team Bot Date: Tue, 20 Jun 2023 20:15:16 -0700 Subject: [PATCH] When parsing repeated primitives, append to a tmp array on stack. Adding to a temporary array of values on stack, then merging it to RepeatedField minimizes dynamic growth of RepeatedField. PiperOrigin-RevId: 542123764 --- .../generated_message_tctable_lite.cc | 244 +++++++++++------- src/google/protobuf/repeated_field.h | 28 ++ .../protobuf/repeated_field_unittest.cc | 18 ++ 3 files changed, 198 insertions(+), 92 deletions(-) diff --git a/src/google/protobuf/generated_message_tctable_lite.cc b/src/google/protobuf/generated_message_tctable_lite.cc index 460131703529..906519c8f01b 100644 --- a/src/google/protobuf/generated_message_tctable_lite.cc +++ b/src/google/protobuf/generated_message_tctable_lite.cc @@ -36,6 +36,7 @@ #include #include +#include "absl/base/optimization.h" #include "google/protobuf/generated_message_tctable_decl.h" #include "google/protobuf/generated_message_tctable_impl.h" #include "google/protobuf/inlined_string_field.h" @@ -43,6 +44,7 @@ #include "google/protobuf/map.h" #include "google/protobuf/message_lite.h" #include "google/protobuf/parse_context.h" +#include "google/protobuf/repeated_field.h" #include "google/protobuf/varint_shuffle.h" #include "google/protobuf/wire_format_lite.h" #include "utf8_validity.h" @@ -406,6 +408,45 @@ inline PROTOBUF_ALWAYS_INLINE void InvertPacked(TcFieldData& data) { data.data ^= Wt ^ WireFormatLite::WIRETYPE_LENGTH_DELIMITED; } +constexpr uint32_t kAccumulatorBytesOnStack = 256; + +// Accumulates fields to buffer repeated fields on parsing path to avoid growing +// repeated field container type too frequently. It flushes to the backing +// repeated fields if it's full or out of the scope. A larger buffer (e.g. 2KiB) +// is actually harmful due to: +// - increased stack overflow risk +// - extra cache misses on accessing local variables +// - less competitive to the cost of growing large buffer +template +class ScopedFieldAccumulator { + public: + constexpr explicit ScopedFieldAccumulator(ContainerType& field) + : field_(field) {} + + ~ScopedFieldAccumulator() { + if (ABSL_PREDICT_TRUE(current_size_ > 0)) { + field_.MergeFromArray(buffer_, current_size_); + } + } + + void Add(ElementType v) { + if (ABSL_PREDICT_FALSE(current_size_ == kSize)) { + field_.MergeFromArray(buffer_, kSize); + current_size_ = 0; + } + buffer_[current_size_++] = v; + } + + private: + static constexpr uint32_t kSize = + kAccumulatorBytesOnStack / sizeof(ElementType); + static_assert(kSize > 0, "Size cannot be zero"); + + uint32_t current_size_ = 0; + ElementType buffer_[kSize]; + ContainerType& field_; +}; + } // namespace ////////////////////////////////////////////////////////////////////////////// @@ -632,14 +673,17 @@ PROTOBUF_ALWAYS_INLINE const char* TcParser::RepeatedFixed( } auto& field = RefAt>(msg, data.offset()); const auto tag = UnalignedLoad(ptr); - do { - field.Add(UnalignedLoad(ptr + sizeof(TagType))); - ptr += sizeof(TagType) + sizeof(LayoutType); - if (PROTOBUF_PREDICT_FALSE(!ctx->DataAvailable(ptr))) { - PROTOBUF_MUSTTAIL return ToParseLoop(PROTOBUF_TC_PARAM_NO_DATA_PASS); - } - } while (UnalignedLoad(ptr) == tag); + { + ScopedFieldAccumulator accumulator(field); + do { + accumulator.Add(UnalignedLoad(ptr + sizeof(TagType))); + ptr += sizeof(TagType) + sizeof(LayoutType); + if (PROTOBUF_PREDICT_FALSE(!ctx->DataAvailable(ptr))) goto parse_loop; + } while (UnalignedLoad(ptr) == tag); + } PROTOBUF_MUSTTAIL return ToTagDispatch(PROTOBUF_TC_PARAM_NO_DATA_PASS); +parse_loop: + PROTOBUF_MUSTTAIL return ToParseLoop(PROTOBUF_TC_PARAM_NO_DATA_PASS); } PROTOBUF_NOINLINE const char* TcParser::FastF32R1(PROTOBUF_TC_PARAM_DECL) { @@ -971,19 +1015,22 @@ PROTOBUF_ALWAYS_INLINE const char* TcParser::RepeatedVarint( } auto& field = RefAt>(msg, data.offset()); const auto expected_tag = UnalignedLoad(ptr); - do { - ptr += sizeof(TagType); - FieldType tmp; - ptr = ParseVarint(ptr, &tmp); - if (ptr == nullptr) { - PROTOBUF_MUSTTAIL return Error(PROTOBUF_TC_PARAM_NO_DATA_PASS); - } - field.Add(ZigZagDecodeHelper(tmp)); - if (PROTOBUF_PREDICT_FALSE(!ctx->DataAvailable(ptr))) { - PROTOBUF_MUSTTAIL return ToParseLoop(PROTOBUF_TC_PARAM_NO_DATA_PASS); - } - } while (UnalignedLoad(ptr) == expected_tag); + { + ScopedFieldAccumulator accumulator(field); + do { + ptr += sizeof(TagType); + FieldType tmp; + ptr = ParseVarint(ptr, &tmp); + if (ptr == nullptr) goto error; + accumulator.Add(ZigZagDecodeHelper(tmp)); + if (PROTOBUF_PREDICT_FALSE(!ctx->DataAvailable(ptr))) goto parse_loop; + } while (UnalignedLoad(ptr) == expected_tag); + } PROTOBUF_MUSTTAIL return ToTagDispatch(PROTOBUF_TC_PARAM_NO_DATA_PASS); +parse_loop: + PROTOBUF_MUSTTAIL return ToParseLoop(PROTOBUF_TC_PARAM_NO_DATA_PASS); +error: + PROTOBUF_MUSTTAIL return Error(PROTOBUF_TC_PARAM_NO_DATA_PASS); } PROTOBUF_NOINLINE const char* TcParser::FastV8R1(PROTOBUF_TC_PARAM_DECL) { @@ -1044,7 +1091,8 @@ const char* TcParser::PackedVarint(PROTOBUF_TC_PARAM_DECL) { // pending hasbits now: SyncHasbits(msg, hasbits, table); auto* field = &RefAt>(msg, data.offset()); - return ctx->ReadPackedVarint(ptr, [field](uint64_t varint) { + ScopedFieldAccumulator accumulator(*field); + return ctx->ReadPackedVarint(ptr, [&](uint64_t varint) { FieldType val; if (zigzag) { if (sizeof(FieldType) == 8) { @@ -1055,7 +1103,7 @@ const char* TcParser::PackedVarint(PROTOBUF_TC_PARAM_DECL) { } else { val = varint; } - field->Add(val); + accumulator.Add(val); }); } @@ -1190,28 +1238,33 @@ const char* TcParser::RepeatedEnum(PROTOBUF_TC_PARAM_DECL) { auto& field = RefAt>(msg, data.offset()); const auto expected_tag = UnalignedLoad(ptr); const TcParseTableBase::FieldAux aux = *table->field_aux(data.aux_idx()); - do { - const char* ptr2 = ptr; // save for unknown enum case - ptr += sizeof(TagType); - uint64_t tmp; - ptr = ParseVarint(ptr, &tmp); - if (ptr == nullptr) { - PROTOBUF_MUSTTAIL return Error(PROTOBUF_TC_PARAM_NO_DATA_PASS); - } - if (PROTOBUF_PREDICT_FALSE( - !EnumIsValidAux(static_cast(tmp), xform_val, aux))) { - // We can avoid duplicate work in MiniParse by directly calling - // table->fallback. - ptr = ptr2; - PROTOBUF_MUSTTAIL return FastUnknownEnumFallback(PROTOBUF_TC_PARAM_PASS); - } - field.Add(static_cast(tmp)); - if (PROTOBUF_PREDICT_FALSE(!ctx->DataAvailable(ptr))) { - PROTOBUF_MUSTTAIL return ToParseLoop(PROTOBUF_TC_PARAM_NO_DATA_PASS); - } - } while (UnalignedLoad(ptr) == expected_tag); + { + ScopedFieldAccumulator accumulator(field); + do { + const char* ptr2 = ptr; // save for unknown enum case + ptr += sizeof(TagType); + uint64_t tmp; + ptr = ParseVarint(ptr, &tmp); + if (ptr == nullptr) goto error; + if (PROTOBUF_PREDICT_FALSE( + !EnumIsValidAux(static_cast(tmp), xform_val, aux))) { + // We can avoid duplicate work in MiniParse by directly calling + // table->fallback. + ptr = ptr2; + goto unknown_enum_fallback; + } + accumulator.Add(static_cast(tmp)); + if (PROTOBUF_PREDICT_FALSE(!ctx->DataAvailable(ptr))) goto parse_loop; + } while (UnalignedLoad(ptr) == expected_tag); + } PROTOBUF_MUSTTAIL return ToTagDispatch(PROTOBUF_TC_PARAM_NO_DATA_PASS); +parse_loop: + PROTOBUF_MUSTTAIL return ToParseLoop(PROTOBUF_TC_PARAM_NO_DATA_PASS); +error: + PROTOBUF_MUSTTAIL return Error(PROTOBUF_TC_PARAM_NO_DATA_PASS); +unknown_enum_fallback: + PROTOBUF_MUSTTAIL return FastUnknownEnumFallback(PROTOBUF_TC_PARAM_PASS); } const TcParser::UnknownFieldOps& TcParser::GetUnknownFieldOps( @@ -1345,19 +1398,22 @@ const char* TcParser::RepeatedEnumSmallRange(PROTOBUF_TC_PARAM_DECL) { auto& field = RefAt>(msg, data.offset()); auto expected_tag = UnalignedLoad(ptr); const uint8_t max = data.aux_idx(); - do { - uint8_t v = ptr[sizeof(TagType)]; - if (PROTOBUF_PREDICT_FALSE(min > v || v > max)) { - PROTOBUF_MUSTTAIL return MiniParse(PROTOBUF_TC_PARAM_NO_DATA_PASS); - } - field.Add(static_cast(v)); - ptr += sizeof(TagType) + 1; - if (PROTOBUF_PREDICT_FALSE(!ctx->DataAvailable(ptr))) { - PROTOBUF_MUSTTAIL return ToParseLoop(PROTOBUF_TC_PARAM_NO_DATA_PASS); - } - } while (UnalignedLoad(ptr) == expected_tag); + { + ScopedFieldAccumulator accumulator(field); + do { + uint8_t v = ptr[sizeof(TagType)]; + if (PROTOBUF_PREDICT_FALSE(min > v || v > max)) goto mini_parse; + accumulator.Add(static_cast(v)); + ptr += sizeof(TagType) + 1; + if (PROTOBUF_PREDICT_FALSE(!ctx->DataAvailable(ptr))) goto parse_loop; + } while (UnalignedLoad(ptr) == expected_tag); + } PROTOBUF_MUSTTAIL return ToTagDispatch(PROTOBUF_TC_PARAM_NO_DATA_PASS); +parse_loop: + PROTOBUF_MUSTTAIL return ToParseLoop(PROTOBUF_TC_PARAM_NO_DATA_PASS); +mini_parse: + PROTOBUF_MUSTTAIL return MiniParse(PROTOBUF_TC_PARAM_NO_DATA_PASS); } PROTOBUF_NOINLINE const char* TcParser::FastEr0R1(PROTOBUF_TC_PARAM_DECL) { @@ -1846,9 +1902,10 @@ PROTOBUF_NOINLINE const char* TcParser::MpRepeatedFixed( constexpr auto size = sizeof(uint64_t); const char* ptr2 = ptr; uint32_t next_tag; + ScopedFieldAccumulator accumulator(field); do { ptr = ptr2; - *field.Add() = UnalignedLoad(ptr); + accumulator.Add(UnalignedLoad(ptr)); ptr += size; if (!ctx->DataAvailable(ptr)) break; ptr2 = ReadTag(ptr, &next_tag); @@ -1862,9 +1919,10 @@ PROTOBUF_NOINLINE const char* TcParser::MpRepeatedFixed( constexpr auto size = sizeof(uint32_t); const char* ptr2 = ptr; uint32_t next_tag; + ScopedFieldAccumulator accumulator(field); do { ptr = ptr2; - *field.Add() = UnalignedLoad(ptr); + accumulator.Add(UnalignedLoad(ptr)); ptr += size; if (!ctx->DataAvailable(ptr)) break; ptr2 = ReadTag(ptr, &next_tag); @@ -1993,66 +2051,60 @@ PROTOBUF_NOINLINE const char* TcParser::MpRepeatedVarint( auto& field = RefAt>(msg, entry.offset); const char* ptr2 = ptr; uint32_t next_tag; + ScopedFieldAccumulator accumulator(field); do { uint64_t tmp; ptr = ParseVarint(ptr2, &tmp); - if (ptr == nullptr) { - PROTOBUF_MUSTTAIL return Error(PROTOBUF_TC_PARAM_NO_DATA_PASS); - } - field.Add(is_zigzag ? WireFormatLite::ZigZagDecode64(tmp) : tmp); + if (ptr == nullptr) goto error; + accumulator.Add(is_zigzag ? WireFormatLite::ZigZagDecode64(tmp) : tmp); if (!ctx->DataAvailable(ptr)) break; ptr2 = ReadTag(ptr, &next_tag); - if (ptr2 == nullptr) { - PROTOBUF_MUSTTAIL return Error(PROTOBUF_TC_PARAM_NO_DATA_PASS); - } + if (ptr2 == nullptr) goto error; } while (next_tag == decoded_tag); } else if (rep == field_layout::kRep32Bits) { auto& field = RefAt>(msg, entry.offset); const char* ptr2 = ptr; uint32_t next_tag; + ScopedFieldAccumulator accumulator(field); do { uint64_t tmp; ptr = ParseVarint(ptr2, &tmp); - if (ptr == nullptr) { - PROTOBUF_MUSTTAIL return Error(PROTOBUF_TC_PARAM_NO_DATA_PASS); - } + if (ptr == nullptr) goto error; if (is_validated_enum) { if (!EnumIsValidAux(tmp, xform_val, *table->field_aux(&entry))) { ptr = ptr2; - PROTOBUF_MUSTTAIL return MpUnknownEnumFallback( - PROTOBUF_TC_PARAM_PASS); + goto unknown_enum_fallback; } } else if (is_zigzag) { tmp = WireFormatLite::ZigZagDecode32(tmp); } - field.Add(tmp); + accumulator.Add(tmp); if (!ctx->DataAvailable(ptr)) break; ptr2 = ReadTag(ptr, &next_tag); - if (ptr2 == nullptr) { - PROTOBUF_MUSTTAIL return Error(PROTOBUF_TC_PARAM_NO_DATA_PASS); - } + if (ptr2 == nullptr) goto error; } while (next_tag == decoded_tag); } else { ABSL_DCHECK_EQ(rep, static_cast(field_layout::kRep8Bits)); auto& field = RefAt>(msg, entry.offset); const char* ptr2 = ptr; uint32_t next_tag; + ScopedFieldAccumulator accumulator(field); do { uint64_t tmp; ptr = ParseVarint(ptr2, &tmp); - if (ptr == nullptr) { - PROTOBUF_MUSTTAIL return Error(PROTOBUF_TC_PARAM_NO_DATA_PASS); - } - field.Add(static_cast(tmp)); + if (ptr == nullptr) goto error; + accumulator.Add(static_cast(tmp)); if (!ctx->DataAvailable(ptr)) break; ptr2 = ReadTag(ptr, &next_tag); - if (ptr2 == nullptr) { - PROTOBUF_MUSTTAIL return Error(PROTOBUF_TC_PARAM_NO_DATA_PASS); - } + if (ptr2 == nullptr) goto error; } while (next_tag == decoded_tag); } PROTOBUF_MUSTTAIL return ToTagDispatch(PROTOBUF_TC_PARAM_NO_DATA_PASS); +error: + PROTOBUF_MUSTTAIL return Error(PROTOBUF_TC_PARAM_NO_DATA_PASS); +unknown_enum_fallback: + PROTOBUF_MUSTTAIL return MpUnknownEnumFallback(PROTOBUF_TC_PARAM_PASS); } PROTOBUF_NOINLINE const char* TcParser::MpPackedVarint(PROTOBUF_TC_PARAM_DECL) { @@ -2074,33 +2126,41 @@ PROTOBUF_NOINLINE const char* TcParser::MpPackedVarint(PROTOBUF_TC_PARAM_DECL) { uint16_t rep = type_card & field_layout::kRepMask; if (rep == field_layout::kRep64Bits) { - auto* field = &RefAt>(msg, entry.offset); - return ctx->ReadPackedVarint(ptr, [field, is_zigzag](uint64_t value) { - field->Add(is_zigzag ? WireFormatLite::ZigZagDecode64(value) : value); - }); + auto& field = RefAt>(msg, entry.offset); + ScopedFieldAccumulator accumulator(field); + return ctx->ReadPackedVarint( + ptr, [&accumulator, is_zigzag](uint64_t value) { + accumulator.Add(is_zigzag ? WireFormatLite::ZigZagDecode64(value) + : value); + }); } else if (rep == field_layout::kRep32Bits) { - auto* field = &RefAt>(msg, entry.offset); + auto& field = RefAt>(msg, entry.offset); if (is_validated_enum) { const TcParseTableBase::FieldAux aux = *table->field_aux(entry.aux_idx); - return ctx->ReadPackedVarint(ptr, [=](int32_t value) { + ScopedFieldAccumulator accumulator(field); + return ctx->ReadPackedVarint(ptr, [=, &accumulator](int32_t value) { if (!EnumIsValidAux(value, xform_val, aux)) { AddUnknownEnum(msg, table, data.tag(), value); } else { - field->Add(value); + accumulator.Add(value); } }); } else { - return ctx->ReadPackedVarint(ptr, [field, is_zigzag](uint64_t value) { - field->Add(is_zigzag ? WireFormatLite::ZigZagDecode32( - static_cast(value)) - : value); - }); + ScopedFieldAccumulator accumulator(field); + return ctx->ReadPackedVarint( + ptr, [&accumulator, is_zigzag](uint64_t value) { + accumulator.Add(is_zigzag ? WireFormatLite::ZigZagDecode32( + static_cast(value)) + : value); + }); } } else { ABSL_DCHECK_EQ(rep, static_cast(field_layout::kRep8Bits)); - auto* field = &RefAt>(msg, entry.offset); - return ctx->ReadPackedVarint( - ptr, [field](uint64_t value) { field->Add(value); }); + auto& field = RefAt>(msg, entry.offset); + ScopedFieldAccumulator accumulator(field); + return ctx->ReadPackedVarint(ptr, [&](uint64_t value) { + accumulator.Add(static_cast(value)); + }); } PROTOBUF_MUSTTAIL return Error(PROTOBUF_TC_PARAM_NO_DATA_PASS); diff --git a/src/google/protobuf/repeated_field.h b/src/google/protobuf/repeated_field.h index 2da745dc0fb0..c657f076fcb4 100644 --- a/src/google/protobuf/repeated_field.h +++ b/src/google/protobuf/repeated_field.h @@ -45,6 +45,7 @@ #define GOOGLE_PROTOBUF_REPEATED_FIELD_H__ #include +#include #include #include #include @@ -56,7 +57,9 @@ #include "google/protobuf/port.h" #include "absl/base/attributes.h" #include "absl/base/dynamic_annotations.h" +#include "absl/base/optimization.h" #include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" #include "absl/meta/type_traits.h" #include "absl/strings/cord.h" #include "google/protobuf/generated_enum_util.h" @@ -346,6 +349,8 @@ class RepeatedField final // This is public due to it being called by generated code. inline void InternalSwap(RepeatedField* other); + void MergeFromArray(const Element* array, size_t length); + private: template friend class Arena::InternalHelper; @@ -605,6 +610,29 @@ inline int RepeatedField::Capacity() const { return total_size_; } +template +inline void RepeatedField::MergeFromArray(const Element* array, + size_t length) { + // Only supports trivially copyable types. + static_assert(std::is_trivially_copyable::value, + "only trivialy copyable types are supported"); + + ABSL_DCHECK_GT(length, 0u); + if (ABSL_PREDICT_TRUE(current_size_ + length > total_size_)) { + Grow(current_size_, current_size_ + length); + } + Element* elem = unsafe_elements(); + ABSL_DCHECK_NE(elem, nullptr); + void* p = elem + ExchangeCurrentSize(current_size_ + length); + memcpy(p, array, sizeof(Element) * length); +} + +template <> +inline void RepeatedField::MergeFromArray(const absl::Cord* array, + size_t length) { + ABSL_LOG(FATAL) << "not supported"; +} + template inline void RepeatedField::AddAlreadyReserved(Element value) { ABSL_DCHECK_LT(current_size_, total_size_); diff --git a/src/google/protobuf/repeated_field_unittest.cc b/src/google/protobuf/repeated_field_unittest.cc index 5507b2475d16..20d8494b4a76 100644 --- a/src/google/protobuf/repeated_field_unittest.cc +++ b/src/google/protobuf/repeated_field_unittest.cc @@ -529,6 +529,24 @@ TEST(RepeatedField, MergeFrom) { EXPECT_EQ(5, destination.Get(4)); } +TEST(RepeatedField, MergeFromArray) { + RepeatedField rep; + + for (int i = 0; i < 7; ++i) { + rep.Add(i); + } + int array[] = {7, 8, 9, 10, 11, 12}; + rep.MergeFromArray(array, 6); + for (int i = 13; i < 19; ++i) { + rep.Add(i); + } + + EXPECT_EQ(rep.size(), 19); + for (int i = 0; i < 19; ++i) { + EXPECT_EQ(rep.Get(i), i); + } +} + TEST(RepeatedField, CopyFrom) { RepeatedField source, destination;