Skip to content

Commit

Permalink
Improve chosing of TLS signature and curve in Automatons (#4449)
Browse files Browse the repository at this point in the history
  • Loading branch information
gpotter2 committed Jul 12, 2024
1 parent 3333075 commit 6b26e21
Show file tree
Hide file tree
Showing 10 changed files with 297 additions and 64 deletions.
68 changes: 48 additions & 20 deletions scapy/layers/tls/automaton_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@
from scapy.packet import Raw
from scapy.compat import bytes_encode

# Typing imports
from typing import (
Optional,
)


class TLSClientAutomaton(_TLSAutomaton):
"""
Expand All @@ -95,12 +100,16 @@ class TLSClientAutomaton(_TLSAutomaton):
:param mycert:
:param mykey: may be provided as filenames. They will be used in the (or post)
handshake, should the server ask for client authentication.
:param client_hello: may hold a TLSClientHello or SSLv2ClientHello to be
sent to the server. This is particularly useful for extensions
tweaking. If not set, a default is populated accordingly.
:param client_hello: may hold a TLSClientHello, TLS13ClientHello or
SSLv2ClientHello to be sent to the server. This is particularly useful
for extensions tweaking. If not set, a default is populated accordingly.
:param version: is a quicker way to advertise a protocol version ("sslv2",
"tls1", "tls12", etc.) It may be overridden by the previous
"tls1", "tls12", "tls13", etc.) It may be overridden by the previous
'client_hello'.
:param session_ticket_file_in: path to a file that contains a session ticket
acquired in a previous session.
:param session_ticket_file_out: path to store any session ticket acquired during
this session.
:param data: is a list of raw data to be sent to the server once the
handshake has been completed. Both 'stop_server' and 'quit' will
work this way.
Expand All @@ -114,9 +123,10 @@ def parse_args(self, server="127.0.0.1", dport=4433, server_name=None,
session_ticket_file_out=None,
psk=None, psk_mode=None,
data=None,
ciphersuite=None,
curve=None,
ciphersuite: Optional[int] = None,
curve: Optional[str] = None,
supported_groups=None,
supported_signature_algorithms=None,
**kargs):

super(TLSClientAutomaton, self).parse_args(mycert=mycert,
Expand Down Expand Up @@ -157,16 +167,29 @@ def parse_args(self, server="127.0.0.1", dport=4433, server_name=None,
if supported_groups is None:
supported_groups = ["secp256r1", "secp384r1", "x448"]
if conf.crypto_valid_advanced:
supported_groups.append("x25519")
supported_groups.extend([
"x25519",
"ffdhe2048",
])
self.supported_groups = supported_groups

if supported_signature_algorithms is None:
supported_signature_algorithms = [
"sha256+rsa",
]
supported_signature_algorithms.insert(0, "sha256+rsaepss")
self.supported_signature_algorithms = supported_signature_algorithms

self.curve = None
self.ciphersuite = None

if ciphersuite is not None:
if ciphersuite in _tls_cipher_suites.keys():
self.ciphersuite = ciphersuite
else:
self.vprint("Unrecognized cipher suite.")

if self.advertised_tls_version == 0x0304:
self.ciphersuite = 0x1301
if ciphersuite is not None:
cs = int(ciphersuite, 16)
if cs in _tls_cipher_suites.keys():
self.ciphersuite = cs
if conf.crypto_valid_advanced:
# Default to x25519 if supported
self.curve = 29
Expand All @@ -192,14 +215,16 @@ def vprint_sessioninfo(self):
if self.verbose:
s = self.cur_session
v = _tls_version[s.tls_version]
self.vprint("Version : %s" % v)
self.vprint("Version : %s" % v)
cs = s.wcs.ciphersuite.name
self.vprint("Cipher suite : %s" % cs)
self.vprint("Cipher suite : %s" % cs)
kx_groupname = s.kx_group
self.vprint("Server temp key : %s" % kx_groupname)
if s.tls_version >= 0x0304:
ms = s.tls13_master_secret
else:
ms = s.master_secret
self.vprint("Master secret : %s" % repr_hex(ms))
self.vprint("Master secret : %s" % repr_hex(ms))
if s.server_certs:
self.vprint("Server certificate chain: %r" % s.server_certs)
if s.tls_version >= 0x0304:
Expand Down Expand Up @@ -306,11 +331,13 @@ def should_add_ClientHello(self):
if self.client_hello:
p = self.client_hello
else:
p = TLSClientHello()
p = TLSClientHello(ciphers=self.ciphersuite)
ext = []
# Add TLS_Ext_SignatureAlgorithms for TLS 1.2 ClientHello
if self.cur_session.advertised_tls_version == 0x0303:
ext += [TLS_Ext_SignatureAlgorithms(sig_algs=["sha256+rsa"])]
ext += [TLS_Ext_SignatureAlgorithms(
sig_algs=self.supported_signature_algorithms,
)]
# Add TLS_Ext_ServerName
if self.server_name:
ext += TLS_Ext_ServerName(
Expand Down Expand Up @@ -1147,8 +1174,9 @@ def tls13_should_add_ClientHello(self):
ext += TLS_Ext_KeyShare_CH(
client_shares=[KeyShareEntry(group=self.curve)]
)
ext += TLS_Ext_SignatureAlgorithms(sig_algs=["sha256+rsaepss",
"sha256+rsa"])
ext += TLS_Ext_SignatureAlgorithms(
sig_algs=self.supported_signature_algorithms,
)
p.ext = ext
self.add_msg(p)
raise self.TLS13_ADDED_CLIENTHELLO()
Expand Down Expand Up @@ -1215,7 +1243,7 @@ def TLS13_HANDLED_ALERT_FROM_SERVERFLIGHT1(self):
self.vprint(self.cur_pkt.mysummary())
raise self.CLOSE_NOTIFY()

@ATMT.condition(TLS13_RECEIVED_SERVERFLIGHT1, prio=4)
@ATMT.condition(TLS13_RECEIVED_SERVERFLIGHT1, prio=5)
def tls13_missing_ServerHello(self):
raise self.MISSING_SERVERHELLO()

Expand Down
121 changes: 101 additions & 20 deletions scapy/layers/tls/automaton_srv.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,24 @@
from scapy.layers.tls.basefields import _tls_version
from scapy.layers.tls.session import tlsSession
from scapy.layers.tls.crypto.groups import _tls_named_groups
from scapy.layers.tls.extensions import TLS_Ext_SupportedVersion_SH, \
TLS_Ext_SupportedGroups, TLS_Ext_Cookie, \
TLS_Ext_SignatureAlgorithms, TLS_Ext_PSKKeyExchangeModes, \
TLS_Ext_EarlyDataIndicationTicket
from scapy.layers.tls.keyexchange_tls13 import TLS_Ext_KeyShare_SH, \
KeyShareEntry, TLS_Ext_KeyShare_HRR, TLS_Ext_PreSharedKey_CH, \
TLS_Ext_PreSharedKey_SH
from scapy.layers.tls.extensions import (
TLS_Ext_Cookie,
TLS_Ext_EarlyDataIndicationTicket,
TLS_Ext_PSKKeyExchangeModes,
TLS_Ext_RenegotiationInfo,
TLS_Ext_SignatureAlgorithms,
TLS_Ext_SupportedGroups,
TLS_Ext_SupportedVersion_SH,
)
from scapy.layers.tls.keyexchange import _tls_hash_sig
from scapy.layers.tls.keyexchange_tls13 import (
TLS_Ext_KeyShare_SH,
KeyShareEntry,
TLS_Ext_KeyShare_HRR,
TLS_Ext_PreSharedKey_CH,
TLS_Ext_PreSharedKey_SH,
get_usable_tls13_sigalgs,
)
from scapy.layers.tls.handshake import TLSCertificate, TLSCertificateRequest, \
TLSCertificateVerify, TLSClientHello, TLSClientKeyExchange, TLSFinished, \
TLSServerHello, TLSServerHelloDone, TLSServerKeyExchange, \
Expand All @@ -55,8 +66,17 @@
TLSApplicationData
from scapy.layers.tls.record_tls13 import TLS13
from scapy.layers.tls.crypto.hkdf import TLS13_HKDF
from scapy.layers.tls.crypto.suites import _tls_cipher_suites_cls, \
get_usable_ciphersuites
from scapy.layers.tls.crypto.suites import (
_tls_cipher_suites_cls,
_tls_cipher_suites,
get_usable_ciphersuites,
)

# Typing imports
from typing import (
Optional,
Union,
)

if conf.crypto_valid:
from cryptography.hazmat.backends import default_backend
Expand Down Expand Up @@ -89,7 +109,8 @@ class TLSServerAutomaton(_TLSAutomaton):

def parse_args(self, server="127.0.0.1", sport=4433,
mycert=None, mykey=None,
preferred_ciphersuite=None,
preferred_ciphersuite: Optional[int] = None,
preferred_signature_algorithm: Union[str, int, None] = None,
client_auth=False,
is_echo_server=True,
max_client_idle_time=60,
Expand Down Expand Up @@ -120,36 +141,65 @@ def parse_args(self, server="127.0.0.1", sport=4433,
self.remote_ip = None
self.remote_port = None

self.preferred_ciphersuite = preferred_ciphersuite
self.client_auth = client_auth
self.is_echo_server = is_echo_server
self.max_client_idle_time = max_client_idle_time
self.curve = None
self.preferred_ciphersuite = None
self.preferred_signature_algorithm = None
self.cookie = cookie
self.psk_secret = psk
self.psk_mode = psk_mode

if handle_session_ticket is None:
handle_session_ticket = session_ticket_file is not None
if handle_session_ticket:
session_ticket_file = session_ticket_file or get_temp_file()
self.handle_session_ticket = handle_session_ticket
self.session_ticket_file = session_ticket_file
for (group_id, ng) in _tls_named_groups.items():
if ng == curve:
self.curve = group_id

if preferred_ciphersuite is not None:
if preferred_ciphersuite in _tls_cipher_suites:
self.preferred_ciphersuite = preferred_ciphersuite
else:
self.vprint("Unrecognized cipher suite.")

if preferred_signature_algorithm is not None:
if preferred_signature_algorithm in _tls_hash_sig:
self.preferred_signature_algorithm = preferred_signature_algorithm
else:
for (sig_id, nc) in _tls_hash_sig.items():
if nc == preferred_signature_algorithm:
self.preferred_signature_algorithm = sig_id
break
else:
self.vprint("Unrecognized signature algorithm.")

if curve:
for (group_id, ng) in _tls_named_groups.items():
if ng == curve:
self.curve = group_id
break
else:
self.vprint("Unrecognized curve.")

def vprint_sessioninfo(self):
if self.verbose:
s = self.cur_session
v = _tls_version[s.tls_version]
self.vprint("Version : %s" % v)
self.vprint("Version : %s" % v)
cs = s.wcs.ciphersuite.name
self.vprint("Cipher suite : %s" % cs)
self.vprint("Cipher suite : %s" % cs)
kx_groupname = s.kx_group
self.vprint("Server temp key : %s" % kx_groupname)
if s.tls_version >= 0x0304:
sigalg = _tls_hash_sig[s.selected_sig_alg]
self.vprint("Negotiated sig_alg : %s" % sigalg)
if s.tls_version < 0x0304:
ms = s.master_secret
else:
ms = s.tls13_master_secret
self.vprint("Master secret : %s" % repr_hex(ms))
self.vprint("Master secret : %s" % repr_hex(ms))
if s.client_certs:
self.vprint("Client certificate chain: %r" % s.client_certs)

Expand Down Expand Up @@ -273,6 +323,13 @@ def should_handle_ClientHello(self):
self.raise_on_packet(TLSClientHello,
self.HANDLED_CLIENTHELLO)

@ATMT.condition(RECEIVED_CLIENTFLIGHT1, prio=3)
def tls13_should_handle_ChangeCipherSpec_after_tls13_retry(self):
# Middlebox compatibility mode after a HelloRetryRequest.
if self.cur_session.tls13_retry:
self.raise_on_packet(TLSChangeCipherSpec,
self.RECEIVED_CLIENTFLIGHT1)

@ATMT.state()
def HANDLED_CLIENTHELLO(self):
"""
Expand Down Expand Up @@ -309,8 +366,6 @@ def should_add_ServerHello(self):
"""
Selecting a cipher suite should be no trouble as we already caught
the None case previously.
Also, we do not manage extensions at all.
"""
if isinstance(self.mykey, PrivKeyRSA):
kx = "RSA"
Expand All @@ -320,7 +375,11 @@ def should_add_ServerHello(self):
c = usable_suites[0]
if self.preferred_ciphersuite in usable_suites:
c = self.preferred_ciphersuite
self.add_msg(TLSServerHello(cipher=c))

# Some extensions
ext = [TLS_Ext_RenegotiationInfo()]

self.add_msg(TLSServerHello(cipher=c, ext=ext))
raise self.ADDED_SERVERHELLO()

@ATMT.state()
Expand Down Expand Up @@ -568,6 +627,12 @@ def tls13_HANDLED_CLIENTHELLO(self):
if self.curve in e.groups:
# Here, we need to send an HelloRetryRequest
raise self.tls13_PREPARE_HELLORETRYREQUEST()

# Signature Algorithms extension is mandatory
if not s.advertised_sig_algs:
self.vprint("Missing signature_algorithms extension in ClientHello!")
raise self.CLOSE_NOTIFY()

raise self.tls13_PREPARE_SERVERFLIGHT1()

@ATMT.state()
Expand Down Expand Up @@ -818,6 +883,22 @@ def tls13_ADDED_CERTIFICATE(self):
@ATMT.condition(tls13_ADDED_CERTIFICATE)
def tls13_should_add_CertificateVerifiy(self):
if not self.cur_session.tls13_psk_secret:
# If we have a preferred signature algorithm, and the client supports
# it, use that.
if self.cur_session.advertised_sig_algs:
usable_sigalgs = get_usable_tls13_sigalgs(
self.cur_session.advertised_sig_algs,
self.mykey,
location="certificateverify",
)
if not usable_sigalgs:
self.vprint("No usable signature algorithm!")
raise self.CLOSE_NOTIFY()
pref_alg = self.preferred_signature_algorithm
if pref_alg in usable_sigalgs:
self.cur_session.selected_sig_alg = pref_alg
else:
self.cur_session.selected_sig_alg = usable_sigalgs[0]
self.add_msg(TLSCertificateVerify())
raise self.tls13_ADDED_CERTIFICATEVERIFY()

Expand Down
7 changes: 7 additions & 0 deletions scapy/layers/tls/handshake.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,11 @@ def dispatch_hook(cls, _pkt=None, *args, **kargs):
return TLS13ServerHello
return TLSServerHello

def build(self, *args, **kargs):
if self.getfieldval("sid") == b"" and self.tls_session:
self.sid = self.tls_session.sid
return super(TLSServerHello, self).build(*args, **kargs)

def post_build(self, p, pay):
if self.random_bytes is None:
p = p[:10] + randstring(28) + p[10 + 28:]
Expand Down Expand Up @@ -707,6 +712,8 @@ def build(self):
fval = self.getfieldval("random_bytes")
if fval is None:
self.random_bytes = _tls_hello_retry_magic
if self.getfieldval("sid") == b"" and self.tls_session:
self.sid = self.tls_session.sid
return _TLSHandshake.build(self)

def tls_session_update(self, msg_str):
Expand Down
2 changes: 1 addition & 1 deletion scapy/layers/tls/handshake_sslv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ class SSLv2ServerFinished(_SSLv2Handshake):

def build(self, *args, **kargs):
fval = self.getfieldval("sid")
if fval == b"":
if fval == b"" and self.tls_session:
self.sid = self.tls_session.sid
return super(SSLv2ServerFinished, self).build(*args, **kargs)

Expand Down
Loading

0 comments on commit 6b26e21

Please sign in to comment.