Skip to content

Commit

Permalink
add forward compatibility for byte datagram keys
Browse files Browse the repository at this point in the history
  • Loading branch information
jackrobison committed Sep 28, 2020
1 parent d0f21c0 commit 3a64ceb
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 7 deletions.
24 changes: 17 additions & 7 deletions lbry/dht/serialization/datagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,27 +144,37 @@ def __init__(self, packet_type: int, rpc_id: bytes, node_id: bytes, exception_ty
self.response = response.decode()


def decode_datagram(datagram: bytes) -> typing.Union[RequestDatagram, ResponseDatagram, ErrorDatagram]:
def _decode_datagram(datagram: bytes):
msg_types = {
REQUEST_TYPE: RequestDatagram,
RESPONSE_TYPE: ResponseDatagram,
ERROR_TYPE: ErrorDatagram
}

primitive: typing.Dict = bdecode(datagram)
if primitive[0] in [REQUEST_TYPE, ERROR_TYPE, RESPONSE_TYPE]: # pylint: disable=unsubscriptable-object
datagram_type = primitive[0] # pylint: disable=unsubscriptable-object

converted = {
str(k).encode() if not isinstance(k, bytes) else k: v for k, v in primitive.items()
}

if converted[b'0'] in [REQUEST_TYPE, ERROR_TYPE, RESPONSE_TYPE]: # pylint: disable=unsubscriptable-object
datagram_type = converted[b'0'] # pylint: disable=unsubscriptable-object
else:
raise ValueError("invalid datagram type")
datagram_class = msg_types[datagram_type]
decoded = {
k: primitive[i] # pylint: disable=unsubscriptable-object
k: converted[str(i).encode()] # pylint: disable=unsubscriptable-object
for i, k in enumerate(datagram_class.required_fields)
if i in primitive # pylint: disable=unsupported-membership-test
if str(i).encode() in converted # pylint: disable=unsupported-membership-test
}
for i, _ in enumerate(OPTIONAL_FIELDS):
if i + OPTIONAL_ARG_OFFSET in primitive:
decoded[i + OPTIONAL_ARG_OFFSET] = primitive[i + OPTIONAL_ARG_OFFSET]
if str(i + OPTIONAL_ARG_OFFSET).encode() in converted:
decoded[i + OPTIONAL_ARG_OFFSET] = converted[str(i + OPTIONAL_ARG_OFFSET).encode()]
return decoded, datagram_class


def decode_datagram(datagram: bytes) -> typing.Union[RequestDatagram, ResponseDatagram, ErrorDatagram]:
decoded, datagram_class = _decode_datagram(datagram)
return datagram_class(**decoded)


Expand Down
34 changes: 34 additions & 0 deletions tests/unit/dht/serialization/test_datagram.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import binascii
import unittest
from lbry.dht.error import DecodeError
from lbry.dht.serialization.bencoding import _bencode
from lbry.dht.serialization.datagram import RequestDatagram, ResponseDatagram, decode_datagram, ErrorDatagram
from lbry.dht.serialization.datagram import _decode_datagram
from lbry.dht.serialization.datagram import REQUEST_TYPE, RESPONSE_TYPE, ERROR_TYPE
from lbry.dht.serialization.datagram import make_compact_address, decode_compact_address

Expand Down Expand Up @@ -139,6 +141,38 @@ def test_optional_field_backwards_compatible(self):
self.assertEqual(datagram.packet_type, REQUEST_TYPE)
self.assertEqual(b'ping', datagram.method)

def test_str_or_int_keys(self):
datagram = decode_datagram(_bencode({
b'0': 0,
b'1': b'\n\xbc\xb5&\x9dl\xfc\x1e\x87\xa0\x8e\x92\x0b\xf3\x9f\xe9\xdf\x8e\x92\xfc',
b'2': b'111111111111111111111111111111111111111111111111',
b'3': b'ping',
b'4': [{b'protocolVersion': 1}],
b'5': b'should not error'
}))
self.assertEqual(datagram.packet_type, REQUEST_TYPE)
self.assertEqual(b'ping', datagram.method)

def test_mixed_str_or_int_keys(self):
# datagram, _ = _bencode({
# b'0': 0,
# 1: b'\n\xbc\xb5&\x9dl\xfc\x1e\x87\xa0\x8e\x92\x0b\xf3\x9f\xe9\xdf\x8e\x92\xfc',
# b'2': b'111111111111111111111111111111111111111111111111',
# 3: b'ping',
# b'4': [{b'protocolVersion': 1}],
# b'5': b'should not error'
# }))
encoded = binascii.unhexlify(b"64313a3069306569316532303a0abcb5269d6cfc1e87a08e920bf39fe9df8e92fc313a3234383a313131313131313131313131313131313131313131313131313131313131313131313131313131313131313131313131693365343a70696e67313a346c6431353a70726f746f636f6c56657273696f6e6931656565313a3531363a73686f756c64206e6f74206572726f7265")
self.assertDictEqual(
{
'packet_type': 0,
'rpc_id': b'\n\xbc\xb5&\x9dl\xfc\x1e\x87\xa0\x8e\x92\x0b\xf3\x9f\xe9\xdf\x8e\x92\xfc',
'node_id': b'111111111111111111111111111111111111111111111111',
'method': b'ping',
'args': [{b'protocolVersion': 1}]
}, _decode_datagram(encoded)[0]
)


class TestCompactAddress(unittest.TestCase):
def test_encode_decode(self, address='1.2.3.4', port=4444, node_id=b'1' * 48):
Expand Down

0 comments on commit 3a64ceb

Please sign in to comment.