Skip to content

Commit

Permalink
add better test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
emmaling27 committed Jan 12, 2023
1 parent 845aa30 commit eb1527e
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,25 +45,17 @@ def write(
:return: Iterable of AirbyteStateMessages wrapped in AirbyteMessage structs
"""
config = cast(ConvexConfig, config)
writer = ConvexWriter(
ConvexClient(config, self.__stream_metadata(configured_catalog.streams))
)
writer = ConvexWriter(ConvexClient(config, self.__stream_metadata(configured_catalog.streams)))

# Setup: Clear tables if in overwrite mode; add indexes if in append_dedup mode.
streams_to_delete = []
indexes_to_add = {}
for configured_stream in configured_catalog.streams:
if configured_stream.destination_sync_mode == DestinationSyncMode.overwrite:
streams_to_delete.append(configured_stream.stream.name)
elif (
configured_stream.destination_sync_mode
== DestinationSyncMode.append_dedup
and configured_stream.primary_key
):
elif configured_stream.destination_sync_mode == DestinationSyncMode.append_dedup and configured_stream.primary_key:

indexes_to_add[
configured_stream.stream.name
] = configured_stream.primary_key
indexes_to_add[configured_stream.stream.name] = configured_stream.primary_key
if len(streams_to_delete) != 0:
writer.delete_stream_entries(streams_to_delete)
if len(indexes_to_add) != 0:
Expand All @@ -78,9 +70,7 @@ def write(
yield message
elif message.type == Type.RECORD and message.record is not None:
if message.record.namespace is not None:
message.record.stream = (
f"{message.record.namespace}_{message.record.stream}"
)
message.record.stream = f"{message.record.namespace}_{message.record.stream}"
msg = message.record.dict()
writer.queue_write_operation(msg)
else:
Expand All @@ -90,9 +80,7 @@ def write(
# Make sure to flush any records still in the queue
writer.flush()

def __stream_metadata(
self, streams: List[ConfiguredAirbyteStream]
) -> Mapping[str, Any]:
def __stream_metadata(self, streams: List[ConfiguredAirbyteStream]) -> Mapping[str, Any]:
stream_metadata = {}
for s in streams:
# Only send a primary key for dedup sync
Expand All @@ -109,9 +97,7 @@ def __stream_metadata(
stream_metadata[name] = stream
return stream_metadata

def check(
self, logger: Logger, config: Mapping[str, Any]
) -> AirbyteConnectionStatus:
def check(self, logger: Logger, config: Mapping[str, Any]) -> AirbyteConnectionStatus:
"""
Tests if the input configuration can be used to successfully connect to the destination with the needed permissions
e.g: if a provided API token or password can be used to connect and write to the destination.
Expand All @@ -132,5 +118,4 @@ def check(
if resp.status_code == 200:
return AirbyteConnectionStatus(status=Status.SUCCEEDED)
else:
return AirbyteConnectionStatus(
status=Status.FAILED, message=f"An exception occurred: {repr(resp)}"
return AirbyteConnectionStatus(status=Status.FAILED, message=f"An exception occurred: {repr(resp)}")
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ def configured_catalog_fixture() -> ConfiguredAirbyteCatalog:
return ConfiguredAirbyteCatalog(streams=[append_stream, overwrite_stream, dedup_stream])


def _state(data: Dict[str, Any]) -> AirbyteMessage:
def state(data: Dict[str, Any]) -> AirbyteMessage:
return AirbyteMessage(type=Type.STATE, state=AirbyteStateMessage(data=data))


def _record(stream: str, str_value: str, int_value: int) -> AirbyteMessage:
def record(stream: str, str_value: str, int_value: int) -> AirbyteMessage:
return AirbyteMessage(
type=Type.RECORD,
record=AirbyteRecordMessage(stream=stream, data={"str_col": str_value, DEDUP_INDEX_FIELD: int_value}, emitted_at=0),
Expand Down Expand Up @@ -103,30 +103,43 @@ def test_check(config: ConvexConfig):
@responses.activate
def test_write(config: ConvexConfig, configured_catalog: ConfiguredAirbyteCatalog):
setup_responses(config)
append_stream, overwrite_stream = configured_catalog.streams[0].stream.name, configured_catalog.streams[1].stream.name

first_state_message = _state({"state": "1"})
first_record_chunk = [_record(append_stream, str(i), i) for i in range(5)] + [_record(overwrite_stream, str(i), i) for i in range(5)]

second_state_message = _state({"state": "2"})
second_record_chunk = [_record(append_stream, str(i), i) for i in range(5, 10)] + [
_record(overwrite_stream, str(i), i) for i in range(5, 10)
]
append_stream, overwrite_stream, dedup_stream = (
configured_catalog.streams[0].stream.name,
configured_catalog.streams[1].stream.name,
configured_catalog.streams[2].stream.name,
)

first_state_message = state({"state": "1"})
first_append_chunk = [record(append_stream, str(i), i) for i in range(5)]
first_overwrite_chunk = [record(overwrite_stream, str(i), i) for i in range(5)]
first_dedup_chunk = [record(dedup_stream, str(i), i) for i in range(10)]
first_record_chunk = first_append_chunk + first_overwrite_chunk + first_dedup_chunk
destination = DestinationConvex()

expected_states = [first_state_message, second_state_message]
output_states = list(
output_state = list(
destination.write(
config, configured_catalog, [*first_record_chunk, first_state_message, *second_record_chunk, second_state_message]
config,
configured_catalog,
[
*first_record_chunk,
first_state_message,
],
)
)
assert expected_states == output_states, "Checkpoint state messages were expected from the destination"

third_state_message = _state({"state": "3"})
third_record_chunk = [_record(append_stream, str(i), i) for i in range(10, 15)] + [
_record(overwrite_stream, str(i), i) for i in range(10, 15)
]

output_states = list(destination.write(config, configured_catalog, [*third_record_chunk, third_state_message]))
assert [third_state_message] == output_states
)[0]
assert first_state_message == output_state

second_state_message = state({"state": "2"})
second_append_chunk = [record(append_stream, str(i), i) for i in range(5, 10)]
second_overwrite_chunk = [record(overwrite_stream, str(i), i) for i in range(5, 10)]
second_dedup_chunk = [record(dedup_stream, str(i + 2), i) for i in range(5)]
second_record_chunk = second_append_chunk + second_overwrite_chunk + second_dedup_chunk
output_state = list(
destination.write(
config,
configured_catalog,
[
*second_record_chunk,
second_state_message,
],
)
)[0]
assert second_state_message == output_state

0 comments on commit eb1527e

Please sign in to comment.