Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Merge pull request #3701 from matrix-org/rav/use_producer_for_responses
Browse files Browse the repository at this point in the history
Use a producer to stream back responses
  • Loading branch information
richvdh authored Aug 17, 2018
2 parents 3f8709f + d82fa0e commit 6326039
Show file tree
Hide file tree
Showing 9 changed files with 63 additions and 64 deletions.
1 change: 1 addition & 0 deletions changelog.d/3701.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Avoid timing out requests while we are streaming back the response
17 changes: 13 additions & 4 deletions synapse/http/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from twisted.python import failure
from twisted.web import resource
from twisted.web.server import NOT_DONE_YET
from twisted.web.static import NoRangeStaticProducer
from twisted.web.util import redirectTo

import synapse.events
Expand All @@ -40,6 +41,11 @@
from synapse.util.caches import intern_dict
from synapse.util.logcontext import preserve_fn

if PY3:
from io import BytesIO
else:
from cStringIO import StringIO as BytesIO

logger = logging.getLogger(__name__)

HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
Expand Down Expand Up @@ -389,8 +395,7 @@ def respond_with_json(request, code, json_object, send_cors=False,
return

if pretty_print:
json_bytes = (encode_pretty_printed_json(json_object) + "\n"
).encode("utf-8")
json_bytes = encode_pretty_printed_json(json_object) + b"\n"
else:
if canonical_json or synapse.events.USE_FROZEN_DICTS:
# canonicaljson already encodes to bytes
Expand Down Expand Up @@ -426,8 +431,12 @@ def respond_with_json_bytes(request, code, json_bytes, send_cors=False,
if send_cors:
set_cors_headers(request)

request.write(json_bytes)
finish_request(request)
# todo: we can almost certainly avoid this copy and encode the json straight into
# the bytesIO, but it would involve faffing around with string->bytes wrappers.
bytes_io = BytesIO(json_bytes)

producer = NoRangeStaticProducer(request, bytes_io)
producer.start()
return NOT_DONE_YET


Expand Down
7 changes: 2 additions & 5 deletions tests/rest/client/v1/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from synapse.util import Clock

from tests import unittest
from tests.server import make_request, setup_test_homeserver
from tests.server import make_request, render, setup_test_homeserver


class CreateUserServletTestCase(unittest.TestCase):
Expand Down Expand Up @@ -77,10 +77,7 @@ def test_POST_createuser_with_valid_user(self):
)

request, channel = make_request(b"POST", url, request_data)
request.render(res)

# Advance the clock because it waits
self.clock.advance(1)
render(request, res, self.clock)

self.assertEquals(channel.result["code"], b"200")

Expand Down
8 changes: 3 additions & 5 deletions tests/rest/client/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from synapse.api.constants import Membership

from tests import unittest
from tests.server import make_request, wait_until_result
from tests.server import make_request, render


class RestTestCase(unittest.TestCase):
Expand Down Expand Up @@ -171,8 +171,7 @@ def create_room_as(self, room_creator, is_public=True, tok=None):
request, channel = make_request(
"POST", path, json.dumps(content).encode('utf8')
)
request.render(self.resource)
wait_until_result(self.hs.get_reactor(), channel)
render(request, self.resource, self.hs.get_reactor())

assert channel.result["code"] == b"200", channel.result
self.auth_user_id = temp_id
Expand Down Expand Up @@ -220,8 +219,7 @@ def change_membership(self, room, src, targ, membership, tok=None, expect_code=2

request, channel = make_request("PUT", path, json.dumps(data).encode('utf8'))

request.render(self.resource)
wait_until_result(self.hs.get_reactor(), channel)
render(request, self.resource, self.hs.get_reactor())

assert int(channel.result["code"]) == expect_code, (
"Expected: %d, got: %d, resp: %r"
Expand Down
23 changes: 8 additions & 15 deletions tests/rest/client/v2_alpha/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
from tests.server import (
ThreadedMemoryReactorClock as MemoryReactorClock,
make_request,
render,
setup_test_homeserver,
wait_until_result,
)

PATH_PREFIX = "/_matrix/client/v2_alpha"
Expand Down Expand Up @@ -76,8 +76,7 @@ def test_add_filter(self):
"/_matrix/client/r0/user/%s/filter" % (self.USER_ID),
self.EXAMPLE_FILTER_JSON,
)
request.render(self.resource)
wait_until_result(self.clock, channel)
render(request, self.resource, self.clock)

self.assertEqual(channel.result["code"], b"200")
self.assertEqual(channel.json_body, {"filter_id": "0"})
Expand All @@ -91,8 +90,7 @@ def test_add_filter_for_other_user(self):
"/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"),
self.EXAMPLE_FILTER_JSON,
)
request.render(self.resource)
wait_until_result(self.clock, channel)
render(request, self.resource, self.clock)

