From 1e2b5c2256d31e34083935f8adb2c8433cd40f7f Mon Sep 17 00:00:00 2001 From: Solly Ross Date: Tue, 25 Aug 2015 16:44:24 -0400 Subject: [PATCH] Rework Auth Plugins to Support HTTP Auth This commit reworks auth plugins slightly to enable support for HTTP authentication. By raising an AuthenticationError, auth plugins can now return HTTP responses to the upgrade request (such as 401). Related to kanaka/noVNC#522 --- tests/test_websocketproxy.py | 14 +++++----- websockify/auth_plugins.py | 50 +++++++++++++++++++++++++++++++++--- websockify/websocket.py | 12 +++++++-- websockify/websocketproxy.py | 33 +++++++++++++++++------- 4 files changed, 88 insertions(+), 21 deletions(-) diff --git a/tests/test_websocketproxy.py b/tests/test_websocketproxy.py index 8103ef6b..92fd5dbe 100644 --- a/tests/test_websocketproxy.py +++ b/tests/test_websocketproxy.py @@ -106,11 +106,11 @@ class TestPlugin(token_plugins.BasePlugin): def lookup(self, token): return (self.source + token).split(',') - self.stubs.Set(websocketproxy.ProxyRequestHandler, 'do_proxy', - lambda *args, **kwargs: None) + self.stubs.Set(websocketproxy.ProxyRequestHandler, 'send_auth_error', + staticmethod(lambda *args, **kwargs: None)) self.handler.server.token_plugin = TestPlugin("somehost,") - self.handler.new_websocket_client() + self.handler.validate_connection() self.assertEqual(self.handler.server.target_host, "somehost") self.assertEqual(self.handler.server.target_port, "blah") @@ -119,9 +119,9 @@ def test_auth_plugin(self): class TestPlugin(auth_plugins.BasePlugin): def authenticate(self, headers, target_host, target_port): if target_host == self.source: - raise auth_plugins.AuthenticationError("some error") + raise auth_plugins.AuthenticationError(response_msg="some_error") - self.stubs.Set(websocketproxy.ProxyRequestHandler, 'do_proxy', + self.stubs.Set(websocketproxy.ProxyRequestHandler, 'send_auth_error', staticmethod(lambda *args, **kwargs: None)) self.handler.server.auth_plugin = TestPlugin("somehost") @@ -129,8 +129,8 @@ def authenticate(self, headers, target_host, target_port): self.handler.server.target_port = "someport" self.assertRaises(auth_plugins.AuthenticationError, - self.handler.new_websocket_client) + self.handler.validate_connection) self.handler.server.target_host = "someotherhost" - self.handler.new_websocket_client() + self.handler.validate_connection() diff --git a/websockify/auth_plugins.py b/websockify/auth_plugins.py index 647c26e6..924d5de2 100644 --- a/websockify/auth_plugins.py +++ b/websockify/auth_plugins.py @@ -7,7 +7,15 @@ def authenticate(self, headers, target_host, target_port): class AuthenticationError(Exception): - pass + def __init__(self, log_msg=None, response_code=403, response_headers={}, response_msg=None): + self.code = response_code + self.headers = response_headers + self.msg = response_msg + + if log_msg is None: + log_msg = response_msg + + super(AuthenticationError, self).__init__('%s %s' % (self.code, log_msg)) class InvalidOriginError(AuthenticationError): @@ -16,8 +24,44 @@ def __init__(self, expected, actual): self.actual_origin = actual super(InvalidOriginError, self).__init__( - "Invalid Origin Header: Expected one of " - "%s, got '%s'" % (expected, actual)) + response_msg='Invalid Origin', + log_msg="Invalid Origin Header: Expected one of " + "%s, got '%s'" % (expected, actual)) + + +class BasicHTTPAuth(object): + def __init__(self, src=None): + self.src = src + + def authenticate(self, headers, target_host, target_port): + import base64 + + auth_header = headers.get('Authorization') + if auth_header: + if not auth_header.startswith('Basic '): + raise AuthenticationError(response_code=403) + + try: + user_pass_raw = base64.b64decode(auth_header[6:]) + except TypeError: + raise AuthenticationError(response_code=403) + + user_pass = user_pass_raw.split(':', 1) + if len(user_pass) != 2: + raise AuthenticationError(response_code=403) + + if not self.validate_creds: + raise AuthenticationError(response_code=403) + + else: + raise AuthenticationError(response_code=401, + response_headers={'WWW-Authenticate': 'Basic realm="Websockify"'}) + + def validate_creds(username, password): + if '%s:%s' % (username, password) == self.src: + return True + else: + return False class ExpectOrigin(object): def __init__(self, src=None): diff --git a/websockify/websocket.py b/websockify/websocket.py index 1cbf5832..7fa96515 100644 --- a/websockify/websocket.py +++ b/websockify/websocket.py @@ -474,9 +474,13 @@ def handle_websocket(self): """Upgrade a connection to Websocket, if requested. If this succeeds, new_websocket_client() will be called. Otherwise, False is returned. """ + if (self.headers.get('upgrade') and self.headers.get('upgrade').lower() == 'websocket'): + # ensure connection is authorized, and determine the target + self.validate_connection() + if not self.do_websocket_handshake(): return False @@ -549,6 +553,10 @@ def new_websocket_client(self): """ Do something with a WebSockets client connection. """ raise Exception("WebSocketRequestHandler.new_websocket_client() must be overloaded") + def validate_connection(self): + """ Ensure that the connection is a valid connection, and set the target. """ + pass + def do_HEAD(self): if self.only_upgrade: self.send_error(405, "Method Not Allowed") @@ -789,7 +797,7 @@ def do_handshake(self, sock, address): """ ready = select.select([sock], [], [], 3)[0] - + if not ready: raise self.EClose("ignoring socket not ready") # Peek, but do not read the data so that we have a opportunity @@ -903,7 +911,7 @@ def do_SIGTERM(self, sig, stack): def top_new_client(self, startsock, address): """ Do something with a WebSockets client connection. """ - # handler process + # handler process client = None try: try: diff --git a/websockify/websocketproxy.py b/websockify/websocketproxy.py index 029b6f33..46ab5459 100755 --- a/websockify/websocketproxy.py +++ b/websockify/websocketproxy.py @@ -18,6 +18,7 @@ except: from BaseHTTPServer import HTTPServer import select from websockify import websocket +from websockify import auth_plugins as auth try: from urllib.parse import parse_qs, urlparse except: @@ -37,20 +38,34 @@ class ProxyRequestHandler(websocket.WebSocketRequestHandler): < - Client send <. - Client send partial """ + + def send_auth_error(self, ex): + self.send_response(ex.code, ex.msg) + self.send_header('Content-Type', 'text/html') + for name, val in ex.headers.items(): + self.send_header(name, val) + + self.end_headers() + + def validate_connection(self): + if self.server.token_plugin: + (self.server.target_host, self.server.target_port) = self.get_target(self.server.token_plugin, self.path) + + if self.server.auth_plugin: + try: + self.server.auth_plugin.authenticate( + headers=self.headers, target_host=self.server.target_host, + target_port=self.server.target_port) + except auth.AuthenticationError: + ex = sys.exc_info()[1] + self.send_auth_error(ex) + raise def new_websocket_client(self): """ Called after a new WebSocket connection has been established. """ - # Checks if we receive a token, and look - # for a valid target for it then - if self.server.token_plugin: - (self.server.target_host, self.server.target_port) = self.get_target(self.server.token_plugin, self.path) - - if self.server.auth_plugin: - self.server.auth_plugin.authenticate( - headers=self.headers, target_host=self.server.target_host, - target_port=self.server.target_port) + # Checking for a token is done in validate_connection() # Connect to the target if self.server.wrap_cmd: