Skip to content

Commit

Permalink
Handle parsing errors more gracefully
Browse files Browse the repository at this point in the history
  • Loading branch information
boppreh committed Jan 18, 2024
1 parent 8e0d5a4 commit 6dd7b40
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 17 deletions.
30 changes: 14 additions & 16 deletions src/hello_tls/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,38 +28,36 @@ class ServerHello:
cipher_suite: CipherSuite
group: Optional[Group]

def _make_stream_parser(packets: Iterable[bytes]) -> Tuple[Callable[[int], bytes], Callable[[], int]]:
def _bytes_to_int(b: bytes) -> int:
return int.from_bytes(b, byteorder='big')

def parse_server_hello(packets: Iterable[bytes]) -> ServerHello:
"""
Returns helper functions to parse a stream of packets.
Parses a Server Hello packet and returns the cipher suite accepted by the server.
"""
start = 0
packets_iter = iter(packets)
data = b''
data = next(packets_iter) # Buffer first packet.
def read_next(length: int) -> bytes:
""" Returns the next `length` unparsed bytes. """
nonlocal start, data
while start + length > len(data):
try:
# This is quadratic, but there are few packets to loop over.
data += next(packets_iter)
except StopIteration:
raise BadServerResponse('Server response ended unexpectedly')
value = data[start:start+length]
start += length
return value
return read_next, lambda: start

def _bytes_to_int(b: bytes) -> int:
return int.from_bytes(b, byteorder='big')

def parse_server_hello(packets: Iterable[bytes]) -> ServerHello:
"""
Parses a Server Hello packet and returns the cipher suite accepted by the server.
"""
read_next, current_position = _make_stream_parser(packets)

if data.startswith(b'HTTP/'):
raise BadServerResponse('Server responded with plaintext HTTP, not TLS', data)

record_type = RecordType(read_next(1))
legacy_record_version = read_next(2)
record_length = _bytes_to_int(read_next(2))
record_end = current_position() + record_length
record_end = start + record_length
if record_type == RecordType.ALERT:
# Server responded with an error.
alert_level = AlertLevel(read_next(1))
Expand All @@ -78,11 +76,11 @@ def parse_server_hello(packets: Iterable[bytes]) -> ServerHello:
cipher_suite = CipherSuite(read_next(2))
compression_method = CompressionMethod(read_next(1))
extensions_length = _bytes_to_int(read_next(2))
extensions_end = current_position() + extensions_length
extensions_end = start + extensions_length

group = None

while current_position() < extensions_end:
while start < extensions_end:
extension_type = ExtensionType(read_next(2))
extension_data_length = read_next(2)
extension_data = read_next(_bytes_to_int(extension_data_length))
Expand Down
6 changes: 5 additions & 1 deletion src/hello_tls/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,11 @@ def packet_stream() -> Iterator[bytes]:
raise EmptyServerResponse()
else:
break
server_hello = parse_server_hello(packet_stream())

try:
server_hello = parse_server_hello(packet_stream())
except ValueError as e:
raise BadServerResponse('Error parsing server response') from e

if server_hello.version not in client_hello.protocols:
# Server picked a protocol we didn't ask for.
Expand Down

0 comments on commit 6dd7b40

Please sign in to comment.