Skip to content

Commit

Permalink
[feat] Make serialize coerce the input. (#7)
Browse files Browse the repository at this point in the history
This is important because it makes gRPC's `request_serializer` be more
liberal in what it accepts when given `serialize`.
  • Loading branch information
lukesneeringer authored Dec 26, 2018
1 parent b4acb4d commit c1e20e1
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 6 deletions.
18 changes: 12 additions & 6 deletions proto/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,19 +257,24 @@ def __prepare__(mcls, name, bases, **kwargs):
def meta(cls):
return cls._meta

def pb(cls, obj=None):
def pb(cls, obj=None, *, coerce: bool = False):
"""Return the underlying protobuf Message class or instance.
Args:
obj: If provided, and an instance of ``cls``, return the
underlying protobuf instance.
coerce (bool): If provided, will attempt to coerce ``obj`` to
``cls`` if it is not already an instance.
"""
if obj is None:
return cls.meta.pb
if not isinstance(obj, cls):
raise TypeError('%r is not an instance of %s' % (
obj, cls.__name__,
))
if coerce:
obj = cls(obj)
else:
raise TypeError('%r is not an instance of %s' % (
obj, cls.__name__,
))
return obj._pb

def wrap(cls, pb):
Expand All @@ -285,12 +290,13 @@ def serialize(cls, instance) -> bytes:
"""Return the serialized proto.
Args:
instance: An instance of this message type.
instance: An instance of this message type, or something
compatible (accepted by the type's constructor).
Returns:
bytes: The serialized representation of the protocol buffer.
"""
return cls.pb(instance).SerializeToString()
return cls.pb(instance, coerce=True).SerializeToString()

def deserialize(cls, payload: bytes) -> 'Message':
"""Given a serialized proto, deserialize it into a Message instance.
Expand Down
10 changes: 10 additions & 0 deletions tests/test_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,16 @@ class Foo(proto.Message):
assert Foo.serialize(foo) == Foo.pb(foo).SerializeToString()


def test_message_dict_serialize():
class Foo(proto.Message):
bar = proto.Field(proto.INT32, number=1)
baz = proto.Field(proto.STRING, number=2)
bacon = proto.Field(proto.BOOL, number=3)

foo = {'bar': 42, 'bacon': True}
assert Foo.serialize(foo) == Foo.pb(foo, coerce=True).SerializeToString()


def test_message_deserialize():
class OldFoo(proto.Message):
bar = proto.Field(proto.INT32, number=1)
Expand Down

0 comments on commit c1e20e1

Please sign in to comment.