Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Need to also emit legacy state for read override connectors when state is not specified #16569

Merged
merged 2 commits into from
Sep 12, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions airbyte-cdk/python/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Changelog

## 0.1.84
- Emit legacy format when state is unspecified for read override connectors

## 0.1.83
- Fix per-stream to send legacy format for connectors that override read

Expand Down
27 changes: 19 additions & 8 deletions airbyte-cdk/python/airbyte_cdk/sources/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import json
import logging
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Any, Generic, Iterable, List, Mapping, MutableMapping, TypeVar, Union

from airbyte_cdk.connector import BaseConnector, DefaultConnectorMixin, TConfig
Expand Down Expand Up @@ -44,7 +45,7 @@ class Source(
ABC,
):
# can be overridden to change an input state
def read_state(self, state_path: str) -> List[AirbyteStateMessage]:
def read_state(self, state_path: str) -> Union[List[AirbyteStateMessage], MutableMapping[str, Any]]:
"""
Retrieves the input state of a sync by reading from the specified JSON file. Incoming state can be deserialized into either
a JSON object for legacy state input or as a list of AirbyteStateMessages for the per-stream state format. Regardless of the
Expand All @@ -55,7 +56,7 @@ def read_state(self, state_path: str) -> List[AirbyteStateMessage]:
if state_path:
state_obj = json.loads(open(state_path, "r").read())
if not state_obj:
return []
return self._emit_legacy_state_format({})
is_per_stream_state = isinstance(state_obj, List)
if is_per_stream_state:
parsed_state_messages = []
Expand All @@ -66,13 +67,23 @@ def read_state(self, state_path: str) -> List[AirbyteStateMessage]:
parsed_state_messages.append(parsed_message)
return parsed_state_messages
else:
# Existing connectors that override read() might not be able to interpret the new state format. We temporarily
# send state in the old format for these connectors, but once all have been upgraded, this block can be removed
# vars(self.__class__) checks if the current class directly overrides the read() function
if "read" in vars(self.__class__):
return state_obj
return self._emit_legacy_state_format(state_obj)
return self._emit_legacy_state_format({})

def _emit_legacy_state_format(self, state_obj) -> Union[List[AirbyteStateMessage], MutableMapping[str, Any]]:
"""
Existing connectors that override read() might not be able to interpret the new state format. We temporarily
send state in the old format for these connectors, but once all have been upgraded, this method can be removed,
and we can then emit state in the list format.
"""
# vars(self.__class__) checks if the current class directly overrides the read() function
if "read" in vars(self.__class__):
return defaultdict(dict, state_obj)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you share why you call defaultdict with a dict default factory?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great question! Not sure if it actually makes a difference, but the original code before the changes also wrapped it in defaultdict so I wanted to preserve the same functionality. Reference to prior changes:

state = defaultdict(dict, state_obj)

else:
if state_obj:
return [AirbyteStateMessage(type=AirbyteStateType.LEGACY, data=state_obj)]
return []
else:
return []

# can be overridden to change an input catalog
def read_catalog(self, catalog_path: str) -> ConfiguredAirbyteCatalog:
Expand Down
2 changes: 1 addition & 1 deletion airbyte-cdk/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

setup(
name="airbyte-cdk",
version="0.1.83",
version="0.1.84",
description="A framework for writing Airbyte Connectors.",
long_description=README,
long_description_content_type="text/markdown",
Expand Down
18 changes: 13 additions & 5 deletions airbyte-cdk/python/unit_tests/sources/test_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
import logging
import tempfile
from collections import defaultdict
from contextlib import nullcontext as does_not_raise
from typing import Any, List, Mapping, MutableMapping, Optional, Tuple
from unittest.mock import MagicMock
Expand Down Expand Up @@ -224,9 +225,9 @@ def streams(self, config):
does_not_raise(),
id="test_incoming_legacy_state",
),
pytest.param([], [], does_not_raise(), id="test_empty_incoming_stream_state"),
pytest.param(None, [], does_not_raise(), id="test_none_incoming_state"),
pytest.param({}, [], does_not_raise(), id="test_empty_incoming_legacy_state"),
pytest.param([], defaultdict(dict, {}), does_not_raise(), id="test_empty_incoming_stream_state"),
pytest.param(None, defaultdict(dict, {}), does_not_raise(), id="test_none_incoming_state"),
pytest.param({}, defaultdict(dict, {}), does_not_raise(), id="test_empty_incoming_legacy_state"),
pytest.param(
[
{
Expand Down Expand Up @@ -301,8 +302,15 @@ def test_read_state_sends_new_legacy_format_if_source_does_not_implement_read():
assert actual == expected_state


def test_read_state_nonexistent(source):
assert [] == source.read_state("")
@pytest.mark.parametrize(
"source, expected_state",
[
pytest.param(MockSource(), {}, id="test_source_implementing_read_returns_legacy_format"),
pytest.param(MockAbstractSource(), [], id="test_source_not_implementing_read_returns_per_stream_format"),
],
)
def test_read_state_nonexistent(source, expected_state):
assert source.read_state("") == expected_state


def test_read_catalog(source):
Expand Down