-
Notifications
You must be signed in to change notification settings - Fork 313
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow to specify static responses #1234
Changes from 4 commits
657bc92
978431a
5b4f416
af77d1b
37c0900
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,77 @@ | ||
import asyncio | ||
import json | ||
import logging | ||
from typing import Optional, List | ||
|
||
import aiohttp | ||
import elasticsearch | ||
from aiohttp import RequestInfo, BaseConnector | ||
from aiohttp.client_proto import ResponseHandler | ||
from aiohttp.helpers import BaseTimerContext | ||
from multidict import CIMultiDictProxy, CIMultiDict | ||
from yarl import URL | ||
|
||
from esrally.utils import io | ||
|
||
|
||
class StaticTransport: | ||
def __init__(self): | ||
self.closed = False | ||
|
||
def is_closing(self): | ||
return False | ||
|
||
def close(self): | ||
self.closed = True | ||
|
||
|
||
class StaticConnector(BaseConnector): | ||
async def _create_connection(self, req: "ClientRequest", traces: List["Trace"], | ||
timeout: "ClientTimeout") -> ResponseHandler: | ||
handler = ResponseHandler(self._loop) | ||
handler.transport = StaticTransport() | ||
handler.protocol = "" | ||
return handler | ||
|
||
|
||
class StaticRequest(aiohttp.ClientRequest): | ||
RESPONSES = None | ||
|
||
async def send(self, conn: "Connection") -> "ClientResponse": | ||
self.response = self.response_class( | ||
self.method, | ||
self.original_url, | ||
writer=self._writer, | ||
continue100=self._continue, | ||
timer=self._timer, | ||
request_info=self.request_info, | ||
traces=self._traces, | ||
loop=self.loop, | ||
session=self._session, | ||
) | ||
path = self.original_url.path | ||
self.response.static_body = StaticRequest.RESPONSES.response(path) | ||
return self.response | ||
|
||
|
||
class StaticResponse(aiohttp.ClientResponse): | ||
def __init__(self, method: str, url: URL, *, writer: "asyncio.Task[None]", | ||
continue100: Optional["asyncio.Future[bool]"], timer: BaseTimerContext, request_info: RequestInfo, | ||
traces: List["Trace"], loop: asyncio.AbstractEventLoop, session: "ClientSession") -> None: | ||
super().__init__(method, url, writer=writer, continue100=continue100, timer=timer, request_info=request_info, | ||
traces=traces, loop=loop, session=session) | ||
self.static_body = None | ||
|
||
async def start(self, connection: "Connection") -> "ClientResponse": | ||
self._closed = False | ||
self._protocol = connection.protocol | ||
self._connection = connection | ||
self._headers = CIMultiDictProxy(CIMultiDict()) | ||
self.status = 200 | ||
return self | ||
|
||
async def text(self, encoding=None, errors="strict"): | ||
return self.static_body | ||
|
||
|
||
class RawClientResponse(aiohttp.ClientResponse): | ||
|
@@ -16,6 +86,65 @@ async def text(self, encoding=None, errors="strict"): | |
return self._body | ||
|
||
|
||
class ResponseMatcher: | ||
def __init__(self, responses): | ||
self.logger = logging.getLogger(__name__) | ||
self.responses = [] | ||
|
||
for response in responses: | ||
path = response["path"] | ||
if path == "*": | ||
matcher = ResponseMatcher.always() | ||
elif path.startswith("*"): | ||
matcher = ResponseMatcher.endswith(path[1:]) | ||
elif path.endswith("*"): | ||
matcher = ResponseMatcher.startswith(path[:-1]) | ||
else: | ||
matcher = ResponseMatcher.equals(path) | ||
|
||
body = response["body"] | ||
body_encoding = response.get("body-encoding", "json") | ||
if body_encoding == "raw": | ||
body = json.dumps(body).encode("utf-8") | ||
elif body_encoding == "json": | ||
body = json.dumps(body) | ||
else: | ||
raise ValueError(f"Unknown body encoding [{body_encoding}] for path [{path}]") | ||
|
||
self.responses.append((path, matcher, body)) | ||
|
||
@staticmethod | ||
def always(): | ||
# pylint: disable=unused-variable | ||
def f(p): | ||
return True | ||
return f | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. alternatively we could replace the closures with more Pythonic (IMHO? :) ) lambdas e.g.
or for
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I implemented this with lambdas earlier but then got a PEP-8 violation warning. :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. LOL whaaaaaat? |
||
|
||
@staticmethod | ||
def startswith(path_pattern): | ||
def f(p): | ||
return p.startswith(path_pattern) | ||
return f | ||
|
||
@staticmethod | ||
def endswith(path_pattern): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One thought, we could enhance it in the future and support simple shell like patterns like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good idea. Before we enhance it, we should probably run microbenchmarks to get a better grasp of the overhead though because we're on the hot code path. |
||
def f(p): | ||
return p.endswith(path_pattern) | ||
return f | ||
|
||
@staticmethod | ||
def equals(path_pattern): | ||
def f(p): | ||
return p == path_pattern | ||
return f | ||
|
||
def response(self, path): | ||
for path_pattern, matcher, body in self.responses: | ||
if matcher(path): | ||
self.logger.debug("Path pattern [%s] matches path [%s].", path_pattern, path) | ||
return body | ||
|
||
|
||
class AIOHttpConnection(elasticsearch.AIOHttpConnection): | ||
def __init__(self, | ||
host="localhost", | ||
|
@@ -52,20 +181,42 @@ def __init__(self, | |
self._trace_configs = [trace_config] if trace_config else None | ||
self._enable_cleanup_closed = kwargs.get("enable_cleanup_closed", False) | ||
|
||
static_responses = kwargs.get("static_responses") | ||
self.use_static_responses = static_responses is not None | ||
|
||
if self.use_static_responses: | ||
# read static responses once and reuse them | ||
if not StaticRequest.RESPONSES: | ||
with open(io.normalize_path(static_responses)) as f: | ||
StaticRequest.RESPONSES = ResponseMatcher(json.load(f)) | ||
|
||
self._request_class = StaticRequest | ||
self._response_class = StaticResponse | ||
else: | ||
self._request_class = aiohttp.ClientRequest | ||
self._response_class = RawClientResponse | ||
|
||
async def _create_aiohttp_session(self): | ||
if self.loop is None: | ||
self.loop = asyncio.get_running_loop() | ||
|
||
if self.use_static_responses: | ||
connector = StaticConnector(limit=self._limit, enable_cleanup_closed=self._enable_cleanup_closed) | ||
else: | ||
connector = aiohttp.TCPConnector( | ||
limit=self._limit, | ||
use_dns_cache=True, | ||
ssl_context=self._ssl_context, | ||
enable_cleanup_closed=self._enable_cleanup_closed | ||
) | ||
|
||
self.session = aiohttp.ClientSession( | ||
headers=self.headers, | ||
auto_decompress=True, | ||
loop=self.loop, | ||
cookie_jar=aiohttp.DummyCookieJar(), | ||
response_class=RawClientResponse, | ||
connector=aiohttp.TCPConnector( | ||
limit=self._limit, | ||
use_dns_cache=True, | ||
ssl_context=self._ssl_context, | ||
enable_cleanup_closed=self._enable_cleanup_closed | ||
), | ||
request_class=self._request_class, | ||
response_class=self._response_class, | ||
connector=connector, | ||
trace_configs=self._trace_configs, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It took me a few seconds to realize that while specifying a
pipeline
it won't really target any ES. One might even use the default pipeline, ES will get launched but won't be used.Should we clarify with a note here or above in the
client-options
section that whenstatic_responses:'file'
is used, the use should also specify thebenchmark-only
pipeline and thedistribution-version
as Rally doesn't have a way to derive the version automatically as it'd normally do.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I've added a note in 37c0900.