Skip to content
This repository has been archived by the owner on Jul 13, 2023. It is now read-only.

Commit

Permalink
feat: add a --proxy_protocol for the partner endpoint
Browse files Browse the repository at this point in the history
allows the endpoint to get remote client ips from the ELB in proxy mode

closes #761
  • Loading branch information
pjenvey committed Dec 16, 2016
1 parent a9cffbb commit f482e64
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 45 deletions.
15 changes: 12 additions & 3 deletions autopush/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from autobahn.twisted.resource import WebSocketResource
from autobahn.twisted.websocket import WebSocketServerFactory
from twisted.internet import reactor, task
from twisted.internet.endpoints import SSL4ServerEndpoint, TCP4ServerEndpoint
from twisted.logger import Logger
from twisted.web.server import Site

Expand Down Expand Up @@ -315,6 +316,10 @@ def _parse_endpoint(sysargs, use_files=True):
parser.add_argument('--client_certs',
help="Allowed TLS client certificates",
type=str, env_var='CLIENT_CERTS', default="{}")
parser.add_argument('--proxy_protocol',
help="Enable HAProxy Proxy Protocol handling",
action="store_true", default=False,
env_var='PROXY_PROTOCOL')

add_shared_args(parser)

Expand Down Expand Up @@ -599,14 +604,18 @@ def endpoint_main(sysargs=None, use_files=True):

# start the senderIDs refresh timer
if args.ssl_key:
context_factory = AutopushSSLContextFactory(
ssl_cf = AutopushSSLContextFactory(
args.ssl_key,
args.ssl_cert,
dh_file=args.ssl_dh_param,
require_peer_certs=settings.enable_tls_auth)
reactor.listenSSL(args.port, site, context_factory)
endpoint = SSL4ServerEndpoint(reactor, args.port, ssl_cf)
else:
reactor.listenTCP(args.port, site)
endpoint = TCP4ServerEndpoint(reactor, args.port)
if args.proxy_protocol:
from twisted.protocols.haproxy import proxyEndpoint
endpoint = proxyEndpoint(endpoint)
endpoint.listen(site)

# Start the table rotation checker/updater
l = task.LoopingCall(settings.update_rotating_tables)
Expand Down
102 changes: 60 additions & 42 deletions autopush/tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from nose.tools import eq_, ok_
from twisted.internet import reactor
from twisted.internet.defer import inlineCallbacks, returnValue, Deferred
from twisted.internet.endpoints import SSL4ServerEndpoint, TCP4ServerEndpoint
from twisted.internet.threads import deferToThread
from twisted.logger import (
globalLogPublisher,
Expand Down Expand Up @@ -58,12 +59,22 @@

@implementer(ILogObserver)
class TestingLogObserver(object):
def __init__(self, test_callback):
self.success = False
self._test_callback = test_callback
def __init__(self):
self._events = []

def __call__(self, event):
self.success |= self._test_callback(event)
self._events.append(event)

def logged(self, predicate):
"""Determine if any log events satisfy the callable"""
assert callable(predicate)
return any(predicate(e) for e in self._events)

def logged_ci(self, predicate):
"""Determine if any log client_infos satisfy the callable"""
assert callable(predicate)
return self.logged(
lambda e: 'client_info' in e and predicate(e['client_info']))


def setUp():
Expand Down Expand Up @@ -354,6 +365,9 @@ def setUp(self):
StatusResource,
)

self.logs = TestingLogObserver()
globalLogPublisher.addObserver(self.logs)

router_table = os.environ.get("ROUTER_TABLE", "router_int_test")
storage_table = os.environ.get("STORAGE_TABLE", "storage_int_test")
message_table = os.environ.get("MESSAGE_TABLE", "message_int_test")
Expand Down Expand Up @@ -421,18 +435,27 @@ def setUp(self):
mount_health_handlers(site, settings)
self._settings = settings
if is_https:
endpoint = reactor.listenSSL(self.endpoint_port, site,
self.endpoint_SSLCF())
ep = SSL4ServerEndpoint(
reactor,
self.endpoint_port,
self.endpoint_SSLCF())
else:
endpoint = reactor.listenTCP(self.endpoint_port, site)
self.website = endpoint
ep = TCP4ServerEndpoint(reactor, self.endpoint_port)
ep = self.wrap_endpoint(ep)
ep.listen(site).addCallback(self._endpoint_listening)

def _endpoint_listening(self, port):
self.website = port

def make_client_certs(self):
return None

def endpoint_SSLCF(self):
raise NotImplementedError # pragma: nocover

def wrap_endpoint(self, ep):
return ep

@inlineCallbacks
def tearDown(self):
dones = [self.websocket.stopListening(), self.website.stopListening(),
Expand All @@ -442,6 +465,7 @@ def tearDown(self):

# Dirty reactor unless we shut down the cached connections
yield self._settings.agent._pool.closeCachedConnections()
globalLogPublisher.removeObserver(self.logs)

@inlineCallbacks
def quick_register(self, use_webpush=False, sslcontext=None):
Expand Down Expand Up @@ -630,21 +654,13 @@ def test_webpush_data_delivery_to_connected_client(self):
# Invalid UTF-8 byte sequence.
data = b"\xc3\x28\xa0\xa1\xe2\x28\xa1"

def message_size_logged(event):
if 'client_info' in event:
if 'message_size' in event['client_info']:
return True
return False

obs = TestingLogObserver(message_size_logged)
globalLogPublisher.addObserver(obs)
result = yield client.send_notification(data=data)
ok_(result is not None)
eq_(result["messageType"], "notification")
eq_(result["channelID"], chan)
eq_(result["data"], "wyigoeIooQ")
ok_(obs.success, "message_size not logged")
globalLogPublisher.removeObserver(obs)
ok_(self.logs.logged_ci(lambda ci: 'message_size' in ci),
"message_size not logged")
yield self.shut_down(client)

@inlineCallbacks
Expand All @@ -661,15 +677,6 @@ def test_webpush_data_delivery_to_disconnected_client(self):
data=b"\xc3\x28\xa0\xa1\xe2\x28\xa1", result="wyigoeIooQ"),
}

