Skip to content

Commit

Permalink
Remove RegisterExtension in message class
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 539568063
  • Loading branch information
anandolee authored and copybara-github committed Jun 12, 2023
1 parent 6fe5c6f commit 3560e23
Show file tree
Hide file tree
Showing 9 changed files with 2,300 additions and 1,815 deletions.
466 changes: 267 additions & 199 deletions php/ext/google/protobuf/php-upb.c

Large diffs are not rendered by default.

1,563 changes: 866 additions & 697 deletions php/ext/google/protobuf/php-upb.h

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions protobuf_deps.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def protobuf_deps():
_github_archive(
name = "upb",
repo = "https://github.com/protocolbuffers/upb",
commit = "7f0092a8021466009e65367aed68f2a1867da880",
sha256 = "3c1e1e58f96b97dde14c0e911cfb6378b14027b17314e946f6a6ce3fcf5b2088",
commit = "56a770818cf47f8ac9e2ac1585a8d2b764214479",
sha256 = "",
patches = ["@com_google_protobuf//build_defs:upb.patch"],
)
26 changes: 16 additions & 10 deletions python/google/protobuf/internal/message_factory_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,16 +149,14 @@ def testGetMessages(self):
self.assertEqual(None,
msg1.Extensions._FindExtensionByNumber(12321))
self.assertEqual(2, len(msg1.Extensions))
if api_implementation.Type() == 'cpp':
self.assertRaises(TypeError,
msg1.Extensions._FindExtensionByName, 0)
self.assertRaises(TypeError,
msg1.Extensions._FindExtensionByNumber, '')
else:
if api_implementation.Type() == 'python':
self.assertEqual(None,
msg1.Extensions._FindExtensionByName(0))
self.assertEqual(None,
msg1.Extensions._FindExtensionByNumber(''))
else:
self.assertRaises(TypeError, msg1.Extensions._FindExtensionByName, 0)
self.assertRaises(TypeError, msg1.Extensions._FindExtensionByNumber, '')

def testDuplicateExtensionNumber(self):
pool = descriptor_pool.DescriptorPool()
Expand All @@ -181,9 +179,11 @@ def testDuplicateExtensionNumber(self):
msg.extension.add(
name='extension_field',
number=2,
type=descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE,
label=descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL,
type_name='Extension',
extendee='Container')
extendee='Container',
)
pool.Add(f)
msgs = message_factory.GetMessageClassesForFiles([f.name], pool)
self.assertIn('google.protobuf.python.internal.Extension', msgs)
Expand All @@ -197,9 +197,11 @@ def testDuplicateExtensionNumber(self):
msg.extension.add(
name='extension_field',
number=2,
type=descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE,
label=descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL,
type_name='Duplicate',
extendee='Container')
extendee='Container',
)
pool.Add(f)

with self.assertRaises(Exception) as cm:
Expand Down Expand Up @@ -240,15 +242,19 @@ def testExtensionValueInDifferentFile(self):
f3.extension.add(
name='top_level_extension_field',
number=2,
type=descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE,
label=descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL,
type_name='ValueType',
extendee='Container')
extendee='Container',
)
f3.message_type.add(name='Extension').extension.add(
name='nested_extension_field',
number=3,
type=descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE,
label=descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL,
type_name='ValueType',
extendee='Container')
extendee='Container',
)

class SimpleDescriptorDB:

Expand Down
5 changes: 1 addition & 4 deletions python/google/protobuf/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,6 @@ class Message(object):
# have an Extensions attribute with __getitem__ and __setitem__.
# Again, not sure how to best convey this.

# TODO(robinson): Document that the class must also have a static
# RegisterExtension(extension_field) method.
# Not sure how to best express at this point.

# TODO(robinson): Document these fields and methods.

__slots__ = []
Expand Down Expand Up @@ -367,6 +363,7 @@ def ByteSize(self):
def FromString(cls, s):
raise NotImplementedError

# TODO(b/286557203): Remove it in OSS
@staticmethod
def RegisterExtension(field_descriptor):
raise NotImplementedError
Expand Down
17 changes: 15 additions & 2 deletions python/google/protobuf/message_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,13 @@ def GetMessageClassesForFiles(files, pool):

for extension in file_desc.extensions_by_name.values():
extended_class = GetMessageClass(extension.containing_type)
extended_class.RegisterExtension(extension)
if api_implementation.Type() != 'python':
# TODO(b/286443080): Remove this check here. Duplicate extension
# register check should be in descriptor_pool.
if extension is not pool.FindExtensionByNumber(
extension.containing_type, extension.number
):
raise ValueError('Double registration of Extensions')
# Recursively load protos for extension field, in order to be able to
# fully represent the extension. This matches the behavior for regular
# fields too.
Expand Down Expand Up @@ -136,7 +142,14 @@ def _InternalCreateMessageClass(descriptor):
GetMessageClass(field.message_type)
for extension in result_class.DESCRIPTOR.extensions:
extended_class = GetMessageClass(extension.containing_type)
extended_class.RegisterExtension(extension)
if api_implementation.Type() != 'python':
# TODO(b/286443080): Remove this check here. Duplicate extension
# register check should be in descriptor_pool.
pool = extension.containing_type.file.pool
if extension is not pool.FindExtensionByNumber(
extension.containing_type, extension.number
):
raise ValueError('Double registration of Extensions')
if extension.message_type:
GetMessageClass(extension.message_type)
return result_class
Expand Down
5 changes: 0 additions & 5 deletions python/google/protobuf/pyext/message_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -209,11 +209,6 @@ CMessageClass* GetOrCreateMessageClass(PyMessageFactory* self,
if (py_extension == nullptr) {
return nullptr;
}
ScopedPyObjectPtr result(cmessage::RegisterExtension(
py_extended_class.get(), py_extension.get()));
if (result == nullptr) {
return nullptr;
}
}
return reinterpret_cast<CMessageClass*>(message_class.release());
}
Expand Down
Loading

0 comments on commit 3560e23

Please sign in to comment.