Skip to content

Commit

Permalink
[Low-code CDK] Add ability to remove fields (#14402)
Browse files Browse the repository at this point in the history
  • Loading branch information
sherifnada authored Jul 12, 2022
1 parent f4524e3 commit 743e6c2
Show file tree
Hide file tree
Showing 11 changed files with 238 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from airbyte_cdk.models import SyncMode
from airbyte_cdk.sources.declarative.retrievers.retriever import Retriever
from airbyte_cdk.sources.declarative.schema.schema_loader import SchemaLoader
from airbyte_cdk.sources.declarative.transformations import RecordTransformation
from airbyte_cdk.sources.streams.core import Stream


Expand All @@ -23,13 +24,25 @@ def __init__(
schema_loader: SchemaLoader,
retriever: Retriever,
cursor_field: Optional[List[str]] = None,
transformations: List[RecordTransformation] = None,
checkpoint_interval: Optional[int] = None,
):
"""
:param name: stream name
:param primary_key: the primary key of the stream
:param schema_loader:
:param retriever:
:param cursor_field:
:param transformations: A list of transformations to be applied to each output record in the stream. Transformations are applied
in the order in which they are defined.
"""
self._name = name
self._primary_key = primary_key
self._cursor_field = cursor_field or []
self._schema_loader = schema_loader
self._retriever = retriever
self._transformations = transformations or []
self._checkpoint_interval = checkpoint_interval

@property
Expand Down Expand Up @@ -84,7 +97,15 @@ def read_records(
stream_slice: Mapping[str, Any] = None,
stream_state: Mapping[str, Any] = None,
) -> Iterable[Mapping[str, Any]]:
return self._retriever.read_records(sync_mode, cursor_field, stream_slice, stream_state)
for record in self._retriever.read_records(sync_mode, cursor_field, stream_slice, stream_state):
yield self._apply_transformations(record)

def _apply_transformations(self, record: Mapping[str, Any]):
output_record = record
for transformation in self._transformations:
output_record = transformation.transform(record)

return output_record

def get_json_schema(self) -> Mapping[str, Any]:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from airbyte_cdk.sources.declarative.requesters.paginators.next_page_url_paginator import NextPageUrlPaginator
from airbyte_cdk.sources.declarative.requesters.paginators.offset_paginator import OffsetPaginator
from airbyte_cdk.sources.declarative.stream_slicers.datetime_stream_slicer import DatetimeStreamSlicer
from airbyte_cdk.sources.declarative.transformations import RemoveFields
from airbyte_cdk.sources.streams.http.requests_native_auth.token import TokenAuthenticator

CLASS_TYPES_REGISTRY: Mapping[str, Type] = {
Expand All @@ -18,4 +19,5 @@
"NextPageUrlPaginator": NextPageUrlPaginator,
"OffsetPaginator": OffsetPaginator,
"TokenAuthenticator": TokenAuthenticator,
"RemoveFields": RemoveFields,
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from airbyte_cdk.sources.declarative.requesters.retriers.default_retrier import DefaultRetrier
from airbyte_cdk.sources.declarative.requesters.retriers.retrier import Retrier
from airbyte_cdk.sources.declarative.types import Config
from airbyte_cdk.sources.streams.http.auth import HttpAuthenticator
from airbyte_cdk.sources.streams.http.auth import HttpAuthenticator, NoAuth


class HttpRequester(Requester):
Expand All @@ -26,7 +26,7 @@ def __init__(
path: [str, InterpolatedString],
http_method: Union[str, HttpMethod] = HttpMethod.GET,
request_options_provider: Optional[RequestOptionsProvider] = None,
authenticator: HttpAuthenticator,
authenticator: HttpAuthenticator = None,
retrier: Optional[Retrier] = None,
config: Config,
):
Expand All @@ -35,7 +35,7 @@ def __init__(
elif isinstance(request_options_provider, dict):
request_options_provider = InterpolatedRequestOptionsProvider(config=config, **request_options_provider)
self._name = name
self._authenticator = authenticator
self._authenticator = authenticator or NoAuth()
if type(url_base) == str:
url_base = InterpolatedString(url_base)
self._url_base = url_base
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#


# RecordTransformation is depended upon by every class in this module (since it's the abc everything implements). For this reason,
# the order of imports matters i.e: this file must fully import RecordTransformation before importing anything which depends on RecordTransformation
# Otherwise there will be a circular dependency (load order will be init.py --> RemoveFields (which tries to import RecordTransformation) -->
# init.py --> circular dep error, since loading this file causes it to try to import itself down the line.
# so we add the split directive below to tell isort to sort imports while keeping RecordTransformation as the first import
from .transformation import RecordTransformation

# isort: split
from .remove_fields import RemoveFields

__all__ = ["RecordTransformation", "RemoveFields"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#

from typing import Any, List, Mapping

import dpath.exceptions
import dpath.util
from airbyte_cdk.sources.declarative.transformations import RecordTransformation
from airbyte_cdk.sources.declarative.types import FieldPointer


class RemoveFields(RecordTransformation):
"""
A transformation which removes fields from a record. The fields removed are designated using FieldPointers.
During transformation, if a field or any of its parents does not exist in the record, no error is thrown.
If an input field pointer references an item in a list (e.g: ["k", 0] in the object {"k": ["a", "b", "c"]}) then
the object at that index is set to None rather than being not entirely removed from the list. TODO change this behavior.
It's possible to remove objects nested in lists e.g: removing [".", 0, "k"] from {".": [{"k": "V"}]} results in {".": [{}]}
Usage syntax:
```yaml
my_stream:
<other parameters..>
transformations:
- type: RemoveFields
field_pointers:
- ["path", "to", "field1"]
- ["path2"]
```
"""

def __init__(self, field_pointers: List[FieldPointer]):
"""
:param field_pointers: pointers to the fields that should be removed
"""
self._field_pointers = field_pointers

def transform(self, record: Mapping[str, Any]) -> Mapping[str, Any]:
"""
:param record: The record to be transformed
:return: the input record with the requested fields removed
"""
for pointer in self._field_pointers:
# the dpath library by default doesn't delete fields from arrays

try:
dpath.util.delete(record, pointer)
except dpath.exceptions.PathNotFound:
# if the (potentially nested) property does not exist, silently skip
pass

return record
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#

from abc import ABC, abstractmethod
from typing import Any, Mapping


class RecordTransformation(ABC):
"""
Implementations of this class define transformations that can be applied to records of a stream.
"""

@abstractmethod
def transform(self, record: Mapping[str, Any]) -> Mapping[str, Any]:
"""
:param record: the input record to be transformed
:return: the transformed record
"""

def __eq__(self, other):
return other.__dict__ == self.__dict__
5 changes: 4 additions & 1 deletion airbyte-cdk/python/airbyte_cdk/sources/declarative/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@

from __future__ import annotations

from typing import Any, Mapping
from typing import Any, List, Mapping

Record = Mapping[str, Any]
# A FieldPointer designates a path to a field inside a mapping. For example, retrieving ["k1", "k1.2"] in the object {"k1" :{"k1.2":
# "hello"}] returns "hello"
FieldPointer = List[str]
Config = Mapping[str, Any]
StreamSlice = Mapping[str, Any]
StreamState = Mapping[str, Any]
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from typing import Any, List, Mapping

from airbyte_cdk.sources.declarative.checks.connection_checker import ConnectionChecker
from airbyte_cdk.sources.declarative.declarative_source import DeclarativeSource
from airbyte_cdk.sources.declarative.parsers.factory import DeclarativeComponentFactory
from airbyte_cdk.sources.declarative.parsers.yaml_parser import YamlParser
Expand All @@ -16,7 +17,7 @@ def __init__(self, path_to_yaml):
self._source_config = self._read_and_parse_yaml_file(path_to_yaml)

@property
def connection_checker(self):
def connection_checker(self) -> ConnectionChecker:
check = self._source_config["check"]
if "class_name" not in check:
check["class_name"] = "airbyte_cdk.sources.declarative.checks.check_stream.CheckStream"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#

from unittest.mock import MagicMock
from unittest import mock
from unittest.mock import MagicMock, call

from airbyte_cdk.models import SyncMode
from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream
from airbyte_cdk.sources.declarative.transformations import RecordTransformation


def test():
Expand All @@ -27,20 +29,29 @@ def test():
retriever.read_records.return_value = records
retriever.stream_slices.return_value = stream_slices

no_op_transform = mock.create_autospec(spec=RecordTransformation)
no_op_transform.transform = MagicMock(side_effect=lambda x: x)
transformations = [no_op_transform]

stream = DeclarativeStream(
name=name,
primary_key=primary_key,
cursor_field=cursor_field,
schema_loader=schema_loader,
retriever=retriever,
transformations=transformations,
checkpoint_interval=checkpoint_interval,
)

assert stream.name == name
assert stream.get_json_schema() == json_schema
assert stream.state == state
assert stream.read_records(SyncMode.full_refresh, cursor_field, None, None) == records
assert list(stream.read_records(SyncMode.full_refresh, cursor_field, None, None)) == records
assert stream.primary_key == primary_key
assert stream.cursor_field == cursor_field
assert stream.stream_slices(sync_mode=SyncMode.incremental, cursor_field=cursor_field, stream_state=None) == stream_slices
assert stream.state_checkpoint_interval == checkpoint_interval
for transformation in transformations:
assert len(transformation.transform.call_args_list) == len(records)
expected_calls = [call(record) for record in records]
transformation.transform.assert_has_calls(expected_calls, any_order=False)
50 changes: 50 additions & 0 deletions airbyte-cdk/python/unit_tests/sources/declarative/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from airbyte_cdk.sources.declarative.retrievers.simple_retriever import SimpleRetriever
from airbyte_cdk.sources.declarative.schema.json_schema import JsonSchema
from airbyte_cdk.sources.declarative.stream_slicers.datetime_stream_slicer import DatetimeStreamSlicer
from airbyte_cdk.sources.declarative.transformations import RemoveFields
from airbyte_cdk.sources.streams.http.requests_native_auth.token import TokenAuthenticator

factory = DeclarativeComponentFactory()
Expand Down Expand Up @@ -304,3 +305,52 @@ def test_full_config_with_defaults():
assert stream._retriever._paginator._interpolated_paginator._next_page_token_template._mapping == {
"next_page_token": "{{ decoded_response.metadata.next}}"
}


class TestCreateTransformations:
# the tabbing matters
base_options = """
name: "lists"
primary_key: id
url_base: "https://api.sendgrid.com"
schema_loader:
file_path: "./source_sendgrid/schemas/{{options.name}}.yaml"
retriever:
requester:
path: "/v3/marketing/lists"
request_parameters:
page_size: 10
record_selector:
extractor:
transform: ".result[]"
"""

def test_no_transformations(self):
content = f"""
the_stream:
class_name: airbyte_cdk.sources.declarative.declarative_stream.DeclarativeStream
options:
{self.base_options}
"""
config = parser.parse(content)
component = factory.create_component(config["the_stream"], input_config)()
assert isinstance(component, DeclarativeStream)
assert [] == component._transformations

def test_remove_fields(self):
content = f"""
the_stream:
class_name: airbyte_cdk.sources.declarative.declarative_stream.DeclarativeStream
options:
{self.base_options}
transformations:
- type: RemoveFields
field_pointers:
- ["path", "to", "field1"]
- ["path2"]
"""
config = parser.parse(content)
component = factory.create_component(config["the_stream"], input_config)()
assert isinstance(component, DeclarativeStream)
expected = [RemoveFields(field_pointers=[["path", "to", "field1"], ["path2"]])]
assert expected == component._transformations
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#
# Copyright (c) 2022 Airbyte, Inc., all rights reserved.
#

from typing import Any, List, Mapping

import pytest
from airbyte_cdk.sources.declarative.transformations import RemoveFields
from airbyte_cdk.sources.declarative.types import FieldPointer


@pytest.mark.parametrize(
["input_record", "field_pointers", "expected"],
[
pytest.param({"k1": "v", "k2": "v"}, [["k1"]], {"k2": "v"}, id="remove a field that exists (flat dict)"),
pytest.param({"k1": "v", "k2": "v"}, [["k3"]], {"k1": "v", "k2": "v"}, id="remove a field that doesn't exist (flat dict)"),
pytest.param({"k1": "v", "k2": "v"}, [["k1"], ["k2"]], {}, id="remove multiple fields that exist (flat dict)"),
# TODO: should we instead splice the element out of the array? I think that's the more intuitive solution
# Otherwise one could just set the field's value to null.
pytest.param({"k1": [1, 2]}, [["k1", 0]], {"k1": [None, 2]}, id="remove field inside array (int index)"),
pytest.param({"k1": [1, 2]}, [["k1", "0"]], {"k1": [None, 2]}, id="remove field inside array (string index)"),
pytest.param(
{"k1": "v", "k2": "v", "k3": [0, 1], "k4": "v"},
[["k1"], ["k2"], ["k3", 0]],
{"k3": [None, 1], "k4": "v"},
id="test all cases (flat)",
),
pytest.param({"k1": [0, 1]}, [[".", "k1", 10]], {"k1": [0, 1]}, id="remove array index that doesn't exist (flat)"),
pytest.param({".": {"k1": [0, 1]}}, [[".", "k1", 10]], {".": {"k1": [0, 1]}}, id="remove array index that doesn't exist (nested)"),
pytest.param({".": {"k2": "v", "k1": "v"}}, [[".", "k1"]], {".": {"k2": "v"}}, id="remove nested field that exists"),
pytest.param(
{".": {"k2": "v", "k1": "v"}}, [[".", "k3"]], {".": {"k2": "v", "k1": "v"}}, id="remove field that doesn't exist (nested)"
),
pytest.param({".": {"k2": "v", "k1": "v"}}, [[".", "k1"], [".", "k2"]], {".": {}}, id="remove multiple fields that exist (nested)"),
pytest.param(
{".": {"k1": [0, 1]}}, [[".", "k1", 0]], {".": {"k1": [None, 1]}}, id="remove multiple fields that exist in arrays (nested)"
),
pytest.param(
{".": {"k1": [{"k2": "v", "k3": "v"}, {"k4": "v"}]}},
[[".", "k1", 0, "k2"], [".", "k1", 1, "k4"]],
{".": {"k1": [{"k3": "v"}, {}]}},
id="remove fields that exist in arrays (deeply nested)",
),
],
)
def test_remove_fields(input_record: Mapping[str, Any], field_pointers: List[FieldPointer], expected: Mapping[str, Any]):
transformation = RemoveFields(field_pointers)
assert transformation.transform(input_record) == expected

0 comments on commit 743e6c2

Please sign in to comment.