Skip to content

Commit

Permalink
hpb: support hpb::RepeatedField<T> inside extensions (GetExtension)
Browse files Browse the repository at this point in the history
Before this change, hpb had no way of returning repeated fields (that are extensions) -- they were incorrectly treated as pure scalars (int32 vs repeated<int32>).

We rectify this hole and now return RepeatedField<T> for a given T.

This CL also cleans up the `if constexpr` special casing we were performing inside GetExtension and delegates that to the UpbExtensionTrait.

PiperOrigin-RevId: 706789273
  • Loading branch information
honglooker authored and copybara-github committed Dec 16, 2024
1 parent dc72833 commit 758b1fb
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 24 deletions.
66 changes: 45 additions & 21 deletions hpb/extension.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@
namespace hpb {
class ExtensionRegistry;

template <typename T>
class RepeatedField;

namespace internal {
template <typename Extendee, typename Extension>
class ExtensionIdentifier;

absl::Status MoveExtension(upb_Message* message, upb_Arena* message_arena,
const upb_MiniTableExtension* ext,
Expand All @@ -52,20 +57,44 @@ absl::Status SetExtension(upb_Message* message, upb_Arena* message_arena,
template <typename T, typename = void>
struct UpbExtensionTrait;

template <typename T>
struct UpbExtensionTrait<hpb::RepeatedField<T>> {
using ReturnType = typename RepeatedField<T>::CProxy;
using DefaultType = std::false_type;

template <typename Msg, typename Id>
static constexpr ReturnType Get(Msg message, const Id& id) {
auto upb_arr = upb_Message_GetExtensionArray(
hpb::interop::upb::GetMessage(message), id.mini_table_ext());
return ReturnType(upb_arr, hpb::interop::upb::GetArena(message));
}
};

template <>
struct UpbExtensionTrait<int32_t> {
using DefaultType = int32_t;
using ReturnType = int32_t;
static constexpr auto kGetter = upb_Message_GetExtensionInt32;
static constexpr auto kSetter = upb_Message_SetExtensionInt32;

template <typename Msg, typename Id>
static constexpr ReturnType Get(Msg message, const Id& id) {
auto default_val = hpb::internal::PrivateAccess::GetDefaultValue(id);
return upb_Message_GetExtensionInt32(hpb::interop::upb::GetMessage(message),
id.mini_table_ext(), default_val);
}
};

template <>
struct UpbExtensionTrait<int64_t> {
using DefaultType = int64_t;
using ReturnType = int64_t;
static constexpr auto kGetter = upb_Message_GetExtensionInt64;
static constexpr auto kSetter = upb_Message_SetExtensionInt64;
template <typename Msg, typename Id>
static constexpr ReturnType Get(Msg message, const Id& id) {
auto default_val = hpb::internal::PrivateAccess::GetDefaultValue(id);
return upb_Message_GetExtensionInt64(hpb::interop::upb::GetMessage(message),
id.mini_table_ext(), default_val);
}
};

// TODO: b/375460289 - flesh out non-promotional msg support that does
Expand All @@ -74,6 +103,19 @@ template <typename T>
struct UpbExtensionTrait<T> {
using DefaultType = std::false_type;
using ReturnType = Ptr<const T>;
template <typename Msg, typename Id>
static constexpr absl::StatusOr<ReturnType> Get(Msg message, const Id& id) {
upb_MessageValue value;
const bool ok = ::hpb::internal::GetOrPromoteExtension(
hpb::interop::upb::GetMessage(message), id.mini_table_ext(),
hpb::interop::upb::GetArena(message), &value);
if (!ok) {
return ExtensionNotFoundError(
upb_MiniTableExtension_Number(id.mini_table_ext()));
}
return Ptr<const T>(::hpb::interop::upb::MakeCHandle<T>(
value.msg_val, hpb::interop::upb::GetArena(message)));
}
};

// -------------------------------------------------------------------
Expand Down Expand Up @@ -285,25 +327,7 @@ absl::StatusOr<typename internal::UpbExtensionTrait<Extension>::ReturnType>
GetExtension(
Ptr<T> message,
const ::hpb::internal::ExtensionIdentifier<Extendee, Extension>& id) {
if constexpr (std::is_integral_v<Extension>) {
auto default_val = hpb::internal::PrivateAccess::GetDefaultValue(id);
absl::StatusOr<Extension> res =
hpb::internal::UpbExtensionTrait<Extension>::kGetter(
hpb::interop::upb::GetMessage(message), id.mini_table_ext(),
default_val);
return res;
} else {
upb_MessageValue value;
const bool ok = ::hpb::internal::GetOrPromoteExtension(
hpb::interop::upb::GetMessage(message), id.mini_table_ext(),
hpb::interop::upb::GetArena(message), &value);
if (!ok) {
return ExtensionNotFoundError(
upb_MiniTableExtension_Number(id.mini_table_ext()));
}
return Ptr<const Extension>(::hpb::interop::upb::MakeCHandle<Extension>(
value.msg_val, hpb::interop::upb::GetArena(message)));
}
return hpb::internal::UpbExtensionTrait<Extension>::Get(message, id);
}

template <typename T, typename Extendee, typename Extension,
Expand Down
2 changes: 2 additions & 0 deletions hpb/status.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

#include "google/protobuf/hpb/status.h"

#include <cstdint>

#include "absl/status/status.h"
#include "absl/strings/str_format.h"
#include "absl/types/source_location.h"
Expand Down
4 changes: 3 additions & 1 deletion hpb/status.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#ifndef GOOGLE_PROTOBUF_HPB_STATUS_H__
#define GOOGLE_PROTOBUF_HPB_STATUS_H__

#include <cstdint>

#include "absl/status/status.h"
#include "absl/types/source_location.h"
#include "upb/wire/decode.h"
Expand All @@ -30,7 +32,7 @@ absl::Status MessageAllocationError(
SourceLocation loc = SourceLocation::current());

absl::Status ExtensionNotFoundError(
int extension_number, SourceLocation loc = SourceLocation::current());
uint32_t extension_number, SourceLocation loc = SourceLocation::current());

absl::Status MessageDecodeError(upb_DecodeStatus status,
SourceLocation loc = SourceLocation::current());
Expand Down
12 changes: 10 additions & 2 deletions hpb_generator/gen_extensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,14 @@ void WriteExtensionIdentifierHeader(const protobuf::FieldDescriptor* ext,
std::string mini_table_name =
absl::StrCat(ExtensionIdentifierBase(ext), "_", ext->name(), "_ext");
std::string linkage = ext->extension_scope() ? "static" : "extern";
std::string ext_type = CppTypeParameterName(ext);
if (ext->is_repeated()) {
ext_type = absl::StrCat("::hpb::RepeatedField<", ext_type, ">");
}
ctx.Emit(
{{"linkage", linkage},
{"extendee_type", ContainingTypeName(ext)},
{"extension_type", CppTypeParameterName(ext)},
{"extension_type", ext_type},
{"extension_name", ext->name()}},
R"cc(
$linkage$ const ::hpb::internal::ExtensionIdentifier<$extendee_type$,
Expand All @@ -70,12 +74,16 @@ void WriteExtensionIdentifier(const protobuf::FieldDescriptor* ext,
absl::StrCat(ExtensionIdentifierBase(ext), "_", ext->name(), "_ext");
std::string class_prefix =
ext->extension_scope() ? ClassName(ext->extension_scope()) + "::" : "";
std::string ext_type = CppTypeParameterName(ext);
if (ext->is_repeated()) {
ext_type = absl::StrCat("::hpb::RepeatedField<", ext_type, ">");
}
ctx.Emit(
{{"containing_type_name", ContainingTypeName(ext)},
{"mini_table_name", mini_table_name},
{"ext_name", ext->name()},
{"default_value", DefaultValue(ext)},
{"ext_type", CppTypeParameterName(ext)},
{"ext_type", ext_type},
{"class_prefix", class_prefix}},
R"cc(
constexpr ::hpb::internal::ExtensionIdentifier<$containing_type_name$,
Expand Down
3 changes: 3 additions & 0 deletions hpb_generator/gen_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ std::string ToCamelCase(const absl::string_view input, bool lower_first) {
}

std::string DefaultValue(const FieldDescriptor* field) {
if (field->is_repeated()) {
return "::std::false_type()";
}
switch (field->cpp_type()) {
case FieldDescriptor::CPPTYPE_INT32:
return absl::StrCat(field->default_value_int32());
Expand Down
49 changes: 49 additions & 0 deletions hpb_generator/tests/extension_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ using ::hpb_unittest::protos::theme;
using ::hpb_unittest::protos::ThemeExtension;
using ::hpb_unittest::someotherpackage::protos::int32_ext;
using ::hpb_unittest::someotherpackage::protos::int64_ext;
using ::hpb_unittest::someotherpackage::protos::repeated_int32_ext;
using ::hpb_unittest::someotherpackage::protos::repeated_int64_ext;
using ::hpb_unittest::someotherpackage::protos::repeated_string_ext;

using ::testing::status::IsOkAndHolds;

Expand Down Expand Up @@ -403,4 +406,50 @@ TEST(CppGeneratedCode, ExtensionFieldNumberConstant) {
EXPECT_EQ(12003, ::hpb::ExtensionNumber(ThemeExtension::theme_extension));
}

TEST(CppGeneratedCode, GetExtensionRepeatedi32) {
TestModel model;
upb::Arena arena;
hpb::ExtensionRegistry extensions(arena);
extensions.AddExtension(repeated_int32_ext);
// These bytes are the serialized form of a repeated int32 field
// with two elements: [2, 3] @index 13004
auto bytes = "\342\254\006\002\002\003";
auto parsed_model = hpb::Parse<TestModel>(bytes, extensions).value();
auto res = hpb::GetExtension(&parsed_model, repeated_int32_ext);
EXPECT_EQ(true, res.ok());
EXPECT_EQ(res->size(), 2);
EXPECT_EQ((*res)[0], 2);
EXPECT_EQ((*res)[1], 3);
}

TEST(CppGeneratedCode, GetExtensionRepeatedi64) {
TestModel model;
upb::Arena arena;
hpb::ExtensionRegistry extensions(arena);
extensions.AddExtension(repeated_int64_ext);
// These bytes represent a repeated int64 field with one element: [322].
auto bytes = "\352\254\006\002\302\002";
auto parsed_model = hpb::Parse<TestModel>(bytes, extensions).value();
auto res = hpb::GetExtension(&parsed_model, repeated_int64_ext);
EXPECT_EQ(true, res.ok());
EXPECT_EQ(res->size(), 1);
EXPECT_EQ((*res)[0], 322);
}

TEST(CppGeneratedCode, GetExtensionRepeatedString) {
TestModel model;
upb::Arena arena;
hpb::ExtensionRegistry extensions(arena);
extensions.AddExtension(repeated_string_ext);
// These bytes represent a repeated string field with two elements:
// ["hello", "world"] @index 13006.
auto bytes = "\362\254\006\005hello\362\254\006\005world";
auto parsed_model = hpb::Parse<TestModel>(bytes, extensions).value();
auto res = hpb::GetExtension(&parsed_model, repeated_string_ext);
EXPECT_EQ(true, res.ok());
EXPECT_EQ(res->size(), 2);
EXPECT_EQ((*res)[0], "hello");
EXPECT_EQ((*res)[1], "world");
}

} // namespace
5 changes: 5 additions & 0 deletions hpb_generator/tests/test_extension.proto
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,9 @@ extend TestModel {
int32 int32_ext = 13002 [default = 644];

int64 int64_ext = 13003 [default = 2147483648];

repeated int32 repeated_int32_ext = 13004;
repeated int64 repeated_int64_ext = 13005;

repeated string repeated_string_ext = 13006;
}

0 comments on commit 758b1fb

Please sign in to comment.