Skip to content
This repository has been archived by the owner on Dec 6, 2023. It is now read-only.

Refactor SSL MiTM out of Connection #95

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
278 changes: 10 additions & 268 deletions nogotofail/mitm/connection/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,108 +13,16 @@
See the License for the specific language governing permissions and
limitations under the License.
'''
from OpenSSL import SSL
from OpenSSL import crypto
import errno
import logging
import os
import select
import socket
import struct
from nogotofail.mitm.util import tls, ssl2, extras
from nogotofail.mitm.util import close_quietly
from nogotofail.mitm.util.tls.types import Extension
import time
import uuid
import errno
import os

class ConnectionWrapper(object):
"""Wrapper around OpenSSL's Connection object to make it act like a real socket.
"""

def __init__(self, connection):
self._connection = connection
self.buffer = ""
self._is_short_send = False

def __getattr__(self, name):
return getattr(self._connection, name)

def recv(self, size, flags=0):
"""Wrapper around pyOpenSSL's Connection.recv
PyOpenSSL doesn't return "" on error like socket.recv does,
instead it throws a SSL.ZeroReturnError or (-1, "Unexpected EOF") erorrs.

Wrap recv so we don't have to deal with that noise.
"""
if flags & socket.MSG_PEEK == 0:
return self._recv(size)
if len(self.buffer) >= size:
return self.buffer[:size]
try:
self.buffer += self._recv(size - len(self.buffer))
except SSL.WantReadError:
pass
return self.buffer[:size]

def _recv(self, size):
if size <= len(self.buffer):
out = self.buffer[:size]
self.buffer = self.buffer[size:]
return out
buf = self.buffer
size -= len(buf)
try:
buf += self._connection.recv(size)
except SSL.SysCallError as e:
if e.args != (-1, "Unexpected EOF"):
raise e
except SSL.ZeroReturnError:
pass
except SSL.WantReadError as e:
# Rethrow the WantRead if we really have no data
if not buf:
raise e
except SSL.Error as e:
if e.args != (-1, "Unexpected EOF"):
raise e
self.buffer = ""
return buf

def send(self, string):
sent = self._connection.send(string)
# Track short send state for our awful fileno hacks
self._is_short_send = sent != len(string)
return sent

_always_read_fd = None
def always_read_fd(self):
"""Return an fd that is always ready for read when passed to select.select. See fileno for why this is needed."""
if ConnectionWrapper._always_read_fd:
return ConnectionWrapper._always_read_fd
ConnectionWrapper._always_read_fd = open("/dev/zero")
return ConnectionWrapper._always_read_fd

def fileno(self):
# _AWFUL_ HACK to support MSG_PEEK without breaking select.select.
# If we read data with a peeking recv then return a fd that is always selectable on read to make sure the connection keeps flowing.
# Note that if the conneciton is handling a short send then we're only waiting for write not read, so use the underlying connection.
# Once the backlog is sent the connection will start trying to read again and we'll return the always_read_fd.
if self.buffer and not self._is_short_send:
return self.always_read_fd().fileno()

return self._connection.fileno()

def stub_verify(conn, cert, errno, errdepth, code):
"""We don't verify the server when we attempt a MiTM.
If the client was connecting to a host with a bad cert
we still want to connect and MiTM them.

Hypothetically someone could MiTM our MiTM and intercept what we intercept,
use caution in what data you send through a MiTM'd connection if you don't trust
the rest of your path to the real endpoint.
"""
return True

from OpenSSL import SSL
from nogotofail.mitm.util import close_quietly, extras

class BaseConnection(object):
"""Handles the creation and bridging of both sides of the network connection
Expand Down Expand Up @@ -149,8 +57,6 @@ class BaseConnection(object):
_connected = False
_blame_in_progress = False

SSL_TIMEOUT = 2

def __init__(
self, server, client_socket, handler_selector,
ssl_handler_selector, data_handler_selector, app_blame):
Expand Down Expand Up @@ -263,114 +169,6 @@ def _start_server_connect_nonblocking(self):
if e.errno != errno.EINPROGRESS:
raise e

def _gen_ssl_connect_fn(self, connection, post_fn):
"""Generate a bridge_fn for doing an ssl handshake on connection.
Once the handshake is completed post_fn will be called
"""
def do_ssl_handshake():
try:
connection.do_handshake()
return post_fn()
except (SSL.WantReadError, SSL.WantWriteError):
pass
except SSL.Error as e:
self.handler.on_ssl_error(e)
return False
except socket.error as e:
return False
return True
return do_ssl_handshake

def start_ssl_mitm(self, client_hello):
"""Start the SSL MiTM.
This is non-blocking and will set the bridge_fns and select_fds as follows:
1. Start the SSL handshake with the server, ignore client data
2. On handshake completion call _on_server_ssl_established
3. Start the SSL handshake with the client, ignore server data
4. On completion call _on_client_ssl_established
5. At this point the SSL MiTM is set up and we switch back to bridging mode
"""
self.client_hello = client_hello
server_name = client_hello.extensions.get(Extension.TYPE.SERVER_NAME)
if server_name:
server_name = server_name.data
self.hostname = server_name
self._start_server_ssl_connection(server_name)

def _start_server_ssl_connection(self, servername=None):
context = SSL.Context(SSL.SSLv23_METHOD)
context.set_verify(SSL.VERIFY_NONE, stub_verify)
self.server_socket.setblocking(False)
connection = SSL.Connection(context, self.server_socket)
self.server_socket = ConnectionWrapper(connection)
if servername:
connection.set_tlsext_host_name(servername)
connection.set_connect_state()
self.server_bridge_fn = self._gen_ssl_connect_fn(connection,
self._on_server_ssl_established)
connection.set_connect_state()
# Stop selecting on the client until we are connected
self.set_select_fds(rlist=[self.server_socket])
# Start the handshake
self.server_bridge_fn()


def _start_client_ssl_connection(self):
server_cert = self.server_socket.get_peer_certificate()
handler_cert = self.handler.on_certificate(server_cert)
ciphers_list = self.handler.on_server_cipher_suites(self.client_hello)

context = SSL.Context(SSL.SSLv23_METHOD)
context.set_verify(SSL.VERIFY_NONE, stub_verify)
if ciphers_list is not None:
context.set_cipher_list(ciphers_list)
if handler_cert is not None:
context.use_certificate_chain_file(handler_cert)
context.use_privatekey_file(handler_cert)

# Required for anonymous/ephemeral DH cipher suites
params_path = extras.get_extras_path("./dhparam")
if os.path.exists(params_path):
context.load_tmp_dh(extras.get_extras_path("./dhparam"))
else:
self.logger.warning("Required file dhparam not found, anonymous/ephemeral DH cipher suites may not work")

# Required for anonymous/ephemeral ECDH cipher suites
# The API is not available in the old version of pyOpenSSL which we
# currently use. Without the code below, anonymous and ephemeral
# ECDH cipher suites will not be used.
if hasattr(context, "set_tmp_ecdh"):
curve = crypto.get_elliptic_curve("prime256v1")
context.set_tmp_ecdh(curve)

# Send our ServerHello to the Client. Note that the Client's ClientHello
# MUST be the first thing that self.client_socket.recv() returns
connection = SSL.Connection(context, self.client_socket)
connection.set_accept_state()
self.client_socket = ConnectionWrapper(connection)
self.client_bridge_fn = self._gen_ssl_connect_fn(connection,
self._on_client_ssl_established)
# Only listen for client events until the connection is established
self.set_select_fds(rlist=[self.client_socket])
# Start the handshake
self.client_bridge_fn()

def _on_server_ssl_established(self):
"""Once the server is connected begin connecting the client"""
self.server_bridge_fn = self._bridge_server
# Start Setting up the client connection
self._start_client_ssl_connection()
return True

def _on_client_ssl_established(self):
"""Once the client is connected return to bridging mode"""
self.client_bridge_fn = self._bridge_client
# Now we are ready to bridge in both directions
self.set_select_fds(rlist=[self.client_socket, self.server_socket])
self.ssl = True
self.handler.on_ssl_establish()
return True

def bridge(self, sock):
"""Handle bridging data from sock to the other party.

Expand Down Expand Up @@ -399,62 +197,6 @@ def close(self, handler_initiated=True):
for handler in self.data_handlers:
handler.on_close(handler_initiated)


def _check_for_ssl(self, client_request):
""" Check for a client_hello in client_request and handle setting up handlers and any mitm.

Returns if client_request was used(and should not be sent to the server)
"""
# check for a TLS Client Hello
record = tls.parse_tls(client_request)
client_hello = None
if record:
first = record.messages[0]
if isinstance(first, tls.types.HandshakeMessage)\
and isinstance(first.obj, tls.types.ClientHello):
client_hello = first.obj
else:
# Check for an SSLv2 Client Hello
record = ssl2.parse_ssl2(client_request)
if record and isinstance(record.message.obj, ssl2.types.ClientHello):
client_hello = record.message.obj

if not client_hello:
return False
return self._handle_hello(client_hello)

def _handle_hello(self, client_hello):
""" Handles the changing of handlers on a TLS client hello and optional mitm

Returns if a MiTM was created
"""
# Check for a server name and set our hostname
if not self.hostname:
server_name = client_hello.extensions.get(Extension.TYPE.SERVER_NAME)
if server_name:
server_name = server_name.data
self.hostname = server_name

# Swap to a new handler if needed.
handler_class = self.ssl_handler_selector(
self, client_hello, self.app_blame)
if handler_class:
handler = handler_class(self)
self.handler.on_remove()
self.handler = handler
self.handler.on_select()

# Check if we should start mitming this connection
should_mitm = self.handler.on_ssl(client_hello)
# Call all the data handler's on_ssl so they can do any analysis they
# need.
for handler in self.data_handlers:
handler.on_ssl(client_hello)
if should_mitm:
self.start_ssl_mitm(client_hello)
return True
return False

def _bridge_client(self):
try:
try:
Expand All @@ -465,12 +207,6 @@ def _bridge_client(self):
for handler in self.data_handlers:
if handler.peek_request(client_request):
return not self.closed
# Check for a TLS client hello we might need to intercept
if not self.ssl:
# If a MiTM was attempted discard client_request, we used it
# for establishing a MiTM with the client.
if self._check_for_ssl(client_request):
return not self.closed
client_request = self.client_socket.recv(65536)
except (socket.error, SSL.WantReadError):
# recv can still time out even if select returned this socket
Expand Down Expand Up @@ -661,6 +397,12 @@ def inject_response(self, response):
break
self.client_socket.sendall(response)

def replace_connection_handler(self, new_handler_class):
handler = new_handler_class(self)
self.handler.on_remove()
self.handler = handler
self.handler.on_select()

class RedirectConnection(BaseConnection):
"""Connection based on getting traffic from iptables redirect rules"""

Expand Down
1 change: 1 addition & 0 deletions nogotofail/mitm/connection/handlers/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from report import ClientReportDetection
from log import RawTrafficLogger
from mitm import SslMitmHandler
from http import *
from imap import *
from smtp import *
Expand Down
Loading