From 9ac9098c9c6c89a5408778c0c7a1dc5db2c6abda Mon Sep 17 00:00:00 2001 From: Anton Agestam Date: Mon, 29 Jan 2024 16:14:03 +0100 Subject: [PATCH] Raise BufferUnderflow when too few bytes to parse --- src/kio/serial/errors.py | 8 +- src/kio/serial/readers.py | 44 ++++++---- tests/serial/test_readers.py | 163 +++++++++++++++++++++++++++-------- 3 files changed, 157 insertions(+), 58 deletions(-) diff --git a/src/kio/serial/errors.py b/src/kio/serial/errors.py index 5ddfcd92..b01fcbe6 100644 --- a/src/kio/serial/errors.py +++ b/src/kio/serial/errors.py @@ -10,13 +10,17 @@ class UnexpectedNull(DecodeError): ... -class SchemaError(SerialError): +class BufferUnderflow(DecodeError): ... -class OutOfBoundValue(SerialError): +class SchemaError(SerialError): ... class EncodeError(SerialError): ... + + +class OutOfBoundValue(EncodeError): + ... diff --git a/src/kio/serial/readers.py b/src/kio/serial/readers.py index a1c3d1dd..9d7d5a95 100644 --- a/src/kio/serial/readers.py +++ b/src/kio/serial/readers.py @@ -23,46 +23,54 @@ from kio.static.primitive import u32 from kio.static.primitive import u64 +from .errors import BufferUnderflow from .errors import UnexpectedNull T = TypeVar("T") Reader: TypeAlias = Callable[[IO[bytes]], T] +def read_exact(buffer: IO[bytes], num_bytes: int) -> bytes: + value = buffer.read(num_bytes) + if len(value) != num_bytes: + raise BufferUnderflow(f"Expected to read {num_bytes}, got {len(value)}") + return value + + def read_boolean(buffer: IO[bytes]) -> bool: - return struct.unpack(">?", buffer.read(1))[0] # type: ignore[no-any-return] + return struct.unpack(">?", read_exact(buffer, 1))[0] # type: ignore[no-any-return] def read_int8(buffer: IO[bytes]) -> i8: - return struct.unpack(">b", buffer.read(1))[0] # type: ignore[no-any-return] + return struct.unpack(">b", read_exact(buffer, 1))[0] # type: ignore[no-any-return] def read_int16(buffer: IO[bytes]) -> i16: - return struct.unpack(">h", buffer.read(2))[0] # type: ignore[no-any-return] + return struct.unpack(">h", read_exact(buffer, 2))[0] # type: ignore[no-any-return] def read_int32(buffer: IO[bytes]) -> i32: - return struct.unpack(">i", buffer.read(4))[0] # type: ignore[no-any-return] + return struct.unpack(">i", read_exact(buffer, 4))[0] # type: ignore[no-any-return] def read_int64(buffer: IO[bytes]) -> i64: - return struct.unpack(">q", buffer.read(8))[0] # type: ignore[no-any-return] + return struct.unpack(">q", read_exact(buffer, 8))[0] # type: ignore[no-any-return] def read_uint8(buffer: IO[bytes]) -> u8: - return struct.unpack(">B", buffer.read(1))[0] # type: ignore[no-any-return] + return struct.unpack(">B", read_exact(buffer, 1))[0] # type: ignore[no-any-return] def read_uint16(buffer: IO[bytes]) -> u16: - return struct.unpack(">H", buffer.read(2))[0] # type: ignore[no-any-return] + return struct.unpack(">H", read_exact(buffer, 2))[0] # type: ignore[no-any-return] def read_uint32(buffer: IO[bytes]) -> u32: - return struct.unpack(">I", buffer.read(4))[0] # type: ignore[no-any-return] + return struct.unpack(">I", read_exact(buffer, 4))[0] # type: ignore[no-any-return] def read_uint64(buffer: IO[bytes]) -> u64: - return struct.unpack(">Q", buffer.read(8))[0] # type: ignore[no-any-return] + return struct.unpack(">Q", read_exact(buffer, 8))[0] # type: ignore[no-any-return] # See description and upstream implementation. @@ -74,7 +82,7 @@ def read_unsigned_varint(buffer: IO[bytes]) -> int: # Increase shift by 7 on each iteration, looping at most 5 times. for shift in range(0, 4 * 7 + 1, 7): # Read value by a byte at a time. - (byte,) = buffer.read(1) + (byte,) = read_exact(buffer, 1) # Add 7 least significant bits to the result. seven_bit_chunk = byte & 0b01111111 result |= seven_bit_chunk << shift @@ -87,7 +95,7 @@ def read_unsigned_varint(buffer: IO[bytes]) -> int: def read_float64(buffer: IO[bytes]) -> float: - return struct.unpack(">d", buffer.read(8))[0] # type: ignore[no-any-return] + return struct.unpack(">d", read_exact(buffer, 8))[0] # type: ignore[no-any-return] def read_compact_string_as_bytes(buffer: IO[bytes]) -> bytes: @@ -97,7 +105,7 @@ def read_compact_string_as_bytes(buffer: IO[bytes]) -> bytes: raise UnexpectedNull( "Unexpectedly read null where compact string/bytes was expected" ) - return buffer.read(length) + return read_exact(buffer, length) def read_compact_string_as_bytes_nullable(buffer: IO[bytes]) -> bytes | None: @@ -105,7 +113,7 @@ def read_compact_string_as_bytes_nullable(buffer: IO[bytes]) -> bytes | None: length = read_unsigned_varint(buffer) - 1 if length == -1: return None - return buffer.read(length) + return read_exact(buffer, length) def read_compact_string(buffer: IO[bytes]) -> str: @@ -123,28 +131,28 @@ def read_legacy_bytes(buffer: IO[bytes]) -> bytes: length = read_int32(buffer) if length == -1: raise UnexpectedNull("Unexpectedly read null where bytes was expected") - return buffer.read(length) + return read_exact(buffer, length) def read_nullable_legacy_bytes(buffer: IO[bytes]) -> bytes | None: length = read_int32(buffer) if length == -1: return None - return buffer.read(length) + return read_exact(buffer, length) def read_legacy_string(buffer: IO[bytes]) -> str: length = read_int16(buffer) if length == -1: raise UnexpectedNull("Unexpectedly read null where string/bytes was expected") - return buffer.read(length).decode() + return read_exact(buffer, length).decode() def read_nullable_legacy_string(buffer: IO[bytes]) -> str | None: length = read_int16(buffer) if length == -1: return None - return buffer.read(length).decode() + return read_exact(buffer, length).decode() read_legacy_array_length: Final = read_int32 @@ -156,7 +164,7 @@ def read_compact_array_length(buffer: IO[bytes]) -> int: def read_uuid(buffer: IO[bytes]) -> UUID | None: - byte_value: bytes = buffer.read(16) + byte_value: bytes = read_exact(buffer, 16) if byte_value == uuid_zero.bytes: return None return UUID(bytes=byte_value) diff --git a/tests/serial/test_readers.py b/tests/serial/test_readers.py index 32c22ae4..5c3bc2d6 100644 --- a/tests/serial/test_readers.py +++ b/tests/serial/test_readers.py @@ -3,9 +3,11 @@ import sys import uuid from typing import IO +from typing import final import pytest +from kio.serial.errors import BufferUnderflow from kio.serial.errors import UnexpectedNull from kio.serial.readers import Reader from kio.serial.readers import read_compact_string @@ -30,6 +32,41 @@ from kio.static.constants import uuid_zero +class BufferUnderflowContract: + reader: Reader[object] + valid_serialization: bytes + + @classmethod + def read(cls, buffer: IO[bytes]) -> object: + return cls.reader(buffer) + + @final + def test_raises_buffer_underflow_when_not_enough_bytes_for_value( + self, + buffer: IO[bytes], + ) -> None: + buffer.write(self.valid_serialization[:-1]) + buffer.seek(0) + + with pytest.raises(BufferUnderflow): + self.read(buffer) + + +class LengthBufferUnderflowContract(BufferUnderflowContract): + length_num_bytes: int + + @final + def test_raises_buffer_underflow_when_not_enough_bytes_for_length( + self, + buffer: IO[bytes], + ) -> None: + buffer.write(self.valid_serialization[: self.length_num_bytes - 1]) + buffer.seek(0) + + with pytest.raises(BufferUnderflow): + self.read(buffer) + + class IntReaderContract: reader: Reader[int] lower_limit: int @@ -42,99 +79,111 @@ class IntReaderContract: def read(cls, buffer: IO[bytes]) -> int: return cls.reader(buffer) + @final def test_can_read_lower_limit_sync(self, buffer: io.BytesIO) -> None: buffer.write(self.lower_limit_as_bytes) buffer.seek(0) assert self.lower_limit == self.read(buffer) + @final def test_can_read_upper_limit_sync(self, buffer: io.BytesIO) -> None: buffer.write(self.upper_limit_as_bytes) buffer.seek(0) assert self.upper_limit == self.read(buffer) + @final def test_can_read_zero_sync(self, buffer: io.BytesIO) -> None: buffer.write(self.zero_as_bytes) buffer.seek(0) assert self.read(buffer) == 0 -class TestReadInt8(IntReaderContract): +class TestReadInt8(IntReaderContract, BufferUnderflowContract): reader = read_int8 lower_limit = -128 lower_limit_as_bytes = b"\x80" upper_limit = 127 upper_limit_as_bytes = b"\x7f" zero_as_bytes = b"\x00" + valid_serialization = zero_as_bytes -class TestReadInt16(IntReaderContract): +class TestReadInt16(IntReaderContract, BufferUnderflowContract): reader = read_int16 lower_limit = -(2**15) lower_limit_as_bytes = b"\x80\x00" upper_limit = 2**15 - 1 upper_limit_as_bytes = b"\x7f\xff" zero_as_bytes = b"\x00\x00" + valid_serialization = zero_as_bytes -class TestReadInt32(IntReaderContract): +class TestReadInt32(IntReaderContract, BufferUnderflowContract): reader = read_int32 lower_limit = -(2**31) lower_limit_as_bytes = b"\x80\x00\x00\x00" upper_limit = 2**31 - 1 upper_limit_as_bytes = b"\x7f\xff\xff\xff" zero_as_bytes = b"\x00\x00\x00\x00" + valid_serialization = zero_as_bytes -class TestReadInt64(IntReaderContract): +class TestReadInt64(IntReaderContract, BufferUnderflowContract): reader = read_int64 lower_limit = -(2**63) lower_limit_as_bytes = b"\x80\x00\x00\x00\x00\x00\x00\x00" upper_limit = 2**63 - 1 upper_limit_as_bytes = b"\x7f\xff\xff\xff\xff\xff\xff\xff" zero_as_bytes = b"\x00\x00\x00\x00\x00\x00\x00\x00" + valid_serialization = zero_as_bytes -class TestReadUint8(IntReaderContract): +class TestReadUint8(IntReaderContract, BufferUnderflowContract): reader = read_uint8 lower_limit = 0 lower_limit_as_bytes = zero_as_bytes = b"\x00" upper_limit = 2**8 - 1 upper_limit_as_bytes = b"\xff" + valid_serialization = zero_as_bytes -class TestReadUint16(IntReaderContract): +class TestReadUint16(IntReaderContract, BufferUnderflowContract): reader = read_uint16 lower_limit = 0 lower_limit_as_bytes = zero_as_bytes = b"\x00\x00" upper_limit = 2**16 - 1 upper_limit_as_bytes = b"\xff\xff" lower_limit_error_message = "argument out of range" + valid_serialization = zero_as_bytes -class TestReadUint32(IntReaderContract): +class TestReadUint32(IntReaderContract, BufferUnderflowContract): reader = read_uint32 lower_limit = 0 lower_limit_as_bytes = zero_as_bytes = b"\x00\x00\x00\x00" upper_limit = 2**32 - 1 upper_limit_as_bytes = b"\xff\xff\xff\xff" lower_limit_error_message = "argument out of range" + valid_serialization = zero_as_bytes -class TestReadUint64(IntReaderContract): +class TestReadUint64(IntReaderContract, BufferUnderflowContract): reader = read_uint64 lower_limit = 0 lower_limit_as_bytes = zero_as_bytes = b"\x00\x00\x00\x00\x00\x00\x00\x00" upper_limit = 2**64 - 1 upper_limit_as_bytes = b"\xff\xff\xff\xff\xff\xff\xff\xff" match_error_message = r"int too large to convert" + valid_serialization = zero_as_bytes -class TestReadUnsignedVarint(IntReaderContract): +class TestReadUnsignedVarint(IntReaderContract, BufferUnderflowContract): reader = read_unsigned_varint lower_limit = 0 lower_limit_as_bytes = zero_as_bytes = b"\x00" upper_limit = 2**31 - 1 upper_limit_as_bytes = b"\xff\xff\xff\xff\x07" + valid_serialization = zero_as_bytes def test_raises_value_error_for_too_long_value(self, buffer: io.BytesIO) -> None: for _ in range(5): @@ -164,7 +213,10 @@ def test_can_read_known_value( assert self.read(buffer) == expected -class TestReadFloat64: +class TestReadFloat64(BufferUnderflowContract): + reader = read_float64 + valid_serialization = struct.pack(">d", 0) + @pytest.mark.parametrize( "value", ( @@ -180,10 +232,14 @@ class TestReadFloat64: def test_can_read_value(self, buffer: io.BytesIO, value: float) -> None: buffer.write(struct.pack(">d", value)) buffer.seek(0) - assert read_float64(buffer) == value + assert self.read(buffer) == value + +class TestReadCompactStringAsBytes(LengthBufferUnderflowContract): + reader = read_compact_string_as_bytes + length_num_bytes = 1 + valid_serialization = b"\x06hello" -class TestReadCompactStringAsBytes: def test_raises_unexpected_null_for_negative_length_sync( self, buffer: io.BytesIO, @@ -191,7 +247,7 @@ def test_raises_unexpected_null_for_negative_length_sync( buffer.write(0b00000000.to_bytes(1, "little")) buffer.seek(0) with pytest.raises(UnexpectedNull): - read_compact_string_as_bytes(buffer) + self.read(buffer) def test_can_read_bytes_sync( self, @@ -202,17 +258,21 @@ def test_can_read_bytes_sync( buffer.write(byte_length.to_bytes(1, "little")) buffer.write(value) buffer.seek(0) - assert value == read_compact_string_as_bytes(buffer) + assert value == self.read(buffer) + +class TestReadCompactStringAsBytesNullable(LengthBufferUnderflowContract): + reader = read_compact_string_as_bytes_nullable + length_num_bytes = 1 + valid_serialization = b"\x06hello" -class TestReadCompactStringAsBytesNullable: def test_returns_null_for_negative_length_sync( self, buffer: io.BytesIO, ) -> None: buffer.write(0b00000000.to_bytes(1, "little")) buffer.seek(0) - assert read_compact_string_as_bytes_nullable(buffer) is None + assert self.read(buffer) is None def test_can_read_bytes_sync( self, @@ -223,10 +283,14 @@ def test_can_read_bytes_sync( buffer.write(byte_length.to_bytes(1, "little")) buffer.write(value) buffer.seek(0) - assert value == read_compact_string_as_bytes_nullable(buffer) + assert value == self.read(buffer) -class TestReadCompactString: +class TestReadCompactString(LengthBufferUnderflowContract): + reader = read_compact_string + length_num_bytes = 1 + valid_serialization = b"\x06hello" + def test_raises_unexpected_null_for_negative_length_sync( self, buffer: io.BytesIO, @@ -234,7 +298,7 @@ def test_raises_unexpected_null_for_negative_length_sync( buffer.write((0).to_bytes(1, "little")) buffer.seek(0) with pytest.raises(UnexpectedNull): - read_compact_string(buffer) + self.read(buffer) def test_can_read_string_sync( self, @@ -246,17 +310,21 @@ def test_can_read_string_sync( buffer.write(byte_length.to_bytes(1, "little")) buffer.write(byte_value) buffer.seek(0) - assert value == read_compact_string(buffer) + assert value == self.read(buffer) -class TestReadCompactStringNullable: +class TestReadCompactStringNullable(LengthBufferUnderflowContract): + reader = read_compact_string_nullable + length_num_bytes = 1 + valid_serialization = b"\x06hello" + def test_returns_null_for_negative_length_sync( self, buffer: io.BytesIO, ) -> None: buffer.write((0).to_bytes(1, "little")) buffer.seek(0) - assert read_compact_string_nullable(buffer) is None + assert self.read(buffer) is None def test_can_read_string_sync( self, @@ -268,17 +336,21 @@ def test_can_read_string_sync( buffer.write(byte_length.to_bytes(1, "little")) buffer.write(byte_value) buffer.seek(0) - assert value == read_compact_string_nullable(buffer) + assert value == self.read(buffer) + +class TestReadNullableLegacyBytes(LengthBufferUnderflowContract): + reader = read_nullable_legacy_bytes + length_num_bytes = 4 + valid_serialization = b"\x00\x00\x00\x05hello" -class TestReadNullableLegacyBytes: def test_returns_none_for_negative_length_sync( self, buffer: io.BytesIO, ) -> None: buffer.write(struct.pack(">i", -1)) buffer.seek(0) - assert read_nullable_legacy_bytes(buffer) is None + assert self.read(buffer) is None def test_can_read_bytes_sync( self, @@ -289,10 +361,14 @@ def test_can_read_bytes_sync( buffer.write(struct.pack(">i", byte_length)) buffer.write(value) buffer.seek(0) - assert value == read_nullable_legacy_bytes(buffer) + assert value == self.read(buffer) -class TestReadLegacyString: +class TestReadLegacyString(LengthBufferUnderflowContract): + reader = read_legacy_string + length_num_bytes = 2 + valid_serialization = b"\x00\x05hello" + def test_raises_unexpected_null_for_negative_length_sync( self, buffer: io.BytesIO, @@ -300,7 +376,7 @@ def test_raises_unexpected_null_for_negative_length_sync( buffer.write(struct.pack(">h", -1)) buffer.seek(0) with pytest.raises(UnexpectedNull): - read_legacy_string(buffer) + self.read(buffer) def test_can_read_string_sync( self, @@ -312,17 +388,21 @@ def test_can_read_string_sync( buffer.write(struct.pack(">h", byte_length)) buffer.write(byte_value) buffer.seek(0) - assert value == read_legacy_string(buffer) + assert value == self.read(buffer) -class TestReadNullableLegacyString: +class TestReadNullableLegacyString(LengthBufferUnderflowContract): + reader = read_nullable_legacy_string + length_num_bytes = 2 + valid_serialization = b"\x00\x05hello" + def test_returns_null_for_negative_length_sync( self, buffer: io.BytesIO, ) -> None: buffer.write(struct.pack(">h", -1)) buffer.seek(0) - assert read_nullable_legacy_string(buffer) is None + assert self.read(buffer) is None def test_can_read_string_sync( self, @@ -334,10 +414,14 @@ def test_can_read_string_sync( buffer.write(struct.pack(">h", byte_length)) buffer.write(byte_value) buffer.seek(0) - assert value == read_nullable_legacy_string(buffer) + assert value == self.read(buffer) + +class TestReadLegacyBytes(LengthBufferUnderflowContract): + reader = read_legacy_bytes + length_num_bytes = 4 + valid_serialization = b"\x00\x00\x00\x05hello" -class TestReadLegacyBytes: def test_raises_unexpected_null_for_negative_length_sync( self, buffer: io.BytesIO, @@ -345,7 +429,7 @@ def test_raises_unexpected_null_for_negative_length_sync( buffer.write(struct.pack(">i", -1)) buffer.seek(0) with pytest.raises(UnexpectedNull): - read_legacy_bytes(buffer) + self.read(buffer) def test_can_read_bytes_sync( self, @@ -357,17 +441,20 @@ def test_can_read_bytes_sync( buffer.write(struct.pack(">i", byte_length)) buffer.write(byte_value) buffer.seek(0) - assert byte_value == read_legacy_bytes(buffer) + assert byte_value == self.read(buffer) + +class TestReadUUID(BufferUnderflowContract): + reader = read_uuid + valid_serialization = uuid_zero.bytes -class TestReadUUID: def test_reads_zero_as_none(self, buffer: io.BytesIO) -> None: buffer.write(uuid_zero.bytes) buffer.seek(0) - assert read_uuid(buffer) is None + assert self.read(buffer) is None def test_can_read_uuid4(self, buffer: io.BytesIO) -> None: value = uuid.uuid4() buffer.write(value.bytes) buffer.seek(0) - assert read_uuid(buffer) == value + assert self.read(buffer) == value