Skip to content

Commit

Permalink
created more shared methods and unified classes more
Browse files Browse the repository at this point in the history
  • Loading branch information
tim-quix committed Nov 26, 2024
1 parent 90519da commit 124db0b
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 57 deletions.
37 changes: 13 additions & 24 deletions quixstreams/sources/community/blob_store/blob_source.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from pathlib import Path
from typing import Optional, Union
from typing import Generator, Optional, Union

from quixstreams.sources.community.file import FileSource
from quixstreams.sources.community.file.compressions import CompressionName
Expand All @@ -20,15 +20,15 @@ def __init__(
blob_client: BlobClient,
blob_format: Union[FormatName, Format],
blob_compression: Optional[CompressionName] = None,
blob_folder: Optional[str] = None,
blob_file: Optional[str] = None,
blob_folder: Optional[Union[str, Path]] = None,
blob_file: Optional[Union[str, Path]] = None,
as_replay: bool = True,
name: Optional[str] = None,
shutdown_timeout: float = 10.0,
):
self._client = blob_client
self._blob_file = blob_file
self._blob_folder = blob_folder
self._blob_file = Path(blob_file) if blob_file else None
self._blob_folder = Path(blob_folder) if blob_folder else None

super().__init__(
filepath=self._client.location,
Expand All @@ -40,24 +40,13 @@ def __init__(
)

def _get_partition_count(self) -> int:
return 1
return self._client.get_root_folder_count(self._blob_folder)

def _check_file_partition_number(self, file: Path) -> int:
return 0
def _file_read(self, file: Path) -> Generator[dict, None, None]:
yield from super()._file_read(self._client.get_raw_blob_stream(file))

def run(self):
blobs = (
[self._blob_file]
if self._blob_file
else self._client.blob_finder(self._blob_folder)
)
while self._running:
for file in blobs:
self._check_file_partition_number(Path(file))
filestream = self._client.get_raw_blob_stream(file)
for record in self._formatter.file_read(filestream):
if self._as_replay and (timestamp := record.get("_timestamp")):
self._replay_delay(timestamp)
self._produce(record)
self.flush()
return
def _file_list(self) -> Generator[Path, None, None]:
if self._blob_file:
yield self._blob_file
else:
yield from self._client.blob_collector(self._blob_folder)
31 changes: 23 additions & 8 deletions quixstreams/sources/community/blob_store/clients/aws.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
from io import BytesIO
from os import getenv
from typing import Generator, Optional
from pathlib import Path
from typing import Generator, Optional, Union

from .base import BlobClient

Expand All @@ -11,6 +13,8 @@
raise


logger = logging.getLogger(__name__)

__all__ = ("AwsS3BlobClient",)


Expand Down Expand Up @@ -55,16 +59,27 @@ def client(self):
self._client: S3Client = boto_client("s3", **self._credentials)
return self._client

def get_raw_blob_stream(self, blob_path: str) -> BytesIO:
data = self.client.get_object(Bucket=self.location, Key=blob_path)[
def get_raw_blob_stream(self, blob_path: Path) -> BytesIO:
data = self.client.get_object(Bucket=self.location, Key=str(blob_path))[
"Body"
].read()
return BytesIO(data)

def blob_finder(self, folder: str) -> Generator[str, None, None]:
# TODO: Recursively navigate through folders.
def get_root_folder_count(self, folder: Path) -> int:
resp = self.client.list_objects(
Bucket=self.location, Prefix=folder, Delimiter="/"
Bucket=self.location, Prefix=str(folder), Delimiter="/"
)
for item in resp["Contents"]:
yield item["Key"]
self._client = None
return len(resp["CommonPrefixes"])

def blob_collector(self, folder: Union[str, Path]) -> Generator[Path, None, None]:
resp = self.client.list_objects(
Bucket=self.location,
Prefix=str(folder),
Delimiter="/",
)
for folder in resp.get("CommonPrefixes", []):
yield from self.blob_collector(folder["Prefix"])

for file in resp.get("Contents", []):
yield Path(file["Key"])
10 changes: 6 additions & 4 deletions quixstreams/sources/community/blob_store/clients/azure.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from io import BytesIO
from pathlib import Path
from typing import Generator, Optional

from .base import BlobClient
Expand Down Expand Up @@ -30,11 +31,12 @@ def client(self):
self._client: ContainerClient = container_client
return self._client

def blob_finder(self, folder: str) -> Generator[str, None, None]:
for item in self.client.list_blob_names(name_starts_with=folder):
def blob_collector(self, folder: Path) -> Generator[str, None, None]:
# TODO: Recursively navigate folders.
for item in self.client.list_blob_names(name_starts_with=str(folder)):
yield item

def get_raw_blob_stream(self, blob_name) -> BytesIO:
blob_client = self.client.get_blob_client(blob_name)
def get_raw_blob_stream(self, blob_name: Path) -> BytesIO:
blob_client = self.client.get_blob_client(str(blob_name))
data = blob_client.download_blob().readall()
return BytesIO(data)
14 changes: 10 additions & 4 deletions quixstreams/sources/community/blob_store/clients/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import abstractmethod
from dataclasses import dataclass
from io import BytesIO
from pathlib import Path
from typing import Any, Iterable, Union

__all__ = ("BlobClient",)
Expand All @@ -10,10 +11,10 @@
class BlobClient:
_client: Any
_credentials: Union[dict, str]
location: str
location: Union[str, Path]

@abstractmethod
@property
@abstractmethod
def client(self): ...

"""
Expand All @@ -24,7 +25,7 @@ def client(self): ...
"""

@abstractmethod
def blob_finder(self, folder: str) -> Iterable[str]: ...
def blob_collector(self, folder: Path) -> Iterable[Path]: ...

"""
Find all blobs starting from a root folder.
Expand All @@ -33,7 +34,12 @@ def blob_finder(self, folder: str) -> Iterable[str]: ...
"""

@abstractmethod
def get_raw_blob_stream(self, blob_path: str) -> BytesIO: ...
def get_root_folder_count(self, filepath: Path) -> int: ...

"""Counts the number of folders at filepath to assume partition counts."""

@abstractmethod
def get_raw_blob_stream(self, blob_path: Path) -> BytesIO: ...

"""
Obtain a specific blob in its raw form.
Expand Down
40 changes: 23 additions & 17 deletions quixstreams/sources/community/file/file.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from pathlib import Path
from time import sleep
from typing import Generator, Optional, Union
from typing import BinaryIO, Generator, Optional, Union

from quixstreams.models import Topic, TopicConfig
from quixstreams.sources import Source
Expand Down Expand Up @@ -107,18 +107,6 @@ def _replay_delay(self, current_timestamp: int):
def _get_partition_count(self) -> int:
return len([f for f in self._filepath.iterdir()])

def default_topic(self) -> Topic:
"""
Uses the file structure to generate the desired partition count for the
internal topic.
:return: the original default topic, with updated partition count
"""
topic = super().default_topic()
topic.config = TopicConfig(
num_partitions=self._get_partition_count(), replication_factor=1
)
return topic

def _check_file_partition_number(self, file: Path):
"""
Checks whether the next file is the start of a new partition so the timestamp
Expand All @@ -130,6 +118,12 @@ def _check_file_partition_number(self, file: Path):
self._previous_partition = partition
logger.debug(f"Beginning reading partition {partition}")

def _file_read(self, file: Union[Path, BinaryIO]) -> Generator[dict, None, None]:
yield from self._formatter.file_read(file)

def _file_list(self) -> Generator[Path, None, None]:
yield from _file_finder(self._filepath)

def _produce(self, record: dict):
kafka_msg = self._producer_topic.serialize(
key=record["_key"],
Expand All @@ -140,14 +134,26 @@ def _produce(self, record: dict):
key=kafka_msg.key, value=kafka_msg.value, timestamp=kafka_msg.timestamp
)

def default_topic(self) -> Topic:
"""
Uses the file structure to generate the desired partition count for the
internal topic.
:return: the original default topic, with updated partition count
"""
topic = super().default_topic()
topic.config = TopicConfig(
num_partitions=self._get_partition_count(), replication_factor=1
)
return topic

def run(self):
while self._running:
for file in _file_finder(self._filepath):
for file in self._file_list():
logger.info(f"Reading files from topic {self._filepath.name}")
self._check_file_partition_number(file)
for record in self._formatter.file_read(file):
if self._as_replay:
self._replay_delay(record["_timestamp"])
for record in self._file_read(file):
if self._as_replay and (timestamp := record.get("_timestamp")):
self._replay_delay(timestamp)
self._produce(record)
self.flush()
return
Expand Down

0 comments on commit 124db0b

Please sign in to comment.