Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Use a producer to stream back responses #3701

Merged
merged 3 commits into from
Aug 17, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/3701.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Avoid timing out requests while we are streaming back the response
17 changes: 13 additions & 4 deletions synapse/http/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from twisted.python import failure
from twisted.web import resource, server
from twisted.web.server import NOT_DONE_YET
from twisted.web.static import NoRangeStaticProducer
from twisted.web.util import redirectTo

import synapse.events
Expand All @@ -42,6 +43,11 @@
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.metrics import Measure

if PY3:
from io import BytesIO
else:
from cStringIO import StringIO as BytesIO

logger = logging.getLogger(__name__)

HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
Expand Down Expand Up @@ -413,8 +419,7 @@ def respond_with_json(request, code, json_object, send_cors=False,
return

if pretty_print:
json_bytes = (encode_pretty_printed_json(json_object) + "\n"
).encode("utf-8")
json_bytes = encode_pretty_printed_json(json_object) + b"\n"
else:
if canonical_json or synapse.events.USE_FROZEN_DICTS:
# canonicaljson already encodes to bytes
Expand Down Expand Up @@ -450,8 +455,12 @@ def respond_with_json_bytes(request, code, json_bytes, send_cors=False,
if send_cors:
set_cors_headers(request)

request.write(json_bytes)
finish_request(request)
# todo: we can almost certainly avoid this copy and encode the json straight into
# the bytesIO, but it would involve faffing around with string->bytes wrappers.
bytes_io = BytesIO(json_bytes)

producer = NoRangeStaticProducer(request, bytes_io)
producer.start()
return NOT_DONE_YET


Expand Down
7 changes: 2 additions & 5 deletions tests/rest/client/v1/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from synapse.util import Clock

from tests import unittest
from tests.server import make_request, setup_test_homeserver
from tests.server import make_request, render, setup_test_homeserver


class CreateUserServletTestCase(unittest.TestCase):
Expand Down Expand Up @@ -77,10 +77,7 @@ def test_POST_createuser_with_valid_user(self):
)

request, channel = make_request(b"POST", url, request_data)
request.render(res)

# Advance the clock because it waits
self.clock.advance(1)
render(request, res, self.clock)

self.assertEquals(channel.result["code"], b"200")

Expand Down
8 changes: 3 additions & 5 deletions tests/rest/client/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from synapse.api.constants import Membership

from tests import unittest
from tests.server import make_request, wait_until_result
from tests.server import make_request, render


class RestTestCase(unittest.TestCase):
Expand Down Expand Up @@ -171,8 +171,7 @@ def create_room_as(self, room_creator, is_public=True, tok=None):
request, channel = make_request(
"POST", path, json.dumps(content).encode('utf8')
)
request.render(self.resource)
wait_until_result(self.hs.get_reactor(), channel)
render(request, self.resource, self.hs.get_reactor())

assert channel.result["code"] == b"200", channel.result
self.auth_user_id = temp_id
Expand Down Expand Up @@ -220,8 +219,7 @@ def change_membership(self, room, src, targ, membership, tok=None, expect_code=2

request, channel = make_request("PUT", path, json.dumps(data).encode('utf8'))

request.render(self.resource)
wait_until_result(self.hs.get_reactor(), channel)
render(request, self.resource, self.hs.get_reactor())

assert int(channel.result["code"]) == expect_code, (
"Expected: %d, got: %d, resp: %r"
Expand Down
23 changes: 8 additions & 15 deletions tests/rest/client/v2_alpha/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
from tests.server import (
ThreadedMemoryReactorClock as MemoryReactorClock,
make_request,
render,
setup_test_homeserver,
wait_until_result,
)

PATH_PREFIX = "/_matrix/client/v2_alpha"
Expand Down Expand Up @@ -76,8 +76,7 @@ def test_add_filter(self):
"/_matrix/client/r0/user/%s/filter" % (self.USER_ID),
self.EXAMPLE_FILTER_JSON,
)
request.render(self.resource)
wait_until_result(self.clock, channel)
render(request, self.resource, self.clock)

self.assertEqual(channel.result["code"], b"200")
self.assertEqual(channel.json_body, {"filter_id": "0"})
Expand All @@ -91,8 +90,7 @@ def test_add_filter_for_other_user(self):
"/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"),
self.EXAMPLE_FILTER_JSON,
)
request.render(self.resource)
wait_until_result(self.clock, channel)
render(request, self.resource, self.clock)

