diff --git a/quixstreams/sources/community/blob_store/blob_source.py b/quixstreams/sources/community/blob_store/blob_source.py index 3fa0199be..914a9e40a 100644 --- a/quixstreams/sources/community/blob_store/blob_source.py +++ b/quixstreams/sources/community/blob_store/blob_source.py @@ -1,6 +1,6 @@ import logging from pathlib import Path -from typing import Optional, Union +from typing import Generator, Optional, Union from quixstreams.sources.community.file import FileSource from quixstreams.sources.community.file.compressions import CompressionName @@ -20,15 +20,15 @@ def __init__( blob_client: BlobClient, blob_format: Union[FormatName, Format], blob_compression: Optional[CompressionName] = None, - blob_folder: Optional[str] = None, - blob_file: Optional[str] = None, + blob_folder: Optional[Union[str, Path]] = None, + blob_file: Optional[Union[str, Path]] = None, as_replay: bool = True, name: Optional[str] = None, shutdown_timeout: float = 10.0, ): self._client = blob_client - self._blob_file = blob_file - self._blob_folder = blob_folder + self._blob_file = Path(blob_file) if blob_file else None + self._blob_folder = Path(blob_folder) if blob_folder else None super().__init__( filepath=self._client.location, @@ -40,24 +40,13 @@ def __init__( ) def _get_partition_count(self) -> int: - return 1 + return self._client.get_root_folder_count(self._blob_folder) - def _check_file_partition_number(self, file: Path) -> int: - return 0 + def _file_read(self, file: Path) -> Generator[dict, None, None]: + yield from super()._file_read(self._client.get_raw_blob_stream(file)) - def run(self): - blobs = ( - [self._blob_file] - if self._blob_file - else self._client.blob_finder(self._blob_folder) - ) - while self._running: - for file in blobs: - self._check_file_partition_number(Path(file)) - filestream = self._client.get_raw_blob_stream(file) - for record in self._formatter.file_read(filestream): - if self._as_replay and (timestamp := record.get("_timestamp")): - self._replay_delay(timestamp) - self._produce(record) - self.flush() - return + def _file_list(self) -> Generator[Path, None, None]: + if self._blob_file: + yield self._blob_file + else: + yield from self._client.blob_collector(self._blob_folder) diff --git a/quixstreams/sources/community/blob_store/clients/aws.py b/quixstreams/sources/community/blob_store/clients/aws.py index ccc92e5e6..6f747d8f9 100644 --- a/quixstreams/sources/community/blob_store/clients/aws.py +++ b/quixstreams/sources/community/blob_store/clients/aws.py @@ -1,6 +1,8 @@ +import logging from io import BytesIO from os import getenv -from typing import Generator, Optional +from pathlib import Path +from typing import Generator, Optional, Union from .base import BlobClient @@ -11,6 +13,8 @@ raise +logger = logging.getLogger(__name__) + __all__ = ("AwsS3BlobClient",) @@ -55,16 +59,27 @@ def client(self): self._client: S3Client = boto_client("s3", **self._credentials) return self._client - def get_raw_blob_stream(self, blob_path: str) -> BytesIO: - data = self.client.get_object(Bucket=self.location, Key=blob_path)[ + def get_raw_blob_stream(self, blob_path: Path) -> BytesIO: + data = self.client.get_object(Bucket=self.location, Key=str(blob_path))[ "Body" ].read() return BytesIO(data) - def blob_finder(self, folder: str) -> Generator[str, None, None]: - # TODO: Recursively navigate through folders. + def get_root_folder_count(self, folder: Path) -> int: resp = self.client.list_objects( - Bucket=self.location, Prefix=folder, Delimiter="/" + Bucket=self.location, Prefix=str(folder), Delimiter="/" ) - for item in resp["Contents"]: - yield item["Key"] + self._client = None + return len(resp["CommonPrefixes"]) + + def blob_collector(self, folder: Union[str, Path]) -> Generator[Path, None, None]: + resp = self.client.list_objects( + Bucket=self.location, + Prefix=str(folder), + Delimiter="/", + ) + for folder in resp.get("CommonPrefixes", []): + yield from self.blob_collector(folder["Prefix"]) + + for file in resp.get("Contents", []): + yield Path(file["Key"]) diff --git a/quixstreams/sources/community/blob_store/clients/azure.py b/quixstreams/sources/community/blob_store/clients/azure.py index 9590e9d2b..2830135a6 100644 --- a/quixstreams/sources/community/blob_store/clients/azure.py +++ b/quixstreams/sources/community/blob_store/clients/azure.py @@ -1,4 +1,5 @@ from io import BytesIO +from pathlib import Path from typing import Generator, Optional from .base import BlobClient @@ -30,11 +31,12 @@ def client(self): self._client: ContainerClient = container_client return self._client - def blob_finder(self, folder: str) -> Generator[str, None, None]: - for item in self.client.list_blob_names(name_starts_with=folder): + def blob_collector(self, folder: Path) -> Generator[str, None, None]: + # TODO: Recursively navigate folders. + for item in self.client.list_blob_names(name_starts_with=str(folder)): yield item - def get_raw_blob_stream(self, blob_name) -> BytesIO: - blob_client = self.client.get_blob_client(blob_name) + def get_raw_blob_stream(self, blob_name: Path) -> BytesIO: + blob_client = self.client.get_blob_client(str(blob_name)) data = blob_client.download_blob().readall() return BytesIO(data) diff --git a/quixstreams/sources/community/blob_store/clients/base.py b/quixstreams/sources/community/blob_store/clients/base.py index 5f3a90458..3cc085ef1 100644 --- a/quixstreams/sources/community/blob_store/clients/base.py +++ b/quixstreams/sources/community/blob_store/clients/base.py @@ -1,6 +1,7 @@ from abc import abstractmethod from dataclasses import dataclass from io import BytesIO +from pathlib import Path from typing import Any, Iterable, Union __all__ = ("BlobClient",) @@ -10,10 +11,10 @@ class BlobClient: _client: Any _credentials: Union[dict, str] - location: str + location: Union[str, Path] - @abstractmethod @property + @abstractmethod def client(self): ... """ @@ -24,7 +25,7 @@ def client(self): ... """ @abstractmethod - def blob_finder(self, folder: str) -> Iterable[str]: ... + def blob_collector(self, folder: Path) -> Iterable[Path]: ... """ Find all blobs starting from a root folder. @@ -33,7 +34,12 @@ def blob_finder(self, folder: str) -> Iterable[str]: ... """ @abstractmethod - def get_raw_blob_stream(self, blob_path: str) -> BytesIO: ... + def get_root_folder_count(self, filepath: Path) -> int: ... + + """Counts the number of folders at filepath to assume partition counts.""" + + @abstractmethod + def get_raw_blob_stream(self, blob_path: Path) -> BytesIO: ... """ Obtain a specific blob in its raw form. diff --git a/quixstreams/sources/community/file/file.py b/quixstreams/sources/community/file/file.py index 91313dd07..64f8f921d 100644 --- a/quixstreams/sources/community/file/file.py +++ b/quixstreams/sources/community/file/file.py @@ -1,7 +1,7 @@ import logging from pathlib import Path from time import sleep -from typing import Generator, Optional, Union +from typing import BinaryIO, Generator, Optional, Union from quixstreams.models import Topic, TopicConfig from quixstreams.sources import Source @@ -107,18 +107,6 @@ def _replay_delay(self, current_timestamp: int): def _get_partition_count(self) -> int: return len([f for f in self._filepath.iterdir()]) - def default_topic(self) -> Topic: - """ - Uses the file structure to generate the desired partition count for the - internal topic. - :return: the original default topic, with updated partition count - """ - topic = super().default_topic() - topic.config = TopicConfig( - num_partitions=self._get_partition_count(), replication_factor=1 - ) - return topic - def _check_file_partition_number(self, file: Path): """ Checks whether the next file is the start of a new partition so the timestamp @@ -130,6 +118,12 @@ def _check_file_partition_number(self, file: Path): self._previous_partition = partition logger.debug(f"Beginning reading partition {partition}") + def _file_read(self, file: Union[Path, BinaryIO]) -> Generator[dict, None, None]: + yield from self._formatter.file_read(file) + + def _file_list(self) -> Generator[Path, None, None]: + yield from _file_finder(self._filepath) + def _produce(self, record: dict): kafka_msg = self._producer_topic.serialize( key=record["_key"], @@ -140,14 +134,26 @@ def _produce(self, record: dict): key=kafka_msg.key, value=kafka_msg.value, timestamp=kafka_msg.timestamp ) + def default_topic(self) -> Topic: + """ + Uses the file structure to generate the desired partition count for the + internal topic. + :return: the original default topic, with updated partition count + """ + topic = super().default_topic() + topic.config = TopicConfig( + num_partitions=self._get_partition_count(), replication_factor=1 + ) + return topic + def run(self): while self._running: - for file in _file_finder(self._filepath): + for file in self._file_list(): logger.info(f"Reading files from topic {self._filepath.name}") self._check_file_partition_number(file) - for record in self._formatter.file_read(file): - if self._as_replay: - self._replay_delay(record["_timestamp"]) + for record in self._file_read(file): + if self._as_replay and (timestamp := record.get("_timestamp")): + self._replay_delay(timestamp) self._produce(record) self.flush() return