Skip to content

Commit

Permalink
Fix all method doc-strings
Browse files Browse the repository at this point in the history
  • Loading branch information
blythed committed Jan 16, 2025
1 parent 74381a7 commit ca6801b
Show file tree
Hide file tree
Showing 34 changed files with 656 additions and 136 deletions.
4 changes: 4 additions & 0 deletions plugins/pillow/superduper_pillow/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ def _encode_data(self, item):
return buffer.getvalue()

def decode_data(self, item):
"""Decode the data.
:param item: The data to decode.
"""
try:
return PIL.Image.open(io.BytesIO(item))
except Exception as e:
Expand Down
14 changes: 11 additions & 3 deletions superduper/backends/base/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,20 @@ def exists(

@abstractmethod
def put_bytes(self, serialized: bytes, file_id: str):
"""Save bytes in artifact store""" ""
"""Save bytes in artifact store
:param serialized: Bytes to save
:param file_id: Identifier of artifact in the store
"""
pass

@abstractmethod
def put_file(self, file_path: str, file_id: str) -> str:
"""Save file in artifact store and return file_id."""
"""Save file in artifact store and return file_id.
:param file_path: Path to file
:param file_id: Identifier of artifact in the store
"""
pass

def save_artifact(self, r: t.Dict):
Expand Down Expand Up @@ -131,7 +139,7 @@ def save_artifact(self, r: t.Dict):
def delete_artifact(self, artifact_ids: t.List[str]):
"""Delete artifact from artifact store.
:param r: dictionary with mandatory fields
:param artifact_ids: list of artifact ids to delete.
"""
for artifact_id in artifact_ids:
try:
Expand Down
7 changes: 6 additions & 1 deletion superduper/backends/base/cdc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@ class CDCBackend(BaseBackend):

@abstractmethod
def handle_event(self, event_type, table, ids):
"""Handle an incoming event."""
"""Handle an incoming event.
:param event_type: The type of event.
:param table: The table to handle.
:param ids: The ids to handle.
"""
pass

@property
Expand Down
12 changes: 9 additions & 3 deletions superduper/backends/base/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@ def type(self) -> str:
"""Return the type of compute engine."""
pass

# TODO is this used anywhere?
@abstractmethod
def release_futures(self, context: str):
"""Release futures from backend."""
"""Release futures from backend.
:param context: Context of futures to release.
"""
pass

# TODO needed?
Expand Down Expand Up @@ -94,7 +96,11 @@ def initialize(self):
"""Connect to address."""

def create_handler(self, *args, **kwargs):
"""Create handler on component declare."""
"""Create handler on component declare.
:param args: *args for `create_handler`
:param kwargs: *kwargs for `create_handler`
"""

@property
def db(self) -> 'Datalayer':
Expand Down
2 changes: 1 addition & 1 deletion superduper/backends/base/data_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def create_table_and_schema(self, identifier: str, schema: "Schema"):
"""Create a schema in the data-backend.
:param identifier: The identifier of the table.
:param mapping: The mapping of the schema.
:param schema: The schema to create.
"""

@abstractmethod
Expand Down
6 changes: 4 additions & 2 deletions superduper/backends/base/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def create_artifact_relation(self, uuid, artifact_ids):
Create a relation between an artifact and a component version.
:param uuid: UUID of component version
:param artifact: artifact
:param artifact_ids: artifact
"""
artifact_ids = (
[artifact_ids] if not isinstance(artifact_ids, list) else artifact_ids
Expand All @@ -109,7 +109,7 @@ def delete_artifact_relation(self, uuid, artifact_ids):
Delete a relation between an artifact and a component version.
:param uuid: UUID of component version
:param artifact: artifact
:param artifact_ids: artifact ids
"""
artifact_ids = (
[artifact_ids] if not isinstance(artifact_ids, list) else artifact_ids
Expand All @@ -127,6 +127,7 @@ def get_artifact_relations(self, uuid=None, artifact_id=None):
"""
Get all relations between an artifact and a component version.
:param uuid: UUID of component version
:param artifact_id: artifact
"""
if uuid is None and artifact_id is None:
Expand Down Expand Up @@ -455,6 +456,7 @@ def replace_object(
:param identifier: identifier of object
:param type_id: type of object
:param version: version of object
:param uuid: UUID of object
"""
if version is None and uuid is None:
assert isinstance(type_id, str)
Expand Down
30 changes: 25 additions & 5 deletions superduper/backends/base/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __post_init__(self, db: t.Optional['Datalayer'] = None):
self.identifier = re.sub('[\-]+', '-', self.identifier)

def unpack(self):
"""Unpack the query."""
parts = _unpack(self.parts)
return type(self)(
db=self.db,
Expand Down Expand Up @@ -206,7 +207,7 @@ def __getitem__(self, item):
def set_db(self, value: 'Datalayer'):
"""Set the datalayer to use to execute the query.
:param db: The datalayer to use to execute the query.
:param value: The datalayer to use to execute the query.
"""

def _set_the_db(r, db):
Expand Down Expand Up @@ -352,7 +353,13 @@ def dict(
uuid: bool = True,
refs: bool = False,
):
"""Return the query as a dictionary."""
"""Return the query as a dictionary.
:param metadata: Include metadata.
:param defaults: Include defaults.
:param uuid: Include UUID.
:param refs: Include references.
"""
query, documents = self._dump_query()
documents = [Document(r) for r in documents]
return Document(
Expand Down Expand Up @@ -563,14 +570,22 @@ def swap_keys(r: str | list | dict):
return out

def tolist(self, db=None, eager_mode=False, **kwargs):
"""Execute and convert to list."""
"""Execute and convert to list.
:param db: Datalayer instance.
:param eager_mode: Eager mode.
:param kwargs: Additional keyword arguments.
"""
return self.execute(db=db, eager_mode=eager_mode, **kwargs).tolist()

def execute(self, db=None, eager_mode=False, handle_outputs=True, **kwargs):
"""
Execute the query.
:param db: Datalayer instance.
:param eager_mode: Eager mode.
:param handle_outputs: Handle outputs.
:param kwargs: Additional keyword arguments.
"""
if self.type == 'select' and handle_outputs and 'outputs' in str(self):
query = self.complete_uuids(db=db or self.db)
Expand Down Expand Up @@ -870,9 +885,10 @@ class Model(_BaseQuery):
:param parts: The parts of the query.
"""

type: t.ClassVar[str] = 'predict'

table: str
identifier: str = ''
type: t.ClassVar[str] = 'predict'

def execute(self):
"""Execute the model as a query."""
Expand All @@ -893,7 +909,11 @@ def do_execute(self, db=None):
return Document({'_base': r})

def dict(self, metadata: bool = True, defaults: bool = True):
"""Return the query as a dictionary."""
"""Return the query as a dictionary.
:param metadata: Include metadata.
:param defaults: Include defaults.
"""
query, documents = self._dump_query()
documents = [Document(r) for r in documents]
return Document(
Expand Down
14 changes: 10 additions & 4 deletions superduper/backends/base/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,11 @@ def close_connection(self):
pass

def consume(self, *args, **kwargs):
"""Start consuming messages from queue."""
"""Start consuming messages from queue.
:param args: positional arguments
:param kwargs: keyword arguments
"""
logging.info(f"Started consuming on queue: {self.queue_name}")
try:
self.start_consuming()
Expand All @@ -85,16 +89,18 @@ def __init__(self, uri: t.Optional[str]):
self.queue: t.Dict = defaultdict(lambda: [])

@abstractmethod
def build_consumer(self, **kwargs):
"""Build a consumer instance."""
def build_consumer(self, **kwargs) -> BaseQueueConsumer:
"""Build a consumer instance.
:param kwargs: keyword arguments to consumer.
"""

@abstractmethod
def publish(self, events: t.List[Event]):
"""
Publish events to local queue.
:param events: list of events
:param to: Component name for events to be published.
"""

@property
Expand Down
5 changes: 4 additions & 1 deletion superduper/backends/local/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,10 @@ def __iter__(self):
return iter(self._cache.keys())

def expire(self, item):
"""Expire an item from the cache."""
"""Expire an item from the cache.
:param item: The item to expire.
"""
try:
del self._cache[item]
for (t, i), uuid in self._component_to_uuid.items():
Expand Down
5 changes: 4 additions & 1 deletion superduper/backends/local/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ def build(cls, CFG, **kwargs):
)

def drop(self, force: bool = False):
"""Drop the cluster."""
"""Drop the cluster.
:param force: Force drop the cluster.
"""
if not force:
if not click.confirm(
"Are you sure you want to drop the cache? ",
Expand Down
13 changes: 4 additions & 9 deletions superduper/backends/local/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,26 +43,21 @@ def name(self) -> str:
"""The name of the backend."""
return "local"

# TODO needed?
def release_futures(self, context: str):
"""Release futures for a given context."""
"""Release futures for a given context.
:param context: The apply context to release futures for.
"""
try:
del self.futures[context]
except KeyError:
logging.warn(f'Could not release futures for context {context}')

# TODO needed? (we have .put)
# TODO hook to do what?
def component_hook(self, *args, **kwargs):
"""Hook for component."""
pass

def submit(self, job: Job) -> str:
"""
Submits a function for local execution.
:param job: The `Job` to be executed.
:param dependencies: List of `job_ids`
"""
args, kwargs = job.get_args_kwargs(self.futures[job.context])

Expand Down
23 changes: 8 additions & 15 deletions superduper/backends/local/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,6 @@ def __init__(self, uri: t.Optional[str] = None):
self._component_uuid_mapping: t.Dict = {}
self.lock = threading.Lock()

def show_pending_create_events(self, type_id: str | None = None):
if type_id is None:
return [
{'type_id': type_id, 'identifier': e.component['identifier']}
for e in self.queue['_apply']
]
else:
return [
e.component['identifier']
for e in self.queue['_apply']
if e.component['type_id'] == type_id
]

def list(self):
"""List all components."""
return self.queue.keys()
Expand Down Expand Up @@ -84,7 +71,10 @@ def list_uuids(self):
return list(self._component_uuid_mapping.values())

def build_consumer(self, **kwargs):
"""Build consumer client."""
"""Build consumer client.
:param kwargs: Additional arguments.
"""
return LocalQueueConsumer()

def publish(self, events: t.List[Event]):
Expand All @@ -111,7 +101,10 @@ def start_consuming(self):
"""Start consuming."""

def consume(self, db: 'Datalayer', queue: t.Dict[str, t.List[Event]]):
"""Consume the current queue and run jobs."""
"""Consume the current queue and run jobs.
:param db: Datalayer instance.
:param queue: Queue to consume.
"""
keys = list(queue.keys())[:]
for k in keys:
consume_events(events=queue[k], table=k, db=db)
Expand Down
9 changes: 6 additions & 3 deletions superduper/backends/local/vector_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)

if t.TYPE_CHECKING:
from superduper import Component
from superduper import Component, VectorIndex


class LocalVectorSearchBackend(VectorSearchBackend):
Expand Down Expand Up @@ -176,8 +176,11 @@ def find_nearest_from_array(self, h, n=100, within_ids=None):
_ids = [self.index[i] for i in ix]
return _ids, scores

def initialize(self, vector_index):
"""Initialize the vector index."""
def initialize(self, vector_index: 'VectorIndex'):
"""Initialize the vector index.
:param vector_index: Vector index to initialize
"""
vector_index.copy_vectors()

def add(self, items: t.Sequence[VectorItem] = (), cache: bool = False) -> None:
Expand Down
Loading

0 comments on commit ca6801b

Please sign in to comment.