Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ISSUE #20771] limiting the number of requests performed to the backe… #21525

Merged
merged 12 commits into from
Jan 24, 2023
Merged
9 changes: 7 additions & 2 deletions airbyte-cdk/python/airbyte_cdk/sources/abstract_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,20 @@
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#

import json
import logging
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterator, List, Mapping, MutableMapping, Optional, Tuple, Union

from airbyte_cdk.models import (
AirbyteCatalog,
AirbyteConnectionStatus,
AirbyteLogMessage,
AirbyteMessage,
AirbyteStateMessage,
ConfiguredAirbyteCatalog,
ConfiguredAirbyteStream,
Level,
Status,
SyncMode,
)
Expand Down Expand Up @@ -232,7 +235,8 @@ def _read_incremental(
has_slices = False
for _slice in slices:
has_slices = True
logger.debug("Processing stream slice", extra={"slice": _slice})
if logger.isEnabledFor(logging.DEBUG):
yield AirbyteMessage(type=MessageType.LOG, log=AirbyteLogMessage(level=Level.INFO, message=f"slice:{json.dumps(_slice)}"))
maxi297 marked this conversation as resolved.
Show resolved Hide resolved
records = stream_instance.read_records(
sync_mode=SyncMode.incremental,
stream_slice=_slice,
Expand Down Expand Up @@ -281,7 +285,8 @@ def _read_full_refresh(
)
total_records_counter = 0
for _slice in slices:
logger.debug("Processing stream slice", extra={"slice": _slice})
if logger.isEnabledFor(logging.DEBUG):
yield AirbyteMessage(type=MessageType.LOG, log=AirbyteLogMessage(level=Level.INFO, message=f"slice:{json.dumps(_slice)}"))
record_data_or_messages = stream_instance.read_records(
stream_slice=_slice,
sync_mode=SyncMode.full_refresh,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,10 @@


class ModelToComponentFactory:
def __init__(self, is_test_read=False):
def __init__(self, limit_pages_fetched_per_slice: int = None, limit_slices_fetched: int = None):
self._init_mappings()
self._is_test_read = is_test_read
self._limit_pages_fetched_per_slice = limit_pages_fetched_per_slice
self._limit_slices_fetched = limit_slices_fetched

def _init_mappings(self):
self.PYDANTIC_MODEL_TO_CONSTRUCTOR: [Type[BaseModel], Callable] = {
Expand Down Expand Up @@ -482,8 +483,8 @@ def create_default_paginator(self, model: DefaultPaginatorModel, config: Config,
config=config,
options=model.options,
)
if self._is_test_read:
return PaginatorTestReadDecorator(paginator)
if self._limit_pages_fetched_per_slice:
return PaginatorTestReadDecorator(paginator, self._limit_pages_fetched_per_slice)
return paginator

def create_dpath_extractor(self, model: DpathExtractorModel, config: Config, **kwargs) -> DpathExtractor:
Expand Down Expand Up @@ -681,7 +682,7 @@ def create_simple_retriever(self, model: SimpleRetrieverModel, config: Config, *
self._create_component_from_model(model=model.stream_slicer, config=config) if model.stream_slicer else SingleSlice(options={})
)

if self._is_test_read:
if self._limit_slices_fetched:
return SimpleRetrieverTestReadDecorator(
name=model.name,
paginator=paginator,
Expand All @@ -690,6 +691,7 @@ def create_simple_retriever(self, model: SimpleRetrieverModel, config: Config, *
record_selector=record_selector,
stream_slicer=stream_slicer,
config=config,
maximum_number_of_slices=self._limit_slices_fetched,
options=model.options,
)
return SimpleRetriever(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,11 @@ class PaginatorTestReadDecorator(Paginator):
_DEFAULT_PAGINATION_LIMIT = 5

def __init__(self, decorated, maximum_number_of_pages: int = None):
maxi297 marked this conversation as resolved.
Show resolved Hide resolved
if maximum_number_of_pages and maximum_number_of_pages < 1:
raise ValueError(f"The maximum number of pages on a test read needs to be strictly positive. Got {maximum_number_of_pages}")
self._maximum_number_of_pages = maximum_number_of_pages if maximum_number_of_pages else self._DEFAULT_PAGINATION_LIMIT
self._decorated = decorated
self._page_count = self._PAGE_COUNT_BEFORE_FIRST_NEXT_CALL
self._maximum_number_of_pages = maximum_number_of_pages if maximum_number_of_pages else self._DEFAULT_PAGINATION_LIMIT

def next_page_token(self, response: requests.Response, last_records: List[Mapping[str, Any]]) -> Optional[Mapping[str, Any]]:
if self._page_count >= self._maximum_number_of_pages:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -417,18 +417,26 @@ def _parse_records_and_emit_request_and_responses(self, request, response, strea
yield from self.parse_response(response, stream_slice=stream_slice, stream_state=stream_state)


@dataclass
class SimpleRetrieverTestReadDecorator(SimpleRetriever):
"""
In some cases, we want to limit the number of requests that are made to the backend source. This class allows for limiting the number of
slices that are queried throughout a read command.
"""

_MAXIMUM_NUMBER_OF_SLICES = 5
maximum_number_of_slices: int = 5

def __post_init__(self, options: Mapping[str, Any]):
super().__post_init__(options)
if self.maximum_number_of_slices and self.maximum_number_of_slices < 1:
raise ValueError(
f"The maximum number of slices on a test read needs to be strictly positive. Got {self.maximum_number_of_slices}"
)

def stream_slices(
self, *, sync_mode: SyncMode, cursor_field: List[str] = None, stream_state: Optional[StreamState] = None
) -> Iterable[Optional[Mapping[str, Any]]]:
return islice(super().stream_slices(sync_mode=sync_mode, stream_state=stream_state), self._MAXIMUM_NUMBER_OF_SLICES)
return islice(super().stream_slices(sync_mode=sync_mode, stream_state=stream_state), self.maximum_number_of_slices)


def _prepared_request_to_airbyte_message(request: requests.PreparedRequest) -> AirbyteMessage:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -633,19 +633,25 @@ def test_response_to_airbyte_message(test_name, response_body, response_headers,


def test_limit_stream_slices():
maximum_number_of_slices = 4
stream_slicer = MagicMock()
stream_slicer.stream_slices.return_value = [{"date": f"2022-01-0{day}"} for day in range(1, 10)]
stream_slicer.stream_slices.return_value = _generate_slices(maximum_number_of_slices * 2)
retriever = SimpleRetrieverTestReadDecorator(
name="stream_name",
primary_key=primary_key,
requester=MagicMock(),
paginator=MagicMock(),
record_selector=MagicMock(),
stream_slicer=stream_slicer,
maximum_number_of_slices=maximum_number_of_slices,
options={},
config={},
)

truncated_slices = retriever.stream_slices(sync_mode=SyncMode.incremental, stream_state=None)
truncated_slices = list(retriever.stream_slices(sync_mode=SyncMode.incremental, stream_state=None))

assert truncated_slices == [{"date": f"2022-01-0{day}"} for day in range(1, 6)]
assert truncated_slices == _generate_slices(maximum_number_of_slices)


def _generate_slices(number_of_slices):
return [{"date": f"2022-01-0{day + 1}"} for day in range(number_of_slices)]
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream
from airbyte_cdk.sources.declarative.manifest_declarative_source import ManifestDeclarativeSource
from jsonschema.exceptions import ValidationError
from unittest.mock import patch

logger = logging.getLogger("airbyte")

Expand Down Expand Up @@ -542,6 +543,95 @@ def test_manifest_without_at_least_one_stream(self, construct_using_pydantic_mod
ManifestDeclarativeSource(source_config=manifest, construct_using_pydantic_models=construct_using_pydantic_models)


@patch("airbyte_cdk.sources.declarative.declarative_source.DeclarativeSource.read")
def test_given_debug_when_read_then_set_log_level(self, declarative_source_read):
maxi297 marked this conversation as resolved.
Show resolved Hide resolved
any_valid_manifest = {
"version": "version",
"definitions": {
"schema_loader": {"name": "{{ options.stream_name }}", "file_path": "./source_sendgrid/schemas/{{ options.name }}.yaml"},
"retriever": {
"paginator": {
"type": "DefaultPaginator",
"page_size": 10,
"page_size_option": {"inject_into": "request_parameter", "field_name": "page_size"},
"page_token_option": {"inject_into": "path"},
"pagination_strategy": {"type": "CursorPagination", "cursor_value": "{{ response._metadata.next }}"},
},
"requester": {
"path": "/v3/marketing/lists",
"authenticator": {"type": "BearerAuthenticator", "api_token": "{{ config.apikey }}"},
"request_parameters": {"page_size": 10},
},
"record_selector": {"extractor": {"field_pointer": ["result"]}},
},
},
"streams": [
{
"type": "DeclarativeStream",
"$options": {"name": "lists", "primary_key": "id", "url_base": "https://api.sendgrid.com"},
"schema_loader": {
"name": "{{ options.stream_name }}",
"file_path": "./source_sendgrid/schemas/{{ options.name }}.yaml",
},
"retriever": {
"paginator": {
"type": "DefaultPaginator",
"page_size": 10,
"page_size_option": {"inject_into": "request_parameter", "field_name": "page_size"},
"page_token_option": {"inject_into": "path"},
"pagination_strategy": {
"type": "CursorPagination",
"cursor_value": "{{ response._metadata.next }}",
"page_size": 10,
},
},
"requester": {
"path": "/v3/marketing/lists",
"authenticator": {"type": "BearerAuthenticator", "api_token": "{{ config.apikey }}"},
"request_parameters": {"page_size": 10},
},
"record_selector": {"extractor": {"field_pointer": ["result"]}},
},
},
{
"type": "DeclarativeStream",
"$options": {"name": "stream_with_custom_requester", "primary_key": "id", "url_base": "https://api.sendgrid.com"},
"schema_loader": {
"name": "{{ options.stream_name }}",
"file_path": "./source_sendgrid/schemas/{{ options.name }}.yaml",
},
"retriever": {
"paginator": {
"type": "DefaultPaginator",
"page_size": 10,
"page_size_option": {"inject_into": "request_parameter", "field_name": "page_size"},
"page_token_option": {"inject_into": "path"},
"pagination_strategy": {
"type": "CursorPagination",
"cursor_value": "{{ response._metadata.next }}",
"page_size": 10,
},
},
"requester": {
"type": "CustomRequester",
"class_name": "unit_tests.sources.declarative.external_component.SampleCustomComponent",
"path": "/v3/marketing/lists",
"custom_request_parameters": {"page_size": 10},
},
"record_selector": {"extractor": {"field_pointer": ["result"]}},
},
},
],
"check": {"type": "CheckStream", "stream_names": ["lists"]},
}
source = ManifestDeclarativeSource(source_config=any_valid_manifest, debug=True, construct_using_pydantic_models=True)

debug_logger = logging.getLogger("logger.debug")
list(source.read(debug_logger, {}, {}, {}))

assert debug_logger.isEnabledFor(logging.DEBUG)


def test_generate_schema():
schema_str = ManifestDeclarativeSource.generate_schema()
schema = json.loads(schema_str)
Expand Down
51 changes: 51 additions & 0 deletions airbyte-cdk/python/unit_tests/sources/test_abstract_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,57 @@ def test_valid_full_refresh_read_with_slices(mocker):
assert expected == messages


def test_read_full_refresh_with_slices_sends_slice_messages(mocker):
"""Given the logger is debug and a full refresh, AirbyteMessages are sent for slices"""
debug_logger = logging.getLogger("airbyte.debug")
debug_logger.setLevel(logging.DEBUG)
slices = [{"1": "1"}, {"2": "2"}]
stream = MockStream(
[({"sync_mode": SyncMode.full_refresh, "stream_slice": s}, [s]) for s in slices],
name="s1",
)

mocker.patch.object(MockStream, "get_json_schema", return_value={})
mocker.patch.object(MockStream, "stream_slices", return_value=slices)

src = MockSource(streams=[stream])
catalog = ConfiguredAirbyteCatalog(
streams=[
_configured_stream(stream, SyncMode.full_refresh),
]
)

messages = src.read(debug_logger, {}, catalog)

assert 2 == len(list(filter(lambda message: message.log and message.log.message.startswith("slice:"), messages)))


def test_read_incremental_with_slices_sends_slice_messages(mocker):
"""Given the logger is debug and a incremental, AirbyteMessages are sent for slices"""
debug_logger = logging.getLogger("airbyte.debug")
debug_logger.setLevel(logging.DEBUG)
slices = [{"1": "1"}, {"2": "2"}]
stream = MockStream(
[({"sync_mode": SyncMode.incremental, "stream_slice": s, 'stream_state': {}}, [s]) for s in slices],
name="s1",
)

MockStream.supports_incremental = mocker.PropertyMock(return_value=True)
mocker.patch.object(MockStream, "get_json_schema", return_value={})
mocker.patch.object(MockStream, "stream_slices", return_value=slices)

src = MockSource(streams=[stream])
catalog = ConfiguredAirbyteCatalog(
streams=[
_configured_stream(stream, SyncMode.incremental),
]
)

messages = src.read(debug_logger, {}, catalog)

assert 2 == len(list(filter(lambda message: message.log and message.log.message.startswith("slice:"), messages)))


class TestIncrementalRead:
@pytest.mark.parametrize(
"use_legacy",
Expand Down
22 changes: 13 additions & 9 deletions airbyte-cdk/python/unit_tests/sources/test_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ def test_internal_config_limit(abstract_source, catalog):
logger_mock.level = logging.DEBUG
del catalog.streams[1]
STREAM_LIMIT = 2
SLICE_DEBUG_LOG_COUNT = 1
FULL_RECORDS_NUMBER = 3
streams = abstract_source.streams(None)
http_stream = streams[0]
Expand All @@ -398,7 +399,7 @@ def test_internal_config_limit(abstract_source, catalog):

catalog.streams[0].sync_mode = SyncMode.full_refresh
records = [r for r in abstract_source.read(logger=logger_mock, config=internal_config, catalog=catalog, state={})]
assert len(records) == STREAM_LIMIT
assert len(records) == STREAM_LIMIT + SLICE_DEBUG_LOG_COUNT
logger_info_args = [call[0][0] for call in logger_mock.info.call_args_list]
# Check if log line matches number of limit
read_log_record = [_l for _l in logger_info_args if _l.startswith("Read")]
Expand All @@ -407,13 +408,13 @@ def test_internal_config_limit(abstract_source, catalog):
# No limit, check if state record produced for incremental stream
catalog.streams[0].sync_mode = SyncMode.incremental
records = [r for r in abstract_source.read(logger=logger_mock, config={}, catalog=catalog, state={})]
assert len(records) == FULL_RECORDS_NUMBER + 1
assert len(records) == FULL_RECORDS_NUMBER + SLICE_DEBUG_LOG_COUNT + 1
assert records[-1].type == Type.STATE

# Set limit and check if state is produced when limit is set for incremental stream
logger_mock.reset_mock()
records = [r for r in abstract_source.read(logger=logger_mock, config=internal_config, catalog=catalog, state={})]
assert len(records) == STREAM_LIMIT + 1
assert len(records) == STREAM_LIMIT + SLICE_DEBUG_LOG_COUNT + 1
assert records[-1].type == Type.STATE
logger_info_args = [call[0][0] for call in logger_mock.info.call_args_list]
read_log_record = [_l for _l in logger_info_args if _l.startswith("Read")]
Expand All @@ -425,40 +426,43 @@ def test_internal_config_limit(abstract_source, catalog):

def test_source_config_no_transform(abstract_source, catalog):
logger_mock = MagicMock()
SLICE_DEBUG_LOG_COUNT = 1
logger_mock.level = logging.DEBUG
streams = abstract_source.streams(None)
http_stream, non_http_stream = streams
http_stream.get_json_schema.return_value = non_http_stream.get_json_schema.return_value = SCHEMA
http_stream.read_records.return_value, non_http_stream.read_records.return_value = [[{"value": 23}] * 5] * 2
records = [r for r in abstract_source.read(logger=logger_mock, config={}, catalog=catalog, state={})]
assert len(records) == 2 * 5
assert [r.record.data for r in records] == [{"value": 23}] * 2 * 5
assert len(records) == 2 * (5 + SLICE_DEBUG_LOG_COUNT)
assert [r.record.data for r in records if r.type == Type.RECORD] == [{"value": 23}] * 2 * 5
assert http_stream.get_json_schema.call_count == 5
assert non_http_stream.get_json_schema.call_count == 5


def test_source_config_transform(abstract_source, catalog):
logger_mock = MagicMock()
logger_mock.level = logging.DEBUG
SLICE_DEBUG_LOG_COUNT = 2
streams = abstract_source.streams(None)
http_stream, non_http_stream = streams
http_stream.transformer = TypeTransformer(TransformConfig.DefaultSchemaNormalization)
non_http_stream.transformer = TypeTransformer(TransformConfig.DefaultSchemaNormalization)
http_stream.get_json_schema.return_value = non_http_stream.get_json_schema.return_value = SCHEMA
http_stream.read_records.return_value, non_http_stream.read_records.return_value = [{"value": 23}], [{"value": 23}]
records = [r for r in abstract_source.read(logger=logger_mock, config={}, catalog=catalog, state={})]
assert len(records) == 2
assert [r.record.data for r in records] == [{"value": "23"}] * 2
assert len(records) == 2 + SLICE_DEBUG_LOG_COUNT
assert [r.record.data for r in records if r.type == Type.RECORD] == [{"value": "23"}] * 2


def test_source_config_transform_and_no_transform(abstract_source, catalog):
logger_mock = MagicMock()
logger_mock.level = logging.DEBUG
SLICE_DEBUG_LOG_COUNT = 2
streams = abstract_source.streams(None)
http_stream, non_http_stream = streams
http_stream.transformer = TypeTransformer(TransformConfig.DefaultSchemaNormalization)
http_stream.get_json_schema.return_value = non_http_stream.get_json_schema.return_value = SCHEMA
http_stream.read_records.return_value, non_http_stream.read_records.return_value = [{"value": 23}], [{"value": 23}]
records = [r for r in abstract_source.read(logger=logger_mock, config={}, catalog=catalog, state={})]
assert len(records) == 2
assert [r.record.data for r in records] == [{"value": "23"}, {"value": 23}]
assert len(records) == 2 + SLICE_DEBUG_LOG_COUNT
assert [r.record.data for r in records if r.type == Type.RECORD] == [{"value": "23"}, {"value": 23}]
Loading