Skip to content

Commit

Permalink
Handle unsorted nature of enumerated cipher suites
Browse files Browse the repository at this point in the history
  • Loading branch information
boppreh committed Oct 23, 2023
1 parent 0eba87a commit 4160c53
Showing 1 changed file with 27 additions and 12 deletions.
39 changes: 27 additions & 12 deletions hello_tls.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from multiprocessing.pool import ThreadPool
from functools import total_ordering
from datetime import datetime
from typing import Sequence
from enum import Enum
Expand All @@ -16,7 +17,13 @@
# Maximum number of cipher suite groups to divide when enumerating.
MAX_WORKERS_PER_PROTOCOL: int = 3

@total_ordering
class Protocol(Enum):
def __lt__(self, other):
if self.__class__ != other.__class__:
return NotImplemented
return self.value < other.value

# Keep protocols in order of preference.
TLS1_3 = b"\x03\x04"
TLS1_2 = b"\x03\x03"
Expand Down Expand Up @@ -49,7 +56,13 @@ class HandshakeType(Enum):
key_update = 24
message_hash = 25

@total_ordering
class CipherSuite(Enum):
def __lt__(self, other):
if self.__class__ != other.__class__:
return NotImplemented
return self.value < other.value

# For compability.
TLS_EMPTY_RENEGOTIATION_INFO_SCSV = b"\x00\xff"

Expand Down Expand Up @@ -271,19 +284,19 @@ def _prefix_length(b: bytes, width_bytes: int = 2) -> bytes:
""" Returns `b` prefixed with its length, encoded as a big-endian integer of `width_bytes` bytes. """
return len(b).to_bytes(width_bytes, byteorder="big") + b

protocol_values = [protocol.value for protocol in hello_prefs.protocols]
protocol_values = [protocol for protocol in hello_prefs.protocols]
max_protocol = max(protocol_values)
# Record and Hanshake versions have a maximum value due to ossification.
legacy_handshake_version = min(Protocol.TLS1_2.value, max_protocol)
legacy_record_version = min(Protocol.TLS1_0.value, max_protocol)
legacy_handshake_version = min(Protocol.TLS1_2, max_protocol)
legacy_record_version = min(Protocol.TLS1_0, max_protocol)

return bytes((
0x16, # Record type: handshake.
*legacy_record_version, # Legacy record version: max TLS 1.0.
*legacy_record_version.value, # Legacy record version: max TLS 1.0.
*_prefix_length(bytes([ # Handshake record.
0x01, # Handshake type: Client Hello.
*_prefix_length(width_bytes=3, b=bytes([ # Client hello handshake.
*legacy_handshake_version, # Legacy client version: max TLS 1.2.
*legacy_handshake_version.value, # Legacy client version: max TLS 1.2.
*32*[0x07], # Random. Any value will do.
32, # Legacy session ID length.
*32*[0x07], # Legacy session ID. Any value will do.
Expand Down Expand Up @@ -426,20 +439,20 @@ def get_server_preferred_cipher_suite(hello_prefs: TlsHelloSettings) -> CipherSu

return server_hello.cipher_suite

def enumerate_server_cipher_suites(hello_prefs: TlsHelloSettings) -> Sequence[CipherSuite]:
def enumerate_server_cipher_suites(hello_prefs: TlsHelloSettings) -> set[CipherSuite]:
"""
Given a list of cipher suites to test, sends a sequence of Client Hello packets to the server,
removing the accepted cipher suite from the list each time.
Returns a list of all cipher suites the server accepted.
"""
logger.info(f"Testing support of {len(hello_prefs.cipher_suites)} cipher suites with protocols {hello_prefs.protocols}")
cipher_suites_to_test = list(hello_prefs.cipher_suites)
accepted_cipher_suites = []
accepted_cipher_suites = set()
while cipher_suites_to_test:
hello_prefs = dataclasses.replace(hello_prefs, cipher_suites=cipher_suites_to_test)
cipher_suite_picked = get_server_preferred_cipher_suite(hello_prefs)
if cipher_suite_picked:
accepted_cipher_suites.append(cipher_suite_picked)
accepted_cipher_suites.add(cipher_suite_picked)
cipher_suites_to_test.remove(cipher_suite_picked)
else:
break
Expand Down Expand Up @@ -538,7 +551,7 @@ def _x509_time_to_datetime(x509_time: bytes | None) -> datetime:
class ServerScanResult:
host: str
port: int
cipher_suites_per_protocol: dict[str, list[CipherSuite]]
cipher_suites_per_protocol: dict[str, set[CipherSuite]]
certificate_chain: list[Certificate] | None

def scan_server(
Expand Down Expand Up @@ -581,7 +594,7 @@ def scan_server(

if enumerate_cipher_suites:
# Add an intermediary name to appease the type checker.
result.cipher_suites_per_protocol = {p.name: [] for p in protocols}
result.cipher_suites_per_protocol = {p.name: set() for p in protocols}

def start_enumeration(protocol: Protocol):
""" Checks if the server supports this protocol, and if so, start enumerating cipher suites. """
Expand All @@ -593,7 +606,7 @@ def start_enumeration(protocol: Protocol):
logger.info(f"Server does not support {protocol}")
return
# Register the cipher suite we found.
accepted_cipher_suites = [first_cipher_suite]
accepted_cipher_suites = {first_cipher_suite}
result.cipher_suites_per_protocol[protocol.name] = accepted_cipher_suites
# Divide remaining cipher suites in groups and enumerate them in parallel.
# Use % to distribute "desirable" cipher suites evenly.
Expand All @@ -602,7 +615,7 @@ def start_enumeration(protocol: Protocol):
logger.debug(f"Starting enumeration of cipher suites for {protocol}")
for cipher_suite_group in groups:
prefs = dataclasses.replace(hello_prefs, protocols=[protocol], cipher_suites=cipher_suite_group)
add_task(lambda prefs=prefs: accepted_cipher_suites.extend(enumerate_server_cipher_suites(prefs)))
add_task(lambda prefs=prefs: accepted_cipher_suites.update(enumerate_server_cipher_suites(prefs)))

for protocol in protocols:
add_task(start_enumeration, (protocol,))
Expand Down Expand Up @@ -675,6 +688,8 @@ class EnhancedJSONEncoder(json.JSONEncoder):
def default(self, o):
if dataclasses.is_dataclass(o):
return dataclasses.asdict(o)
if isinstance(o, set):
return sorted(o)
elif isinstance(o, Enum):
return o.name
elif isinstance(o, datetime):
Expand Down

0 comments on commit 4160c53

Please sign in to comment.