Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: json serialization and deserialization support stringy enums #112

Merged
merged 10 commits into from
Sep 2, 2020
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def unit(session, proto="python"):
"--cov-config=.coveragerc",
"--cov-report=term",
"--cov-report=html",
os.path.join("tests", ""),
*(session.posargs or [os.path.join("tests", "")]),
)


Expand Down
27 changes: 25 additions & 2 deletions proto/_file_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import collections.abc
import collections
import inspect
import logging

from google.protobuf import descriptor_pb2
from google.protobuf import descriptor_pool
from google.protobuf import message
from google.protobuf import reflection
Expand All @@ -27,11 +28,33 @@

class _FileInfo(
collections.namedtuple(
"_FileInfo", ["descriptor", "messages", "enums", "name", "nested"]
"_FileInfo",
["descriptor", "messages", "enums", "name", "nested", "nested_enum"],
)
):
registry = {} # Mapping[str, '_FileInfo']

@classmethod
def maybe_add_descriptor(cls, filename, package):
descriptor = cls.registry.get(filename)
if not descriptor:
descriptor = cls.registry[filename] = cls(
descriptor=descriptor_pb2.FileDescriptorProto(
name=filename, package=package, syntax="proto3",
),
enums=collections.OrderedDict(),
messages=collections.OrderedDict(),
name=filename,
nested={},
nested_enum={},
)

return descriptor

@staticmethod
def proto_file_name(name):
return "{0}.proto".format(name).replace(".", "/")

def _get_manifest(self, new_class):
module = inspect.getmodule(new_class)
if hasattr(module, "__protobuf__"):
Expand Down
51 changes: 51 additions & 0 deletions proto/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

import enum

from google.protobuf import descriptor_pb2

from proto import _file_info
from proto import _package_info
from proto.marshal.rules.enums import EnumRule

Expand All @@ -30,9 +33,51 @@ def __new__(mcls, name, bases, attrs):
# this component belongs within the file.
package, marshal = _package_info.compile(name, attrs)

# Determine the local path of this proto component within the file.
local_path = tuple(attrs.get("__qualname__", name).split("."))

# Sanity check: We get the wrong full name if a class is declared
# inside a function local scope; correct this.
if "<locals>" in local_path:
ix = local_path.index("<locals>")
local_path = local_path[: ix - 1] + local_path[ix + 1 :]

# Determine the full name in protocol buffers.
full_name = ".".join((package,) + local_path).lstrip(".")
enum_desc = descriptor_pb2.EnumDescriptorProto(
name=name,
# Note: the superclass ctor removes the variants, so get them now.
# Note: proto3 requires that the first variant value be zero.
value=sorted(
(
descriptor_pb2.EnumValueDescriptorProto(name=name, number=number)
# Minor hack to get all the enum variants out.
for name, number in attrs.items()
if isinstance(number, int)
),
key=lambda v: v.number,
),
)

filename = _file_info._FileInfo.proto_file_name(
attrs.get("__module__", name.lower())
)

file_info = _file_info._FileInfo.maybe_add_descriptor(filename, package)
if len(local_path) == 1:
file_info.descriptor.enum_type.add().MergeFrom(enum_desc)
else:
file_info.nested_enum[local_path] = enum_desc

# Run the superclass constructor.
cls = super().__new__(mcls, name, bases, attrs)

# We can't just add a "_meta" element to attrs because the Enum
# machinery doesn't know what to do with a non-int value.
cls._meta = _EnumInfo(full_name=full_name, pb=enum_desc)

file_info.enums[full_name] = cls

# Register the enum with the marshal.
marshal.register(cls, EnumRule(cls))

Expand All @@ -44,3 +89,9 @@ class Enum(enum.IntEnum, metaclass=ProtoEnumMeta):
"""A enum object that also builds a protobuf enum descriptor."""

pass


class _EnumInfo:
def __init__(self, *, full_name: str, pb):
self.full_name = full_name
self.pb = pb
26 changes: 12 additions & 14 deletions proto/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,22 +85,20 @@ def descriptor(self):
type_name = (
self.message.DESCRIPTOR.full_name
if hasattr(self.message, "DESCRIPTOR")
else self.message.meta.full_name
else self.message._meta.full_name
)
elif isinstance(self.enum, str):
if not self.enum.startswith(self.package):
self.enum = "{package}.{name}".format(
package=self.package, name=self.enum,
)
type_name = self.enum
elif self.enum:
# Nos decipiat.
#
# As far as the wire format is concerned, enums are int32s.
# Protocol buffers itself also only sends ints; the enum
# objects are simply helper classes for translating names
# and values and it is the user's job to resolve to an int.
#
# Therefore, the non-trivial effort of adding the actual
# enum descriptors seems to add little or no actual value.
#
# FIXME: Eventually, come back and put in the actual enum
# descriptors.
proto_type = ProtoType.INT32
type_name = (
self.enum.DESCRIPTOR.full_name
if hasattr(self.enum, "DESCRIPTOR")
else self.enum._meta.full_name
)

