Skip to content

Commit

Permalink
Raise BufferUnderflow when too few bytes to parse
Browse files Browse the repository at this point in the history
  • Loading branch information
aiven-anton committed Jan 29, 2024
1 parent 2f9d9cc commit 9ac9098
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 58 deletions.
8 changes: 6 additions & 2 deletions src/kio/serial/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,17 @@ class UnexpectedNull(DecodeError):
...


class SchemaError(SerialError):
class BufferUnderflow(DecodeError):
...


class OutOfBoundValue(SerialError):
class SchemaError(SerialError):
...


class EncodeError(SerialError):
...


class OutOfBoundValue(EncodeError):
...
44 changes: 26 additions & 18 deletions src/kio/serial/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,46 +23,54 @@
from kio.static.primitive import u32
from kio.static.primitive import u64

from .errors import BufferUnderflow
from .errors import UnexpectedNull

T = TypeVar("T")
Reader: TypeAlias = Callable[[IO[bytes]], T]


def read_exact(buffer: IO[bytes], num_bytes: int) -> bytes:
value = buffer.read(num_bytes)
if len(value) != num_bytes:
raise BufferUnderflow(f"Expected to read {num_bytes}, got {len(value)}")
return value


def read_boolean(buffer: IO[bytes]) -> bool:
return struct.unpack(">?", buffer.read(1))[0] # type: ignore[no-any-return]
return struct.unpack(">?", read_exact(buffer, 1))[0] # type: ignore[no-any-return]


def read_int8(buffer: IO[bytes]) -> i8:
return struct.unpack(">b", buffer.read(1))[0] # type: ignore[no-any-return]
return struct.unpack(">b", read_exact(buffer, 1))[0] # type: ignore[no-any-return]


def read_int16(buffer: IO[bytes]) -> i16:
return struct.unpack(">h", buffer.read(2))[0] # type: ignore[no-any-return]
return struct.unpack(">h", read_exact(buffer, 2))[0] # type: ignore[no-any-return]


def read_int32(buffer: IO[bytes]) -> i32:
return struct.unpack(">i", buffer.read(4))[0] # type: ignore[no-any-return]
return struct.unpack(">i", read_exact(buffer, 4))[0] # type: ignore[no-any-return]


def read_int64(buffer: IO[bytes]) -> i64:
return struct.unpack(">q", buffer.read(8))[0] # type: ignore[no-any-return]
return struct.unpack(">q", read_exact(buffer, 8))[0] # type: ignore[no-any-return]


def read_uint8(buffer: IO[bytes]) -> u8:
return struct.unpack(">B", buffer.read(1))[0] # type: ignore[no-any-return]
return struct.unpack(">B", read_exact(buffer, 1))[0] # type: ignore[no-any-return]


def read_uint16(buffer: IO[bytes]) -> u16:
return struct.unpack(">H", buffer.read(2))[0] # type: ignore[no-any-return]
return struct.unpack(">H", read_exact(buffer, 2))[0] # type: ignore[no-any-return]


def read_uint32(buffer: IO[bytes]) -> u32:
return struct.unpack(">I", buffer.read(4))[0] # type: ignore[no-any-return]
return struct.unpack(">I", read_exact(buffer, 4))[0] # type: ignore[no-any-return]


def read_uint64(buffer: IO[bytes]) -> u64:
return struct.unpack(">Q", buffer.read(8))[0] # type: ignore[no-any-return]
return struct.unpack(">Q", read_exact(buffer, 8))[0] # type: ignore[no-any-return]


# See description and upstream implementation.
Expand All @@ -74,7 +82,7 @@ def read_unsigned_varint(buffer: IO[bytes]) -> int:
# Increase shift by 7 on each iteration, looping at most 5 times.
for shift in range(0, 4 * 7 + 1, 7):
# Read value by a byte at a time.
(byte,) = buffer.read(1)
(byte,) = read_exact(buffer, 1)
# Add 7 least significant bits to the result.
seven_bit_chunk = byte & 0b01111111
result |= seven_bit_chunk << shift
Expand All @@ -87,7 +95,7 @@ def read_unsigned_varint(buffer: IO[bytes]) -> int:


def read_float64(buffer: IO[bytes]) -> float:
return struct.unpack(">d", buffer.read(8))[0] # type: ignore[no-any-return]
return struct.unpack(">d", read_exact(buffer, 8))[0] # type: ignore[no-any-return]


def read_compact_string_as_bytes(buffer: IO[bytes]) -> bytes:
Expand All @@ -97,15 +105,15 @@ def read_compact_string_as_bytes(buffer: IO[bytes]) -> bytes:
raise UnexpectedNull(
"Unexpectedly read null where compact string/bytes was expected"
)
return buffer.read(length)
return read_exact(buffer, length)


def read_compact_string_as_bytes_nullable(buffer: IO[bytes]) -> bytes | None:
# Apache Kafka® uses the string length plus 1.
length = read_unsigned_varint(buffer) - 1
if length == -1:
return None
return buffer.read(length)
return read_exact(buffer, length)


def read_compact_string(buffer: IO[bytes]) -> str:
Expand All @@ -123,28 +131,28 @@ def read_legacy_bytes(buffer: IO[bytes]) -> bytes:
length = read_int32(buffer)
if length == -1:
raise UnexpectedNull("Unexpectedly read null where bytes was expected")
return buffer.read(length)
return read_exact(buffer, length)


def read_nullable_legacy_bytes(buffer: IO[bytes]) -> bytes | None:
length = read_int32(buffer)
if length == -1:
return None
return buffer.read(length)
return read_exact(buffer, length)


def read_legacy_string(buffer: IO[bytes]) -> str:
length = read_int16(buffer)
if length == -1:
raise UnexpectedNull("Unexpectedly read null where string/bytes was expected")
return buffer.read(length).decode()
return read_exact(buffer, length).decode()


def read_nullable_legacy_string(buffer: IO[bytes]) -> str | None:
length = read_int16(buffer)
if length == -1:
return None
return buffer.read(length).decode()
return read_exact(buffer, length).decode()


read_legacy_array_length: Final = read_int32
Expand All @@ -156,7 +164,7 @@ def read_compact_array_length(buffer: IO[bytes]) -> int:


def read_uuid(buffer: IO[bytes]) -> UUID | None:
byte_value: bytes = buffer.read(16)
byte_value: bytes = read_exact(buffer, 16)
if byte_value == uuid_zero.bytes:
return None
return UUID(bytes=byte_value)
Expand Down
Loading

0 comments on commit 9ac9098

Please sign in to comment.