self.assertEqual(channel.result["code"], b"403")
self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN)
Expand All @@ -105,8 +103,7 @@ def test_add_filter_non_local_user(self):
"/_matrix/client/r0/user/%s/filter" % (self.USER_ID),
self.EXAMPLE_FILTER_JSON,
)
request.render(self.resource)
wait_until_result(self.clock, channel)
render(request, self.resource, self.clock)

self.hs.is_mine = _is_mine
self.assertEqual(channel.result["code"], b"403")
Expand All @@ -121,8 +118,7 @@ def test_get_filter(self):
request, channel = make_request(
"GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.USER_ID, filter_id)
)
request.render(self.resource)
wait_until_result(self.clock, channel)
render(request, self.resource, self.clock)

self.assertEqual(channel.result["code"], b"200")
self.assertEquals(channel.json_body, self.EXAMPLE_FILTER)
Expand All @@ -131,8 +127,7 @@ def test_get_filter_non_existant(self):
request, channel = make_request(
"GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.USER_ID)
)
request.render(self.resource)
wait_until_result(self.clock, channel)
render(request, self.resource, self.clock)

self.assertEqual(channel.result["code"], b"400")
self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND)
Expand All @@ -143,8 +138,7 @@ def test_get_filter_invalid_id(self):
request, channel = make_request(
"GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.USER_ID)
)
request.render(self.resource)
wait_until_result(self.clock, channel)
render(request, self.resource, self.clock)

self.assertEqual(channel.result["code"], b"400")

Expand All @@ -153,7 +147,6 @@ def test_get_filter_no_id(self):
request, channel = make_request(
"GET", "/_matrix/client/r0/user/%s/filter/" % (self.USER_ID)
)
request.render(self.resource)
wait_until_result(self.clock, channel)
render(request, self.resource, self.clock)

self.assertEqual(channel.result["code"], b"400")
26 changes: 9 additions & 17 deletions tests/rest/client/v2_alpha/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from synapse.util import Clock

from tests import unittest
from tests.server import make_request, setup_test_homeserver, wait_until_result
from tests.server import make_request, render, setup_test_homeserver


class RegisterRestServletTestCase(unittest.TestCase):
Expand Down Expand Up @@ -72,8 +72,7 @@ def test_POST_appservice_registration_valid(self):
request, channel = make_request(
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
)
request.render(self.resource)
wait_until_result(self.clock, channel)
render(request, self.resource, self.clock)