self.assertEqual(channel.result["code"], b"403")
self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN)
Expand All @@ -105,8 +103,7 @@ def test_add_filter_non_local_user(self):
"/_matrix/client/r0/user/%s/filter" % (self.USER_ID),
self.EXAMPLE_FILTER_JSON,
)
request.render(self.resource)
wait_until_result(self.clock, channel)
render(request, self.resource, self.clock)

self.hs.is_mine = _is_mine
self.assertEqual(channel.result["code"], b"403")
Expand All @@ -121,8 +118,7 @@ def test_get_filter(self):
request, channel = make_request(
"GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.USER_ID, filter_id)
)
request.render(self.resource)
wait_until_result(self.clock, channel)
render(request, self.resource, self.clock)

self.assertEqual(channel.result["code"], b"200")
self.assertEquals(channel.json_body, self.EXAMPLE_FILTER)
Expand All @@ -131,8 +127,7 @@ def test_get_filter_non_existant(self):
request, channel = make_request(
"GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.USER_ID)
)
request.render(self.resource)
wait_until_result(self.clock, channel)
render(request, self.resource, self.clock)

self.assertEqual(channel.result["code"], b"400")
self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND)
Expand All @@ -143,8 +138,7 @@ def test_get_filter_invalid_id(self):
request, channel = make_request(
"GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.USER_ID)
)
request.render(self.resource)
wait_until_result(self.clock, channel)
render(request, self.resource, self.clock)

self.assertEqual(channel.result["code"], b"400")

Expand All @@ -153,7 +147,6 @@ def test_get_filter_no_id(self):
request, channel = make_request(
"GET", "/_matrix/client/r0/user/%s/filter/" % (self.USER_ID)
)
request.render(self.resource)
wait_until_result(self.clock, channel)
render(request, self.resource, self.clock)

self.assertEqual(channel.result["code"], b"400")
26 changes: 9 additions & 17 deletions tests/rest/client/v2_alpha/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from synapse.util import Clock

from tests import unittest
from tests.server import make_request, setup_test_homeserver, wait_until_result
from tests.server import make_request, render, setup_test_homeserver


class RegisterRestServletTestCase(unittest.TestCase):
Expand Down Expand Up @@ -72,8 +72,7 @@ def test_POST_appservice_registration_valid(self):
request, channel = make_request(
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
)
request.render(self.resource)
wait_until_result(self.clock, channel)
render(request, self.resource, self.clock)

