Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework Auth Plugins to Support HTTP Auth #194

Merged
merged 1 commit into from
Aug 28, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions tests/test_websocketproxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,11 @@ class TestPlugin(token_plugins.BasePlugin):
def lookup(self, token):
return (self.source + token).split(',')

self.stubs.Set(websocketproxy.ProxyRequestHandler, 'do_proxy',
lambda *args, **kwargs: None)
self.stubs.Set(websocketproxy.ProxyRequestHandler, 'send_auth_error',
staticmethod(lambda *args, **kwargs: None))

self.handler.server.token_plugin = TestPlugin("somehost,")
self.handler.new_websocket_client()
self.handler.validate_connection()

self.assertEqual(self.handler.server.target_host, "somehost")
self.assertEqual(self.handler.server.target_port, "blah")
Expand All @@ -119,18 +119,18 @@ def test_auth_plugin(self):
class TestPlugin(auth_plugins.BasePlugin):
def authenticate(self, headers, target_host, target_port):
if target_host == self.source:
raise auth_plugins.AuthenticationError("some error")
raise auth_plugins.AuthenticationError(response_msg="some_error")

self.stubs.Set(websocketproxy.ProxyRequestHandler, 'do_proxy',
self.stubs.Set(websocketproxy.ProxyRequestHandler, 'send_auth_error',
staticmethod(lambda *args, **kwargs: None))

self.handler.server.auth_plugin = TestPlugin("somehost")
self.handler.server.target_host = "somehost"
self.handler.server.target_port = "someport"

self.assertRaises(auth_plugins.AuthenticationError,
self.handler.new_websocket_client)
self.handler.validate_connection)

self.handler.server.target_host = "someotherhost"
self.handler.new_websocket_client()
self.handler.validate_connection()

50 changes: 47 additions & 3 deletions websockify/auth_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,15 @@ def authenticate(self, headers, target_host, target_port):


class AuthenticationError(Exception):
pass
def __init__(self, log_msg=None, response_code=403, response_headers={}, response_msg=None):
self.code = response_code
self.headers = response_headers
self.msg = response_msg

if log_msg is None:
log_msg = response_msg

super(AuthenticationError, self).__init__('%s %s' % (self.code, log_msg))


class InvalidOriginError(AuthenticationError):
Expand All @@ -16,8 +24,44 @@ def __init__(self, expected, actual):
self.actual_origin = actual

super(InvalidOriginError, self).__init__(
"Invalid Origin Header: Expected one of "
"%s, got '%s'" % (expected, actual))
response_msg='Invalid Origin',
log_msg="Invalid Origin Header: Expected one of "
"%s, got '%s'" % (expected, actual))


class BasicHTTPAuth(object):
def __init__(self, src=None):
self.src = src

def authenticate(self, headers, target_host, target_port):
import base64

auth_header = headers.get('Authorization')
if auth_header:
if not auth_header.startswith('Basic '):
raise AuthenticationError(response_code=403)

try:
user_pass_raw = base64.b64decode(auth_header[6:])
except TypeError:
raise AuthenticationError(response_code=403)

user_pass = user_pass_raw.split(':', 1)
if len(user_pass) != 2:
raise AuthenticationError(response_code=403)

if not self.validate_creds:
raise AuthenticationError(response_code=403)

else:
raise AuthenticationError(response_code=401,
response_headers={'WWW-Authenticate': 'Basic realm="Websockify"'})

def validate_creds(username, password):
if '%s:%s' % (username, password) == self.src:
return True
else:
return False

class ExpectOrigin(object):
def __init__(self, src=None):
Expand Down
12 changes: 10 additions & 2 deletions websockify/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,9 +474,13 @@ def handle_websocket(self):
"""Upgrade a connection to Websocket, if requested. If this succeeds,
new_websocket_client() will be called. Otherwise, False is returned.
"""

if (self.headers.get('upgrade') and
self.headers.get('upgrade').lower() == 'websocket'):

# ensure connection is authorized, and determine the target
self.validate_connection()

if not self.do_websocket_handshake():
return False

Expand Down Expand Up @@ -549,6 +553,10 @@ def new_websocket_client(self):
""" Do something with a WebSockets client connection. """
raise Exception("WebSocketRequestHandler.new_websocket_client() must be overloaded")

def validate_connection(self):
""" Ensure that the connection is a valid connection, and set the target. """
pass

def do_HEAD(self):
if self.only_upgrade:
self.send_error(405, "Method Not Allowed")
Expand Down Expand Up @@ -789,7 +797,7 @@ def do_handshake(self, sock, address):
"""
ready = select.select([sock], [], [], 3)[0]


if not ready:
raise self.EClose("ignoring socket not ready")
# Peek, but do not read the data so that we have a opportunity
Expand Down Expand Up @@ -903,7 +911,7 @@ def do_SIGTERM(self, sig, stack):

def top_new_client(self, startsock, address):
""" Do something with a WebSockets client connection. """
# handler process
# handler process
client = None
try:
try:
Expand Down
33 changes: 24 additions & 9 deletions websockify/websocketproxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
except: from BaseHTTPServer import HTTPServer
import select
from websockify import websocket
from websockify import auth_plugins as auth
try:
from urllib.parse import parse_qs, urlparse
except:
Expand All @@ -37,20 +38,34 @@ class ProxyRequestHandler(websocket.WebSocketRequestHandler):
< - Client send
<. - Client send partial
"""

def send_auth_error(self, ex):
self.send_response(ex.code, ex.msg)
self.send_header('Content-Type', 'text/html')
for name, val in ex.headers.items():
self.send_header(name, val)

self.end_headers()

def validate_connection(self):
if self.server.token_plugin:
(self.server.target_host, self.server.target_port) = self.get_target(self.server.token_plugin, self.path)

if self.server.auth_plugin:
try:
self.server.auth_plugin.authenticate(
headers=self.headers, target_host=self.server.target_host,
target_port=self.server.target_port)
except auth.AuthenticationError:
ex = sys.exc_info()[1]
self.send_auth_error(ex)
raise

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like auth should be before token plugin. Reduces the ability for non-authorized connections to do a denial of service to the system by inducing disk or DB activity due to token authorization.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is that currently, the authenticate method receives the decode host and port (that way you could say "person X is only authorized to connect to host/port Y"), so changing it would break backwards compatibility. Hmm...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you're right. I suppose if it becomes an issue, we could always add an optional early auth call that doesn't have the target resolved yet. And really, it's probably better to do authorization on the token. But having the target is useful for authorization in many cases. And truth be told, if the token->target information is on disk or in a DB, then user auth info probably is too. Anyways, it was really just a thought that came to me, I'm fine with the change as is.

def new_websocket_client(self):
"""
Called after a new WebSocket connection has been established.
"""
# Checks if we receive a token, and look
# for a valid target for it then
if self.server.token_plugin:
(self.server.target_host, self.server.target_port) = self.get_target(self.server.token_plugin, self.path)

if self.server.auth_plugin:
self.server.auth_plugin.authenticate(
headers=self.headers, target_host=self.server.target_host,
target_port=self.server.target_port)
# Checking for a token is done in validate_connection()

# Connect to the target
if self.server.wrap_cmd:
Expand Down