Skip to content

Commit

Permalink
Improve copy of map fields.
Browse files Browse the repository at this point in the history
On copy construct/assign we can:
 - guarantee there is enough space for all elements. This avoids rehashing.
 - guarantee all elements are unique. This avoids lookups.

PiperOrigin-RevId: 709057139
  • Loading branch information
protobuf-github-bot authored and copybara-github committed Dec 23, 2024
1 parent 913f7b0 commit 34a397b
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 31 deletions.
130 changes: 103 additions & 27 deletions src/google/protobuf/map.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,19 @@
#include <type_traits>
#include <utility>

#include "absl/base/optimization.h"
#include "absl/memory/memory.h"
#include "google/protobuf/message_lite.h"

#include "absl/base/attributes.h"
#include "absl/base/optimization.h"
#include "absl/base/prefetch.h"
#include "absl/container/btree_map.h"
#include "absl/hash/hash.h"
#include "absl/log/absl_check.h"
#include "absl/meta/type_traits.h"
#include "absl/numeric/bits.h"
#include "absl/strings/string_view.h"
#include "google/protobuf/arena.h"
#include "google/protobuf/generated_enum_util.h"
#include "google/protobuf/internal_visibility.h"
#include "google/protobuf/message_lite.h"
#include "google/protobuf/port.h"
#include "google/protobuf/wire_format_lite.h"

