From 5866e4575c7e43cfcabb98efd6f36b4ba5c4bb30 Mon Sep 17 00:00:00 2001 From: Mike Kruskal Date: Tue, 23 May 2023 13:59:02 -0700 Subject: [PATCH] Add assertions to reflection methods. This will ensure that the message is the appropriate type. Failing to pass this check can lead to UB and crashes. PiperOrigin-RevId: 534549318 --- objectivec/GPBAny.pbobjc.m | 4 +- objectivec/GPBApi.pbobjc.m | 12 +-- objectivec/GPBDuration.pbobjc.m | 4 +- objectivec/GPBEmpty.pbobjc.m | 4 +- objectivec/GPBFieldMask.pbobjc.m | 4 +- objectivec/GPBSourceContext.pbobjc.m | 4 +- objectivec/GPBStruct.pbobjc.m | 12 +-- objectivec/GPBTimestamp.pbobjc.m | 4 +- objectivec/GPBType.pbobjc.m | 20 ++--- objectivec/GPBWrappers.pbobjc.m | 36 ++++----- .../protobuf/generated_message_reflection.cc | 77 ++++++++++++++----- .../generated_message_reflection_unittest.cc | 10 +++ 12 files changed, 120 insertions(+), 71 deletions(-) diff --git a/objectivec/GPBAny.pbobjc.m b/objectivec/GPBAny.pbobjc.m index ada41e569d3d..9b7939d52b34 100644 --- a/objectivec/GPBAny.pbobjc.m +++ b/objectivec/GPBAny.pbobjc.m @@ -91,9 +91,9 @@ + (GPBDescriptor *)descriptor { "\001\001\004\241!!\000"; [localDescriptor setupExtraTextInfo:extraTextFormatInfo]; #endif // !GPBOBJC_SKIP_MESSAGE_TEXTFORMAT_EXTRAS - #if defined(DEBUG) && DEBUG +#if defined(DEBUG) && DEBUG NSAssert(descriptor == nil, @"Startup recursed!"); - #endif // DEBUG +#endif // DEBUG descriptor = localDescriptor; } return descriptor; diff --git a/objectivec/GPBApi.pbobjc.m b/objectivec/GPBApi.pbobjc.m index a6283a84f176..b416445c0f04 100644 --- a/objectivec/GPBApi.pbobjc.m +++ b/objectivec/GPBApi.pbobjc.m @@ -145,9 +145,9 @@ + (GPBDescriptor *)descriptor { fieldCount:(uint32_t)(sizeof(fields) / sizeof(GPBMessageFieldDescription)) storageSize:sizeof(GPBApi__storage_) flags:(GPBDescriptorInitializationFlags)(GPBDescriptorInitializationFlag_UsesClassRefs | GPBDescriptorInitializationFlag_Proto3OptionalKnown | GPBDescriptorInitializationFlag_ClosedEnumSupportKnown)]; - #if defined(DEBUG) && DEBUG +#if defined(DEBUG) && DEBUG NSAssert(descriptor == nil, @"Startup recursed!"); - #endif // DEBUG +#endif // DEBUG descriptor = localDescriptor; } return descriptor; @@ -272,9 +272,9 @@ + (GPBDescriptor *)descriptor { "\002\002\007\244\241!!\000\004\010\244\241!!\000"; [localDescriptor setupExtraTextInfo:extraTextFormatInfo]; #endif // !GPBOBJC_SKIP_MESSAGE_TEXTFORMAT_EXTRAS - #if defined(DEBUG) && DEBUG +#if defined(DEBUG) && DEBUG NSAssert(descriptor == nil, @"Startup recursed!"); - #endif // DEBUG +#endif // DEBUG descriptor = localDescriptor; } return descriptor; @@ -341,9 +341,9 @@ + (GPBDescriptor *)descriptor { fieldCount:(uint32_t)(sizeof(fields) / sizeof(GPBMessageFieldDescription)) storageSize:sizeof(GPBMixin__storage_) flags:(GPBDescriptorInitializationFlags)(GPBDescriptorInitializationFlag_UsesClassRefs | GPBDescriptorInitializationFlag_Proto3OptionalKnown | GPBDescriptorInitializationFlag_ClosedEnumSupportKnown)]; - #if defined(DEBUG) && DEBUG +#if defined(DEBUG) && DEBUG NSAssert(descriptor == nil, @"Startup recursed!"); - #endif // DEBUG +#endif // DEBUG descriptor = localDescriptor; } return descriptor; diff --git a/objectivec/GPBDuration.pbobjc.m b/objectivec/GPBDuration.pbobjc.m index 1a54a76b306c..a618e9cd57b5 100644 --- a/objectivec/GPBDuration.pbobjc.m +++ b/objectivec/GPBDuration.pbobjc.m @@ -86,9 +86,9 @@ + (GPBDescriptor *)descriptor { fieldCount:(uint32_t)(sizeof(fields) / sizeof(GPBMessageFieldDescription)) storageSize:sizeof(GPBDuration__storage_) flags:(GPBDescriptorInitializationFlags)(GPBDescriptorInitializationFlag_UsesClassRefs | GPBDescriptorInitializationFlag_Proto3OptionalKnown | GPBDescriptorInitializationFlag_ClosedEnumSupportKnown)]; - #if defined(DEBUG) && DEBUG +#if defined(DEBUG) && DEBUG NSAssert(descriptor == nil, @"Startup recursed!"); - #endif // DEBUG +#endif // DEBUG descriptor = localDescriptor; } return descriptor; diff --git a/objectivec/GPBEmpty.pbobjc.m b/objectivec/GPBEmpty.pbobjc.m index b69738c5934f..c33be8bb963a 100644 --- a/objectivec/GPBEmpty.pbobjc.m +++ b/objectivec/GPBEmpty.pbobjc.m @@ -62,9 +62,9 @@ + (GPBDescriptor *)descriptor { fieldCount:0 storageSize:sizeof(GPBEmpty__storage_) flags:(GPBDescriptorInitializationFlags)(GPBDescriptorInitializationFlag_UsesClassRefs | GPBDescriptorInitializationFlag_Proto3OptionalKnown | GPBDescriptorInitializationFlag_ClosedEnumSupportKnown)]; - #if defined(DEBUG) && DEBUG +#if defined(DEBUG) && DEBUG NSAssert(descriptor == nil, @"Startup recursed!"); - #endif // DEBUG +#endif // DEBUG descriptor = localDescriptor; } return descriptor; diff --git a/objectivec/GPBFieldMask.pbobjc.m b/objectivec/GPBFieldMask.pbobjc.m index 8de5ad42ffb0..347bcbef0e92 100644 --- a/objectivec/GPBFieldMask.pbobjc.m +++ b/objectivec/GPBFieldMask.pbobjc.m @@ -75,9 +75,9 @@ + (GPBDescriptor *)descriptor { fieldCount:(uint32_t)(sizeof(fields) / sizeof(GPBMessageFieldDescription)) storageSize:sizeof(GPBFieldMask__storage_) flags:(GPBDescriptorInitializationFlags)(GPBDescriptorInitializationFlag_UsesClassRefs | GPBDescriptorInitializationFlag_Proto3OptionalKnown | GPBDescriptorInitializationFlag_ClosedEnumSupportKnown)]; - #if defined(DEBUG) && DEBUG +#if defined(DEBUG) && DEBUG NSAssert(descriptor == nil, @"Startup recursed!"); - #endif // DEBUG +#endif // DEBUG descriptor = localDescriptor; } return descriptor; diff --git a/objectivec/GPBSourceContext.pbobjc.m b/objectivec/GPBSourceContext.pbobjc.m index 7ae605b32a81..cea417de1be8 100644 --- a/objectivec/GPBSourceContext.pbobjc.m +++ b/objectivec/GPBSourceContext.pbobjc.m @@ -75,9 +75,9 @@ + (GPBDescriptor *)descriptor { fieldCount:(uint32_t)(sizeof(fields) / sizeof(GPBMessageFieldDescription)) storageSize:sizeof(GPBSourceContext__storage_) flags:(GPBDescriptorInitializationFlags)(GPBDescriptorInitializationFlag_UsesClassRefs | GPBDescriptorInitializationFlag_Proto3OptionalKnown | GPBDescriptorInitializationFlag_ClosedEnumSupportKnown)]; - #if defined(DEBUG) && DEBUG +#if defined(DEBUG) && DEBUG NSAssert(descriptor == nil, @"Startup recursed!"); - #endif // DEBUG +#endif // DEBUG descriptor = localDescriptor; } return descriptor; diff --git a/objectivec/GPBStruct.pbobjc.m b/objectivec/GPBStruct.pbobjc.m index eabd761866e0..600b49574f51 100644 --- a/objectivec/GPBStruct.pbobjc.m +++ b/objectivec/GPBStruct.pbobjc.m @@ -115,9 +115,9 @@ + (GPBDescriptor *)descriptor { fieldCount:(uint32_t)(sizeof(fields) / sizeof(GPBMessageFieldDescription)) storageSize:sizeof(GPBStruct__storage_) flags:(GPBDescriptorInitializationFlags)(GPBDescriptorInitializationFlag_UsesClassRefs | GPBDescriptorInitializationFlag_Proto3OptionalKnown | GPBDescriptorInitializationFlag_ClosedEnumSupportKnown)]; - #if defined(DEBUG) && DEBUG +#if defined(DEBUG) && DEBUG NSAssert(descriptor == nil, @"Startup recursed!"); - #endif // DEBUG +#endif // DEBUG descriptor = localDescriptor; } return descriptor; @@ -222,9 +222,9 @@ + (GPBDescriptor *)descriptor { [localDescriptor setupOneofs:oneofs count:(uint32_t)(sizeof(oneofs) / sizeof(char*)) firstHasIndex:-1]; - #if defined(DEBUG) && DEBUG +#if defined(DEBUG) && DEBUG NSAssert(descriptor == nil, @"Startup recursed!"); - #endif // DEBUG +#endif // DEBUG descriptor = localDescriptor; } return descriptor; @@ -285,9 +285,9 @@ + (GPBDescriptor *)descriptor { fieldCount:(uint32_t)(sizeof(fields) / sizeof(GPBMessageFieldDescription)) storageSize:sizeof(GPBListValue__storage_) flags:(GPBDescriptorInitializationFlags)(GPBDescriptorInitializationFlag_UsesClassRefs | GPBDescriptorInitializationFlag_Proto3OptionalKnown | GPBDescriptorInitializationFlag_ClosedEnumSupportKnown)]; - #if defined(DEBUG) && DEBUG +#if defined(DEBUG) && DEBUG NSAssert(descriptor == nil, @"Startup recursed!"); - #endif // DEBUG +#endif // DEBUG descriptor = localDescriptor; } return descriptor; diff --git a/objectivec/GPBTimestamp.pbobjc.m b/objectivec/GPBTimestamp.pbobjc.m index edd73498a21f..dc6baebf2c95 100644 --- a/objectivec/GPBTimestamp.pbobjc.m +++ b/objectivec/GPBTimestamp.pbobjc.m @@ -86,9 +86,9 @@ + (GPBDescriptor *)descriptor { fieldCount:(uint32_t)(sizeof(fields) / sizeof(GPBMessageFieldDescription)) storageSize:sizeof(GPBTimestamp__storage_) flags:(GPBDescriptorInitializationFlags)(GPBDescriptorInitializationFlag_UsesClassRefs | GPBDescriptorInitializationFlag_Proto3OptionalKnown | GPBDescriptorInitializationFlag_ClosedEnumSupportKnown)]; - #if defined(DEBUG) && DEBUG +#if defined(DEBUG) && DEBUG NSAssert(descriptor == nil, @"Startup recursed!"); - #endif // DEBUG +#endif // DEBUG descriptor = localDescriptor; } return descriptor; diff --git a/objectivec/GPBType.pbobjc.m b/objectivec/GPBType.pbobjc.m index db7a41f7e515..7b44e3995170 100644 --- a/objectivec/GPBType.pbobjc.m +++ b/objectivec/GPBType.pbobjc.m @@ -307,9 +307,9 @@ + (GPBDescriptor *)descriptor { fieldCount:(uint32_t)(sizeof(fields) / sizeof(GPBMessageFieldDescription)) storageSize:sizeof(GPBType__storage_) flags:(GPBDescriptorInitializationFlags)(GPBDescriptorInitializationFlag_UsesClassRefs | GPBDescriptorInitializationFlag_Proto3OptionalKnown | GPBDescriptorInitializationFlag_ClosedEnumSupportKnown)]; - #if defined(DEBUG) && DEBUG +#if defined(DEBUG) && DEBUG NSAssert(descriptor == nil, @"Startup recursed!"); - #endif // DEBUG +#endif // DEBUG descriptor = localDescriptor; } return descriptor; @@ -468,9 +468,9 @@ + (GPBDescriptor *)descriptor { "\001\006\004\241!!\000"; [localDescriptor setupExtraTextInfo:extraTextFormatInfo]; #endif // !GPBOBJC_SKIP_MESSAGE_TEXTFORMAT_EXTRAS - #if defined(DEBUG) && DEBUG +#if defined(DEBUG) && DEBUG NSAssert(descriptor == nil, @"Startup recursed!"); - #endif // DEBUG +#endif // DEBUG descriptor = localDescriptor; } return descriptor; @@ -593,9 +593,9 @@ + (GPBDescriptor *)descriptor { fieldCount:(uint32_t)(sizeof(fields) / sizeof(GPBMessageFieldDescription)) storageSize:sizeof(GPBEnum__storage_) flags:(GPBDescriptorInitializationFlags)(GPBDescriptorInitializationFlag_UsesClassRefs | GPBDescriptorInitializationFlag_Proto3OptionalKnown | GPBDescriptorInitializationFlag_ClosedEnumSupportKnown)]; - #if defined(DEBUG) && DEBUG +#if defined(DEBUG) && DEBUG NSAssert(descriptor == nil, @"Startup recursed!"); - #endif // DEBUG +#endif // DEBUG descriptor = localDescriptor; } return descriptor; @@ -673,9 +673,9 @@ + (GPBDescriptor *)descriptor { fieldCount:(uint32_t)(sizeof(fields) / sizeof(GPBMessageFieldDescription)) storageSize:sizeof(GPBEnumValue__storage_) flags:(GPBDescriptorInitializationFlags)(GPBDescriptorInitializationFlag_UsesClassRefs | GPBDescriptorInitializationFlag_Proto3OptionalKnown | GPBDescriptorInitializationFlag_ClosedEnumSupportKnown)]; - #if defined(DEBUG) && DEBUG +#if defined(DEBUG) && DEBUG NSAssert(descriptor == nil, @"Startup recursed!"); - #endif // DEBUG +#endif // DEBUG descriptor = localDescriptor; } return descriptor; @@ -730,9 +730,9 @@ + (GPBDescriptor *)descriptor { fieldCount:(uint32_t)(sizeof(fields) / sizeof(GPBMessageFieldDescription)) storageSize:sizeof(GPBOption__storage_) flags:(GPBDescriptorInitializationFlags)(GPBDescriptorInitializationFlag_UsesClassRefs | GPBDescriptorInitializationFlag_Proto3OptionalKnown | GPBDescriptorInitializationFlag_ClosedEnumSupportKnown)]; - #if defined(DEBUG) && DEBUG +#if defined(DEBUG) && DEBUG NSAssert(descriptor == nil, @"Startup recursed!"); - #endif // DEBUG +#endif // DEBUG descriptor = localDescriptor; } return descriptor; diff --git a/objectivec/GPBWrappers.pbobjc.m b/objectivec/GPBWrappers.pbobjc.m index 07be68bacee4..30461e3f6fe9 100644 --- a/objectivec/GPBWrappers.pbobjc.m +++ b/objectivec/GPBWrappers.pbobjc.m @@ -83,9 +83,9 @@ + (GPBDescriptor *)descriptor { fieldCount:(uint32_t)(sizeof(fields) / sizeof(GPBMessageFieldDescription)) storageSize:sizeof(GPBDoubleValue__storage_) flags:(GPBDescriptorInitializationFlags)(GPBDescriptorInitializationFlag_UsesClassRefs | GPBDescriptorInitializationFlag_Proto3OptionalKnown | GPBDescriptorInitializationFlag_ClosedEnumSupportKnown)]; - #if defined(DEBUG) && DEBUG +#if defined(DEBUG) && DEBUG NSAssert(descriptor == nil, @"Startup recursed!"); - #endif // DEBUG +#endif // DEBUG descriptor = localDescriptor; } return descriptor; @@ -129,9 +129,9 @@ + (GPBDescriptor *)descriptor { fieldCount:(uint32_t)(sizeof(fields) / sizeof(GPBMessageFieldDescription)) storageSize:sizeof(GPBFloatValue__storage_) flags:(GPBDescriptorInitializationFlags)(GPBDescriptorInitializationFlag_UsesClassRefs | GPBDescriptorInitializationFlag_Proto3OptionalKnown | GPBDescriptorInitializationFlag_ClosedEnumSupportKnown)]; - #if defined(DEBUG) && DEBUG +#if defined(DEBUG) && DEBUG NSAssert(descriptor == nil, @"Startup recursed!"); - #endif // DEBUG +#endif // DEBUG descriptor = localDescriptor; } return descriptor; @@ -175,9 +175,9 @@ + (GPBDescriptor *)descriptor { fieldCount:(uint32_t)(sizeof(fields) / sizeof(GPBMessageFieldDescription)) storageSize:sizeof(GPBInt64Value__storage_) flags:(GPBDescriptorInitializationFlags)(GPBDescriptorInitializationFlag_UsesClassRefs | GPBDescriptorInitializationFlag_Proto3OptionalKnown | GPBDescriptorInitializationFlag_ClosedEnumSupportKnown)]; - #if defined(DEBUG) && DEBUG +#if defined(DEBUG) && DEBUG NSAssert(descriptor == nil, @"Startup recursed!"); - #endif // DEBUG +#endif // DEBUG descriptor = localDescriptor; } return descriptor; @@ -221,9 +221,9 @@ + (GPBDescriptor *)descriptor { fieldCount:(uint32_t)(sizeof(fields) / sizeof(GPBMessageFieldDescription)) storageSize:sizeof(GPBUInt64Value__storage_) flags:(GPBDescriptorInitializationFlags)(GPBDescriptorInitializationFlag_UsesClassRefs | GPBDescriptorInitializationFlag_Proto3OptionalKnown | GPBDescriptorInitializationFlag_ClosedEnumSupportKnown)]; - #if defined(DEBUG) && DEBUG +#if defined(DEBUG) && DEBUG NSAssert(descriptor == nil, @"Startup recursed!"); - #endif // DEBUG +#endif // DEBUG descriptor = localDescriptor; } return descriptor; @@ -267,9 +267,9 @@ + (GPBDescriptor *)descriptor { fieldCount:(uint32_t)(sizeof(fields) / sizeof(GPBMessageFieldDescription)) storageSize:sizeof(GPBInt32Value__storage_) flags:(GPBDescriptorInitializationFlags)(GPBDescriptorInitializationFlag_UsesClassRefs | GPBDescriptorInitializationFlag_Proto3OptionalKnown | GPBDescriptorInitializationFlag_ClosedEnumSupportKnown)]; - #if defined(DEBUG) && DEBUG +#if defined(DEBUG) && DEBUG NSAssert(descriptor == nil, @"Startup recursed!"); - #endif // DEBUG +#endif // DEBUG descriptor = localDescriptor; } return descriptor; @@ -313,9 +313,9 @@ + (GPBDescriptor *)descriptor { fieldCount:(uint32_t)(sizeof(fields) / sizeof(GPBMessageFieldDescription)) storageSize:sizeof(GPBUInt32Value__storage_) flags:(GPBDescriptorInitializationFlags)(GPBDescriptorInitializationFlag_UsesClassRefs | GPBDescriptorInitializationFlag_Proto3OptionalKnown | GPBDescriptorInitializationFlag_ClosedEnumSupportKnown)]; - #if defined(DEBUG) && DEBUG +#if defined(DEBUG) && DEBUG NSAssert(descriptor == nil, @"Startup recursed!"); - #endif // DEBUG +#endif // DEBUG descriptor = localDescriptor; } return descriptor; @@ -358,9 +358,9 @@ + (GPBDescriptor *)descriptor { fieldCount:(uint32_t)(sizeof(fields) / sizeof(GPBMessageFieldDescription)) storageSize:sizeof(GPBBoolValue__storage_) flags:(GPBDescriptorInitializationFlags)(GPBDescriptorInitializationFlag_UsesClassRefs | GPBDescriptorInitializationFlag_Proto3OptionalKnown | GPBDescriptorInitializationFlag_ClosedEnumSupportKnown)]; - #if defined(DEBUG) && DEBUG +#if defined(DEBUG) && DEBUG NSAssert(descriptor == nil, @"Startup recursed!"); - #endif // DEBUG +#endif // DEBUG descriptor = localDescriptor; } return descriptor; @@ -404,9 +404,9 @@ + (GPBDescriptor *)descriptor { fieldCount:(uint32_t)(sizeof(fields) / sizeof(GPBMessageFieldDescription)) storageSize:sizeof(GPBStringValue__storage_) flags:(GPBDescriptorInitializationFlags)(GPBDescriptorInitializationFlag_UsesClassRefs | GPBDescriptorInitializationFlag_Proto3OptionalKnown | GPBDescriptorInitializationFlag_ClosedEnumSupportKnown)]; - #if defined(DEBUG) && DEBUG +#if defined(DEBUG) && DEBUG NSAssert(descriptor == nil, @"Startup recursed!"); - #endif // DEBUG +#endif // DEBUG descriptor = localDescriptor; } return descriptor; @@ -450,9 +450,9 @@ + (GPBDescriptor *)descriptor { fieldCount:(uint32_t)(sizeof(fields) / sizeof(GPBMessageFieldDescription)) storageSize:sizeof(GPBBytesValue__storage_) flags:(GPBDescriptorInitializationFlags)(GPBDescriptorInitializationFlag_UsesClassRefs | GPBDescriptorInitializationFlag_Proto3OptionalKnown | GPBDescriptorInitializationFlag_ClosedEnumSupportKnown)]; - #if defined(DEBUG) && DEBUG +#if defined(DEBUG) && DEBUG NSAssert(descriptor == nil, @"Startup recursed!"); - #endif // DEBUG +#endif // DEBUG descriptor = localDescriptor; } return descriptor; diff --git a/src/google/protobuf/generated_message_reflection.cc b/src/google/protobuf/generated_message_reflection.cc index bd3b4e07ce25..b563c459cecd 100644 --- a/src/google/protobuf/generated_message_reflection.cc +++ b/src/google/protobuf/generated_message_reflection.cc @@ -46,6 +46,7 @@ #include "absl/log/absl_check.h" #include "absl/log/absl_log.h" #include "absl/strings/match.h" +#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "google/protobuf/descriptor.h" @@ -201,6 +202,22 @@ void ReportReflectionUsageError(const Descriptor* descriptor, << description; } +#ifndef NDEBUG +void ReportReflectionUsageMessageError(const Descriptor* expected, + const Descriptor* actual, + const FieldDescriptor* field, + const char* method) { + ABSL_LOG(FATAL) << absl::StrFormat( + "Protocol Buffer reflection usage error:\n" + " Method : google::protobuf::Reflection::%s\n" + " Expected type: %s\n" + " Actual type : %s\n" + " Field : %s\n" + " Problem : Message is not the right type for reflection", + method, expected->full_name(), actual->full_name(), field->full_name()); +} +#endif + const char* cpptype_names_[FieldDescriptor::MAX_CPPTYPE + 1] = { "INVALID_CPPTYPE", "CPPTYPE_INT32", "CPPTYPE_INT64", "CPPTYPE_UINT32", "CPPTYPE_UINT64", "CPPTYPE_DOUBLE", "CPPTYPE_FLOAT", "CPPTYPE_BOOL", @@ -266,6 +283,16 @@ static void ReportReflectionUsageEnumTypeError( if (value->type() != field->enum_type()) \ ReportReflectionUsageEnumTypeError(descriptor_, field, #METHOD, value) +#ifdef NDEBUG +// Avoid a virtual method call in optimized builds. +#define USAGE_CHECK_MESSAGE(METHOD, MESSAGE) +#else +#define USAGE_CHECK_MESSAGE(METHOD, MESSAGE) \ + if (descriptor_ != (MESSAGE)->GetDescriptor()) \ + ReportReflectionUsageMessageError(descriptor_, (MESSAGE)->GetDescriptor(), \ + field, #METHOD) +#endif + #define USAGE_CHECK_MESSAGE_TYPE(METHOD) \ USAGE_CHECK_EQ(field->containing_type(), descriptor_, METHOD, \ "Field does not match message type."); @@ -277,10 +304,17 @@ static void ReportReflectionUsageEnumTypeError( "Field is singular; the method requires a repeated field.") #define USAGE_CHECK_ALL(METHOD, LABEL, CPPTYPE) \ + USAGE_CHECK_MESSAGE(METHOD, &message); \ USAGE_CHECK_MESSAGE_TYPE(METHOD); \ USAGE_CHECK_##LABEL(METHOD); \ USAGE_CHECK_TYPE(METHOD, CPPTYPE) +#define USAGE_MUTABLE_CHECK_ALL(METHOD, LABEL, CPPTYPE) \ + USAGE_CHECK_MESSAGE(METHOD, message); \ + USAGE_CHECK_MESSAGE_TYPE(METHOD); \ + USAGE_CHECK_##LABEL(METHOD); \ + USAGE_CHECK_TYPE(METHOD, CPPTYPE) + } // namespace // =================================================================== @@ -1175,6 +1209,7 @@ void Reflection::UnsafeArenaSwapFields( bool Reflection::HasField(const Message& message, const FieldDescriptor* field) const { + USAGE_CHECK_MESSAGE(HasField, &message); USAGE_CHECK_MESSAGE_TYPE(HasField); USAGE_CHECK_SINGULAR(HasField); @@ -1273,6 +1308,7 @@ void Reflection::InternalSwap(Message* lhs, Message* rhs) const { int Reflection::FieldSize(const Message& message, const FieldDescriptor* field) const { + USAGE_CHECK_MESSAGE(FieldSize, &message); USAGE_CHECK_MESSAGE_TYPE(FieldSize); USAGE_CHECK_REPEATED(FieldSize); @@ -1318,6 +1354,7 @@ int Reflection::FieldSize(const Message& message, void Reflection::ClearField(Message* message, const FieldDescriptor* field) const { + USAGE_CHECK_MESSAGE(ClearField, message); USAGE_CHECK_MESSAGE_TYPE(ClearField); if (field->is_extension()) { @@ -1435,6 +1472,7 @@ void Reflection::ClearField(Message* message, void Reflection::RemoveLast(Message* message, const FieldDescriptor* field) const { + USAGE_CHECK_MESSAGE(RemoveLast, message); USAGE_CHECK_MESSAGE_TYPE(RemoveLast); USAGE_CHECK_REPEATED(RemoveLast); @@ -1483,7 +1521,7 @@ void Reflection::RemoveLast(Message* message, Message* Reflection::ReleaseLast(Message* message, const FieldDescriptor* field) const { - USAGE_CHECK_ALL(ReleaseLast, REPEATED, MESSAGE); + USAGE_MUTABLE_CHECK_ALL(ReleaseLast, REPEATED, MESSAGE); Message* released; if (field->is_extension()) { @@ -1508,7 +1546,7 @@ Message* Reflection::ReleaseLast(Message* message, Message* Reflection::UnsafeArenaReleaseLast( Message* message, const FieldDescriptor* field) const { - USAGE_CHECK_ALL(UnsafeArenaReleaseLast, REPEATED, MESSAGE); + USAGE_MUTABLE_CHECK_ALL(UnsafeArenaReleaseLast, REPEATED, MESSAGE); if (field->is_extension()) { return static_cast( @@ -1527,6 +1565,7 @@ Message* Reflection::UnsafeArenaReleaseLast( void Reflection::SwapElements(Message* message, const FieldDescriptor* field, int index1, int index2) const { + USAGE_CHECK_MESSAGE(Swap, message); USAGE_CHECK_MESSAGE_TYPE(Swap); USAGE_CHECK_REPEATED(Swap); @@ -1692,7 +1731,7 @@ void Reflection::ListFields(const Message& message, \ void Reflection::Set##TYPENAME( \ Message* message, const FieldDescriptor* field, PASSTYPE value) const { \ - USAGE_CHECK_ALL(Set##TYPENAME, SINGULAR, CPPTYPE); \ + USAGE_MUTABLE_CHECK_ALL(Set##TYPENAME, SINGULAR, CPPTYPE); \ if (field->is_extension()) { \ return MutableExtensionSet(message)->Set##TYPENAME( \ field->number(), field->type(), value, field); \ @@ -1715,7 +1754,7 @@ void Reflection::ListFields(const Message& message, void Reflection::SetRepeated##TYPENAME(Message* message, \ const FieldDescriptor* field, \ int index, PASSTYPE value) const { \ - USAGE_CHECK_ALL(SetRepeated##TYPENAME, REPEATED, CPPTYPE); \ + USAGE_MUTABLE_CHECK_ALL(SetRepeated##TYPENAME, REPEATED, CPPTYPE); \ if (field->is_extension()) { \ MutableExtensionSet(message)->SetRepeated##TYPENAME(field->number(), \ index, value); \ @@ -1726,7 +1765,7 @@ void Reflection::ListFields(const Message& message, \ void Reflection::Add##TYPENAME( \ Message* message, const FieldDescriptor* field, PASSTYPE value) const { \ - USAGE_CHECK_ALL(Add##TYPENAME, REPEATED, CPPTYPE); \ + USAGE_MUTABLE_CHECK_ALL(Add##TYPENAME, REPEATED, CPPTYPE); \ if (field->is_extension()) { \ MutableExtensionSet(message)->Add##TYPENAME( \ field->number(), field->type(), field->options().packed(), value, \ @@ -1846,7 +1885,7 @@ absl::Cord Reflection::GetCord(const Message& message, void Reflection::SetString(Message* message, const FieldDescriptor* field, std::string value) const { - USAGE_CHECK_ALL(SetString, SINGULAR, STRING); + USAGE_MUTABLE_CHECK_ALL(SetString, SINGULAR, STRING); if (field->is_extension()) { return MutableExtensionSet(message)->SetString( field->number(), field->type(), std::move(value), field); @@ -1897,7 +1936,7 @@ void Reflection::SetString(Message* message, const FieldDescriptor* field, void Reflection::SetString(Message* message, const FieldDescriptor* field, const absl::Cord& value) const { - USAGE_CHECK_ALL(SetString, SINGULAR, STRING); + USAGE_MUTABLE_CHECK_ALL(SetString, SINGULAR, STRING); if (field->is_extension()) { return absl::CopyCordToString(value, MutableExtensionSet(message)->MutableString( @@ -1981,7 +2020,7 @@ const std::string& Reflection::GetRepeatedStringReference( void Reflection::SetRepeatedString(Message* message, const FieldDescriptor* field, int index, std::string value) const { - USAGE_CHECK_ALL(SetRepeatedString, REPEATED, STRING); + USAGE_MUTABLE_CHECK_ALL(SetRepeatedString, REPEATED, STRING); if (field->is_extension()) { MutableExtensionSet(message)->SetRepeatedString(field->number(), index, std::move(value)); @@ -1999,7 +2038,7 @@ void Reflection::SetRepeatedString(Message* message, void Reflection::AddString(Message* message, const FieldDescriptor* field, std::string value) const { - USAGE_CHECK_ALL(AddString, REPEATED, STRING); + USAGE_MUTABLE_CHECK_ALL(AddString, REPEATED, STRING); if (field->is_extension()) { MutableExtensionSet(message)->AddString(field->number(), field->type(), std::move(value), field); @@ -2048,7 +2087,7 @@ void Reflection::SetEnum(Message* message, const FieldDescriptor* field, void Reflection::SetEnumValue(Message* message, const FieldDescriptor* field, int value) const { - USAGE_CHECK_ALL(SetEnumValue, SINGULAR, ENUM); + USAGE_MUTABLE_CHECK_ALL(SetEnumValue, SINGULAR, ENUM); if (!CreateUnknownEnumValues(field)) { // Check that the value is valid if we don't support direct storage of // unknown enum values. @@ -2105,7 +2144,7 @@ void Reflection::SetRepeatedEnum(Message* message, const FieldDescriptor* field, void Reflection::SetRepeatedEnumValue(Message* message, const FieldDescriptor* field, int index, int value) const { - USAGE_CHECK_ALL(SetRepeatedEnum, REPEATED, ENUM); + USAGE_MUTABLE_CHECK_ALL(SetRepeatedEnum, REPEATED, ENUM); if (!CreateUnknownEnumValues(field)) { // Check that the value is valid if we don't support direct storage of // unknown enum values. @@ -2139,7 +2178,7 @@ void Reflection::AddEnum(Message* message, const FieldDescriptor* field, void Reflection::AddEnumValue(Message* message, const FieldDescriptor* field, int value) const { - USAGE_CHECK_ALL(AddEnum, REPEATED, ENUM); + USAGE_MUTABLE_CHECK_ALL(AddEnum, REPEATED, ENUM); if (!CreateUnknownEnumValues(field)) { // Check that the value is valid if we don't support direct storage of // unknown enum values. @@ -2226,7 +2265,7 @@ const Message& Reflection::GetMessage(const Message& message, Message* Reflection::MutableMessage(Message* message, const FieldDescriptor* field, MessageFactory* factory) const { - USAGE_CHECK_ALL(MutableMessage, SINGULAR, MESSAGE); + USAGE_MUTABLE_CHECK_ALL(MutableMessage, SINGULAR, MESSAGE); if (factory == nullptr) factory = message_factory_; @@ -2261,7 +2300,7 @@ Message* Reflection::MutableMessage(Message* message, void Reflection::UnsafeArenaSetAllocatedMessage( Message* message, Message* sub_message, const FieldDescriptor* field) const { - USAGE_CHECK_ALL(SetAllocatedMessage, SINGULAR, MESSAGE); + USAGE_MUTABLE_CHECK_ALL(SetAllocatedMessage, SINGULAR, MESSAGE); if (field->is_extension()) { @@ -2326,7 +2365,7 @@ void Reflection::SetAllocatedMessage(Message* message, Message* sub_message, Message* Reflection::UnsafeArenaReleaseMessage(Message* message, const FieldDescriptor* field, MessageFactory* factory) const { - USAGE_CHECK_ALL(ReleaseMessage, SINGULAR, MESSAGE); + USAGE_MUTABLE_CHECK_ALL(ReleaseMessage, SINGULAR, MESSAGE); if (factory == nullptr) factory = message_factory_; @@ -2390,7 +2429,7 @@ const Message& Reflection::GetRepeatedMessage(const Message& message, Message* Reflection::MutableRepeatedMessage(Message* message, const FieldDescriptor* field, int index) const { - USAGE_CHECK_ALL(MutableRepeatedMessage, REPEATED, MESSAGE); + USAGE_MUTABLE_CHECK_ALL(MutableRepeatedMessage, REPEATED, MESSAGE); if (field->is_extension()) { return static_cast( @@ -2410,7 +2449,7 @@ Message* Reflection::MutableRepeatedMessage(Message* message, Message* Reflection::AddMessage(Message* message, const FieldDescriptor* field, MessageFactory* factory) const { - USAGE_CHECK_ALL(AddMessage, REPEATED, MESSAGE); + USAGE_MUTABLE_CHECK_ALL(AddMessage, REPEATED, MESSAGE); if (factory == nullptr) factory = message_factory_; @@ -2452,7 +2491,7 @@ Message* Reflection::AddMessage(Message* message, const FieldDescriptor* field, void Reflection::AddAllocatedMessage(Message* message, const FieldDescriptor* field, Message* new_entry) const { - USAGE_CHECK_ALL(AddAllocatedMessage, REPEATED, MESSAGE); + USAGE_MUTABLE_CHECK_ALL(AddAllocatedMessage, REPEATED, MESSAGE); if (field->is_extension()) { MutableExtensionSet(message)->AddAllocatedMessage(field, new_entry); @@ -2471,7 +2510,7 @@ void Reflection::AddAllocatedMessage(Message* message, void Reflection::UnsafeArenaAddAllocatedMessage(Message* message, const FieldDescriptor* field, Message* new_entry) const { - USAGE_CHECK_ALL(UnsafeArenaAddAllocatedMessage, REPEATED, MESSAGE); + USAGE_MUTABLE_CHECK_ALL(UnsafeArenaAddAllocatedMessage, REPEATED, MESSAGE); if (field->is_extension()) { MutableExtensionSet(message)->UnsafeArenaAddAllocatedMessage(field, diff --git a/src/google/protobuf/generated_message_reflection_unittest.cc b/src/google/protobuf/generated_message_reflection_unittest.cc index 4c8e0ab7098f..29e584355493 100644 --- a/src/google/protobuf/generated_message_reflection_unittest.cc +++ b/src/google/protobuf/generated_message_reflection_unittest.cc @@ -1299,6 +1299,7 @@ TEST(GeneratedMessageReflectionTest, ArenaReleaseOneofMessageTest) { TEST(GeneratedMessageReflectionTest, UsageErrors) { unittest::TestAllTypes message; + unittest::ForeignMessage foreign; const Reflection* reflection = message.GetReflection(); const Descriptor* descriptor = message.GetDescriptor(); @@ -1322,6 +1323,15 @@ TEST(GeneratedMessageReflectionTest, UsageErrors) { " Field : protobuf_unittest.TestAllTypes.repeated_int32\n" " Problem : Field is repeated; the method requires a " "singular field."); + EXPECT_DEBUG_DEATH( + reflection->GetInt32(foreign, + descriptor->FindFieldByName("optional_int32")), + "Protocol Buffer reflection usage error:\n" + " Method : google::protobuf::Reflection::GetInt32\n" + " Expected type: protobuf_unittest.TestAllTypes\n" + " Actual type : protobuf_unittest.ForeignMessage\n" + " Field : protobuf_unittest.TestAllTypes.optional_int32\n" + " Problem : Message is not the right type for reflection"); EXPECT_DEATH( reflection->GetInt32( message,