diff --git a/src/hello_tls/protocol.py b/src/hello_tls/protocol.py index 4de4991..4f6dd0d 100644 --- a/src/hello_tls/protocol.py +++ b/src/hello_tls/protocol.py @@ -28,38 +28,36 @@ class ServerHello: cipher_suite: CipherSuite group: Optional[Group] -def _make_stream_parser(packets: Iterable[bytes]) -> Tuple[Callable[[int], bytes], Callable[[], int]]: +def _bytes_to_int(b: bytes) -> int: + return int.from_bytes(b, byteorder='big') + +def parse_server_hello(packets: Iterable[bytes]) -> ServerHello: """ - Returns helper functions to parse a stream of packets. + Parses a Server Hello packet and returns the cipher suite accepted by the server. """ start = 0 packets_iter = iter(packets) - data = b'' + data = next(packets_iter) # Buffer first packet. def read_next(length: int) -> bytes: + """ Returns the next `length` unparsed bytes. """ nonlocal start, data while start + length > len(data): try: + # This is quadratic, but there are few packets to loop over. data += next(packets_iter) except StopIteration: raise BadServerResponse('Server response ended unexpectedly') value = data[start:start+length] start += length return value - return read_next, lambda: start - -def _bytes_to_int(b: bytes) -> int: - return int.from_bytes(b, byteorder='big') - -def parse_server_hello(packets: Iterable[bytes]) -> ServerHello: - """ - Parses a Server Hello packet and returns the cipher suite accepted by the server. - """ - read_next, current_position = _make_stream_parser(packets) + if data.startswith(b'HTTP/'): + raise BadServerResponse('Server responded with plaintext HTTP, not TLS', data) + record_type = RecordType(read_next(1)) legacy_record_version = read_next(2) record_length = _bytes_to_int(read_next(2)) - record_end = current_position() + record_length + record_end = start + record_length if record_type == RecordType.ALERT: # Server responded with an error. alert_level = AlertLevel(read_next(1)) @@ -78,11 +76,11 @@ def parse_server_hello(packets: Iterable[bytes]) -> ServerHello: cipher_suite = CipherSuite(read_next(2)) compression_method = CompressionMethod(read_next(1)) extensions_length = _bytes_to_int(read_next(2)) - extensions_end = current_position() + extensions_length + extensions_end = start + extensions_length group = None - while current_position() < extensions_end: + while start < extensions_end: extension_type = ExtensionType(read_next(2)) extension_data_length = read_next(2) extension_data = read_next(_bytes_to_int(extension_data_length)) diff --git a/src/hello_tls/scan.py b/src/hello_tls/scan.py index cd8ea77..6e5f67d 100644 --- a/src/hello_tls/scan.py +++ b/src/hello_tls/scan.py @@ -101,7 +101,11 @@ def packet_stream() -> Iterator[bytes]: raise EmptyServerResponse() else: break - server_hello = parse_server_hello(packet_stream()) + + try: + server_hello = parse_server_hello(packet_stream()) + except ValueError as e: + raise BadServerResponse('Error parsing server response') from e if server_hello.version not in client_hello.protocols: # Server picked a protocol we didn't ask for.