Skip to content

Commit

Permalink
Python CDK: add schema transformer class
Browse files Browse the repository at this point in the history
  • Loading branch information
Dmytro Rezchykov committed Sep 16, 2021
1 parent 2b7c56e commit a21e032
Show file tree
Hide file tree
Showing 7 changed files with 380 additions and 4 deletions.
3 changes: 3 additions & 0 deletions airbyte-cdk/python/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Changelog

## 0.1.21
Add ability of object normalization according to its jsonschema.

## 0.1.20
- Allow using `requests.auth.AuthBase` as authenticators instead of custom CDK authenticators.
- Implement Oauth2Authenticator, MultipleTokenAuthenticator and TokenAuthenticator authenticators.
Expand Down
23 changes: 22 additions & 1 deletion airbyte-cdk/python/airbyte_cdk/sources/abstract_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
import copy
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, Iterator, List, Mapping, MutableMapping, Optional, Tuple
from functools import lru_cache
from typing import Any, Dict, Iterator, List, Mapping, MutableMapping, Optional, Tuple

from airbyte_cdk.logger import AirbyteLogger
from airbyte_cdk.models import (
Expand All @@ -35,6 +36,7 @@
AirbyteMessage,
AirbyteRecordMessage,
AirbyteStateMessage,
AirbyteStream,
ConfiguredAirbyteCatalog,
ConfiguredAirbyteStream,
Status,
Expand All @@ -45,6 +47,7 @@
from airbyte_cdk.sources.streams import Stream
from airbyte_cdk.sources.streams.http.http import HttpStream
from airbyte_cdk.sources.utils.schema_helpers import InternalConfig, split_config
from airbyte_cdk.sources.utils.transform import Transformer


class AbstractSource(Source, ABC):
Expand All @@ -70,6 +73,9 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]:
:return: A list of the streams in this source connector.
"""

# Stream name to instance map for applying output object transformation
_stream_instances: Dict[str, AirbyteStream] = {}

@property
def name(self) -> str:
"""Source name"""
Expand Down Expand Up @@ -101,6 +107,7 @@ def read(
# TODO assert all streams exist in the connector
# get the streams once in case the connector needs to make any queries to generate them
stream_instances = {s.name: s for s in self.streams(config)}
self._stream_instances = stream_instances
for configured_stream in catalog.streams:
stream_instance = stream_instances.get(configured_stream.stream.name)
if not stream_instance:
Expand Down Expand Up @@ -227,7 +234,21 @@ def _checkpoint_state(self, stream_name, stream_state, connector_state, logger):
connector_state[stream_name] = stream_state
return AirbyteMessage(type=MessageType.STATE, state=AirbyteStateMessage(data=connector_state))

@lru_cache(maxsize=None)
def _get_stream_transformer_and_schema(self, stream_name: str) -> Tuple[Transformer, dict]:
"""
Lookup stream's transform object and jsonschema based on stream name.
This function would be called a lot so using caching to save on costly
get_json_schema operation.
:param stream_name name of stream from catalog.
:return tuple with stream transformer object and discover json schema.
"""
stream_instance = self._stream_instances.get(stream_name)
return stream_instance.transformer, stream_instance.get_json_schema()

def _as_airbyte_record(self, stream_name: str, data: Mapping[str, Any]):
now_millis = int(datetime.now().timestamp()) * 1000
transformer, schema = self._get_stream_transformer_and_schema(stream_name)
transformer.transform(data, schema)
message = AirbyteRecordMessage(stream=stream_name, data=data, emitted_at=now_millis)
return AirbyteMessage(type=MessageType.RECORD, record=message)
7 changes: 7 additions & 0 deletions airbyte-cdk/python/airbyte_cdk/sources/streams/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from airbyte_cdk.logger import AirbyteLogger
from airbyte_cdk.models import AirbyteStream, SyncMode
from airbyte_cdk.sources.utils.schema_helpers import ResourceSchemaLoader
from airbyte_cdk.sources.utils.transform import TransformConfig, Transformer


def package_name_from_class(cls: object) -> str:
Expand All @@ -39,6 +40,9 @@ def package_name_from_class(cls: object) -> str:
return module.__name__.split(".")[0]


_default_transformer = Transformer(TransformConfig.NoTransform)


class Stream(ABC):
"""
Base abstract class for an Airbyte Stream. Makes no assumption of the Stream's underlying transport protocol.
Expand All @@ -47,6 +51,9 @@ class Stream(ABC):
# Use self.logger in subclasses to log any messages
logger = AirbyteLogger() # TODO use native "logging" loggers with custom handlers

