diff --git a/tests/protocol/helpers.py b/tests/protocol/helpers.py index 4a69c48..bad5e90 100644 --- a/tests/protocol/helpers.py +++ b/tests/protocol/helpers.py @@ -3,7 +3,7 @@ from typing import Optional from unittest.mock import Mock -from zerocom.protocol.abc import BaseReader, BaseWriter +from zerocom.protocol.base_io import BaseReader, BaseWriter class Reader(BaseReader): diff --git a/tests/protocol/test_abc.py b/tests/protocol/test_abc.py deleted file mode 100644 index ac6a060..0000000 --- a/tests/protocol/test_abc.py +++ /dev/null @@ -1,227 +0,0 @@ -from __future__ import annotations - -import pytest - -from tests.protocol.helpers import ReadFunctionMock, Reader, WriteFunctionMock, Writer - - -class TestReader: - @classmethod - def setup_class(cls): - """Initialize writer instance to be tested.""" - cls.reader = Reader() - - @pytest.fixture - def read_mock(self, monkeypatch: pytest.MonkeyPatch): - """Monkeypatch the read function with a mock which is returned.""" - mock_f = ReadFunctionMock() - monkeypatch.setattr(self.reader.__class__, "read", mock_f) - yield mock_f - - # Run this assertion after the test, to ensure that all specified data - # to be read, actually was read - mock_f.assert_read_everything() - - @pytest.mark.parametrize( - "read_bytes,expected_value", - ( - ([10], 10), - ([255], 255), - ([0], 0), - ), - ) - def test_read_ubyte(self, read_bytes: list[int], expected_value: int, read_mock: ReadFunctionMock): - """Reading byte int should return an integer in a single unsigned byte.""" - read_mock.combined_data = bytearray(read_bytes) - assert self.reader.read_ubyte() == expected_value - - @pytest.mark.parametrize( - "read_bytes,expected_value", - ( - ([236], -20), - ([128], -128), - ([20], 20), - ([127], 127), - ), - ) - def test_read_byte(self, read_bytes: list[int], expected_value: int, read_mock: ReadFunctionMock): - """Negative number bytes should be read from two's complement format.""" - read_mock.combined_data = bytearray(read_bytes) - assert self.reader.read_byte() == expected_value - - @pytest.mark.parametrize( - "read_bytes,expected_value", - ( - ([0], 0), - ([1], 1), - ([2], 2), - ([15], 15), - ([127], 127), - ([128, 1], 128), - ([129, 1], 129), - ([255, 1], 255), - ([192, 132, 61], 1000000), - ([255, 255, 255, 255, 7], 2147483647), - ), - ) - def test_read_varint(self, read_bytes: list[int], expected_value: int, read_mock: ReadFunctionMock): - """Reading varint bytes results in correct values.""" - read_mock.combined_data = bytearray(read_bytes) - assert self.reader.read_varint() == expected_value - - @pytest.mark.parametrize( - "read_bytes,expected_value", - ( - ([0], 0), - ([154, 1], 154), - ([255, 255, 3], 2**16 - 1), - ), - ) - def test_read_varint_max_size(self, read_bytes: list[int], expected_value: int, read_mock: ReadFunctionMock): - """Varint reading should be limitable to n max bytes and work with values in range.""" - read_mock.combined_data = bytearray(read_bytes) - assert self.reader.read_varint(max_size=2) == expected_value - - def test_read_varnum_max_size_out_of_range(self, read_mock: ReadFunctionMock): - """Varint reading limited to n max bytes should raise an IOError for numbers out of this range.""" - read_mock.combined_data = bytearray([128, 128, 4]) - with pytest.raises(IOError): - self.reader.read_varint(max_size=2) - - @pytest.mark.parametrize( - "read_bytes,expected_string", - ( - ([len("test")] + list(map(ord, "test")), "test"), - ([len("a" * 100)] + list(map(ord, "a" * 100)), "a" * 100), - ([0], ""), - ), - ) - def test_read_utf(self, read_bytes: list[int], expected_string: str, read_mock: ReadFunctionMock): - """Reading UTF string results in correct values.""" - read_mock.combined_data = bytearray(read_bytes) - assert self.reader.read_utf() == expected_string - - @pytest.mark.parametrize( - "read_bytes,expected_bytes", - ( - ([1, 1], [1]), - ([0], []), - ([5, 104, 101, 108, 108, 111], [104, 101, 108, 108, 111]), - ), - ) - def test_read_bytearray(self, read_bytes: list[int], expected_bytes: list[int], read_mock: ReadFunctionMock): - """Writing a bytearray results in correct bytes.""" - read_mock.combined_data = bytearray(read_bytes) - assert self.reader.read_bytearray() == bytearray(expected_bytes) - - -class TestWriter: - @classmethod - def setup_class(cls): - """Initialize writer instance to be tested.""" - cls.writer = Writer() - - @pytest.fixture - def write_mock(self, monkeypatch: pytest.MonkeyPatch): - """Monkeypatch the write function with a mock which is returned.""" - mock_f = WriteFunctionMock() - monkeypatch.setattr(self.writer.__class__, "write", mock_f) - return mock_f - - def test_write_byte(self, write_mock: WriteFunctionMock): - """Writing byte int should store an integer in a single byte.""" - self.writer.write_byte(15) - write_mock.assert_has_data(bytearray([15])) - - def test_write_byte_negative(self, write_mock: WriteFunctionMock): - """Negative number bytes should be stored in two's complement format.""" - self.writer.write_byte(-20) - write_mock.assert_has_data(bytearray([236])) - - def test_write_byte_out_of_range(self): - """Signed bytes should only allow writes from -128 to 127.""" - with pytest.raises(ValueError): - self.writer.write_byte(-129) - with pytest.raises(ValueError): - self.writer.write_byte(128) - - def test_write_ubyte(self, write_mock: WriteFunctionMock): - """Writing unsigned byte int should store an integer in a single byte.""" - self.writer.write_byte(80) - write_mock.assert_has_data(bytearray([80])) - - def test_write_ubyte_out_of_range(self): - """Unsigned bytes should only allow writes from 0 to 255.""" - with pytest.raises(ValueError): - self.writer.write_ubyte(256) - with pytest.raises(ValueError): - self.writer.write_ubyte(-1) - - @pytest.mark.parametrize( - "number,expected_bytes", - ( - (0, [0]), - (1, [1]), - (2, [2]), - (15, [15]), - (127, [127]), - (128, [128, 1]), - (129, [129, 1]), - (255, [255, 1]), - (1000000, [192, 132, 61]), - (2147483647, [255, 255, 255, 255, 7]), - ), - ) - def test_write_varint(self, number: int, expected_bytes: list[int], write_mock: WriteFunctionMock): - """Writing varints results in correct bytes.""" - self.writer.write_varint(number) - write_mock.assert_has_data(bytearray(expected_bytes)) - - def test_write_varint_out_of_range(self): - """Varint without max size should only work with positive integers.""" - with pytest.raises(ValueError): - self.writer.write_varint(-1) - - @pytest.mark.parametrize( - "number,expected_bytes", - ( - (0, [0]), - (154, [154, 1]), - (2**16 - 1, [255, 255, 3]), - ), - ) - def test_write_varint_max_size(self, number: int, expected_bytes: list[int], write_mock: WriteFunctionMock): - """Varints should be limitable to n max bytes and work with values in range.""" - self.writer.write_varint(number, max_size=2) - write_mock.assert_has_data(bytearray(expected_bytes)) - - def test_write_varint_max_size_out_of_range(self): - """Varints limited to n max bytes should raise ValueErrors for numbers out of this range.""" - with pytest.raises(ValueError): - self.writer.write_varint(2**16, max_size=2) - - @pytest.mark.parametrize( - "string,expected_bytes", - ( - ("test", [len("test")] + list(map(ord, "test"))), - ("a" * 100, [len("a" * 100)] + list(map(ord, "a" * 100))), - ("", [0]), - ), - ) - def test_write_utf(self, string: str, expected_bytes: list[int], write_mock: WriteFunctionMock): - """Writing UTF string results in correct bytes.""" - self.writer.write_utf(string) - write_mock.assert_has_data(bytearray(expected_bytes)) - - @pytest.mark.parametrize( - "input_bytes,expected_bytes", - ( - ([1], [1, 1]), - ([], [0]), - ([104, 101, 108, 108, 111], [5, 104, 101, 108, 108, 111]), - ), - ) - def test_write_bytearray(self, input_bytes: list[int], expected_bytes: list[int], write_mock: WriteFunctionMock): - """Writing a bytearray results in correct bytes.""" - self.writer.write_bytearray(bytearray(input_bytes)) - write_mock.assert_has_data(bytearray(expected_bytes)) diff --git a/tests/protocol/test_base_io.py b/tests/protocol/test_base_io.py new file mode 100644 index 0000000..6249bd7 --- /dev/null +++ b/tests/protocol/test_base_io.py @@ -0,0 +1,294 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +from tests.protocol.helpers import ReadFunctionMock, Reader, WriteFunctionMock, Writer +from zerocom.protocol.base_io import INT_FORMATS_TYPE, StructFormat +from zerocom.protocol.utils import to_twos_complement + + +class TestReader: + @classmethod + def setup_class(cls): + """Initialize writer instance to be tested.""" + cls.reader = Reader() + + @pytest.fixture + def read_mock(self, monkeypatch: pytest.MonkeyPatch): + """Monkeypatch the read function with a mock which is returned.""" + mock_f = ReadFunctionMock() + monkeypatch.setattr(self.reader.__class__, "read", mock_f) + yield mock_f + + # Run this assertion after the test, to ensure that all specified data + # to be read, actually was read + mock_f.assert_read_everything() + + @pytest.mark.parametrize( + "format,read_bytes,expected_value", + ( + (StructFormat.UBYTE, [0], 0), + (StructFormat.UBYTE, [10], 10), + (StructFormat.UBYTE, [255], 255), + (StructFormat.BYTE, [0], 0), + (StructFormat.BYTE, [20], 20), + (StructFormat.BYTE, [127], 127), + (StructFormat.BYTE, [to_twos_complement(-20, bits=8)], -20), + (StructFormat.BYTE, [to_twos_complement(-128, bits=8)], -128), + ), + ) + def test_read_value( + self, + format: INT_FORMATS_TYPE, + read_bytes: list[int], + expected_value: Any, + read_mock: ReadFunctionMock, + ): + """Reading given values of certain struct format should produce proper expected values.""" + read_mock.combined_data = bytearray(read_bytes) + assert self.reader.read_value(format) == expected_value + + @pytest.mark.parametrize( + "read_bytes,expected_value", + ( + ([0], 0), + ([1], 1), + ([2], 2), + ([15], 15), + ([127], 127), + ([128, 1], 128), + ([129, 1], 129), + ([255, 1], 255), + ([192, 132, 61], 1000000), + ([255, 255, 255, 255, 7], 2147483647), + ), + ) + def test_read_varuint(self, read_bytes: list[int], expected_value: int, read_mock: ReadFunctionMock): + """Reading varuint bytes results in correct values.""" + read_mock.combined_data = bytearray(read_bytes) + assert self.reader.read_varint(max_bits=32) == expected_value + + @pytest.mark.parametrize( + "read_bytes,max_bits", + ( + ([128, 128, 4], 16), + ([128, 128, 128, 128, 16], 32), + ), + ) + def test_read_varuint_out_of_range(self, read_bytes: list[int], max_bits: int, read_mock: ReadFunctionMock): + """Varuint reading limited to n max bits should raise an IOError for numbers out of this range.""" + read_mock.combined_data = bytearray(read_bytes) + with pytest.raises(IOError): + self.reader.read_varuint(max_bits=max_bits) + + @pytest.mark.parametrize( + "read_bytes,expected_value", + ( + ([0], 0), + ([1], 1), + ([128, 1], 128), + ([255, 1], 255), + ([255, 255, 255, 255, 7], 2147483647), + ([255, 255, 255, 255, 15], -1), + ([128, 254, 255, 255, 15], -256), + ), + ) + def test_read_varint(self, read_bytes: list[int], expected_value: int, read_mock: ReadFunctionMock): + """Reading varint bytes results in correct values.""" + read_mock.combined_data = bytearray(read_bytes) + assert self.reader.read_varint(max_bits=32) == expected_value + + @pytest.mark.parametrize( + "read_bytes,max_bits", + ( + ([128, 128, 4], 16), + ([255, 255, 255, 255, 23], 32), + ([128, 128, 192, 152, 214, 197, 215, 227, 235, 10], 64), + ([128, 128, 192, 231, 169, 186, 168, 156, 148, 245, 255, 255, 255, 255, 3], 64), + ), + ) + def test_read_varint_out_of_range(self, read_bytes: list[int], max_bits: int, read_mock: ReadFunctionMock): + """Reading varint outside of signed max_bits int range should raise ValueError on it's own.""" + read_mock.combined_data = bytearray(read_bytes) + with pytest.raises(IOError): + self.reader.read_varint(max_bits=max_bits) + + # The data bytearray was intentionally not fully read/depleted, however by default + # ending the function with data remaining in the read_mock would trigger an + # AssertionError, so we expllicitly clear it here to prevent that error + read_mock.combined_data = bytearray() + + @pytest.mark.parametrize( + "read_bytes,expected_string", + ( + ([len("test")] + list(map(ord, "test")), "test"), + ([len("a" * 100)] + list(map(ord, "a" * 100)), "a" * 100), + ([0], ""), + ), + ) + def test_read_utf(self, read_bytes: list[int], expected_string: str, read_mock: ReadFunctionMock): + """Reading UTF string results in correct values.""" + read_mock.combined_data = bytearray(read_bytes) + assert self.reader.read_utf() == expected_string + + @pytest.mark.parametrize( + "read_bytes,expected_bytes", + ( + ([1, 1], [1]), + ([0], []), + ([5, 104, 101, 108, 108, 111], [104, 101, 108, 108, 111]), + ), + ) + def test_read_bytearray(self, read_bytes: list[int], expected_bytes: list[int], read_mock: ReadFunctionMock): + """Writing a bytearray results in correct bytes.""" + read_mock.combined_data = bytearray(read_bytes) + assert self.reader.read_bytearray() == bytearray(expected_bytes) + + +class TestWriter: + @classmethod + def setup_class(cls): + """Initialize writer instance to be tested.""" + cls.writer = Writer() + + @pytest.fixture + def write_mock(self, monkeypatch: pytest.MonkeyPatch): + """Monkeypatch the write function with a mock which is returned.""" + mock_f = WriteFunctionMock() + monkeypatch.setattr(self.writer.__class__, "write", mock_f) + return mock_f + + @pytest.mark.parametrize( + "format,value,expected_bytes", + ( + (StructFormat.UBYTE, 0, [0]), + (StructFormat.UBYTE, 15, [15]), + (StructFormat.UBYTE, 255, [255]), + (StructFormat.BYTE, 0, [0]), + (StructFormat.BYTE, 15, [15]), + (StructFormat.BYTE, 127, [127]), + (StructFormat.BYTE, -20, [to_twos_complement(-20, bits=8)]), + (StructFormat.BYTE, -128, [to_twos_complement(-128, bits=8)]), + ), + ) + def test_write_value( + self, + format: INT_FORMATS_TYPE, + value: Any, + expected_bytes: list[int], + write_mock: WriteFunctionMock, + ): + """Writing different values of certain struct format should produce proper expected values.""" + self.writer.write_value(format, value) + write_mock.assert_has_data(bytearray(expected_bytes)) + + @pytest.mark.parametrize( + "format,value", + ( + (StructFormat.UBYTE, -1), + (StructFormat.UBYTE, 256), + (StructFormat.BYTE, -129), + (StructFormat.BYTE, 128), + ), + ) + def test_write_value_out_of_range( + self, + format: INT_FORMATS_TYPE, + value: Any, + ): + """Trying to write out of range values for given struct type should produce an exception.""" + with pytest.raises(ValueError): + self.writer.write_value(format, value) + + @pytest.mark.parametrize( + "number,expected_bytes", + ( + (0, [0]), + (1, [1]), + (2, [2]), + (15, [15]), + (127, [127]), + (128, [128, 1]), + (129, [129, 1]), + (255, [255, 1]), + (1000000, [192, 132, 61]), + (2147483647, [255, 255, 255, 255, 7]), + ), + ) + def test_write_varuint(self, number: int, expected_bytes: list[int], write_mock: WriteFunctionMock): + """Writing varuints results in correct bytes.""" + self.writer.write_varuint(number, max_bits=32) + write_mock.assert_has_data(bytearray(expected_bytes)) + + @pytest.mark.parametrize( + "write_value,max_bits", + ( + (-1, 128), + (-1, 1), + (2**16, 16), + (2**32, 32), + ), + ) + def test_write_varuint_out_of_range(self, write_value: int, max_bits: int): + """Trying to write a varuint bigger than specified bit size should produce ValueError""" + with pytest.raises(ValueError): + self.writer.write_varuint(write_value, max_bits=max_bits) + + @pytest.mark.parametrize( + "number,expected_bytes", + ( + (0, [0]), + (1, [1]), + (128, [128, 1]), + (255, [255, 1]), + (2147483647, [255, 255, 255, 255, 7]), + (-1, [255, 255, 255, 255, 15]), + (-256, [128, 254, 255, 255, 15]), + ), + ) + def test_write_varint(self, number: int, expected_bytes: list[int], write_mock: WriteFunctionMock): + """Writing varints results in correct bytes.""" + self.writer.write_varint(number, max_bits=32) + write_mock.assert_has_data(bytearray(expected_bytes)) + + @pytest.mark.parametrize( + "value,max_bits", + ( + (-2147483649, 32), + (2147483648, 32), + (10**20, 32), + (-(10**20), 32), + ), + ) + def test_write_varint_out_of_range(self, value: int, max_bits: int): + """Writing varint outside of signed max_bits int range should raise ValueError on it's own.""" + with pytest.raises(ValueError): + self.writer.write_varint(value, max_bits=max_bits) + + @pytest.mark.parametrize( + "string,expected_bytes", + ( + ("test", [len("test")] + list(map(ord, "test"))), + ("a" * 100, [len("a" * 100)] + list(map(ord, "a" * 100))), + ("", [0]), + ), + ) + def test_write_utf(self, string: str, expected_bytes: list[int], write_mock: WriteFunctionMock): + """Writing UTF string results in correct bytes.""" + self.writer.write_utf(string) + write_mock.assert_has_data(bytearray(expected_bytes)) + + @pytest.mark.parametrize( + "input_bytes,expected_bytes", + ( + ([1], [1, 1]), + ([], [0]), + ([104, 101, 108, 108, 111], [5, 104, 101, 108, 108, 111]), + ), + ) + def test_write_bytearray(self, input_bytes: list[int], expected_bytes: list[int], write_mock: WriteFunctionMock): + """Writing a bytearray results in correct bytes.""" + self.writer.write_bytearray(bytearray(input_bytes)) + write_mock.assert_has_data(bytearray(expected_bytes)) diff --git a/zerocom/exceptions.py b/zerocom/exceptions.py new file mode 100644 index 0000000..491c8a2 --- /dev/null +++ b/zerocom/exceptions.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from enum import Enum +from typing import Literal, Optional, overload + + +class ZerocomError(Exception): + ... + + +class MalformedPacketState(Enum): + """Enum describing all possible states for a malformed packet.""" + + MALFORMED_PACKET_ID = "Failed to read packet id" + UNRECOGNIZED_PACKET_ID = "Unknown packet id" + MALFORMED_PACKET_BODY = "Reading packet failed" + + +class MalformedPacketError(ZerocomError): + """Exception representing an issue while receiving packet.""" + + @overload + def __init__(self, state: Literal[MalformedPacketState.MALFORMED_PACKET_ID], *, ioerror: IOError): + ... + + @overload + def __init__(self, state: Literal[MalformedPacketState.UNRECOGNIZED_PACKET_ID], *, packet_id: int): + ... + + @overload + def __init__(self, state: Literal[MalformedPacketState.MALFORMED_PACKET_BODY], *, ioerror: IOError, packet_id: int): + ... + + def __init__( + self, + state: MalformedPacketState, + *, + ioerror: Optional[IOError] = None, + packet_id: Optional[int] = None, + ): + self.state = state + self.packet_id = packet_id + self.ioerror = ioerror + + msg_tail = [] + if self.packet_id: + msg_tail.append(f"Packet ID: {self.packet_id}") + if self.ioerror: + msg_tail.append(f"Underlying IOError data: {self.ioerror}") + + msg = self.state.value + if len(msg_tail) > 0: + msg += f" ({', '.join(msg_tail)})" + + self.msg = msg + return super().__init__(msg) diff --git a/zerocom/packets/__init__.py b/zerocom/packets/__init__.py new file mode 100644 index 0000000..7b37142 --- /dev/null +++ b/zerocom/packets/__init__.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from zerocom.exceptions import MalformedPacketError, MalformedPacketState +from zerocom.packets.abc import Packet +from zerocom.packets.message import MessagePacket +from zerocom.protocol.base_io import BaseReader, BaseWriter, StructFormat + +_PACKETS: list[type[Packet]] = [MessagePacket] +PACKET_MAP: dict[int, type[Packet]] = {} + +for packet_cls in _PACKETS: + PACKET_MAP[packet_cls.PACKET_ID] = packet_cls + + +# TODO: Consider adding these functions into BaseWriter/BaseReader + + +def write_packet(writer: BaseWriter, packet: Packet) -> None: + """Write given packet.""" + writer.write_value(StructFormat.SHORT, packet.PACKET_ID) + packet.write(writer) + + +def read_packet(reader: BaseReader) -> Packet: + """Read any arbitrary packet based on it's ID.""" + try: + packet_id = reader.read_value(StructFormat.SHORT) + except IOError as exc: + raise MalformedPacketError(MalformedPacketState.MALFORMED_PACKET_ID, ioerror=exc) + + if packet_id not in PACKET_MAP: + raise MalformedPacketError(MalformedPacketState.UNRECOGNIZED_PACKET_ID, packet_id=packet_id) + + try: + return PACKET_MAP[packet_id].read(reader) + except IOError as exc: + raise MalformedPacketError(MalformedPacketState.MALFORMED_PACKET_BODY, ioerror=exc, packet_id=packet_id) diff --git a/zerocom/packets/abc.py b/zerocom/packets/abc.py new file mode 100644 index 0000000..1552935 --- /dev/null +++ b/zerocom/packets/abc.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from abc import ABC +from typing import ClassVar + +from zerocom.protocol.rw_capable import ReadWriteCapable + + +class Packet(ReadWriteCapable, ABC): + """Base class for all packets""" + + PACKET_ID: ClassVar[int] + + def __init__(self, *args, **kwargs): + """Enforce PAKCET_ID being set for each instance of concrete packet classes.""" + cls = self.__class__ + if not hasattr(cls, "PACKET_ID"): + raise TypeError(f"Can't instantiate abstract {cls.__name__} class without defining PACKET_ID variable.") + return super().__init__(*args, **kwargs) diff --git a/zerocom/packets/message.py b/zerocom/packets/message.py new file mode 100644 index 0000000..3714a59 --- /dev/null +++ b/zerocom/packets/message.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import rsa + +from zerocom.packets.abc import Packet +from zerocom.protocol.base_io import BaseReader, BaseWriter + +if TYPE_CHECKING: + from typing_extensions import Self + + +class MessagePacket(Packet): + """Packet conveying message information.""" + + PACKET_ID = 0 + + def __init__(self, content: str, signature: bytes) -> None: + super().__init__() + self.content = content + self.signature = signature + + def write(self, writer: BaseWriter) -> None: + writer.write_bytearray(self.signature) + writer.write_utf(self.content) + + @classmethod + def read(cls, reader: BaseReader) -> Self: + signature = reader.read_bytearray() + content = reader.read_utf() + return cls(content, signature) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(content={self.content!r}, signature={self.signature!r})" + + def verify(self, public_key: rsa.PublicKey) -> bool: + """Verify that the message signature was made by a private key with given public_key.""" + try: + used_hash = rsa.verify(self.content.encode(), self.signature, public_key) + except rsa.VerificationError: + return False + else: + return used_hash == "SHA-1" + + @classmethod + def make_signed(cls, content: str, private_key: rsa.PrivateKey) -> Self: + """Create a new message packet with given content signed by given private_key.""" + signature = rsa.sign(content.encode(), private_key, "SHA-1") + return cls(content, signature) diff --git a/zerocom/protocol/abc.py b/zerocom/protocol/abc.py deleted file mode 100644 index d85043c..0000000 --- a/zerocom/protocol/abc.py +++ /dev/null @@ -1,197 +0,0 @@ -from __future__ import annotations - -import struct -from abc import ABC, abstractmethod -from itertools import count -from typing import Any, Optional, cast - -from zerocom.protocol.utils import enforce_range - - -class BaseWriter(ABC): - """Base class holding write buffer/connection interactions.""" - - __slots__ = () - - @abstractmethod - def write(self, data: bytes) -> None: - ... - - def _write_packed(self, fmt: str, *value: object) -> None: - """Write a value of given struct format in big-endian mode. - - Available formats are listed in struct module's docstring. - """ - self.write(struct.pack(">" + fmt, *value)) - - @enforce_range(typ="Byte (8-bit signed int)", byte_size=1, signed=True) - def write_byte(self, value: int) -> None: - """Write a single signed 8-bit integer. - - Signed 8-bit integers must be within the range of -128 and 127. Going outside this range will raise a - ValueError. - - Number is written in two's complement format. - """ - self._write_packed("b", value) - - @enforce_range(typ="Unsigned byte (8-bit unsigned int)", byte_size=1, signed=False) - def write_ubyte(self, value: int) -> None: - """Write a single unsigned 8-bit integer. - - Unsigned 8-bit integers must be within range of 0 and 255. Going outside this range will raise a ValueError. - """ - self._write_packed("B", value) - - def write_varint(self, value: int, *, max_size: Optional[int] = None) -> None: - """Write an arbitrarily big unsigned integer in a variable length format. - - This is a standard way of transmitting ints, and it allows smaller numbers to take less bytes. - - Will keep writing bytes until the value is depleted (fully sent). If `max_size` is specified, writing will be - limited up to integer values of max_size bytes, and trying to write bigger values will rase a ValueError. Note - that limiting to max_size of 4 (32-bit int) doesn't imply at most 4 bytes will be sent, and will in fact take 5 - bytes at most, due to the variable encoding overhead. - - Varnums use 7 least significant bits of each sent byte to encode the value, and the most significant bit to - indicate whether there is another byte after it. The least significant group is written first, followed by each - of the more significant groups, making varints little-endian, however in groups of 7 bits, not 8. - """ - # We can't use enforce_range as decorator directly, because our byte_size varies - # instead run it manually from here as a check function - _wrapper = enforce_range( - typ=f"{max_size if max_size else 'unlimited'}-byte unsigned varnum", - byte_size=max_size if max_size else None, - signed=False, - ) - _check_f = _wrapper(lambda self, value: None) - _check_f(self, value) - - remaining = value - while True: - if remaining & ~0x7F == 0: # final byte - self.write_ubyte(remaining) - return - # Write only 7 least significant bits, with the first being 1. - # first bit here represents that there will be another value after - self.write_ubyte(remaining & 0x7F | 0x80) - # Subtract the value we've already sent (7 least significant bits) - remaining >>= 7 - - def write_utf(self, value: str, max_varint_size: int = 2) -> None: - """Write a UTF-8 encoded string, prefixed with a varshort of it's size (in bytes). - - Will write n bytes, depending on the amount of bytes in the string + up to 3 bytes from prefix varshort, - holding this size (n). This means a maximum of 2**31-1 + 5 bytes can be written. - - Individual UTF-8 characters can take up to 4 bytes, however most of the common ones take up less. Assuming the - worst case of 4 bytes per every character, at most 8192 characters can be written, however this number - will usually be much bigger (up to 4x) since it's unlikely each character would actually take up 4 bytes. (All - of the ASCII characters only take up 1 byte). - - If the given string is longer than this, ValueError will be raised for trying to write an invalid varshort. - """ - data = bytearray(value, "utf-8") - self.write_varint(len(value), max_size=max_varint_size) - self.write(data) - - def write_bytearray(self, data: bytearray) -> None: - """Write an arbitrary sequence of bytes, prefixed with a varint of it's size.""" - self.write_varint(len(data)) - self.write(data) - - -class BaseReader(ABC): - """Base class holding read buffer/connection interactions.""" - - __slots__ = () - - @abstractmethod - def read(self, length: int) -> bytearray: - ... - - def _read_unpacked(self, fmt: str) -> Any: # noqa: ANN401 - """Read bytes and unpack them into given struct format in big-endian mode. - - - The amount of bytes to read will be determined based on the format string automatically. - i.e.: With format of "iii" (referring to 3 signed 32-bit ints), the read length is set as 3x4 (since a signed - 32-bit int takes 4 bytes), making the total length to read 12 bytes, returned as Tuple[int, int, int] - - Available formats are listed in struct module's docstring. - """ - length = struct.calcsize(fmt) - data = self.read(length) - unpacked = struct.unpack(">" + fmt, data) - - if len(unpacked) == 1: - return unpacked[0] - return unpacked - - def read_byte(self) -> int: - """Read a single signed 8-bit integer. - - Will read 1 byte in two's complement format, getting int values between -128 and 127. - """ - return self._read_unpacked("b") - - def read_ubyte(self) -> int: - """Read a single unsigned 8-bit integer. - - Will read 1 byte, getting int value between 0 and 255 directly. - """ - return self._read_unpacked("B") - - def read_varint(self, *, max_size: Optional[int] = None) -> int: - """Read an arbitrarily big unsigned integer in a variable length format. - - This is a standard way of transmitting ints, and it allows smaller numbers to take less bytes. - - Will keep reading bytes until the value is depleted (fully sent). If `max_size` is specified, reading will be - limited up to integer values of max_size bytes, and trying to read bigger values will rase an IOError. Note - that limiting to max_size of 4 (32-bit int) doesn't imply at most 4 bytes will be sent, and will in fact take 5 - bytes at most, due to the variable encoding overhead. - - Varnums use 7 least significant bits of each sent byte to encode the value, and the most significant bit to - indicate whether there is another byte after it. The least significant group is written first, followed by each - of the more significant groups, making varints little-endian, however in groups of 7 bits, not 8. - """ - value_max = (1 << (max_size * 8)) - 1 if max_size else None - result = 0 - for i in count(): - byte = self.read_ubyte() - # Read 7 least significant value bits in this byte, and shift them appropriately to be in the right place - # then simply add them (OR) as additional 7 most significant bits in our result - result |= (byte & 0x7F) << (7 * i) - - # Ensure that we stop reading and raise an error if the size gets over the maximum - # (if the current amount of bits is higher than allowed size in bits) - if value_max and result > value_max: - max_size = cast(int, max_size) - raise IOError(f"Received varint was outside the range of {max_size}-byte ({max_size * 8}-bit) int.") - - # If the most significant bit is 0, we should stop reading - if not byte & 0x80: - break - - return result - - def read_utf(self, max_varint_size: int = 2) -> str: - """Read a UTF-8 encoded string, prefixed with a varshort of it's size (in bytes). - - Will read n bytes, depending on the prefix varint (amount of bytes in the string) + up to 3 bytes from prefix - varshort itself, holding this size (n). This means a maximum of 2**15-1 + 3 bytes can be read (and written). - - Individual UTF-8 characters can take up to 4 bytes, however most of the common ones take up less. Assuming the - worst case of 4 bytes per every character, at most 8192 characters can be written, however this number - will usually be much bigger (up to 4x) since it's unlikely each character would actually take up 4 bytes. (All - of the ASCII characters only take up 1 byte). - """ - length = self.read_varint(max_size=max_varint_size) - bytes = self.read(length) - return bytes.decode("utf-8") - - def read_bytearray(self) -> bytearray: - """Read an arbitrary sequence of bytes, prefixed with a varint of it's size.""" - length = self.read_varint() - return self.read(length) diff --git a/zerocom/protocol/base_io.py b/zerocom/protocol/base_io.py new file mode 100644 index 0000000..95c94c0 --- /dev/null +++ b/zerocom/protocol/base_io.py @@ -0,0 +1,242 @@ +from __future__ import annotations + +import struct +from abc import ABC, abstractmethod +from enum import Enum +from itertools import count +from typing import Literal, TypeAlias, Union, overload + +from zerocom.protocol.utils import from_twos_complement, to_twos_complement + + +class StructFormat(str, Enum): + """All possible write/read struct types.""" + + BOOL = "?" + CHAR = "c" + BYTE = "b" + UBYTE = "B" + SHORT = "h" + USHORT = "H" + INT = "i" + UINT = "I" + LONG = "l" + ULONG = "L" + FLOAT = "f" + DOUBLE = "d" + HALFFLOAT = "e" + LONGLONG = "q" + ULONGLONG = "Q" + + +INT_FORMATS_TYPE: TypeAlias = Union[ + Literal[StructFormat.BYTE], + Literal[StructFormat.UBYTE], + Literal[StructFormat.SHORT], + Literal[StructFormat.USHORT], + Literal[StructFormat.INT], + Literal[StructFormat.UINT], + Literal[StructFormat.LONG], + Literal[StructFormat.ULONG], + Literal[StructFormat.LONGLONG], + Literal[StructFormat.ULONGLONG], +] + +FLOAT_FORMATS_TYPE: TypeAlias = Union[ + Literal[StructFormat.FLOAT], + Literal[StructFormat.DOUBLE], + Literal[StructFormat.HALFFLOAT], +] + + +class BaseWriter(ABC): + """Base class holding write buffer/connection interactions.""" + + __slots__ = () + + @abstractmethod + def write(self, data: bytes) -> None: + ... + + @overload + def write_value(self, fmt: INT_FORMATS_TYPE, value: int) -> None: + ... + + @overload + def write_value(self, fmt: FLOAT_FORMATS_TYPE, value: float) -> None: + ... + + @overload + def write_value(self, fmt: Literal[StructFormat.BOOL], value: bool) -> None: + ... + + @overload + def write_value(self, fmt: Literal[StructFormat.CHAR], value: str) -> None: + ... + + def write_value(self, fmt: StructFormat, value: object) -> None: + """Write a value of given struct format in big-endian mode.""" + try: + self.write(struct.pack(">" + fmt.value, value)) + except struct.error as exc: + raise ValueError(str(exc)) from exc + + def write_varuint(self, value: int, /, *, max_bits: int) -> None: + """Write an arbitrarily big unsigned integer in a variable length format. + + This is a standard way of transmitting ints, and it allows smaller numbers to take less bytes. + + Writing will be limited up to integer values of `max_bits` bits, and trying to write bigger values will rase a + ValueError. Note that setting `max_bits` to for example 32 bits doesn't mean that at most 4 bytes will be sent, + in this case it would actually take at most 5 bytes, due to the variable encoding overhead. + + Varints send bytes where 7 least significant bits are value bits, and the most significant bit is continuation + flag bit. If this continuation bit is set (1), it indicates that there will be another varnum byte sent after + this one. The least significant group is written first, followed by each of the more significant groups, making + varnums little-endian, however in groups of 7 bits, not 8. + """ + value_max = (1 << (max_bits)) - 1 + if value < 0 or value > value_max: + raise ValueError(f"Tried to write varint outside of the range of {max_bits}-bit int.") + + remaining = value + while True: + if remaining & ~0x7F == 0: # final byte + self.write_value(StructFormat.UBYTE, remaining) + return + # Write only 7 least significant bits with the first bit being 1, marking there will be another byte + self.write_value(StructFormat.UBYTE, remaining & 0x7F | 0x80) + # Subtract the value we've already sent (7 least significant bits) + remaining >>= 7 + + def write_varint(self, value: int, /, *, max_bits: int) -> None: + """Write an arbitrarily big signed integer in a variable length format. + + For more information about varints check `write_varuint` docstring. + """ + val = to_twos_complement(value, bits=max_bits) + self.write_varuint(val, max_bits=max_bits) + + def write_utf(self, value: str, /, *, max_varuint_bits: int = 16) -> None: + """Write a UTF-8 encoded string, prefixed with a varuint of given bit size. + + Will write n bytes, depending on the amount of bytes in the string + bytes from prefix varuint, + holding this size (n). + + Individual UTF-8 characters can take up to 4 bytes, however most of the common ones take up less. In most + cases, characters will generally only take 1 byte per character (for all ASCII characters). + + The amount of bytes can't surpass the specified varuint size, otherwise a ValueError will be raised from trying + to write an invalid varuint. + """ + data = bytearray(value, "utf-8") + self.write_varuint(len(data), max_bits=max_varuint_bits) + self.write(data) + + def write_bytearray(self, data: bytes, /, *, max_varuint_bits: int = 16) -> None: + """Write an arbitrary sequence of bytes, prefixed with a varint of it's size.""" + self.write_varuint(len(data), max_bits=max_varuint_bits) + self.write(data) + + +class BaseReader(ABC): + """Base class holding read buffer/connection interactions.""" + + __slots__ = () + + @abstractmethod + def read(self, length: int) -> bytearray: + ... + + @overload + def read_value(self, fmt: INT_FORMATS_TYPE) -> int: + ... + + @overload + def read_value(self, fmt: FLOAT_FORMATS_TYPE) -> float: + ... + + @overload + def read_value(self, fmt: Literal[StructFormat.BOOL]) -> bool: + ... + + @overload + def read_value(self, fmt: Literal[StructFormat.CHAR]) -> str: + ... + + def read_value(self, fmt: StructFormat) -> object: + """Read a value into given struct format in big-endian mode. + + The amount of bytes to read will be determined based on the struct format automatically. + """ + length = struct.calcsize(fmt.value) + data = self.read(length) + try: + unpacked = struct.unpack(">" + fmt.value, data) + except struct.error as exc: + raise ValueError(str(exc)) from exc + return unpacked[0] + + def read_varuint(self, *, max_bits: int) -> int: + """Read an arbitrarily big unsigned integer in a variable length format. + + This is a standard way of transmitting ints, and it allows smaller numbers to take less bytes. + + Reading will be limited up to integer values of `max_bits` bits, and trying to read bigger values will rase an + IOError. Note that setting `max_bits` to for example 32 bits doesn't mean that at most 4 bytes will be read, + in this case it would actually read at most 5 bytes, due to the variable encoding overhead. + + Varints send bytes where 7 least significant bits are value bits, and the most significant bit is continuation + flag bit. If this continuation bit is set (1), it indicates that there will be another varnum byte sent after + this one. The least significant group is written first, followed by each of the more significant groups, making + varnums little-endian, however in groups of 7 bits, not 8. + """ + value_max = (1 << (max_bits)) - 1 + + result = 0 + for i in count(): # pragma: no branch # count() iterator won't ever deplete + byte = self.read_value(StructFormat.UBYTE) + # Read 7 least significant value bits in this byte, and shift them appropriately to be in the right place + # then simply add them (OR) as additional 7 most significant bits in our result + result |= (byte & 0x7F) << (7 * i) + + # Ensure that we stop reading and raise an error if the size gets over the maximum + # (if the current amount of bits is higher than allowed size in bits) + if result > value_max: + raise IOError(f"Received varint was outside the range of {max_bits}-bit int.") + + # If the most significant bit is 0, we should stop reading + if not byte & 0x80: + break + + return result + + def read_varint(self, *, max_bits: int) -> int: + """Read an arbitrarily big signed integer in a variable length format. + + For more information about varints check `read_varuint` docstring. + """ + unsigned_num = self.read_varuint(max_bits=max_bits) + val = from_twos_complement(unsigned_num, bits=max_bits) + return val + + def read_utf(self, *, max_varuint_bits: int = 16) -> str: + """Read a UTF-8 encoded string, prefixed with a varuint of given bit size. + + Will read n bytes, depending on the amount of bytes in the string + bytes from prefix varuint, + holding this size (n) + + Individual UTF-8 characters can take up to 4 bytes, however most of the common ones take up less. In most + cases, characters will generally only take 1 byte per character (for all ASCII characters). + + The amount of bytes can't surpass the specified varuint size, otherwise an IOError will be raised from trying + to read an invalid varuint. + """ + length = self.read_varuint(max_bits=max_varuint_bits) + bytes = self.read(length) + return bytes.decode("utf-8") + + def read_bytearray(self, *, max_varuint_bits: int = 16) -> bytearray: + """Read an arbitrary sequence of bytes, prefixed with a varint of it's size.""" + length = self.read_varuint(max_bits=max_varuint_bits) + return self.read(length) diff --git a/zerocom/protocol/buffer.py b/zerocom/protocol/buffer.py index 210df6d..b4bbf3a 100644 --- a/zerocom/protocol/buffer.py +++ b/zerocom/protocol/buffer.py @@ -1,6 +1,6 @@ from __future__ import annotations -from zerocom.protocol.abc import BaseReader, BaseWriter +from zerocom.protocol.base_io import BaseReader, BaseWriter class Buffer(BaseReader, BaseWriter, bytearray): diff --git a/zerocom/protocol/connection.py b/zerocom/protocol/connection.py index 5375ad6..3d28e3b 100644 --- a/zerocom/protocol/connection.py +++ b/zerocom/protocol/connection.py @@ -3,7 +3,7 @@ import socket from typing import Generic, TypeVar -from zerocom.protocol.abc import BaseReader, BaseWriter +from zerocom.protocol.base_io import BaseReader, BaseWriter T_SOCK = TypeVar("T_SOCK", bound=socket.socket) diff --git a/zerocom/protocol/rw_capable.py b/zerocom/protocol/rw_capable.py new file mode 100644 index 0000000..3bd2074 --- /dev/null +++ b/zerocom/protocol/rw_capable.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +from zerocom.protocol.base_io import BaseReader, BaseWriter + +if TYPE_CHECKING: + from typing_extensions import Self + + +class WriteCapable(ABC): + """Base class providing writing capabilities.""" + + @abstractmethod + def write(self, writer: BaseWriter) -> None: + raise NotImplementedError() + + +class ReadCapable(ABC): + """Base class providing reading capabilities.""" + + @classmethod + @abstractmethod + def read(cls, reader: BaseReader) -> Self: + raise NotImplementedError() + + +class ReadWriteCapable(ReadCapable, WriteCapable): + """Base class providing read and write capabilities.""" diff --git a/zerocom/protocol/utils.py b/zerocom/protocol/utils.py index e0d0350..b0c6803 100644 --- a/zerocom/protocol/utils.py +++ b/zerocom/protocol/utils.py @@ -1,54 +1,25 @@ from __future__ import annotations -from functools import wraps -from typing import Callable, Optional, TYPE_CHECKING, TypeVar, cast -if TYPE_CHECKING: - from typing_extensions import ParamSpec +def to_twos_complement(num: int, bits: int) -> int: + """Convert a given number into twos complement format of given amount of bits.""" + value_max = 1 << (bits - 1) + value_min = value_max * -1 + # With two's complement, we have one more negative number than positive + # this means we can't be exactly at value_max, but we can be at exactly value_min + if num >= value_max or num < value_min: + raise ValueError(f"Can't convert number {num} into {bits}-bit twos complement format - out of range") - P = ParamSpec("P") -else: - P = [] + return num + (1 << bits) if num < 0 else num -R = TypeVar("R") +def from_twos_complement(num: int, bits: int) -> int: + """Convert a given number from twos complement format of given amount of bits.""" + value_max = (1 << bits) - 1 + if num < 0 or num > value_max: + raise ValueError(f"Can't convert number {num} from {bits}-bit twos complement format - out of range") -def enforce_range(*, typ: str, byte_size: Optional[int], signed: bool) -> Callable: - """Decorator enforcing proper int value range, based on the number of max bytes (size). + if num & (1 << (bits - 1)) != 0: + num -= 1 << bits - If a value is outside of the automatically determined allowed range, a ValueError will be raised, - showing the given `typ` along with the allowed range info. - - If the byte_size is None, infinite max size is assumed. Note that this is only possible with unsigned types, - since there's no point in enforcing infinite range. - """ - if byte_size is None: - if signed is True: - raise ValueError("Enforcing infinite byte-size for signed type doesn't make sense (includes all numbers).") - value_max = float("inf") - value_max_s = "infinity" - value_min = 0 - value_min_s = "0" - else: - if signed: - value_max = (1 << (byte_size * 8 - 1)) - 1 - value_max_s = f"{value_max} (2**{byte_size * 8 - 1} - 1)" - value_min = -1 << (byte_size * 8 - 1) - value_min_s = f"{value_min} (-2**{byte_size * 8 - 1})" - else: - value_max = (1 << (byte_size * 8)) - 1 - value_max_s = f"{value_max} (2**{byte_size * 8} - 1)" - value_min = 0 - value_min_s = "0" - - def wrapper(func: Callable[P, R]) -> Callable[P, R]: - @wraps(func) - def inner(*args: P.args, **kwargs: P.kwargs) -> R: - value = cast(int, args[1]) - if value > value_max or value < value_min: - raise ValueError(f"{typ} must be within {value_min_s} and {value_max_s}, got {value}.") - return func(*args, **kwargs) - - return inner - - return wrapper + return num