diff --git a/aiohttp/abc.py b/aiohttp/abc.py index 245d9020052..ae8aa6bb809 100644 --- a/aiohttp/abc.py +++ b/aiohttp/abc.py @@ -9,11 +9,6 @@ class AbstractRouter(metaclass=ABCMeta): def resolve(self, request): """Return MATCH_INFO for given request""" - @asyncio.coroutine - @abstractmethod - def reverse(self, method, endpoint, **kwargs): - """Return URL string for """ - class AbstractMatchInfo(metaclass=ABCMeta): @@ -21,8 +16,3 @@ class AbstractMatchInfo(metaclass=ABCMeta): @abstractmethod def handler(self): """Return handler for match info""" - - @property - @abstractmethod - def endpoint(self): - """Return endpoint for match info""" diff --git a/aiohttp/web.py b/aiohttp/web.py index a97b78f6021..91f62d55848 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -1,3 +1,4 @@ +import abc import asyncio import binascii import collections @@ -839,30 +840,158 @@ class HTTPVersionNotSupported(HTTPServerError): class UrlMappingMatchInfo(dict, AbstractMatchInfo): - def __init__(self, match_dict, entry): + def __init__(self, match_dict, route): super().__init__(match_dict) - self._entry = entry + self._route = route @property def handler(self): - return self._entry.handler + return self._route.handler @property - def endpoint(self): - return self._entry.endpoint + def route(self): + return self._route -BaseEntry = collections.namedtuple('BaseEntry', - 'regex method handler endpoint path type') +class Route(metaclass=abc.ABCMeta): + def __init__(self, method, handler, name): + self._method = method + self._handler = handler + self._name = name + + @property + def method(self): + return self._method + + @property + def handler(self): + return self._handler + + @property + def name(self): + return self._name + + @abc.abstractmethod + def match(self, path): + """Return dict with info for given path or + None if route cannot process path.""" + + @abc.abstractmethod + def url(self, **kwargs): + """Construct url for route with additional params.""" + + @staticmethod + def _append_query(url, query): + if query is not None: + return url + "?" + urlencode(query) + else: + return url + + +class PlainRoute(Route): + def __init__(self, method, handler, name, path): + super().__init__(method, handler, name) + self._path = path + + def match(self, path): + # string comparsion about 10 times faster than regexp matching + if self._path == path: + return {} + else: + return None + + def url(self, *, query=None): + return self._append_query(self._path, query) + + def __repr__(self): + name = "'" + self.name + "' " if self.name is not None else "" + return " {handler!r}".format( + name=name, method=self.method, path=self._path, + handler=self.handler) + + +class DynamicRoute(Route): + + def __init__(self, method, handler, name, pattern, formatter): + super().__init__(method, handler, name) + self._pattern = pattern + self._formatter = formatter + + def match(self, path): + match = self._pattern.match(path) + if match is None: + return None + else: + return match.groupdict() + + def url(self, *, parts, query=None): + url = self._formatter.format_map(parts) + return self._append_query(url, query) + + def __repr__(self): + name = "'" + self.name + "' " if self.name is not None else "" + return (" {handler!r}" + .format(name=name, method=self.method, + formatter=self._formatter, handler=self.handler)) -class Entry(BaseEntry): - DYNAMIC = "DYNAMIC" - STATIC = "STATIC" - PLAIN = "PLAIN" +class StaticRoute(Route): -class UrlDispatcher(AbstractRouter): + def __init__(self, name, prefix, directory): + assert prefix.startswith('/'), prefix + assert prefix.endswith('/'), prefix + super().__init__('GET', self.handle, name) + self._prefix = prefix + self._prefix_len = len(self._prefix) + self._directory = directory + + def match(self, path): + if not path.startswith(self._prefix): + return None + return {'filename': path[self._prefix_len:]} + + def url(self, *, filename, query=None): + while filename.startswith('/'): + filename = filename[1:] + url = self._prefix + filename + return self._append_query(url, query) + + @asyncio.coroutine + def handle(self, request): + resp = StreamResponse(request) + filename = request.match_info['filename'] + filepath = os.path.join(self._directory, filename) + if '..' in filename: + raise HTTPNotFound(request) + if not os.path.exists(filepath) or not os.path.isfile(filepath): + raise HTTPNotFound(request) + + ct = mimetypes.guess_type(filename)[0] + if not ct: + ct = 'application/octet-stream' + resp.content_type = ct + + resp.headers['transfer-encoding'] = 'chunked' + resp.send_headers() + + with open(filepath, 'rb') as f: + chunk = f.read(1024) + while chunk: + resp.write(chunk) + chunk = f.read(1024) + + yield from resp.write_eof() + return resp + + def __repr__(self): + name = "'" + self.name + "' " if self.name is not None else "" + return " {directory!r}".format( + name=name, method=self.method, path=self._prefix, + directory=self._directory) + + +class UrlDispatcher(AbstractRouter, collections.abc.Mapping): DYN = re.compile(r'^\{[a-zA-Z][_a-zA-Z0-9]*\}$') GOOD = r'[^{}/]+' @@ -873,160 +1002,91 @@ class UrlDispatcher(AbstractRouter): def __init__(self): super().__init__() self._urls = [] - self._endpoints = {} + self._routes = {} @asyncio.coroutine def resolve(self, request): path = request.path method = request.method allowed_methods = set() - for entry in self._urls: - match = entry.regex.match(path) - if match is None: + for route in self._urls: + match_dict = route.match(path) + if match_dict is None: continue - if entry.method != method: - allowed_methods.add(entry.method) + route_method = route.method + if route_method != method: + allowed_methods.add(route_method) else: - break + return UrlMappingMatchInfo(match_dict, route) else: if allowed_methods: raise HTTPMethodNotAllowed(request, method, allowed_methods) else: raise HTTPNotFound(request) - matchdict = match.groupdict() - return UrlMappingMatchInfo(matchdict, entry) + def __iter__(self): + return iter(self._routes) - @asyncio.coroutine - def reverse(self, method, endpoint, *, parts=None, filename=None, - query=None): - method = method.upper() - entry = self._endpoints.get((method, endpoint)) - if entry is None: - raise KeyError("[{}] {!r} endpoint not found" - .format(method, endpoint)) - - if filename is not None and entry.type is not Entry.STATIC: - raise ValueError("Cannot use filename with non-static route") - - if entry.type is Entry.DYNAMIC: - if parts is None: - raise ValueError( - "Dynamic endpoint requires nonempty parts parameter") - url = entry.path.format_map(parts) - elif entry.type is Entry.PLAIN: - if parts: - raise ValueError( - "Plain endpoint doesn't allow parts parameter") - url = entry.path - elif entry.type is Entry.STATIC: - if filename is None: - raise ValueError( - "filename must be not empty for static routes") - while filename.startswith('/'): - filename = filename[1:] - url = entry.path + filename - else: - raise ValueError( - "Not supported endpoint type {}".format(entry.type)) + def __len__(self): + return len(self._routes) - if query is not None: - qs = "?" + urlencode(query) - else: - qs = "" - return url + qs - - def _register_endpoint(self, new_entry): - endpoint = new_entry.endpoint - method = new_entry.method - if endpoint is not None: - key = (method, endpoint) - if key in self._endpoints: - entry = self._endpoints[key] - raise ValueError('Duplicate endpoint {!r}, ' - 'already handled by [{}] {} -> {!r}' - .format(endpoint, - entry.method, - entry.path, - entry.handler)) + def __contains__(self, name): + return name in self._routes + + def __getitem__(self, name): + return self._routes[name] + + def _register_endpoint(self, route): + name = route.name + if name is not None: + if name in self._routes: + raise ValueError('Duplicate {!r}, ' + 'already handled by {!r}' + .format(name, self._routes[name])) else: - self._endpoints[key] = new_entry - self._urls.append(new_entry) + self._routes[name] = route + self._urls.append(route) - def add_route(self, method, path, handler, *, endpoint=None): + def add_route(self, method, path, handler, *, name=None): assert path.startswith('/') assert callable(handler), handler method = method.upper() assert method in self.METHODS, method - regexp = [] - entry_type = Entry.PLAIN + parts = [] + factory = PlainRoute for part in path.split('/'): if not part: continue if self.DYN.match(part): - regexp.append('(?P<'+part[1:-1]+'>'+self.GOOD+')') - entry_type = Entry.DYNAMIC + parts.append('(?P<'+part[1:-1]+'>'+self.GOOD+')') + factory = DynamicRoute elif self.PLAIN.match(part): - regexp.append(re.escape(part)) + parts.append(re.escape(part)) else: raise ValueError("Invalid path '{}'['{}']".format(path, part)) - pattern = '/' + '/'.join(regexp) - if path.endswith('/') and pattern != '/': - pattern += '/' - compiled = re.compile('^' + pattern + '$') - new_entry = Entry(compiled, method, handler, - endpoint, path, entry_type) - self._register_endpoint(new_entry) - - def _static_file_handler_maker(self, path): - @asyncio.coroutine - def _handler(request): - resp = StreamResponse(request) - filename = request.match_info['filename'] - filepath = os.path.join(path, filename) - if '..' in filename: - raise HTTPNotFound(request) - if not os.path.exists(filepath) or not os.path.isfile(filepath): - raise HTTPNotFound(request) - - ct = mimetypes.guess_type(filename)[0] - if not ct: - ct = 'application/octet-stream' - resp.content_type = ct - - resp.headers['transfer-encoding'] = 'chunked' - resp.send_headers() - - with open(filepath, 'rb') as f: - chunk = f.read(1024) - while chunk: - resp.write(chunk) - chunk = f.read(1024) - - yield from resp.write_eof() - return resp - - return _handler - - def add_static(self, prefix, path, *, endpoint=None): + if factory is PlainRoute: + route = PlainRoute(method, handler, name, path) + else: + pattern = '/' + '/'.join(parts) + if path.endswith('/') and pattern != '/': + pattern += '/' + compiled = re.compile('^' + pattern + '$') + route = DynamicRoute(method, handler, name, compiled, path) + self._register_endpoint(route) + + def add_static(self, prefix, path, *, name=None): """ Adds static files view :param prefix - url prefix :param path - folder with files """ assert prefix.startswith('/') - assert os.path.exists(path), 'Path does not exist %s' % path + assert os.path.isdir(path), 'Path does not directory %s' % path path = os.path.abspath(path) - method = 'GET' - suffix = r'(?P.*)' # match everything after static prefix if not prefix.endswith('/'): prefix += '/' - compiled = re.compile('^' + prefix + suffix + '$') - new_entry = Entry( - compiled, method, - self._static_file_handler_maker(path), - endpoint, prefix, Entry.STATIC) - self._register_endpoint(new_entry) + route = StaticRoute(name, prefix, path) + self._register_endpoint(route) ############################################################ diff --git a/tests/test_urldispatch.py b/tests/test_urldispatch.py index f46c3b297d7..2c459e8bbac 100644 --- a/tests/test_urldispatch.py +++ b/tests/test_urldispatch.py @@ -4,7 +4,7 @@ from unittest import mock import aiohttp.web from aiohttp.web import (UrlDispatcher, Request, Response, - HTTPMethodNotAllowed, HTTPNotFound, Entry) + HTTPMethodNotAllowed, HTTPNotFound) from aiohttp.multidict import MultiDict from aiohttp.protocol import HttpVersion, RawRequestMessage @@ -38,7 +38,7 @@ def test_add_route_root(self): self.assertIsNotNone(info) self.assertEqual(0, len(info)) self.assertIs(handler, info.handler) - self.assertIsNone(info.endpoint) + self.assertIsNone(info.route.name) def test_add_route_simple(self): handler = lambda req: Response(req) @@ -48,7 +48,7 @@ def test_add_route_simple(self): self.assertIsNotNone(info) self.assertEqual(0, len(info)) self.assertIs(handler, info.handler) - self.assertIsNone(info.endpoint) + self.assertIsNone(info.route.name) def test_add_with_matchdict(self): handler = lambda req: Response(req) @@ -58,16 +58,16 @@ def test_add_with_matchdict(self): self.assertIsNotNone(info) self.assertEqual({'to': 'tail'}, info) self.assertIs(handler, info.handler) - self.assertIsNone(info.endpoint) + self.assertIsNone(info.route.name) - def test_add_with_endpoint(self): + def test_add_with_name(self): handler = lambda req: Response(req) self.router.add_route('GET', '/handler/to/path', handler, - endpoint='endpoint') + name='name') req = self.make_request('GET', '/handler/to/path') info = self.loop.run_until_complete(self.router.resolve(req)) self.assertIsNotNone(info) - self.assertEqual('endpoint', info.endpoint) + self.assertEqual('name', info.route.name) def test_add_with_tailing_slash(self): handler = lambda req: Response(req) @@ -148,94 +148,104 @@ def test_raise_method_not_found(self): exc = ctx.exception self.assertEqual(404, exc.status) - def test_double_add_url_with_the_same_endpoint(self): - self.router.add_route('GET', '/get', lambda r: None, endpoint='name') + def test_double_add_url_with_the_same_name(self): + self.router.add_route('GET', '/get', lambda r: None, name='name') - regexp = ("Duplicate endpoint 'name', " - r"already handled by \[GET\] /get -> ") + regexp = ("Duplicate 'name', already handled by") with self.assertRaisesRegex(ValueError, regexp): self.router.add_route('GET', '/get_other', lambda r: None, - endpoint='name') + name='name') def test_reverse_plain(self): - self.router.add_route('GET', '/get', lambda r: None, endpoint='name') + self.router.add_route('GET', '/get', lambda r: None, name='name') - url = self.loop.run_until_complete(self.router.reverse('GET', 'name')) + url = self.router['name'].url() self.assertEqual('/get', url) - def test_reverse_plain_with_parts(self): - self.router.add_route('GET', '/get', lambda r: None, endpoint='name') - - with self.assertRaisesRegex( - ValueError, - "Plain endpoint doesn't allow parts parameter"): - self.loop.run_until_complete( - self.router.reverse('GET', 'name', parts={'a': 'b'})) - - def test_reverse_unknown_endpoint(self): - with self.assertRaisesRegex( - KeyError, - r"\[GET\] 'unknown' endpoint not found"): - self.loop.run_until_complete(self.router.reverse('GET', 'unknown')) + def test_reverse_unknown_route_name(self): + with self.assertRaises(KeyError): + self.router['unknown'] def test_reverse_dynamic(self): self.router.add_route('GET', '/get/{name}', - lambda r: None, endpoint='name') + lambda r: None, name='name') - url = self.loop.run_until_complete( - self.router.reverse('GET', 'name', parts={'name': 'John'})) + url = self.router['name'].url(parts={'name': 'John'}) self.assertEqual('/get/John', url) - def test_reverse_dynamic_without_parts(self): - self.router.add_route('GET', '/get/{name}', - lambda r: None, endpoint='name') - - with self.assertRaisesRegex( - ValueError, - "Dynamic endpoint requires nonempty parts parameter"): - self.loop.run_until_complete(self.router.reverse('GET', 'name')) - def test_reverse_with_qs(self): - self.router.add_route('GET', '/get', lambda r: None, endpoint='name') - - url = self.loop.run_until_complete( - self.router.reverse('GET', 'name', query=[('a', 'b'), ('c', 1)])) + self.router.add_route('GET', '/get', lambda r: None, name='name') + url = self.router['name'].url(query=[('a', 'b'), ('c', 1)]) self.assertEqual('/get?a=b&c=1', url) - def test_reverse_nonstatic_with_filename(self): - self.router.add_route('GET', '/get', lambda r: None, endpoint='name') - - with self.assertRaisesRegex( - ValueError, - "Cannot use filename with non-static route"): - self.loop.run_until_complete(self.router.reverse('GET', 'name', - filename='a.txt')) - def test_reverse_static(self): self.router.add_static('/st', os.path.dirname(aiohttp.__file__), - endpoint='static') - - url = self.loop.run_until_complete( - self.router.reverse('GET', 'static', filename='/dir/a.txt')) + name='static') + url = self.router['static'].url(filename='/dir/a.txt') self.assertEqual('/st/dir/a.txt', url) - def test_reverse_static_without_filename(self): - self.router.add_static('/st', os.path.dirname(aiohttp.__file__), - endpoint='static') - - with self.assertRaisesRegex( - ValueError, - 'filename must be not empty for static routes'): - self.loop.run_until_complete(self.router.reverse('GET', 'static')) - - def test_reverse_unknown_endpoint_type(self): - self.router._register_endpoint(Entry('compiled', 'GET', 'handler', - 'endpoint', '/path', 'UNKNOWN')) - - with self.assertRaisesRegex( - ValueError, - 'Not supported endpoint type UNKNOWN'): - self.loop.run_until_complete( - self.router.reverse('GET', 'endpoint')) + def test_plain_not_match(self): + self.router.add_route('GET', '/get/path', + lambda r: None, name='name') + route = self.router['name'] + self.assertIsNone(route.match('/another/path')) + + def test_dynamic_not_match(self): + self.router.add_route('GET', '/get/{name}', + lambda r: None, name='name') + route = self.router['name'] + self.assertIsNone(route.match('/another/path')) + + def test_static_not_match(self): + self.router.add_static('/pre', os.path.dirname(aiohttp.__file__), + name='name') + route = self.router['name'] + self.assertIsNone(route.match('/another/path')) + + def test_dynamic_with_trailing_slash(self): + self.router.add_route('GET', '/get/{name}/', + lambda r: None, name='name') + route = self.router['name'] + self.assertEqual({'name': 'John'}, route.match('/get/John/')) + + def test_len(self): + self.router.add_route('GET', '/get1', + lambda r: None, name='name1') + self.router.add_route('GET', '/get2', + lambda r: None, name='name2') + self.assertEqual(2, len(self.router)) + + def test_iter(self): + self.router.add_route('GET', '/get1', + lambda r: None, name='name1') + self.router.add_route('GET', '/get2', + lambda r: None, name='name2') + self.assertEqual({'name1', 'name2'}, set(iter(self.router))) + + def test_contains(self): + self.router.add_route('GET', '/get1', + lambda r: None, name='name1') + self.router.add_route('GET', '/get2', + lambda r: None, name='name2') + self.assertIn('name1', self.router) + self.assertNotIn('name3', self.router) + + def test_plain_repr(self): + self.router.add_route('GET', '/get/path', + lambda r: None, name='name') + self.assertRegex(repr(self.router['name']), + r"