self.assertEquals(channel.result["code"], b"200", channel.result)
det_data = {
Expand All @@ -89,25 +88,22 @@ def test_POST_appservice_registration_invalid(self):
request, channel = make_request(
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
)
request.render(self.resource)
wait_until_result(self.clock, channel)
render(request, self.resource, self.clock)

self.assertEquals(channel.result["code"], b"401", channel.result)

def test_POST_bad_password(self):
request_data = json.dumps({"username": "kermit", "password": 666})
request, channel = make_request(b"POST", self.url, request_data)
request.render(self.resource)
wait_until_result(self.clock, channel)
render(request, self.resource, self.clock)

self.assertEquals(channel.result["code"], b"400", channel.result)
self.assertEquals(channel.json_body["error"], "Invalid password")

def test_POST_bad_username(self):
request_data = json.dumps({"username": 777, "password": "monkey"})
request, channel = make_request(b"POST", self.url, request_data)
request.render(self.resource)
wait_until_result(self.clock, channel)
render(request, self.resource, self.clock)

self.assertEquals(channel.result["code"], b"400", channel.result)
self.assertEquals(channel.json_body["error"], "Invalid username")
Expand All @@ -126,8 +122,7 @@ def test_POST_user_valid(self):
self.device_handler.check_device_registered = Mock(return_value=device_id)

request, channel = make_request(b"POST", self.url, request_data)
request.render(self.resource)
wait_until_result(self.clock, channel)
render(request, self.resource, self.clock)

det_data = {
"user_id": user_id,
Expand All @@ -149,8 +144,7 @@ def test_POST_disabled_registration(self):
self.registration_handler.register = Mock(return_value=("@user:id", "t"))

request, channel = make_request(b"POST", self.url, request_data)
request.render(self.resource)
wait_until_result(self.clock, channel)
render(request, self.resource, self.clock)

self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(channel.json_body["error"], "Registration has been disabled")
Expand All @@ -162,8 +156,7 @@ def test_POST_guest_registration(self):
self.registration_handler.register = Mock(return_value=(user_id, None))

request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}")
request.render(self.resource)
wait_until_result(self.clock, channel)
render(request, self.resource, self.clock)

det_data = {
"user_id": user_id,
Expand All @@ -177,8 +170,7 @@ def test_POST_disabled_guest_registration(self):
self.hs.config.allow_guest_access = False

request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}")
request.render(self.resource)
wait_until_result(self.clock, channel)
render(request, self.resource, self.clock)

self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(channel.json_body["error"], "Guest access is disabled")
5 changes: 2 additions & 3 deletions tests/rest/client/v2_alpha/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from tests.server import (
ThreadedMemoryReactorClock as MemoryReactorClock,
make_request,
render,
setup_test_homeserver,
wait_until_result,
)

PATH_PREFIX = "/_matrix/client/v2_alpha"
Expand Down Expand Up @@ -69,8 +69,7 @@ def get_user_by_req(request, allow_guest=False, rights="access"):

def test_sync_argless(self):
request, channel = make_request("GET", "/_matrix/client/r0/sync")
request.render(self.resource)
wait_until_result(self.clock, channel)
render(request, self.resource, self.clock)

self.assertEqual(channel.result["code"], b"200")
self.assertTrue(
Expand Down
23 changes: 19 additions & 4 deletions tests/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class FakeChannel(object):
"""

result = attr.ib(default=attr.Factory(dict))
_producer = None

@property
def json_body(self):
Expand All @@ -49,6 +50,15 @@ def write(self, content):

self.result["body"] += content

def registerProducer(self, producer, streaming):
self._producer = producer

def unregisterProducer(self):
if self._producer is None:
return

self._producer = None

def requestDone(self, _self):
self.result["done"] = True

Expand Down Expand Up @@ -111,14 +121,19 @@ def make_request(method, path, content=b""):
return req, channel


def wait_until_result(clock, channel, timeout=100):
def wait_until_result(clock, request, timeout=100):
"""
Wait until the channel has a result.
Wait until the request is finished.
"""
clock.run()
x = 0

while not channel.result:
while not request.finished:

# If there's a producer, tell it to resume producing so we get content
if request._channel._producer:
request._channel._producer.resumeProducing()

x += 1

if x > timeout:
Expand All @@ -129,7 +144,7 @@ def wait_until_result(clock, channel, timeout=100):

def render(request, resource, clock):
request.render(resource)
wait_until_result(clock, request._channel)
wait_until_result(clock, request)


class ThreadedMemoryReactorClock(MemoryReactorClock):
Expand Down
17 changes: 6 additions & 11 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from synapse.util import Clock

from tests import unittest
from tests.server import make_request, setup_test_homeserver
from tests.server import make_request, render, setup_test_homeserver


class JsonResourceTests(unittest.TestCase):
Expand Down Expand Up @@ -37,7 +37,7 @@ def _callback(request, **kwargs):
)

request, channel = make_request(b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83")
request.render(res)
render(request, res, self.reactor)

self.assertEqual(request.args, {b'a': [u"\N{SNOWMAN}".encode('utf8')]})
self.assertEqual(got_kwargs, {u"room_id": u"\N{SNOWMAN}"})
Expand All @@ -55,7 +55,7 @@ def _callback(request, **kwargs):
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)

request, channel = make_request(b"GET", b"/_matrix/foo")
request.render(res)
render(request, res, self.reactor)

self.assertEqual(channel.result["code"], b'500')

Expand All @@ -78,13 +78,8 @@ def _callback(request, **kwargs):
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)

request, channel = make_request(b"GET", b"/_matrix/foo")
request.render(res)
render(request, res, self.reactor)

# No error has been raised yet
self.assertTrue("code" not in channel.result)

# Advance time, now there's an error
self.reactor.advance(1)
self.assertEqual(channel.result["code"], b'500')

def test_callback_synapseerror(self):
Expand All @@ -100,7 +95,7 @@ def _callback(request, **kwargs):
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)

request, channel = make_request(b"GET", b"/_matrix/foo")
request.render(res)
render(request, res, self.reactor)

self.assertEqual(channel.result["code"], b'403')
self.assertEqual(channel.json_body["error"], "Forbidden!!one!")
Expand All @@ -121,7 +116,7 @@ def _callback(request, **kwargs):
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)

request, channel = make_request(b"GET", b"/_matrix/foobar")
request.render(res)
render(request, res, self.reactor)

self.assertEqual(channel.result["code"], b'400')
self.assertEqual(channel.json_body["error"], "Unrecognized request")
Expand Down

0 comments on commit 6326039

Please sign in to comment.