Expand Down Expand Up @@ -365,6 +365,7 @@ class PROTOBUF_EXPORT UntypedMapBase {
protected:
// 16 bytes is the minimum useful size for the array cache in the arena.
static constexpr map_index_t kMinTableSize = 16 / sizeof(void*);
static constexpr map_index_t kMaxTableSize = map_index_t{1} << 31;

public:
Arena* arena() const { return arena_; }
Expand Down Expand Up @@ -792,6 +793,25 @@ class KeyMapBase : public UntypedMapBase {
return num_buckets - num_buckets / 16 * 4 - num_buckets % 2;
}

// For a particular size, calculate the lowest capacity `cap` where
// `size <= CalculateHiCutoff(cap)`.
static size_type CalculateCapacityForSize(size_type size) {
ABSL_DCHECK_NE(size, 0u);

if (size > kMaxTableSize / 2) {
return kMaxTableSize;
}

size_t capacity = size_type{1} << (std::numeric_limits<size_type>::digits -
absl::countl_zero(size - 1));

if (size > CalculateHiCutoff(capacity)) {
capacity *= 2;
}

return std::max<size_type>(capacity, kMinTableSize);
}

void AssertLoadFactor() const {
ABSL_DCHECK_LE(num_elements_, CalculateHiCutoff(num_buckets_));
}
Expand All @@ -812,7 +832,9 @@ class KeyMapBase : public UntypedMapBase {
// practice, this seems fine.
if (ABSL_PREDICT_FALSE(new_size > hi_cutoff)) {
if (num_buckets_ <= max_size() / 2) {
Resize(num_buckets_ * 2);
Resize(kMinTableSize > kGlobalEmptyTableSize * 2
? std::max(kMinTableSize, num_buckets_ * 2)
: num_buckets_ * 2);
return true;
}
} else if (ABSL_PREDICT_FALSE(new_size <= lo_cutoff &&
Expand All @@ -836,13 +858,34 @@ class KeyMapBase : public UntypedMapBase {
return false;
}

// Interpret `head` as a linked list and insert all the nodes into `this`.
// REQUIRES: this->empty()
// REQUIRES: the input nodes have unique keys
PROTOBUF_NOINLINE void MergeIntoEmpty(NodeBase* head, size_t num_nodes) {
ABSL_DCHECK_EQ(size(), size_t{0});
ABSL_DCHECK_NE(num_nodes, size_t{0});
if (const map_index_t needed_capacity = CalculateCapacityForSize(num_nodes);
needed_capacity != this->num_buckets_) {
Resize(std::max(kMinTableSize, needed_capacity));
}
num_elements_ = num_nodes;
AssertLoadFactor();
while (head != nullptr) {
KeyNode* node = static_cast<KeyNode*>(head);
head = head->next;
absl::PrefetchToLocalCacheNta(head);
InsertUnique(BucketNumber(TS::ToView(node->key())), node);
}
}

// Resize to the given number of buckets.
void Resize(map_index_t new_num_buckets) {
ABSL_DCHECK_GE(new_num_buckets, kMinTableSize);
ABSL_DCHECK(absl::has_single_bit(new_num_buckets));
if (num_buckets_ == kGlobalEmptyTableSize) {
// This is the global empty array.
// Just overwrite with a new one. No need to transfer or free anything.
ABSL_DCHECK_GE(kMinTableSize, new_num_buckets);
num_buckets_ = index_of_first_non_null_ = kMinTableSize;
num_buckets_ = index_of_first_non_null_ = new_num_buckets;
table_ = CreateEmptyTable(num_buckets_);
return;
}
Expand Down Expand Up @@ -997,7 +1040,7 @@ class Map : private internal::KeyMapBase<internal::KeyForBase<Key>> {
private:
Map(Arena* arena, const Map& other) : Map(arena) {
StaticValidityCheck();
insert(other.begin(), other.end());
CopyFromImpl(other);
}
static_assert(!std::is_const<mapped_type>::value &&
!std::is_const<key_type>::value,
Expand Down Expand Up @@ -1343,7 +1386,7 @@ class Map : private internal::KeyMapBase<internal::KeyForBase<Key>> {
Map& operator=(const Map& other) ABSL_ATTRIBUTE_LIFETIME_BOUND {
if (this != &other) {
clear();
insert(other.begin(), other.end());
CopyFromImpl(other);
}
return *this;
}
Expand All @@ -1352,12 +1395,13 @@ class Map : private internal::KeyMapBase<internal::KeyForBase<Key>> {
if (arena() == other.arena()) {
InternalSwap(&other);
} else {
// TODO: optimize this. The temporary copy can be allocated
// in the same arena as the other message, and the "other = copy" can
// be replaced with the fast-path swap above.
Map copy = *this;
*this = other;
other = copy;
size_t other_size = other.size();
Node* other_copy = this->CloneFromOther(other);
other = *this;
this->clear();
if (other_size != 0) {
this->MergeIntoEmpty(other_copy, other_size);
}
}
}

Expand Down Expand Up @@ -1406,18 +1450,7 @@ class Map : private internal::KeyMapBase<internal::KeyForBase<Key>> {
}

template <typename K, typename... Args>
std::pair<iterator, bool> TryEmplaceInternal(K&& k, Args&&... args) {
auto p = this->FindHelper(TS::ToView(k));
internal::map_index_t b = p.bucket;
// Case 1: key was already present.
if (p.node != nullptr)
return std::make_pair(iterator(internal::UntypedMapIterator{
static_cast<Node*>(p.node), this, p.bucket}),
false);
// Case 2: insert.
if (this->ResizeIfLoadIsOutOfRange(this->num_elements_ + 1)) {
b = this->BucketNumber(TS::ToView(k));
}
PROTOBUF_ALWAYS_INLINE Node* CreateNode(K&& k, Args&&... args) {
// If K is not key_type, make the conversion to key_type explicit.
using TypeToInit = typename std::conditional<
std::is_same<typename std::decay<K>::type, key_type>::value, K&&,
Expand All @@ -1437,7 +1470,50 @@ class Map : private internal::KeyMapBase<internal::KeyForBase<Key>> {
// Note: if `T` is arena constructible, `Args` needs to be empty.
Arena::CreateInArenaStorage(&node->kv.second, this->arena_,
std::forward<Args>(args)...);
return node;
}

// Copy all elements from `other`, using the arena from `this`.
// Return them as a linked list, using the `next` pointer in the node.
PROTOBUF_NOINLINE Node* CloneFromOther(const Map& other) {
Node* head = nullptr;
for (const auto& [key, value] : other) {
Node* new_node;
if constexpr (std::is_base_of_v<MessageLite, mapped_type>) {
new_node = CreateNode(key);
new_node->kv.second = value;
} else {
new_node = CreateNode(key, value);
}
new_node->next = head;
head = new_node;
}
return head;
}

void CopyFromImpl(const Map& other) {
if (other.empty()) return;
// We split the logic in two: first we clone the data which requires
// Key/Value types, then we insert them all which only requires Key.
// That way we reduce code duplication.
this->MergeIntoEmpty(CloneFromOther(other), other.size());
}

template <typename K, typename... Args>
std::pair<iterator, bool> TryEmplaceInternal(K&& k, Args&&... args) {
auto p = this->FindHelper(TS::ToView(k));
internal::map_index_t b = p.bucket;
// Case 1: key was already present.
if (p.node != nullptr) {
return std::make_pair(iterator(internal::UntypedMapIterator{
static_cast<Node*>(p.node), this, p.bucket}),
false);
}
// Case 2: insert.
if (this->ResizeIfLoadIsOutOfRange(this->num_elements_ + 1)) {
b = this->BucketNumber(TS::ToView(k));
}
auto* node = CreateNode(std::forward<K>(k), std::forward<Args>(args)...);
this->InsertUnique(b, node);
++this->num_elements_;
return std::make_pair(iterator(internal::UntypedMapIterator{node, this, b}),
Expand Down
8 changes: 5 additions & 3 deletions src/google/protobuf/map_field.h
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,10 @@ class TypeDefinedMapFieldBase : public MapFieldBase {
TypeDefinedMapFieldBase(const VTable* vtable, Arena* arena)
: MapFieldBase(vtable, arena), map_(arena) {}

TypeDefinedMapFieldBase(const VTable* vtable, Arena* arena,
const TypeDefinedMapFieldBase& from)
: MapFieldBase(vtable, arena), map_(arena, from.GetMap()) {}

protected:
~TypeDefinedMapFieldBase() { map_.~Map(); }

Expand Down Expand Up @@ -657,9 +661,7 @@ class MapField final : public TypeDefinedMapFieldBase<Key, T> {
MapField(ArenaInitialized, Arena* arena) : MapField(arena) {}
MapField(InternalVisibility, Arena* arena) : MapField(arena) {}
MapField(InternalVisibility, Arena* arena, const MapField& from)
: MapField(arena) {
this->MergeFromImpl(*this, from);
}
: TypeDefinedMapFieldBase<Key, T>(&kVTable, arena, from) {}

private:
typedef void InternalArenaConstructable_;
Expand Down
18 changes: 18 additions & 0 deletions src/google/protobuf/map_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,24 @@ TEST(MapTest, CopyConstructionMaintainsProperLoadFactor) {
}
}

TEST(MapTest, CalculateCapacityForSizeTest) {
for (size_t size = 1; size < 1000; ++size) {
size_t capacity = MapTestPeer::CalculateCapacityForSize(size);
// Verify is large enough for `size`.
EXPECT_LE(size, MapTestPeer::CalculateHiCutoff(capacity));
if (capacity > MapTestPeer::kMinTableSize) {
// Verify it's the smallest capacity that's large enough.
EXPECT_GT(size, MapTestPeer::CalculateHiCutoff(capacity / 2));
}
}

// Verify very large size does not overflow bucket calculation.
for (size_t size : {0x30000001u, 0x40000000u, 0x50000000u, 0x60000000u,
0x70000000u, 0x80000000u, 0x90000000u, 0xFFFFFFFFu}) {
EXPECT_EQ(0x80000000u, MapTestPeer::CalculateCapacityForSize(size));
}
}

TEST(MapTest, AlwaysSerializesBothEntries) {
for (const Message* prototype :
{static_cast<const Message*>(
Expand Down
9 changes: 8 additions & 1 deletion src/google/protobuf/map_test.inc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#endif // _WIN32

#include <algorithm>
#include <cstddef>
#include <memory>
#include <random>
#include <sstream>
Expand Down Expand Up @@ -201,14 +202,20 @@ struct MapTestPeer {
map.Resize(num_buckets);
}

static int CalculateHiCutoff(int num_buckets) {
static size_t CalculateHiCutoff(size_t num_buckets) {
return Map<int, int>::CalculateHiCutoff(num_buckets);
}

static size_t CalculateCapacityForSize(size_t size) {
return Map<int, int>::CalculateCapacityForSize(size);
}

template <typename Map>
static auto GetTypeInfo() {
return Map::GetTypeInfo();
}

static constexpr size_t kMinTableSize = UntypedMapBase::kMinTableSize;
};

namespace {
Expand Down

0 comments on commit 34a397b

Please sign in to comment.