Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Convert some of the media REST code to async/await #7110

Merged
merged 4 commits into from
Mar 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/7110.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert some of synapse.rest.media to async/await.
110 changes: 49 additions & 61 deletions synapse/rest/media/v1/media_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

import twisted.internet.error
import twisted.web.http
from twisted.internet import defer
from twisted.web.resource import Resource

from synapse.api.errors import (
Expand Down Expand Up @@ -114,15 +113,14 @@ def _start_update_recently_accessed(self):
"update_recently_accessed_media", self._update_recently_accessed
)

@defer.inlineCallbacks
def _update_recently_accessed(self):
async def _update_recently_accessed(self):
remote_media = self.recently_accessed_remotes
self.recently_accessed_remotes = set()

local_media = self.recently_accessed_locals
self.recently_accessed_locals = set()

yield self.store.update_cached_last_access_time(
await self.store.update_cached_last_access_time(
local_media, remote_media, self.clock.time_msec()
)

Expand All @@ -138,8 +136,7 @@ def mark_recently_accessed(self, server_name, media_id):
else:
self.recently_accessed_locals.add(media_id)

@defer.inlineCallbacks
def create_content(
async def create_content(
self, media_type, upload_name, content, content_length, auth_user
):
"""Store uploaded content for a local user and return the mxc URL
Expand All @@ -158,11 +155,11 @@ def create_content(

file_info = FileInfo(server_name=None, file_id=media_id)

fname = yield self.media_storage.store_file(content, file_info)
fname = await self.media_storage.store_file(content, file_info)

logger.info("Stored local media in file %r", fname)

yield self.store.store_local_media(
await self.store.store_local_media(
media_id=media_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
Expand All @@ -171,12 +168,11 @@ def create_content(
user_id=auth_user,
)

yield self._generate_thumbnails(None, media_id, media_id, media_type)
await self._generate_thumbnails(None, media_id, media_id, media_type)

return "mxc://%s/%s" % (self.server_name, media_id)

@defer.inlineCallbacks
def get_local_media(self, request, media_id, name):
async def get_local_media(self, request, media_id, name):
"""Responds to reqests for local media, if exists, or returns 404.

Args:
Expand All @@ -190,7 +186,7 @@ def get_local_media(self, request, media_id, name):
Deferred: Resolves once a response has successfully been written
to request
"""
media_info = yield self.store.get_local_media(media_id)
media_info = await self.store.get_local_media(media_id)
if not media_info or media_info["quarantined_by"]:
respond_404(request)
return
Expand All @@ -204,13 +200,12 @@ def get_local_media(self, request, media_id, name):

file_info = FileInfo(None, media_id, url_cache=url_cache)

responder = yield self.media_storage.fetch_media(file_info)
yield respond_with_responder(
responder = await self.media_storage.fetch_media(file_info)
await respond_with_responder(
request, responder, media_type, media_length, upload_name
)

@defer.inlineCallbacks
def get_remote_media(self, request, server_name, media_id, name):
async def get_remote_media(self, request, server_name, media_id, name):
"""Respond to requests for remote media.

Args:
Expand All @@ -236,8 +231,8 @@ def get_remote_media(self, request, server_name, media_id, name):
# We linearize here to ensure that we don't try and download remote
# media multiple times concurrently
key = (server_name, media_id)
with (yield self.remote_media_linearizer.queue(key)):
responder, media_info = yield self._get_remote_media_impl(
with (await self.remote_media_linearizer.queue(key)):
responder, media_info = await self._get_remote_media_impl(
server_name, media_id
)

Expand All @@ -246,14 +241,13 @@ def get_remote_media(self, request, server_name, media_id, name):
media_type = media_info["media_type"]
media_length = media_info["media_length"]
upload_name = name if name else media_info["upload_name"]
yield respond_with_responder(
await respond_with_responder(
request, responder, media_type, media_length, upload_name
)
else:
respond_404(request)

@defer.inlineCallbacks
def get_remote_media_info(self, server_name, media_id):
async def get_remote_media_info(self, server_name, media_id):
"""Gets the media info associated with the remote file, downloading
if necessary.

Expand All @@ -274,8 +268,8 @@ def get_remote_media_info(self, server_name, media_id):
# We linearize here to ensure that we don't try and download remote
# media multiple times concurrently
key = (server_name, media_id)
with (yield self.remote_media_linearizer.queue(key)):
responder, media_info = yield self._get_remote_media_impl(
with (await self.remote_media_linearizer.queue(key)):
responder, media_info = await self._get_remote_media_impl(
server_name, media_id
)

Expand All @@ -286,8 +280,7 @@ def get_remote_media_info(self, server_name, media_id):

return media_info

@defer.inlineCallbacks
def _get_remote_media_impl(self, server_name, media_id):
async def _get_remote_media_impl(self, server_name, media_id):
"""Looks for media in local cache, if not there then attempt to
download from remote server.

Expand All @@ -299,7 +292,7 @@ def _get_remote_media_impl(self, server_name, media_id):
Returns:
Deferred[(Responder, media_info)]
"""
media_info = yield self.store.get_cached_remote_media(server_name, media_id)
media_info = await self.store.get_cached_remote_media(server_name, media_id)

# file_id is the ID we use to track the file locally. If we've already
# seen the file then reuse the existing ID, otherwise genereate a new
Expand All @@ -317,19 +310,18 @@ def _get_remote_media_impl(self, server_name, media_id):
logger.info("Media is quarantined")
raise NotFoundError()

responder = yield self.media_storage.fetch_media(file_info)
responder = await self.media_storage.fetch_media(file_info)
if responder:
return responder, media_info

# Failed to find the file anywhere, lets download it.

media_info = yield self._download_remote_file(server_name, media_id, file_id)
media_info = await self._download_remote_file(server_name, media_id, file_id)

responder = yield self.media_storage.fetch_media(file_info)
responder = await self.media_storage.fetch_media(file_info)
return responder, media_info

@defer.inlineCallbacks
def _download_remote_file(self, server_name, media_id, file_id):
async def _download_remote_file(self, server_name, media_id, file_id):
"""Attempt to download the remote file from the given server name,
using the given file_id as the local id.

Expand All @@ -351,7 +343,7 @@ def _download_remote_file(self, server_name, media_id, file_id):
("/_matrix/media/v1/download", server_name, media_id)
)
try:
length, headers = yield self.client.get_file(
length, headers = await self.client.get_file(
server_name,
request_path,
output_stream=f,
Expand Down Expand Up @@ -397,15 +389,15 @@ def _download_remote_file(self, server_name, media_id, file_id):
)
raise SynapseError(502, "Failed to fetch remote media")

yield finish()
await finish()

media_type = headers[b"Content-Type"][0].decode("ascii")
upload_name = get_filename_from_headers(headers)
time_now_ms = self.clock.time_msec()

logger.info("Stored remote media in file %r", fname)

yield self.store.store_cached_remote_media(
await self.store.store_cached_remote_media(
origin=server_name,
media_id=media_id,
media_type=media_type,
Expand All @@ -423,7 +415,7 @@ def _download_remote_file(self, server_name, media_id, file_id):
"filesystem_id": file_id,
}

yield self._generate_thumbnails(server_name, media_id, file_id, media_type)
await self._generate_thumbnails(server_name, media_id, file_id, media_type)

return media_info

Expand Down Expand Up @@ -458,16 +450,15 @@ def _generate_thumbnail(self, thumbnailer, t_width, t_height, t_method, t_type):

return t_byte_source

@defer.inlineCallbacks
def generate_local_exact_thumbnail(
async def generate_local_exact_thumbnail(
self, media_id, t_width, t_height, t_method, t_type, url_cache
):
input_path = yield self.media_storage.ensure_media_is_in_local_cache(
input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(None, media_id, url_cache=url_cache)
)

thumbnailer = Thumbnailer(input_path)
t_byte_source = yield defer_to_thread(
t_byte_source = await defer_to_thread(
self.hs.get_reactor(),
self._generate_thumbnail,
thumbnailer,
Expand All @@ -490,7 +481,7 @@ def generate_local_exact_thumbnail(
thumbnail_type=t_type,
)

output_path = yield self.media_storage.store_file(
output_path = await self.media_storage.store_file(
t_byte_source, file_info
)
finally:
Expand All @@ -500,22 +491,21 @@ def generate_local_exact_thumbnail(

t_len = os.path.getsize(output_path)

yield self.store.store_local_thumbnail(
await self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len
)

return output_path

@defer.inlineCallbacks
def generate_remote_exact_thumbnail(
async def generate_remote_exact_thumbnail(
self, server_name, file_id, media_id, t_width, t_height, t_method, t_type
):
input_path = yield self.media_storage.ensure_media_is_in_local_cache(
input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(server_name, file_id, url_cache=False)
)

thumbnailer = Thumbnailer(input_path)
t_byte_source = yield defer_to_thread(
t_byte_source = await defer_to_thread(
self.hs.get_reactor(),
self._generate_thumbnail,
thumbnailer,
Expand All @@ -537,7 +527,7 @@ def generate_remote_exact_thumbnail(
thumbnail_type=t_type,
)

output_path = yield self.media_storage.store_file(
output_path = await self.media_storage.store_file(
t_byte_source, file_info
)
finally:
Expand All @@ -547,7 +537,7 @@ def generate_remote_exact_thumbnail(

t_len = os.path.getsize(output_path)

yield self.store.store_remote_media_thumbnail(
await self.store.store_remote_media_thumbnail(
server_name,
media_id,
file_id,
Expand All @@ -560,8 +550,7 @@ def generate_remote_exact_thumbnail(

return output_path

@defer.inlineCallbacks
def _generate_thumbnails(
async def _generate_thumbnails(
self, server_name, media_id, file_id, media_type, url_cache=False
):
"""Generate and store thumbnails for an image.
Expand All @@ -582,7 +571,7 @@ def _generate_thumbnails(
if not requirements:
return

input_path = yield self.media_storage.ensure_media_is_in_local_cache(
input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(server_name, file_id, url_cache=url_cache)
)

Expand All @@ -600,7 +589,7 @@ def _generate_thumbnails(
return

if thumbnailer.transpose_method is not None:
m_width, m_height = yield defer_to_thread(
m_width, m_height = await defer_to_thread(
self.hs.get_reactor(), thumbnailer.transpose
)

Expand All @@ -620,11 +609,11 @@ def _generate_thumbnails(
for (t_width, t_height, t_type), t_method in iteritems(thumbnails):
# Generate the thumbnail
if t_method == "crop":
t_byte_source = yield defer_to_thread(
t_byte_source = await defer_to_thread(
self.hs.get_reactor(), thumbnailer.crop, t_width, t_height, t_type
)
elif t_method == "scale":
t_byte_source = yield defer_to_thread(
t_byte_source = await defer_to_thread(
self.hs.get_reactor(), thumbnailer.scale, t_width, t_height, t_type
)
else:
Expand All @@ -646,7 +635,7 @@ def _generate_thumbnails(
url_cache=url_cache,
)

output_path = yield self.media_storage.store_file(
output_path = await self.media_storage.store_file(
t_byte_source, file_info
)
finally:
Expand All @@ -656,7 +645,7 @@ def _generate_thumbnails(

# Write to database
if server_name:
yield self.store.store_remote_media_thumbnail(
await self.store.store_remote_media_thumbnail(
server_name,
media_id,
file_id,
Expand All @@ -667,15 +656,14 @@ def _generate_thumbnails(
t_len,
)
else:
yield self.store.store_local_thumbnail(
await self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len
)

return {"width": m_width, "height": m_height}

@defer.inlineCallbacks
def delete_old_remote_media(self, before_ts):
old_media = yield self.store.get_remote_media_before(before_ts)
async def delete_old_remote_media(self, before_ts):
old_media = await self.store.get_remote_media_before(before_ts)

deleted = 0

Expand All @@ -689,7 +677,7 @@ def delete_old_remote_media(self, before_ts):

# TODO: Should we delete from the backup store

with (yield self.remote_media_linearizer.queue(key)):
with (await self.remote_media_linearizer.queue(key)):
full_path = self.filepaths.remote_media_filepath(origin, file_id)
try:
os.remove(full_path)
Expand All @@ -705,7 +693,7 @@ def delete_old_remote_media(self, before_ts):
)
shutil.rmtree(thumbnail_dir, ignore_errors=True)

yield self.store.delete_remote_media(origin, media_id)
await self.store.delete_remote_media(origin, media_id)
deleted += 1

return {"deleted": deleted}
Expand Down
Loading