Skip to content

Commit

Permalink
use the Python 3.5 coroutine syntax
Browse files Browse the repository at this point in the history
Closes #130.
  • Loading branch information
andreasots committed Jun 2, 2017
1 parent 4ba9fda commit 0810e52
Show file tree
Hide file tree
Showing 20 changed files with 148 additions and 250 deletions.
71 changes: 14 additions & 57 deletions common/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import urllib.request

import aiohttp
import async_timeout

from common import config
from common import utils
Expand Down Expand Up @@ -59,37 +60,29 @@ def request(url, data=None, method='GET', maxtries=3, headers={}, timeout=5, **k
# Limit the number of parallel HTTP connections to a server.
http_request_session = aiohttp.ClientSession(connector=aiohttp.TCPConnector(limit=6))
atexit.register(lambda: asyncio.get_event_loop().run_until_complete(http_request_session.close()))
@asyncio.coroutine
def request_coro(url, data=None, method='GET', maxtries=3, headers={}, timeout=5, allow_redirects=True):
async def request_coro(url, data=None, method='GET', maxtries=3, headers={}, timeout=5, allow_redirects=True):
headers["User-Agent"] = "LRRbot/2.0 (https://lrrbot.mrphlip.com/)"
firstex = None

# FIXME(#130): aiohttp fails to decode HEAD requests with Content-Encoding set. Do GET requests instead.
real_method = method
if method == 'HEAD':
real_method = 'GET'

if method == 'GET':
params = data
data = None
else:
params = None
while True:
try:
res = yield from asyncio.wait_for(http_request_session.request(real_method, url, params=params, data=data, headers=headers, allow_redirects=allow_redirects), timeout)
if method == "HEAD":
yield from res.release()
return res
status_class = res.status // 100
if status_class != 2:
yield from res.read()
if status_class == 4:
maxtries = 1
yield from res.release()
raise urllib.error.HTTPError(res.url, res.status, res.reason, res.headers, None)
text = yield from res.text()
yield from res.release()
return text
with async_timeout.timeout(timeout):
async with http_request_session.request(method, url, params=params, data=data, headers=headers, allow_redirects=allow_redirects) as res:
if method == "HEAD":
return res
status_class = res.status // 100
if status_class != 2:
await res.read()
if status_class == 4:
maxtries = 1
raise urllib.error.HTTPError(res.url, res.status, res.reason, res.headers, None)
text = await res.text()
return text
except utils.PASSTHROUGH_EXCEPTIONS:
raise
except Exception as e:
Expand All @@ -101,39 +94,3 @@ def request_coro(url, data=None, method='GET', maxtries=3, headers={}, timeout=5
else:
break
raise firstex

def api_request(uri, *args, **kwargs):
# Send the information to the server
try:
res = request(config.config['siteurl'] + uri, *args, **kwargs)
except utils.PASSTHROUGH_EXCEPTIONS:
raise
except Exception:
log.exception("Error at server in %s" % uri)
else:
try:
res = json.loads(res)
except ValueError:
log.exception("Error parsing server response from %s: %s", uri, res)
else:
if 'success' not in res:
log.error("Error at server in %s" % uri)
return res

@asyncio.coroutine
def api_request_coro(uri, *args, **kwargs):
try:
res = yield from request_coro(config.config['siteurl'] + uri, *args, **kwargs)
except utils.PASSTHROUGH_EXCEPTIONS:
raise
except Exception:
log.exception("Error at server in %s" % uri)
else:
try:
res = json.loads(res)
except ValueError:
log.exception("Error parsing server response from %s: %s", uri, res)
else:
if 'success' not in res:
log.error("Error at server in %s" % uri)
return res
21 changes: 8 additions & 13 deletions common/patreon.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import datetime
import json

Expand All @@ -8,20 +7,18 @@
from common.config import config
from common import http

@asyncio.coroutine
def request_token(grant_type, **data):
async def request_token(grant_type, **data):
data.update({
"grant_type": grant_type,
"client_id": config["patreon_clientid"],
"client_secret": config["patreon_clientsecret"],
})
data = yield from http.request_coro("https://api.patreon.com/oauth2/token", data=data, method="POST")
data = await http.request_coro("https://api.patreon.com/oauth2/token", data=data, method="POST")
data = json.loads(data)
expiry = datetime.datetime.now(pytz.utc) + datetime.timedelta(seconds=data["expires_in"])
return data["access_token"], data["refresh_token"], expiry

@asyncio.coroutine
def get_token(engine, metadata, user):
async def get_token(engine, metadata, user):
def filter_by_user(query, user):
if isinstance(user, int):
return query.where(users.c.id == user)
Expand All @@ -47,7 +44,7 @@ def filter_by_user(query, user):
if access_token is None:
raise Exception("User not logged in")
if expiry < datetime.datetime.now(pytz.utc):
access_token, refresh_token, expiry = yield from request_token("refresh_token", refresh_token=refresh_token)
access_token, refresh_token, expiry = await request_token("refresh_token", refresh_token=refresh_token)
with engine.begin() as conn:
conn.execute(patreon_users.update().where(patreon_users.c.id == patreon_id),
access_token=access_token,
Expand All @@ -57,15 +54,13 @@ def filter_by_user(query, user):

return access_token

@asyncio.coroutine
def get_campaigns(token, include=["creator", "goals", "rewards"]):
async def get_campaigns(token, include=["creator", "goals", "rewards"]):
data = {"include": ",".join(include)}
headers = {"Authorization": "Bearer %s" % token}
data = yield from http.request_coro("https://api.patreon.com/oauth2/api/current_user/campaigns", data=data, headers=headers)
data = await http.request_coro("https://api.patreon.com/oauth2/api/current_user/campaigns", data=data, headers=headers)
return json.loads(data)

@asyncio.coroutine
def current_user(token):
async def current_user(token):
headers = {"Authorization": "Bearer %s" % token}
data = yield from http.request_coro("https://api.patreon.com/oauth2/api/current_user", headers=headers)
data = await http.request_coro("https://api.patreon.com/oauth2/api/current_user", headers=headers)
return json.loads(data)
8 changes: 3 additions & 5 deletions common/slack.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
import asyncio
import json

from common.config import config
from common import http

@asyncio.coroutine
def send_message(text, **keys):
async def send_message(text, **keys):
keys['text'] = text

headers = {
"Content-Type": "application/json",
}

if config['slack_webhook_url'] is not None:
yield from http.request_coro(config['slack_webhook_url'], method="POST", data=json.dumps(keys), headers=headers)
await http.request_coro(config['slack_webhook_url'], method="POST", data=json.dumps(keys), headers=headers)

def escape(text):
return text \
.replace("&", "&amp;") \
.replace("<", "&lt;") \
.replace(">", "&gt;")
.replace(">", "&gt;")
35 changes: 14 additions & 21 deletions common/twitch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json
import random
import asyncio
import socket
import dateutil.parser

Expand Down Expand Up @@ -98,8 +97,7 @@ def is_stream_live(username=None):
channel_data = get_info(username, use_fallback=False)
return channel_data and channel_data['live']

@asyncio.coroutine
def get_subscribers(channel, token, count=5, offset=None, latest=True):
async def get_subscribers(channel, token, count=5, offset=None, latest=True):
headers = {
"Authorization": "OAuth %s" % token,
"Client-ID": config['twitch_clientid'],
Expand All @@ -110,15 +108,14 @@ def get_subscribers(channel, token, count=5, offset=None, latest=True):
}
if offset is not None:
data['offset'] = str(offset)
res = yield from common.http.request_coro("https://api.twitch.tv/kraken/channels/%s/subscriptions" % channel, headers=headers, data=data)
res = await common.http.request_coro("https://api.twitch.tv/kraken/channels/%s/subscriptions" % channel, headers=headers, data=data)
subscriber_data = json.loads(res)
return [
(sub['user']['display_name'], sub['user'].get('logo'), sub['created_at'], sub.get('updated_at', sub['created_at']))
for sub in subscriber_data['subscriptions']
]

@asyncio.coroutine
def get_follows_channels(username=None):
async def get_follows_channels(username=None):
if username is None:
username = config["username"]
headers = {
Expand All @@ -128,15 +125,14 @@ def get_follows_channels(username=None):
follows = []
total = 1
while len(follows) < total:
data = yield from common.http.request_coro(url, headers=headers)
data = await common.http.request_coro(url, headers=headers)
data = json.loads(data)
total = data["_total"]
follows += data["follows"]
url = data["_links"]["next"]
return follows

@asyncio.coroutine
def get_streams_followed(token):
async def get_streams_followed(token):
url = "https://api.twitch.tv/kraken/streams/followed"
headers = {
"Authorization": "OAuth %s" % token,
Expand All @@ -145,38 +141,35 @@ def get_streams_followed(token):
streams = []
total = 1
while len(streams) < total:
data = yield from common.http.request_coro(url, headers=headers)
data = await common.http.request_coro(url, headers=headers)
data = json.loads(data)
total = data["_total"]
streams += data["streams"]
url = data["_links"]["next"]
return streams

@asyncio.coroutine
def follow_channel(target, token):
async def follow_channel(target, token):
headers = {
"Authorization": "OAuth %s" % token,
"Client-ID": config['twitch_clientid'],
}
yield from common.http.request_coro("https://api.twitch.tv/kraken/users/%s/follows/channels/%s" % (config["username"], target),
data={"notifications": "false"}, method="PUT", headers=headers)
await common.http.request_coro("https://api.twitch.tv/kraken/users/%s/follows/channels/%s" % (config["username"], target),
data={"notifications": "false"}, method="PUT", headers=headers)

@asyncio.coroutine
def unfollow_channel(target, token):
async def unfollow_channel(target, token):
headers = {
"Authorization": "OAuth %s" % token,
"Client-ID": config['twitch_clientid'],
}
yield from common.http.request_coro("https://api.twitch.tv/kraken/users/%s/follows/channels/%s" % (config["username"], target),
method="DELETE", headers=headers)
await common.http.request_coro("https://api.twitch.tv/kraken/users/%s/follows/channels/%s" % (config["username"], target),
method="DELETE", headers=headers)

@asyncio.coroutine
def get_videos(channel=None, offset=0, limit=10, broadcasts=False, hls=False):
async def get_videos(channel=None, offset=0, limit=10, broadcasts=False, hls=False):
channel = channel or config["channel"]
headers = {
"Client-ID": config['twitch_clientid'],
}
data = yield from common.http.request_coro("https://api.twitch.tv/kraken/channels/%s/videos" % channel, headers=headers, data={
data = await common.http.request_coro("https://api.twitch.tv/kraken/channels/%s/videos" % channel, headers=headers, data={
"offset": str(offset),
"limit": str(limit),
"broadcasts": "true" if broadcasts else "false",
Expand Down
15 changes: 6 additions & 9 deletions common/url.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,14 @@
log = logging.getLogger("common.url")

@utils.cache(60 * 60, params=[0])
@asyncio.coroutine
def canonical_url(url, depth=10):
async def canonical_url(url, depth=10):
urls = []
while depth > 0:
if not url.startswith("http://") and not url.startswith("https://"):
url = "http://" + url
urls.append(url)
try:
res = yield from request_coro(url, method="HEAD", allow_redirects=False)
res = await request_coro(url, method="HEAD", allow_redirects=False)
if res.status in range(300, 400) and "Location" in res.headers:
url = res.headers["Location"]
depth -= 1
Expand All @@ -30,10 +29,9 @@ def canonical_url(url, depth=10):
return urls

@utils.cache(24 * 60 * 60)
@asyncio.coroutine
def get_tlds():
async def get_tlds():
tlds = set()
data = yield from request_coro("https://data.iana.org/TLD/tlds-alpha-by-domain.txt")
data = await request_coro("https://data.iana.org/TLD/tlds-alpha-by-domain.txt")
for line in data.splitlines():
if not line.startswith("#"):
line = line.strip().lower()
Expand All @@ -43,13 +41,12 @@ def get_tlds():
return tlds

@utils.cache(24 * 60 * 60)
@asyncio.coroutine
def url_regex():
async def url_regex():
parens = ["()", "[]", "{}", "<>", '""', "''"]

# Sort TLDs in decreasing order by length to avoid incorrect matches.
# For example: if 'co' is before 'com', 'example.com/path' is matched as 'example.co'.
tlds = sorted((yield from get_tlds()), key=lambda e: len(e), reverse=True)
tlds = sorted((await get_tlds()), key=lambda e: len(e), reverse=True)
re_tld = "(?:" + "|".join(map(re.escape, tlds)) + ")"
re_hostname = "(?:(?:(?:[\w-]+\.)+" + re_tld + "\.?)|(?:\d{,3}(?:\.\d{,3}){3})|(?:\[[0-9a-fA-F:.]+\]))"
re_url = "((?:https?://)?" + re_hostname + "(?::\d+)?(?:/[\x5E\s\u200b]*)?)"
Expand Down
Loading

0 comments on commit 0810e52

Please sign in to comment.