# Transformer object ot perform output data transformation
transformer: Transformer = _default_transformer

@property
def name(self) -> str:
"""
Expand Down
150 changes: 150 additions & 0 deletions airbyte-cdk/python/airbyte_cdk/sources/utils/transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
#
# MIT License
#
# Copyright (c) 2020 Airbyte
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
from enum import Flag, auto
from typing import Any, Callable, Dict

from jsonschema import Draft7Validator, validators


class TransformConfig(Flag):
"""
Transformer class config. Configs can be combined using bitwise or operator e.g.
```
TransformConfig.DefaultSchemaNormalization | TransformConfig.CustomSchemaNormalization
```
"""

NoTransform = auto()
DefaultSchemaNormalization = auto()
CustomSchemaNormalization = auto()


class Transformer:
"""
Class for transforming object before output.
"""

def __init__(self, config: TransformConfig):
"""
Initialize Transformer instance.
:param config Transform config that would be applied to object
"""
if TransformConfig.NoTransform in config and config != TransformConfig.NoTransform:
raise Exception("NoTransform option cannot be combined with another flags.")
self._config = config
all_validators = {
key: self.__normalize_and_validate(key, orig_validator)
for key, orig_validator in Draft7Validator.VALIDATORS.items()
if key in ["type", "array", "$ref", "properties", "items"]
}
self._normalizer = validators.create(meta_schema=Draft7Validator.META_SCHEMA, validators=all_validators)

def __normalize(self, original_item: Any, subschema: Dict[str, Any]):
"""
Applies different transform function to object's field according to config.
:param original_item original value of field.
:param subschema part of the jsonschema containing field type/format data.
:return Final field value.
"""
if TransformConfig.DefaultSchemaNormalization in self._config:
original_item = self.default_convert(original_item, subschema)

if TransformConfig.CustomSchemaNormalization in self._config:
raise NotImplementedError("Custom normalization is not implemented yet")
return original_item

@staticmethod
def default_convert(original_item: Any, subschema: Dict[str, Any]) -> Any:
"""
Default transform function that is used when TransformConfig.DefaultSchemaNormalization flag set.
:param original_item original value of field.
:param subschema part of the jsonschema containing field type/format data.
:return transformed field value.
"""
target_type = subschema["type"]
if original_item is None and "null" in target_type:
return None
if isinstance(target_type, list):
target_type = [t for t in target_type if t != "null"]
if len(target_type) != 1:
return original_item
target_type = target_type[0]
try:
if target_type == "string":
return str(original_item)
elif target_type == "number":
return float(original_item)
elif target_type == "integer":
return int(original_item)
elif target_type == "boolean":
return bool(original_item)
except ValueError:
return original_item
return original_item

def __normalize_and_validate(self, schema_key: str, original_validator: Callable):
"""
Traverse through object fields using native jsonschema validator and apply normalization function.
:param schema_key related json schema key that currently being validated/normalized.
:original_validator: native jsonschema validator callback.
"""

def normalizator(validator_instance, val, instance, schema):
def resolve(subschema):
if "$ref" in subschema:
_, resolved = validator_instance.resolver.resolve(subschema["$ref"])
return resolved
return subschema

if schema_key == "type" and instance is not None:
if "object" in val:
for k, subschema in schema["properties"].items():
if k in instance:
subschema = resolve(subschema)
instance[k] = self.__normalize(instance[k], subschema)
elif "array" in val:
subschema = schema.get("items")
subschema = resolve(subschema)
for index, item in enumerate(instance):
instance[index] = self.__normalize(item, subschema)
# Running native jsonschema traverse algorithm after field normalization is done.
yield from original_validator(validator_instance, val, instance, schema)

return normalizator

def transform(self, instance: Dict[str, Any], schema: Dict[str, Any]):
"""
Normalize and validate according to config.
:param instance object instance for normalization/transformation. All modification are done by modifing existent object.
:schema object's jsonschema for normalization.
"""
if TransformConfig.NoTransform in self._config:
return
normalizer = self._normalizer(schema)
for e in normalizer.iter_errors(instance):
"""
just calling normalizer.validate() would throw an exception on
first validation occurences and stop processing rest of schema.
"""
# TODO: log warning
2 changes: 1 addition & 1 deletion airbyte-cdk/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

setup(
name="airbyte-cdk",
version="0.1.20",
version="0.1.21",
description="A framework for writing Airbyte Connectors.",
long_description=README,
long_description_content_type="text/markdown",
Expand Down
47 changes: 45 additions & 2 deletions airbyte-cdk/python/unit_tests/sources/test_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from airbyte_cdk.sources import AbstractSource, Source
from airbyte_cdk.sources.streams.core import Stream
from airbyte_cdk.sources.streams.http.http import HttpStream
from airbyte_cdk.sources.utils.transform import TransformConfig, Transformer


class MockSource(Source):
Expand Down Expand Up @@ -81,6 +82,7 @@ def abstract_source(mocker):
class MockHttpStream(MagicMock, HttpStream):
url_base = "http://example.com"
path = "/dummy/path"
get_json_schema = MagicMock()

def supports_incremental(self):
return True
Expand All @@ -92,6 +94,7 @@ def __init__(self, *args, **kvargs):

class MockStream(MagicMock, Stream):
page_size = None
get_json_schema = MagicMock()

def __init__(self, *args, **kvargs):
MagicMock.__init__(self)
Expand Down Expand Up @@ -145,8 +148,7 @@ def test_read_catalog(source):
def test_internal_config(abstract_source, catalog):
streams = abstract_source.streams(None)
assert len(streams) == 2
http_stream = streams[0]
non_http_stream = streams[1]
http_stream, non_http_stream = streams
assert isinstance(http_stream, HttpStream)
assert not isinstance(non_http_stream, HttpStream)
http_stream.read_records.return_value = [{}] * 3
Expand Down Expand Up @@ -216,3 +218,44 @@ def test_internal_config_limit(abstract_source, catalog):
logger_info_args = [call[0][0] for call in logger_mock.info.call_args_list]
read_log_record = [_l for _l in logger_info_args if _l.startswith("Read")]
assert read_log_record[0].startswith(f"Read {STREAM_LIMIT} ")


SCHEMA = {"type": "object", "properties": {"value": {"type": "string"}}}


def test_source_config_no_transform(abstract_source, catalog):
logger_mock = MagicMock()
streams = abstract_source.streams(None)
http_stream, non_http_stream = streams
http_stream.get_json_schema.return_value = non_http_stream.get_json_schema.return_value = SCHEMA
http_stream.read_records.return_value, non_http_stream.read_records.return_value = [[{"value": 23}] * 5] * 2
records = [r for r in abstract_source.read(logger=logger_mock, config={}, catalog=catalog, state={})]
assert len(records) == 2 * 5
assert [r.record.data for r in records] == [{"value": 23}] * 2 * 5
assert http_stream.get_json_schema.call_count == 1
assert non_http_stream.get_json_schema.call_count == 1


def test_source_config_transform(abstract_source, catalog):
logger_mock = MagicMock()
streams = abstract_source.streams(None)
http_stream, non_http_stream = streams
http_stream.transformer = Transformer(TransformConfig.DefaultSchemaNormalization)
non_http_stream.transformer = Transformer(TransformConfig.DefaultSchemaNormalization)
http_stream.get_json_schema.return_value = non_http_stream.get_json_schema.return_value = SCHEMA
http_stream.read_records.return_value, non_http_stream.read_records.return_value = [{"value": 23}], [{"value": 23}]
records = [r for r in abstract_source.read(logger=logger_mock, config={}, catalog=catalog, state={})]
assert len(records) == 2
assert [r.record.data for r in records] == [{"value": "23"}] * 2


def test_source_config_transform_and_no_transform(abstract_source, catalog):
logger_mock = MagicMock()
streams = abstract_source.streams(None)
http_stream, non_http_stream = streams
http_stream.transformer = Transformer(TransformConfig.DefaultSchemaNormalization)
http_stream.get_json_schema.return_value = non_http_stream.get_json_schema.return_value = SCHEMA
http_stream.read_records.return_value, non_http_stream.read_records.return_value = [{"value": 23}], [{"value": 23}]
records = [r for r in abstract_source.read(logger=logger_mock, config={}, catalog=catalog, state={})]
assert len(records) == 2
assert [r.record.data for r in records] == [{"value": "23"}, {"value": 23}]
Loading

0 comments on commit a21e032

Please sign in to comment.