From 4160c53bc5be9bfa52c8eb24f77ce0156bfe68ca Mon Sep 17 00:00:00 2001 From: BoppreH Date: Mon, 23 Oct 2023 21:33:58 +0200 Subject: [PATCH] Handle unsorted nature of enumerated cipher suites --- hello_tls.py | 39 +++++++++++++++++++++++++++------------ 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/hello_tls.py b/hello_tls.py index 8c17613..b027d50 100644 --- a/hello_tls.py +++ b/hello_tls.py @@ -1,4 +1,5 @@ from multiprocessing.pool import ThreadPool +from functools import total_ordering from datetime import datetime from typing import Sequence from enum import Enum @@ -16,7 +17,13 @@ # Maximum number of cipher suite groups to divide when enumerating. MAX_WORKERS_PER_PROTOCOL: int = 3 +@total_ordering class Protocol(Enum): + def __lt__(self, other): + if self.__class__ != other.__class__: + return NotImplemented + return self.value < other.value + # Keep protocols in order of preference. TLS1_3 = b"\x03\x04" TLS1_2 = b"\x03\x03" @@ -49,7 +56,13 @@ class HandshakeType(Enum): key_update = 24 message_hash = 25 +@total_ordering class CipherSuite(Enum): + def __lt__(self, other): + if self.__class__ != other.__class__: + return NotImplemented + return self.value < other.value + # For compability. TLS_EMPTY_RENEGOTIATION_INFO_SCSV = b"\x00\xff" @@ -271,19 +284,19 @@ def _prefix_length(b: bytes, width_bytes: int = 2) -> bytes: """ Returns `b` prefixed with its length, encoded as a big-endian integer of `width_bytes` bytes. """ return len(b).to_bytes(width_bytes, byteorder="big") + b - protocol_values = [protocol.value for protocol in hello_prefs.protocols] + protocol_values = [protocol for protocol in hello_prefs.protocols] max_protocol = max(protocol_values) # Record and Hanshake versions have a maximum value due to ossification. - legacy_handshake_version = min(Protocol.TLS1_2.value, max_protocol) - legacy_record_version = min(Protocol.TLS1_0.value, max_protocol) + legacy_handshake_version = min(Protocol.TLS1_2, max_protocol) + legacy_record_version = min(Protocol.TLS1_0, max_protocol) return bytes(( 0x16, # Record type: handshake. - *legacy_record_version, # Legacy record version: max TLS 1.0. + *legacy_record_version.value, # Legacy record version: max TLS 1.0. *_prefix_length(bytes([ # Handshake record. 0x01, # Handshake type: Client Hello. *_prefix_length(width_bytes=3, b=bytes([ # Client hello handshake. - *legacy_handshake_version, # Legacy client version: max TLS 1.2. + *legacy_handshake_version.value, # Legacy client version: max TLS 1.2. *32*[0x07], # Random. Any value will do. 32, # Legacy session ID length. *32*[0x07], # Legacy session ID. Any value will do. @@ -426,7 +439,7 @@ def get_server_preferred_cipher_suite(hello_prefs: TlsHelloSettings) -> CipherSu return server_hello.cipher_suite -def enumerate_server_cipher_suites(hello_prefs: TlsHelloSettings) -> Sequence[CipherSuite]: +def enumerate_server_cipher_suites(hello_prefs: TlsHelloSettings) -> set[CipherSuite]: """ Given a list of cipher suites to test, sends a sequence of Client Hello packets to the server, removing the accepted cipher suite from the list each time. @@ -434,12 +447,12 @@ def enumerate_server_cipher_suites(hello_prefs: TlsHelloSettings) -> Sequence[Ci """ logger.info(f"Testing support of {len(hello_prefs.cipher_suites)} cipher suites with protocols {hello_prefs.protocols}") cipher_suites_to_test = list(hello_prefs.cipher_suites) - accepted_cipher_suites = [] + accepted_cipher_suites = set() while cipher_suites_to_test: hello_prefs = dataclasses.replace(hello_prefs, cipher_suites=cipher_suites_to_test) cipher_suite_picked = get_server_preferred_cipher_suite(hello_prefs) if cipher_suite_picked: - accepted_cipher_suites.append(cipher_suite_picked) + accepted_cipher_suites.add(cipher_suite_picked) cipher_suites_to_test.remove(cipher_suite_picked) else: break @@ -538,7 +551,7 @@ def _x509_time_to_datetime(x509_time: bytes | None) -> datetime: class ServerScanResult: host: str port: int - cipher_suites_per_protocol: dict[str, list[CipherSuite]] + cipher_suites_per_protocol: dict[str, set[CipherSuite]] certificate_chain: list[Certificate] | None def scan_server( @@ -581,7 +594,7 @@ def scan_server( if enumerate_cipher_suites: # Add an intermediary name to appease the type checker. - result.cipher_suites_per_protocol = {p.name: [] for p in protocols} + result.cipher_suites_per_protocol = {p.name: set() for p in protocols} def start_enumeration(protocol: Protocol): """ Checks if the server supports this protocol, and if so, start enumerating cipher suites. """ @@ -593,7 +606,7 @@ def start_enumeration(protocol: Protocol): logger.info(f"Server does not support {protocol}") return # Register the cipher suite we found. - accepted_cipher_suites = [first_cipher_suite] + accepted_cipher_suites = {first_cipher_suite} result.cipher_suites_per_protocol[protocol.name] = accepted_cipher_suites # Divide remaining cipher suites in groups and enumerate them in parallel. # Use % to distribute "desirable" cipher suites evenly. @@ -602,7 +615,7 @@ def start_enumeration(protocol: Protocol): logger.debug(f"Starting enumeration of cipher suites for {protocol}") for cipher_suite_group in groups: prefs = dataclasses.replace(hello_prefs, protocols=[protocol], cipher_suites=cipher_suite_group) - add_task(lambda prefs=prefs: accepted_cipher_suites.extend(enumerate_server_cipher_suites(prefs))) + add_task(lambda prefs=prefs: accepted_cipher_suites.update(enumerate_server_cipher_suites(prefs))) for protocol in protocols: add_task(start_enumeration, (protocol,)) @@ -675,6 +688,8 @@ class EnhancedJSONEncoder(json.JSONEncoder): def default(self, o): if dataclasses.is_dataclass(o): return dataclasses.asdict(o) + if isinstance(o, set): + return sorted(o) elif isinstance(o, Enum): return o.name elif isinstance(o, datetime):