Skip to content

Commit

Permalink
Rework Auth Plugins to Support HTTP Auth
Browse files Browse the repository at this point in the history
This commit reworks auth plugins slightly to enable
support for HTTP authentication.  By raising an
AuthenticationError, auth plugins can now return
HTTP responses to the upgrade request (such as 401).

Related to novnc/noVNC#522
  • Loading branch information
DirectXMan12 committed Aug 25, 2015
1 parent 69c04c8 commit c13ad47
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 21 deletions.
14 changes: 7 additions & 7 deletions tests/test_websocketproxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,11 @@ class TestPlugin(token_plugins.BasePlugin):
def lookup(self, token):
return (self.source + token).split(',')

self.stubs.Set(websocketproxy.ProxyRequestHandler, 'do_proxy',
lambda *args, **kwargs: None)
self.stubs.Set(websocketproxy.ProxyRequestHandler, 'send_auth_error',
staticmethod(lambda *args, **kwargs: None))

self.handler.server.token_plugin = TestPlugin("somehost,")
self.handler.new_websocket_client()
self.handler.validate_connection()

self.assertEqual(self.handler.server.target_host, "somehost")
self.assertEqual(self.handler.server.target_port, "blah")
Expand All @@ -119,18 +119,18 @@ def test_auth_plugin(self):
class TestPlugin(auth_plugins.BasePlugin):
def authenticate(self, headers, target_host, target_port):
if target_host == self.source:
raise auth_plugins.AuthenticationError("some error")
raise auth_plugins.AuthenticationError(response_msg="some_error")

self.stubs.Set(websocketproxy.ProxyRequestHandler, 'do_proxy',
self.stubs.Set(websocketproxy.ProxyRequestHandler, 'send_auth_error',
staticmethod(lambda *args, **kwargs: None))

self.handler.server.auth_plugin = TestPlugin("somehost")
self.handler.server.target_host = "somehost"
self.handler.server.target_port = "someport"

self.assertRaises(auth_plugins.AuthenticationError,
self.handler.new_websocket_client)
self.handler.validate_connection)

self.handler.server.target_host = "someotherhost"
self.handler.new_websocket_client()
self.handler.validate_connection()

49 changes: 46 additions & 3 deletions websockify/auth_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,15 @@ def authenticate(self, headers, target_host, target_port):


class AuthenticationError(Exception):
pass
def __init__(self, log_msg=None, response_code=403, response_headers={}, response_msg=None):
self.code = response_code
self.headers = response_headers
self.msg = response_msg

if log_msg is None:
log_msg = response_msg

super(AuthenticationError, self).__init__('%s %s' % (self.code, log_msg))


class InvalidOriginError(AuthenticationError):
Expand All @@ -16,9 +24,44 @@ def __init__(self, expected, actual):
self.actual_origin = actual

super(InvalidOriginError, self).__init__(
"Invalid Origin Header: Expected one of "
"%s, got '%s'" % (expected, actual))
response_msg='Invalid Origin',
log_msg="Invalid Origin Header: Expected one of "
"%s, got '%s'" % (expected, actual))


class BasicHTTPAuth(object):
def __init__(self, src=None):
self.src = src

def authenticate(self, headers, target_host, target_port):
import base64

auth_header = headers.getheader('Authorization')
if auth_header:
if not auth_header.startswith('Basic '):
raise AuthenticationError(response_code=403)

try:
user_pass_raw = base64.b64decode(auth_header[6:])
except TypeError:
raise AuthenticationError(response_code=403)

user_pass = user_pass_raw.split(':', 1)
if len(user_pass) != 2:
raise AuthenticationError(response_code=403)

if not self.validate_creds:
raise AuthenticationError(response_code=403)

else:
raise AuthenticationError(response_code=401,
response_headers={'WWW-Authenticate': 'Basic realm="Websockify"'})

def validate_creds(username, password):
if '%s:%s' % (username, password) == self.src:
return True
else:
return False

class ExpectOrigin(object):
def __init__(self, src=None):
Expand Down
12 changes: 10 additions & 2 deletions websockify/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,9 +474,13 @@ def handle_websocket(self):
"""Upgrade a connection to Websocket, if requested. If this succeeds,
new_websocket_client() will be called. Otherwise, False is returned.
"""

if (self.headers.get('upgrade') and
self.headers.get('upgrade').lower() == 'websocket'):

# ensure connection is authorized, and determine the target
self.validate_connection()

if not self.do_websocket_handshake():
return False

Expand Down Expand Up @@ -549,6 +553,10 @@ def new_websocket_client(self):
""" Do something with a WebSockets client connection. """
raise Exception("WebSocketRequestHandler.new_websocket_client() must be overloaded")

def validate_connection(self):
""" Ensure that the connection is a valid connection, and set the target. """
pass

def do_HEAD(self):
if self.only_upgrade:
self.send_error(405, "Method Not Allowed")
Expand Down Expand Up @@ -789,7 +797,7 @@ def do_handshake(self, sock, address):
"""
ready = select.select([sock], [], [], 3)[0]


if not ready:
raise self.EClose("ignoring socket not ready")
# Peek, but do not read the data so that we have a opportunity
Expand Down Expand Up @@ -903,7 +911,7 @@ def do_SIGTERM(self, sig, stack):

def top_new_client(self, startsock, address):
""" Do something with a WebSockets client connection. """
# handler process
# handler process
client = None
try:
try:
Expand Down
33 changes: 24 additions & 9 deletions websockify/websocketproxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
except: from BaseHTTPServer import HTTPServer
import select
from websockify import websocket
from websockify import auth_plugins as auth
try:
from urllib.parse import parse_qs, urlparse
except:
Expand All @@ -37,20 +38,34 @@ class ProxyRequestHandler(websocket.WebSocketRequestHandler):
< - Client send
<. - Client send partial
"""

def send_auth_error(self, ex):
self.send_response(ex.code, ex.msg)
self.send_header('Content-Type', 'text/html')
for name, val in ex.headers.items():
self.send_header(name, val)

self.end_headers()

def validate_connection(self):
if self.server.token_plugin:
(self.server.target_host, self.server.target_port) = self.get_target(self.server.token_plugin, self.path)

if self.server.auth_plugin:
try:
self.server.auth_plugin.authenticate(
headers=self.headers, target_host=self.server.target_host,
target_port=self.server.target_port)
except auth.AuthenticationError:
ex = sys.exc_info()[1]
self.send_auth_error(ex)
raise

def new_websocket_client(self):
"""
Called after a new WebSocket connection has been established.
"""
# Checks if we receive a token, and look
# for a valid target for it then
if self.server.token_plugin:
(self.server.target_host, self.server.target_port) = self.get_target(self.server.token_plugin, self.path)

if self.server.auth_plugin:
self.server.auth_plugin.authenticate(
headers=self.headers, target_host=self.server.target_host,
target_port=self.server.target_port)
# Checking for a token is done in validate_connection()

# Connect to the target
if self.server.wrap_cmd:
Expand Down

0 comments on commit c13ad47

Please sign in to comment.