# Set the descriptor.
self._descriptor = descriptor_pb2.FieldDescriptorProto(
Expand Down
33 changes: 17 additions & 16 deletions proto/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,24 +183,13 @@ def __new__(mcls, name, bases, attrs):
# Determine the filename.
# We determine an appropriate proto filename based on the
# Python module.
filename = "{0}.proto".format(
new_attrs.get("__module__", name.lower()).replace(".", "/")
filename = _file_info._FileInfo.proto_file_name(
new_attrs.get("__module__", name.lower())
)

# Get or create the information about the file, including the
# descriptor to which the new message descriptor shall be added.
file_info = _file_info._FileInfo.registry.setdefault(
filename,
_file_info._FileInfo(
descriptor=descriptor_pb2.FileDescriptorProto(
name=filename, package=package, syntax="proto3",
),
enums=collections.OrderedDict(),
messages=collections.OrderedDict(),
name=filename,
nested={},
),
)
file_info = _file_info._FileInfo.maybe_add_descriptor(filename, package)

# Ensure any imports that would be necessary are assigned to the file
# descriptor proto being created.
Expand All @@ -227,6 +216,11 @@ def __new__(mcls, name, bases, attrs):
for child_path in child_paths:
desc.nested_type.add().MergeFrom(file_info.nested.pop(child_path))

# Same thing, but for enums
child_paths = [p for p in file_info.nested_enum.keys() if local_path == p[:-1]]
for child_path in child_paths:
desc.enum_type.add().MergeFrom(file_info.nested_enum.pop(child_path))

# Add the descriptor to the file if it is a top-level descriptor,
# or to a "holding area" for nested messages otherwise.
if len(local_path) == 1:
Expand Down Expand Up @@ -325,17 +319,24 @@ def deserialize(cls, payload: bytes) -> "Message":
"""
return cls.wrap(cls.pb().FromString(payload))

def to_json(cls, instance) -> str:
def to_json(cls, instance, *, use_integers_for_enums=True) -> str:
"""Given a message instance, serialize it to json

Args:
instance: An instance of this message type, or something
compatible (accepted by the type's constructor).
use_integers_for_enums (Optional(bool)): An option that determines whether enum
values should be represented by strings (False) or integers (True).
Default is True.

Returns:
str: The json string representation of the protocol buffer.
"""
return MessageToJson(cls.pb(instance))
return MessageToJson(
cls.pb(instance),
use_integers_for_enums=use_integers_for_enums,
including_default_value_fields=True,
)

def from_json(cls, payload) -> "Message":
"""Given a json string representing an instance,
Expand Down
73 changes: 73 additions & 0 deletions tests/test_fields_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import proto
import sys


def test_outer_enum_init():
Expand Down Expand Up @@ -252,3 +253,75 @@ class Foo(proto.Message):
assert "color" not in foo
assert not foo.color
assert Foo.pb(foo).color == 0


class Zone(proto.Enum):
EPIPELAGIC = 0
MESOPELAGIC = 1
ABYSSOPELAGIC = 2
HADOPELAGIC = 3


def test_enum_outest():
z = Zone(value=Zone.MESOPELAGIC)

assert z == Zone.MESOPELAGIC


def test_nested_enum_from_string():
class Zone(proto.Enum):
EPIPELAGIC = 0
MESOPELAGIC = 1
BATHYPELAGIC = 2
ABYSSOPELAGIC = 3

class Squid(proto.Message):
zone = proto.Field(Zone, number=1)

class Trawl(proto.Message):
# Note: this indirection with the nested field
# is necessary to trigger the exception for testing.
# Setting the field in an existing message accepts strings AND
# checks for valid variants.
# Similarly, constructing a message directly with a top level
# enum field kwarg passed as a string is also handled correctly, i.e.
# s = Squid(zone="ABYSSOPELAGIC")
# does NOT raise an exception.
squids = proto.RepeatedField(Squid, number=1)

t = Trawl(squids=[{"zone": "MESOPELAGIC"}])
assert t.squids[0] == Squid(zone=Zone.MESOPELAGIC)


def test_enum_field_by_string():
class Zone(proto.Enum):
EPIPELAGIC = 0
MESOPELAGIC = 1
BATHYPELAGIC = 2
ABYSSOPELAGIC = 3

class Squid(proto.Message):
zone = proto.Field(proto.ENUM, number=1, enum="Zone")

s = Squid(zone=Zone.BATHYPELAGIC)
assert s.zone == Zone.BATHYPELAGIC


def test_enum_field_by_string_with_package():
sys.modules[__name__].__protobuf__ = proto.module(package="mollusca.cephalopoda")
try:

class Zone(proto.Enum):
EPIPELAGIC = 0
MESOPELAGIC = 1
BATHYPELAGIC = 2
ABYSSOPELAGIC = 3

class Squid(proto.Message):
zone = proto.Field(proto.ENUM, number=1, enum="Zone")

finally:
del sys.modules[__name__].__protobuf__

s = Squid(zone="ABYSSOPELAGIC")
assert s.zone == Zone.ABYSSOPELAGIC
1 change: 1 addition & 0 deletions tests/test_file_info_salting.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def sample_file_info(name):
messages=collections.OrderedDict(),
name=filename,
nested={},
nested_enum={},
),
)

Expand Down
1 change: 1 addition & 0 deletions tests/test_file_info_salting_with_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def sample_file_info(name):
messages=collections.OrderedDict(),
name=filename,
nested={},
nested_enum={},
),
)

Expand Down