Skip to content

Commit

Permalink
forgot to also emit dictionary instead of empty lists when state not …
Browse files Browse the repository at this point in the history
…specified (airbytehq#16569)
  • Loading branch information
brianjlai authored and robbinhan committed Sep 29, 2022
1 parent 52847ea commit 007dfe3
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 14 deletions.
3 changes: 3 additions & 0 deletions airbyte-cdk/python/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Changelog

## 0.1.84
- Emit legacy format when state is unspecified for read override connectors

## 0.1.83
- Fix per-stream to send legacy format for connectors that override read

Expand Down
27 changes: 19 additions & 8 deletions airbyte-cdk/python/airbyte_cdk/sources/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import json
import logging
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Any, Generic, Iterable, List, Mapping, MutableMapping, TypeVar, Union

from airbyte_cdk.connector import BaseConnector, DefaultConnectorMixin, TConfig
Expand Down Expand Up @@ -44,7 +45,7 @@ class Source(
ABC,
):
# can be overridden to change an input state
def read_state(self, state_path: str) -> List[AirbyteStateMessage]:
def read_state(self, state_path: str) -> Union[List[AirbyteStateMessage], MutableMapping[str, Any]]:
"""
Retrieves the input state of a sync by reading from the specified JSON file. Incoming state can be deserialized into either
a JSON object for legacy state input or as a list of AirbyteStateMessages for the per-stream state format. Regardless of the
Expand All @@ -55,7 +56,7 @@ def read_state(self, state_path: str) -> List[AirbyteStateMessage]:
if state_path:
state_obj = json.loads(open(state_path, "r").read())
if not state_obj:
return []
return self._emit_legacy_state_format({})
is_per_stream_state = isinstance(state_obj, List)
if is_per_stream_state:
parsed_state_messages = []
Expand All @@ -66,13 +67,23 @@ def read_state(self, state_path: str) -> List[AirbyteStateMessage]:
parsed_state_messages.append(parsed_message)
return parsed_state_messages
else:
# Existing connectors that override read() might not be able to interpret the new state format. We temporarily
# send state in the old format for these connectors, but once all have been upgraded, this block can be removed
# vars(self.__class__) checks if the current class directly overrides the read() function
if "read" in vars(self.__class__):
return state_obj
return self._emit_legacy_state_format(state_obj)
return self._emit_legacy_state_format({})

def _emit_legacy_state_format(self, state_obj) -> Union[List[AirbyteStateMessage], MutableMapping[str, Any]]:
"""
Existing connectors that override read() might not be able to interpret the new state format. We temporarily
send state in the old format for these connectors, but once all have been upgraded, this method can be removed,
and we can then emit state in the list format.
"""
# vars(self.__class__) checks if the current class directly overrides the read() function
if "read" in vars(self.__class__):
return defaultdict(dict, state_obj)
else:
if state_obj:
return [AirbyteStateMessage(type=AirbyteStateType.LEGACY, data=state_obj)]
return []
else:
return []

# can be overridden to change an input catalog
def read_catalog(self, catalog_path: str) -> ConfiguredAirbyteCatalog:
Expand Down
2 changes: 1 addition & 1 deletion airbyte-cdk/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

setup(
name="airbyte-cdk",
version="0.1.83",
version="0.1.84",
description="A framework for writing Airbyte Connectors.",
long_description=README,
long_description_content_type="text/markdown",
Expand Down
18 changes: 13 additions & 5 deletions airbyte-cdk/python/unit_tests/sources/test_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
import logging
import tempfile
from collections import defaultdict
from contextlib import nullcontext as does_not_raise
from typing import Any, List, Mapping, MutableMapping, Optional, Tuple
from unittest.mock import MagicMock
Expand Down Expand Up @@ -224,9 +225,9 @@ def streams(self, config):
does_not_raise(),
id="test_incoming_legacy_state",
),
pytest.param([], [], does_not_raise(), id="test_empty_incoming_stream_state"),
pytest.param(None, [], does_not_raise(), id="test_none_incoming_state"),
pytest.param({}, [], does_not_raise(), id="test_empty_incoming_legacy_state"),
pytest.param([], defaultdict(dict, {}), does_not_raise(), id="test_empty_incoming_stream_state"),
pytest.param(None, defaultdict(dict, {}), does_not_raise(), id="test_none_incoming_state"),
pytest.param({}, defaultdict(dict, {}), does_not_raise(), id="test_empty_incoming_legacy_state"),
pytest.param(
[
{
Expand Down Expand Up @@ -301,8 +302,15 @@ def test_read_state_sends_new_legacy_format_if_source_does_not_implement_read():
assert actual == expected_state


def test_read_state_nonexistent(source):
assert [] == source.read_state("")
@pytest.mark.parametrize(
"source, expected_state",
[
pytest.param(MockSource(), {}, id="test_source_implementing_read_returns_legacy_format"),
pytest.param(MockAbstractSource(), [], id="test_source_not_implementing_read_returns_per_stream_format"),
],
)
def test_read_state_nonexistent(source, expected_state):
assert source.read_state("") == expected_state


def test_read_catalog(source):
Expand Down

0 comments on commit 007dfe3

Please sign in to comment.