def message_size_logged(event):
if 'client_info' in event:
if 'message_size' in event['client_info']:
return True
return False

obs = TestingLogObserver(message_size_logged)
globalLogPublisher.addObserver(obs)

client = Client("ws://localhost:9010/", use_webpush=True)
yield client.connect()
yield client.hello()
Expand All @@ -695,8 +702,8 @@ def message_size_logged(event):
ok_("encoding" in headers)
yield client.ack(chan, result["version"])

ok_(obs.success, "message_size not logged")
globalLogPublisher.removeObserver(obs)
ok_(self.logs.logged_ci(lambda ci: 'message_size' in ci),
"message_size not logged")
yield self.shut_down(client)

@inlineCallbacks
Expand Down Expand Up @@ -838,26 +845,15 @@ def test_topic_no_delivery_on_reconnect(self):

@inlineCallbacks
def test_basic_delivery_with_vapid(self):

def message_size_logged(event):
if 'client_info' in event:
if 'router_key' in event['client_info']:
return True
return False

obs = TestingLogObserver(message_size_logged)
globalLogPublisher.addObserver(obs)

data = str(uuid.uuid4())
client = yield self.quick_register(use_webpush=True)
vapid_info = _get_vapid()
result = yield client.send_notification(data=data, vapid=vapid_info)
eq_(result["headers"]["encryption"], client._crypto_key)
eq_(result["data"], base64url_encode(data))
eq_(result["messageType"], "notification")
ok_(obs.success, "message_size not logged")
globalLogPublisher.removeObserver(obs)

ok_(self.logs.logged_ci(lambda ci: 'router_key' in ci),
"router_key not logged")
yield self.shut_down(client)

@inlineCallbacks
Expand Down Expand Up @@ -1938,6 +1934,28 @@ def test_registration(self):
eq_(ca_data['body'], base64url_encode(data))


class TestProxyProtocol(IntegrationBase):

def wrap_endpoint(self, ep):
from twisted.protocols.haproxy import proxyEndpoint
return proxyEndpoint(ep)

@inlineCallbacks
def test_proxy_protocol(self):
ip = '198.51.100.22'
proto_line = 'PROXY TCP4 {} 203.0.113.7 35646 80\r\n'.format(ip)
# the proxy proto. line comes before the request: we can sneak
# it in before the verb
response, body = yield _agent(
'{}GET'.format(proto_line),
"http://localhost:9020/v1/err",
)
eq_(response.code, 418)
payload = json.loads(body)
eq_(payload['error'], "Test Error")
ok_(self.logs.logged_ci(lambda ci: ci.get('remote_ip') == ip))


@inlineCallbacks
def _agent(method, url, contextFactory=None, headers=None, body=None):
kwargs = {}
Expand Down
5 changes: 5 additions & 0 deletions autopush/tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,11 @@ def test_client_certs(self):
], False)
ok_(not returncode)

def test_proxy_protocol(self):
endpoint_main([
"--proxy_protocol",
], False)

@patch('hyper.tls', spec=hyper.tls)
def test_client_certs_parse(self, mock):
ap = make_settings(self.TestArg)
Expand Down
3 changes: 3 additions & 0 deletions configs/autopush_endpoint.ini.sample
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,6 @@ port = 8082
; e.g.:
; {"client1": ["2C:78:31.."], "client2": ["3F:D0:E0..", "E2:19:B1.."]}
#client_certs =

; Enable HAProxy Proxy Protocol handling
#proxy_protocol

0 comments on commit f482e64

Please sign in to comment.