self.assertEquals(channel.result["code"], b"200", channel.result)
det_data = {
Expand All @@ -89,25 +88,22 @@ def test_POST_appservice_registration_invalid(self):
request, channel = make_request(
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
)
request.render(self.resource)
wait_until_result(self.clock, channel)
render(request, self.resource, self.clock)

self.assertEquals(channel.result["code"], b"401", channel.result)

def test_POST_bad_password(self):
request_data = json.dumps({"username": "kermit", "password": 666})
request, channel = make_request(b"POST", self.url, request_data)
request.render(self.resource)
wait_until_result(self.clock, channel)
render(request, self.resource, self.clock)

self.assertEquals(channel.result["code"], b"400", channel.result)
self.assertEquals(channel.json_body["error"], "Invalid password")

def test_POST_bad_username(self):
request_data = json.dumps({"username": 777, "password": "monkey"})
request, channel = make_request(b"POST", self.url, request_data)
request.render(self.resource)
wait_until_result(self.clock, channel)
render(request, self.resource, self.clock)

self.assertEquals(channel.result["code"], b"400", channel.result)
self.assertEquals(channel.json_body["error"], "Invalid username")
Expand All @@ -126,8 +122,7 @@ def test_POST_user_valid(self):
self.device_handler.check_device_registered = Mock(return_value=device_id)

request, channel = make_request(b"POST", self.url, request_data)
request.render(self.resource)
wait_until_result(self.clock, channel)
render(request, self.resource, self.clock)

det_data = {
"user_id": user_id,
Expand All @@ -149,8 +144,7 @@ def test_POST_disabled_registration(self):
self.registration_handler.register = Mock(return_value=("@user:id", "t"))

request, channel = make_request(b"POST", self.url, request_data)
request.render(self.resource)
wait_until_result(self.clock, channel)
render(request, self.resource, self.clock)

self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(channel.json_body["error"], "Registration has been disabled")
Expand All @@ -162,8 +156,7 @@ def test_POST_guest_registration(self):
self.registration_handler.register = Mock(return_value=(user_id, None))

request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}")
request.render(self.resource)
wait_until_result(self.clock, channel)
render(request, self.resource, self.clock)

det_data = {
"user_id": user_id,
Expand All @@ -177,8 +170,7 @@ def test_POST_disabled_guest_registration(self):
self.hs.config.allow_guest_access = False

request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}")
request.render(self.resource)
wait_until_result(self.clock, channel)
render(request, self.resource, self.clock)

self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(channel.json_body["error"], "Guest access is disabled")
5 changes: 2 additions & 3 deletions tests/rest/client/v2_alpha/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from tests.server import (
ThreadedMemoryReactorClock as MemoryReactorClock,
make_request,
render,
setup_test_homeserver,
wait_until_result,
)

PATH_PREFIX = "/_matrix/client/v2_alpha"
Expand Down Expand Up @@ -69,8 +69,7 @@ def get_user_by_req(request, allow_guest=False, rights="access"):

def test_sync_argless(self):
request, channel = make_request("GET", "/_matrix/client/r0/sync")
request.render(self.resource)
wait_until_result(self.clock, channel)
render(request, self.resource, self.clock)

self.assertEqual(channel.result["code"], b"200")
self.assertTrue(
Expand Down
23 changes: 19 additions & 4 deletions tests/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class FakeChannel(object):
"""

result = attr.ib(default=attr.Factory(dict))
_producer = None

@property
def json_body(self):
Expand All @@ -49,6 +50,15 @@ def write(self, content):

self.result["body"] += content

def registerProducer(self, producer, streaming):
self._producer = producer

def unregisterProducer(self):
if self._producer is None:
return

self._producer = None

def requestDone(self, _self):
self.result["done"] = True

Expand Down Expand Up @@ -111,14 +121,19 @@ def make_request(method, path, content=b""):
return req, channel


def wait_until_result(clock, channel, timeout=100):
def wait_until_result(clock, request, timeout=100):
"""
Wait until the channel has a result.
Wait until the request is finished.
"""
clock.run()
x = 0

while not channel.result:
while not request.finished:

# If there's a producer, tell it to resume producing so we get content
if request._channel._producer:
request._channel._producer.resumeProducing()

x += 1

if x > timeout:
Expand All @@ -129,7 +144,7 @@ def wait_until_result(clock, channel, timeout=100):

def render(request, resource, clock):
request.render(resource)
wait_until_result(clock, request._channel)
wait_until_result(clock, request)


class ThreadedMemoryReactorClock(MemoryReactorClock):
Expand Down
17 changes: 6 additions & 11 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from synapse.util import Clock

from tests import unittest
from tests.server import make_request, setup_test_homeserver
from tests.server import make_request, render, setup_test_homeserver


class JsonResourceTests(unittest.TestCase):
Expand Down Expand Up @@ -37,7 +37,7 @@ def _callback(request, **kwargs):
)

request, channel = make_request(b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83")
request.render(res)
render(request, res, self.reactor)

self.assertEqual(request.args, {b'a': [u"\N{SNOWMAN}".encode('utf8')]})
self.assertEqual(got_kwargs, {u"room_id": u"\N{SNOWMAN}"})
Expand All @@ -55,7 +55,7 @@ def _callback(request, **kwargs):
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)

request, channel = make_request(b"GET", b"/_matrix/foo")
request.render(res)
render(request, res, self.reactor)

self.assertEqual(channel.result["code"], b'500')

Expand All @@ -78,13 +78,8 @@ def _callback(request, **kwargs):
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)

request, channel = make_request(b"GET", b"/_matrix/foo")
request.render(res)
render(request, res, self.reactor)

# No error has been raised yet
self.assertTrue("code" not in channel.result)

# Advance time, now there's an error
self.reactor.advance(1)
self.assertEqual(channel.result["code"], b'500')

def test_callback_synapseerror(self):
Expand All @@ -100,7 +95,7 @@ def _callback(request, **kwargs):
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)

request, channel = make_request(b"GET", b"/_matrix/foo")
request.render(res)
render(request, res, self.reactor)

self.assertEqual(channel.result["code"], b'403')
self.assertEqual(channel.json_body["error"], "Forbidden!!one!")
Expand All @@ -121,7 +116,7 @@ def _callback(request, **kwargs):
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)

request, channel = make_request(b"GET", b"/_matrix/foobar")
request.render(res)
render(request, res, self.reactor)

self.assertEqual(channel.result["code"], b'400')
self.assertEqual(channel.json_body["error"], "Unrecognized request")
Expand Down