Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Internal] Add unit tests for retriable requests #879

Merged
merged 4 commits into from
Jan 30, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 127 additions & 80 deletions tests/test_base_client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import io
import random
from http.server import BaseHTTPRequestHandler
from typing import Iterator, List
from typing import Callable, Iterator, List, Optional, Tuple, Type
from unittest.mock import Mock

import pytest
from requests import PreparedRequest, Response, Timeout

from databricks.sdk import errors, useragent
from databricks.sdk._base_client import (_BaseClient, _RawResponse,
Expand Down Expand Up @@ -357,91 +358,137 @@ def tell(self):
assert client._is_seekable_stream(CustomSeekableStream())


@pytest.mark.parametrize(
'input_data',
[
b"0123456789", # bytes -> BytesIO
"0123456789", # str -> BytesIO
io.BytesIO(b"0123456789"), # BytesIO directly
io.StringIO("0123456789"), # StringIO
])
def test_reset_seekable_stream_on_retry(input_data):
received_data = []

# Retry two times before succeeding.
def inner(h: BaseHTTPRequestHandler):
if len(received_data) == 2:
h.send_response(200)
h.end_headers()
else:
h.send_response(429)
h.end_headers()

content_length = int(h.headers.get('Content-Length', 0))
if content_length > 0:
received_data.append(h.rfile.read(content_length))

with http_fixture_server(inner) as host:
client = _BaseClient()

# Retries should reset the stream.
client.do('POST', f'{host}/foo', data=input_data)

assert received_data == [b"0123456789", b"0123456789", b"0123456789"]


def test_reset_seekable_stream_to_their_initial_position_on_retry():
received_data = []

# Retry two times before succeeding.
def inner(h: BaseHTTPRequestHandler):
if len(received_data) == 2:
h.send_response(200)
h.end_headers()
class RetryTestCase:

def __init__(self, data_provider: Callable, offset: Optional[int], expected_failure: bool,
expected_result: bytes):
self._data_provider = data_provider
self._offset = offset
self._expected_result = expected_result
self._expected_failure = expected_failure

def get_data(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test cases are reused, so we need to construct a fresh data object every time.

data = self._data_provider()
if self._offset is not None:
data.seek(self._offset)
return data

@classmethod
def create_non_seekable_stream(cls, data: bytes):
result = io.BytesIO(data)
result.seekable = lambda: False # makes the stream appear non-seekable
return result


class MockSession:
Copy link
Contributor Author

@ksafonov-db ksafonov-db Jan 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This mock session reads all the input stream before failing.


def __init__(self, failure_count: int, failure_provider: Callable[[], Response]):
self._failure_count = failure_count
self._received_requests: List[bytes] = []
self._failure_provider = failure_provider

@classmethod
def raise_timeout_exception(cls):
raise Timeout("Fake timeout")

@classmethod
def return_retryable_response(cls):
# fill response fields so that logging does not fail
response = Response()
response._content = b''
response.status_code = 429
response.headers = {'Retry-After': '1'}
response.url = 'http://test.com/'

response.request = PreparedRequest()
response.request.url = response.url
response.request.method = 'POST'
response.request.headers = None
response.request.body = b''
return response

# following the signature of Session.request()
def request(self,
method,
url,
params=None,
data=None,
headers=None,
cookies=None,
files=None,
auth=None,
timeout=None,
allow_redirects=True,
proxies=None,
hooks=None,
stream=None,
verify=None,
cert=None,
json=None):
request_body = data.read()

if isinstance(request_body, str):
request_body = request_body.encode('utf-8') # to be able to compare with expected bytes

self._received_requests.append(request_body)
if self._failure_count > 0:
self._failure_count -= 1
return self._failure_provider()
#
else:
h.send_response(429)
h.end_headers()

content_length = int(h.headers.get('Content-Length', 0))
if content_length > 0:
received_data.append(h.rfile.read(content_length))

input_data = io.BytesIO(b"0123456789")
input_data.seek(4)
# fill response fields so that logging does not fail
response = Response()
response._content = b''
response.status_code = 200
response.reason = 'OK'
response.url = url

with http_fixture_server(inner) as host:
client = _BaseClient()

# Retries should reset the stream.
client.do('POST', f'{host}/foo', data=input_data)

assert received_data == [b"456789", b"456789", b"456789"]
assert input_data.tell() == 10 # EOF
response.request = PreparedRequest()
response.request.url = url
response.request.method = method
response.request.headers = headers
response.request.body = data
return response


def test_no_retry_or_reset_on_non_seekable_stream():
requests = []

# Always respond with a response that triggers a retry.
def inner(h: BaseHTTPRequestHandler):
content_length = int(h.headers.get('Content-Length', 0))
if content_length > 0:
requests.append(h.rfile.read(content_length))
@pytest.mark.parametrize(
'test_case',
[
# bytes -> BytesIO
RetryTestCase(lambda: b"0123456789", None, False, b"0123456789"),
# str -> BytesIO
RetryTestCase(lambda: "0123456789", None, False, b"0123456789"),
# BytesIO directly
RetryTestCase(lambda: io.BytesIO(b"0123456789"), None, False, b"0123456789"),
# BytesIO directly with offset
RetryTestCase(lambda: io.BytesIO(b"0123456789"), 4, False, b"456789"),
# StringIO
RetryTestCase(lambda: io.StringIO("0123456789"), None, False, b"0123456789"),
# Non-seekable
RetryTestCase(lambda: RetryTestCase.create_non_seekable_stream(b"0123456789"), None, True,
b"0123456789")
])
@pytest.mark.parametrize('failure', [[MockSession.raise_timeout_exception, Timeout],
[MockSession.return_retryable_response, errors.TooManyRequests]])
def test_rewind_seekable_stream(test_case: RetryTestCase, failure: Tuple[Callable[[], Response], Type]):
failure_count = 2

h.send_response(429)
h.send_header('Retry-After', '1')
h.end_headers()
data = test_case.get_data()

input_data = io.BytesIO(b"0123456789")
input_data.seekable = lambda: False # makes the stream appear non-seekable
session = MockSession(failure_count, failure[0])
client = _BaseClient()
client._session = session

with http_fixture_server(inner) as host:
client = _BaseClient()
def do():
client.do('POST', f'test.com/foo', data=data)

# Should raise error immediately without retry.
with pytest.raises(DatabricksError):
client.do('POST', f'{host}/foo', data=input_data)
if test_case._expected_failure:
expected_attempts_made = 1
exception_class = failure[1]
with pytest.raises(exception_class):
do()
else:
expected_attempts_made = failure_count + 1
do()

# Verify that only one request was made (no retries).
assert requests == [b"0123456789"]
assert input_data.tell() == 10 # EOF
assert session._received_requests == [test_case._expected_result for _ in range(expected_attempts_made)]
Loading