Skip to content

Commit

Permalink
Avoid decoding request body unless it needs to be logged.
Browse files Browse the repository at this point in the history
Signed-off-by: dblock <dblock@amazon.com>
  • Loading branch information
dblock committed Nov 10, 2023
1 parent d8dc547 commit 6e5bd2c
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 47 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
### Removed
- Removed leftover support for Python 2.7 ([#548](https://github.com/opensearch-project/opensearch-py/pull/548))
### Fixed
- Avoid decoding request body unless it needs to be logged ([#571](https://github.com/opensearch-project/opensearch-py/pull/571))
### Security
### Dependencies
- Bumps `sphinx` from <7.1 to <7.3
Expand Down
6 changes: 3 additions & 3 deletions opensearchpy/_async/http_aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ async def perform_request(
except Exception as e:
self.log_request_fail(
method,
str(url),
url,
url_path,
orig_body,
self.loop.time() - start,
Expand All @@ -337,7 +337,7 @@ async def perform_request(
if not (200 <= response.status < 300) and response.status not in ignore:
self.log_request_fail(
method,
str(url),
url,
url_path,
orig_body,
duration,
Expand All @@ -351,7 +351,7 @@ async def perform_request(
)

self.log_request_success(
method, str(url), url_path, orig_body, response.status, raw_data, duration
method, url, url_path, orig_body, response.status, raw_data, duration
)

return response.status, response.headers, raw_data
Expand Down
29 changes: 12 additions & 17 deletions opensearchpy/connection/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,16 @@ def _pretty_json(self, data: Union[str, bytes]) -> str:
# non-json data or a bulk request
return data # type: ignore

def _log_request_response(
self, body: Optional[Union[str, bytes]], response: Optional[str]
) -> None:
if logger.isEnabledFor(logging.DEBUG):
if body and isinstance(body, bytes):
body = body.decode("utf-8", "ignore")
logger.debug("> %s", body)
if response is not None:
logger.debug("< %s", response)

def _log_trace(
self,
method: str,
Expand Down Expand Up @@ -246,17 +256,11 @@ def log_request_success(
"""Log a successful API call."""
# TODO: optionally pass in params instead of full_url and do urlencode only when needed

# body has already been serialized to utf-8, deserialize it for logging
# TODO: find a better way to avoid (de)encoding the body back and forth
if body and isinstance(body, bytes):
body = body.decode("utf-8", "ignore")

logger.info(
"%s %s [status:%s request:%.3fs]", method, full_url, status_code, duration
)
logger.debug("> %s", body)
logger.debug("< %s", response)

self._log_request_response(body, response)
self._log_trace(method, path, body, status_code, response, duration)

def log_request_fail(
Expand All @@ -283,18 +287,9 @@ def log_request_fail(
exc_info=exception is not None,
)

# body has already been serialized to utf-8, deserialize it for logging
# TODO: find a better way to avoid (de)encoding the body back and forth
if body and isinstance(body, bytes):
body = body.decode("utf-8", "ignore")

logger.debug("> %s", body)

self._log_request_response(body, response)
self._log_trace(method, path, body, status_code, response, duration)

if response is not None:
logger.debug("< %s", response)

def _raise_error(
self,
status_code: int,
Expand Down
37 changes: 36 additions & 1 deletion test_opensearchpy/test_async/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from opensearchpy import AIOHttpConnection, AsyncOpenSearch, __versionstr__, serializer
from opensearchpy.compat import reraise_exceptions
from opensearchpy.connection import Connection, async_connections
from opensearchpy.exceptions import ConnectionError, TransportError
from opensearchpy.exceptions import ConnectionError, NotFoundError, TransportError
from test_opensearchpy.TestHttpServer import TestHTTPServer

pytestmark: MarkDecorator = pytest.mark.asyncio
Expand Down Expand Up @@ -303,6 +303,41 @@ async def test_uncompressed_body_logged(self, logger: Any) -> None:
assert '> {"example": "body"}' == req[0][0] % req[0][1:]
assert "< {}" == resp[0][0] % resp[0][1:]

@patch("opensearchpy.connection.base.logger", return_value=MagicMock())
async def test_body_not_logged(self, logger: Any) -> None:
logger.isEnabledFor.return_value = False

con = await self._get_mock_connection()
await con.perform_request("GET", "/", body=b'{"example": "body"}')

assert logger.isEnabledFor.call_count == 1
assert logger.debug.call_count == 0

@patch("opensearchpy.connection.base.logger")
async def test_failure_body_logged(self, logger: Any) -> None:
con = await self._get_mock_connection(response_code=404)
with pytest.raises(NotFoundError) as e:
await con.perform_request("GET", "/invalid", body=b'{"example": "body"}')
assert str(e.value) == "NotFoundError(404, '{}')"

assert 2 == logger.debug.call_count
req, resp = logger.debug.call_args_list

assert '> {"example": "body"}' == req[0][0] % req[0][1:]
assert "< {}" == resp[0][0] % resp[0][1:]

@patch("opensearchpy.connection.base.logger", return_value=MagicMock())
async def test_failure_body_not_logged(self, logger: Any) -> None:
logger.isEnabledFor.return_value = False

con = await self._get_mock_connection(response_code=404)
with pytest.raises(NotFoundError) as e:
await con.perform_request("GET", "/invalid")
assert str(e.value) == "NotFoundError(404, '{}')"

assert logger.isEnabledFor.call_count == 1
assert logger.debug.call_count == 0

async def test_surrogatepass_into_bytes(self) -> None:
buf = b"\xe4\xbd\xa0\xe5\xa5\xbd\xed\xa9\xaa"
con = await self._get_mock_connection(response_body=buf)
Expand Down
74 changes: 56 additions & 18 deletions test_opensearchpy/test_connection/test_requests_http_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from typing import Any

import pytest
from mock import Mock, patch
from mock import MagicMock, Mock, patch
from requests.auth import AuthBase

from opensearchpy.connection import Connection, RequestsHttpConnection
Expand All @@ -52,15 +52,15 @@ class TestRequestsHttpConnection(TestCase):
def _get_mock_connection(
self,
connection_params: Any = {},
status_code: int = 200,
response_code: int = 200,
response_body: bytes = b"{}",
) -> Any:
con = RequestsHttpConnection(**connection_params)

def _dummy_send(*args: Any, **kwargs: Any) -> Any:
dummy_response = Mock()
dummy_response.headers = {}
dummy_response.status_code = status_code
dummy_response.status_code = response_code
dummy_response.content = response_body
dummy_response.request = args[0]
dummy_response.cookies = {}
Expand Down Expand Up @@ -229,28 +229,28 @@ def test_repr(self) -> None:
)

def test_conflict_error_is_returned_on_409(self) -> None:
con = self._get_mock_connection(status_code=409)
con = self._get_mock_connection(response_code=409)
self.assertRaises(ConflictError, con.perform_request, "GET", "/", {}, "")

def test_not_found_error_is_returned_on_404(self) -> None:
con = self._get_mock_connection(status_code=404)
con = self._get_mock_connection(response_code=404)
self.assertRaises(NotFoundError, con.perform_request, "GET", "/", {}, "")

def test_request_error_is_returned_on_400(self) -> None:
con = self._get_mock_connection(status_code=400)
con = self._get_mock_connection(response_code=400)
self.assertRaises(RequestError, con.perform_request, "GET", "/", {}, "")

@patch("opensearchpy.connection.base.logger")
def test_head_with_404_doesnt_get_logged(self, logger: Any) -> None:
con = self._get_mock_connection(status_code=404)
con = self._get_mock_connection(response_code=404)
self.assertRaises(NotFoundError, con.perform_request, "HEAD", "/", {}, "")
self.assertEqual(0, logger.warning.call_count)

@patch("opensearchpy.connection.base.tracer")
@patch("opensearchpy.connection.base.logger")
def test_failed_request_logs_and_traces(self, logger: Any, tracer: Any) -> None:
con = self._get_mock_connection(
response_body=b'{"answer": 42}', status_code=500
response_body=b'{"answer": 42}', response_code=500
)
self.assertRaises(
TransportError,
Expand Down Expand Up @@ -326,7 +326,7 @@ def test_uncompressed_body_logged(self, logger: Any) -> None:

con = self._get_mock_connection(
connection_params={"http_compress": True},
status_code=500,
response_code=500,
response_body=b'{"hello":"world"}',
)
with pytest.raises(TransportError):
Expand All @@ -337,6 +337,41 @@ def test_uncompressed_body_logged(self, logger: Any) -> None:
self.assertEqual('> {"example": "body2"}', req[0][0] % req[0][1:])
self.assertEqual('< {"hello":"world"}', resp[0][0] % resp[0][1:])

@patch("opensearchpy.connection.base.logger", return_value=MagicMock())
def test_body_not_logged(self, logger: Any) -> None:
logger.isEnabledFor.return_value = False

con = self._get_mock_connection()
con.perform_request("GET", "/", body=b'{"example": "body"}')

self.assertEqual(logger.isEnabledFor.call_count, 1)
self.assertEqual(logger.debug.call_count, 0)

@patch("opensearchpy.connection.base.logger")
def test_failure_body_logged(self, logger: Any) -> None:
con = self._get_mock_connection(response_code=404)
with pytest.raises(NotFoundError) as e:
con.perform_request("GET", "/invalid", body=b'{"example": "body"}')
self.assertEqual(str(e.value), "NotFoundError(404, '{}')")

self.assertEqual(2, logger.debug.call_count)
req, resp = logger.debug.call_args_list

self.assertEqual('> {"example": "body"}', req[0][0] % req[0][1:])
self.assertEqual("< {}", resp[0][0] % resp[0][1:])

@patch("opensearchpy.connection.base.logger", return_value=MagicMock())
def test_failure_body_not_logged(self, logger: Any) -> None:
logger.isEnabledFor.return_value = False

con = self._get_mock_connection(response_code=404)
with pytest.raises(NotFoundError) as e:
con.perform_request("GET", "/invalid")
self.assertEqual(str(e.value), "NotFoundError(404, '{}')")

self.assertEqual(logger.isEnabledFor.call_count, 1)
self.assertEqual(logger.debug.call_count, 0)

def test_defaults(self) -> None:
con = self._get_mock_connection()
request = self._get_request(con, "GET", "/")
Expand Down Expand Up @@ -403,7 +438,7 @@ def send_raise(*_: Any, **__: Any) -> Any:

with pytest.raises(RecursionError) as e:
conn.perform_request("GET", "/")
assert str(e.value) == "Wasn't modified!"
self.assertEqual(str(e.value), "Wasn't modified!")

def mock_session(self) -> Any:
access_key = uuid.uuid4().hex
Expand Down Expand Up @@ -472,7 +507,7 @@ def test_aws_signer_signs_with_query_string(self, mock_sign: Any) -> None:
)


class TestRequestsConnectionRedirect:
class TestRequestsConnectionRedirect(TestCase):
server1: TestHTTPServer
server2: TestHTTPServer

Expand All @@ -495,20 +530,23 @@ def test_redirect_failure_when_allow_redirect_false(self) -> None:
conn = RequestsHttpConnection("localhost", port=8080, use_ssl=False, timeout=60)
with pytest.raises(TransportError) as e:
conn.perform_request("GET", "/redirect", allow_redirects=False)
assert e.value.status_code == 302
self.assertEqual(e.value.status_code, 302)

# allow_redirects = True (Default)
def test_redirect_success_when_allow_redirect_true(self) -> None:
conn = RequestsHttpConnection("localhost", port=8080, use_ssl=False, timeout=60)
user_agent = conn._get_default_user_agent()
status, headers, data = conn.perform_request("GET", "/redirect")
assert status == 200
self.assertEqual(status, 200)
data = json.loads(data)
assert data["headers"] == {
"Host": "localhost:8090",
"Accept-Encoding": "identity",
"User-Agent": user_agent,
}
self.assertEqual(
data["headers"],
{
"Host": "localhost:8090",
"Accept-Encoding": "identity",
"User-Agent": user_agent,
},
)


class TestSignerWithFrozenCredentials(TestRequestsHttpConnection):
Expand Down
Loading

0 comments on commit 6e5bd2c

Please sign in to comment.