From ca6801b3ad6f3ce6170683acdeed8342bca01637 Mon Sep 17 00:00:00 2001 From: Duncan Blythe Date: Thu, 16 Jan 2025 18:32:28 +0100 Subject: [PATCH] Fix all method doc-strings --- plugins/pillow/superduper_pillow/encoder.py | 4 + superduper/backends/base/artifacts.py | 14 +- superduper/backends/base/cdc.py | 7 +- superduper/backends/base/compute.py | 12 +- superduper/backends/base/data_backend.py | 2 +- superduper/backends/base/metadata.py | 6 +- superduper/backends/base/query.py | 30 +- superduper/backends/base/queue.py | 14 +- superduper/backends/local/cache.py | 5 +- superduper/backends/local/cluster.py | 5 +- superduper/backends/local/compute.py | 13 +- superduper/backends/local/queue.py | 23 +- superduper/backends/local/vector_search.py | 9 +- superduper/base/datalayer.py | 34 +- superduper/base/document.py | 35 +- superduper/base/event.py | 10 +- superduper/base/leaf.py | 17 +- superduper/base/logger.py | 2 +- superduper/components/cdc.py | 4 + superduper/components/component.py | 43 ++- superduper/components/dataset.py | 5 +- superduper/components/datatype.py | 76 ++++- superduper/components/graph.py | 8 +- superduper/components/listener.py | 11 +- superduper/components/model.py | 46 ++- superduper/components/schema.py | 7 + superduper/components/template.py | 1 + superduper/components/vector_index.py | 4 - superduper/misc/download.py | 303 ++++++++++++++++++ superduper/misc/special_dicts.py | 5 +- .../unittest/backends/local/test_artifacts.py | 4 +- test/unittest/component/test_listener.py | 2 +- test/unittest/test_docstrings.py | 29 +- test/utils/component/datatype.py | 2 +- 34 files changed, 656 insertions(+), 136 deletions(-) create mode 100644 superduper/misc/download.py diff --git a/plugins/pillow/superduper_pillow/encoder.py b/plugins/pillow/superduper_pillow/encoder.py index de0abad1c..35f0f9f9f 100644 --- a/plugins/pillow/superduper_pillow/encoder.py +++ b/plugins/pillow/superduper_pillow/encoder.py @@ -23,6 +23,10 @@ def _encode_data(self, item): return buffer.getvalue() def decode_data(self, item): + """Decode the data. + + :param item: The data to decode. + """ try: return PIL.Image.open(io.BytesIO(item)) except Exception as e: diff --git a/superduper/backends/base/artifacts.py b/superduper/backends/base/artifacts.py index eabf6125a..2575861c3 100644 --- a/superduper/backends/base/artifacts.py +++ b/superduper/backends/base/artifacts.py @@ -92,12 +92,20 @@ def exists( @abstractmethod def put_bytes(self, serialized: bytes, file_id: str): - """Save bytes in artifact store""" "" + """Save bytes in artifact store + + :param serialized: Bytes to save + :param file_id: Identifier of artifact in the store + """ pass @abstractmethod def put_file(self, file_path: str, file_id: str) -> str: - """Save file in artifact store and return file_id.""" + """Save file in artifact store and return file_id. + + :param file_path: Path to file + :param file_id: Identifier of artifact in the store + """ pass def save_artifact(self, r: t.Dict): @@ -131,7 +139,7 @@ def save_artifact(self, r: t.Dict): def delete_artifact(self, artifact_ids: t.List[str]): """Delete artifact from artifact store. - :param r: dictionary with mandatory fields + :param artifact_ids: list of artifact ids to delete. """ for artifact_id in artifact_ids: try: diff --git a/superduper/backends/base/cdc.py b/superduper/backends/base/cdc.py index dd6494291..2dc4c07b3 100644 --- a/superduper/backends/base/cdc.py +++ b/superduper/backends/base/cdc.py @@ -12,7 +12,12 @@ class CDCBackend(BaseBackend): @abstractmethod def handle_event(self, event_type, table, ids): - """Handle an incoming event.""" + """Handle an incoming event. + + :param event_type: The type of event. + :param table: The table to handle. + :param ids: The ids to handle. + """ pass @property diff --git a/superduper/backends/base/compute.py b/superduper/backends/base/compute.py index f6adb88ca..c46b82260 100644 --- a/superduper/backends/base/compute.py +++ b/superduper/backends/base/compute.py @@ -24,10 +24,12 @@ def type(self) -> str: """Return the type of compute engine.""" pass - # TODO is this used anywhere? @abstractmethod def release_futures(self, context: str): - """Release futures from backend.""" + """Release futures from backend. + + :param context: Context of futures to release. + """ pass # TODO needed? @@ -94,7 +96,11 @@ def initialize(self): """Connect to address.""" def create_handler(self, *args, **kwargs): - """Create handler on component declare.""" + """Create handler on component declare. + + :param args: *args for `create_handler` + :param kwargs: *kwargs for `create_handler` + """ @property def db(self) -> 'Datalayer': diff --git a/superduper/backends/base/data_backend.py b/superduper/backends/base/data_backend.py index 80bb10f96..aff5c6dea 100644 --- a/superduper/backends/base/data_backend.py +++ b/superduper/backends/base/data_backend.py @@ -94,7 +94,7 @@ def create_table_and_schema(self, identifier: str, schema: "Schema"): """Create a schema in the data-backend. :param identifier: The identifier of the table. - :param mapping: The mapping of the schema. + :param schema: The schema to create. """ @abstractmethod diff --git a/superduper/backends/base/metadata.py b/superduper/backends/base/metadata.py index 635a6db85..12fb25805 100644 --- a/superduper/backends/base/metadata.py +++ b/superduper/backends/base/metadata.py @@ -92,7 +92,7 @@ def create_artifact_relation(self, uuid, artifact_ids): Create a relation between an artifact and a component version. :param uuid: UUID of component version - :param artifact: artifact + :param artifact_ids: artifact """ artifact_ids = ( [artifact_ids] if not isinstance(artifact_ids, list) else artifact_ids @@ -109,7 +109,7 @@ def delete_artifact_relation(self, uuid, artifact_ids): Delete a relation between an artifact and a component version. :param uuid: UUID of component version - :param artifact: artifact + :param artifact_ids: artifact ids """ artifact_ids = ( [artifact_ids] if not isinstance(artifact_ids, list) else artifact_ids @@ -127,6 +127,7 @@ def get_artifact_relations(self, uuid=None, artifact_id=None): """ Get all relations between an artifact and a component version. + :param uuid: UUID of component version :param artifact_id: artifact """ if uuid is None and artifact_id is None: @@ -455,6 +456,7 @@ def replace_object( :param identifier: identifier of object :param type_id: type of object :param version: version of object + :param uuid: UUID of object """ if version is None and uuid is None: assert isinstance(type_id, str) diff --git a/superduper/backends/base/query.py b/superduper/backends/base/query.py index 6bc2d1eff..5391e1853 100644 --- a/superduper/backends/base/query.py +++ b/superduper/backends/base/query.py @@ -60,6 +60,7 @@ def __post_init__(self, db: t.Optional['Datalayer'] = None): self.identifier = re.sub('[\-]+', '-', self.identifier) def unpack(self): + """Unpack the query.""" parts = _unpack(self.parts) return type(self)( db=self.db, @@ -206,7 +207,7 @@ def __getitem__(self, item): def set_db(self, value: 'Datalayer'): """Set the datalayer to use to execute the query. - :param db: The datalayer to use to execute the query. + :param value: The datalayer to use to execute the query. """ def _set_the_db(r, db): @@ -352,7 +353,13 @@ def dict( uuid: bool = True, refs: bool = False, ): - """Return the query as a dictionary.""" + """Return the query as a dictionary. + + :param metadata: Include metadata. + :param defaults: Include defaults. + :param uuid: Include UUID. + :param refs: Include references. + """ query, documents = self._dump_query() documents = [Document(r) for r in documents] return Document( @@ -563,7 +570,12 @@ def swap_keys(r: str | list | dict): return out def tolist(self, db=None, eager_mode=False, **kwargs): - """Execute and convert to list.""" + """Execute and convert to list. + + :param db: Datalayer instance. + :param eager_mode: Eager mode. + :param kwargs: Additional keyword arguments. + """ return self.execute(db=db, eager_mode=eager_mode, **kwargs).tolist() def execute(self, db=None, eager_mode=False, handle_outputs=True, **kwargs): @@ -571,6 +583,9 @@ def execute(self, db=None, eager_mode=False, handle_outputs=True, **kwargs): Execute the query. :param db: Datalayer instance. + :param eager_mode: Eager mode. + :param handle_outputs: Handle outputs. + :param kwargs: Additional keyword arguments. """ if self.type == 'select' and handle_outputs and 'outputs' in str(self): query = self.complete_uuids(db=db or self.db) @@ -870,9 +885,10 @@ class Model(_BaseQuery): :param parts: The parts of the query. """ + type: t.ClassVar[str] = 'predict' + table: str identifier: str = '' - type: t.ClassVar[str] = 'predict' def execute(self): """Execute the model as a query.""" @@ -893,7 +909,11 @@ def do_execute(self, db=None): return Document({'_base': r}) def dict(self, metadata: bool = True, defaults: bool = True): - """Return the query as a dictionary.""" + """Return the query as a dictionary. + + :param metadata: Include metadata. + :param defaults: Include defaults. + """ query, documents = self._dump_query() documents = [Document(r) for r in documents] return Document( diff --git a/superduper/backends/base/queue.py b/superduper/backends/base/queue.py index 6f74793e5..d30610342 100644 --- a/superduper/backends/base/queue.py +++ b/superduper/backends/base/queue.py @@ -58,7 +58,11 @@ def close_connection(self): pass def consume(self, *args, **kwargs): - """Start consuming messages from queue.""" + """Start consuming messages from queue. + + :param args: positional arguments + :param kwargs: keyword arguments + """ logging.info(f"Started consuming on queue: {self.queue_name}") try: self.start_consuming() @@ -85,8 +89,11 @@ def __init__(self, uri: t.Optional[str]): self.queue: t.Dict = defaultdict(lambda: []) @abstractmethod - def build_consumer(self, **kwargs): - """Build a consumer instance.""" + def build_consumer(self, **kwargs) -> BaseQueueConsumer: + """Build a consumer instance. + + :param kwargs: keyword arguments to consumer. + """ @abstractmethod def publish(self, events: t.List[Event]): @@ -94,7 +101,6 @@ def publish(self, events: t.List[Event]): Publish events to local queue. :param events: list of events - :param to: Component name for events to be published. """ @property diff --git a/superduper/backends/local/cache.py b/superduper/backends/local/cache.py index 246b21335..3160cd4b5 100644 --- a/superduper/backends/local/cache.py +++ b/superduper/backends/local/cache.py @@ -101,7 +101,10 @@ def __iter__(self): return iter(self._cache.keys()) def expire(self, item): - """Expire an item from the cache.""" + """Expire an item from the cache. + + :param item: The item to expire. + """ try: del self._cache[item] for (t, i), uuid in self._component_to_uuid.items(): diff --git a/superduper/backends/local/cluster.py b/superduper/backends/local/cluster.py index 8c1e3bfb8..07805962f 100644 --- a/superduper/backends/local/cluster.py +++ b/superduper/backends/local/cluster.py @@ -44,7 +44,10 @@ def build(cls, CFG, **kwargs): ) def drop(self, force: bool = False): - """Drop the cluster.""" + """Drop the cluster. + + :param force: Force drop the cluster. + """ if not force: if not click.confirm( "Are you sure you want to drop the cache? ", diff --git a/superduper/backends/local/compute.py b/superduper/backends/local/compute.py index 007370c55..f2b2c3fad 100644 --- a/superduper/backends/local/compute.py +++ b/superduper/backends/local/compute.py @@ -43,26 +43,21 @@ def name(self) -> str: """The name of the backend.""" return "local" - # TODO needed? def release_futures(self, context: str): - """Release futures for a given context.""" + """Release futures for a given context. + + :param context: The apply context to release futures for. + """ try: del self.futures[context] except KeyError: logging.warn(f'Could not release futures for context {context}') - # TODO needed? (we have .put) - # TODO hook to do what? - def component_hook(self, *args, **kwargs): - """Hook for component.""" - pass - def submit(self, job: Job) -> str: """ Submits a function for local execution. :param job: The `Job` to be executed. - :param dependencies: List of `job_ids` """ args, kwargs = job.get_args_kwargs(self.futures[job.context]) diff --git a/superduper/backends/local/queue.py b/superduper/backends/local/queue.py index 8bb76137c..c3f4dbd4b 100644 --- a/superduper/backends/local/queue.py +++ b/superduper/backends/local/queue.py @@ -30,19 +30,6 @@ def __init__(self, uri: t.Optional[str] = None): self._component_uuid_mapping: t.Dict = {} self.lock = threading.Lock() - def show_pending_create_events(self, type_id: str | None = None): - if type_id is None: - return [ - {'type_id': type_id, 'identifier': e.component['identifier']} - for e in self.queue['_apply'] - ] - else: - return [ - e.component['identifier'] - for e in self.queue['_apply'] - if e.component['type_id'] == type_id - ] - def list(self): """List all components.""" return self.queue.keys() @@ -84,7 +71,10 @@ def list_uuids(self): return list(self._component_uuid_mapping.values()) def build_consumer(self, **kwargs): - """Build consumer client.""" + """Build consumer client. + + :param kwargs: Additional arguments. + """ return LocalQueueConsumer() def publish(self, events: t.List[Event]): @@ -111,7 +101,10 @@ def start_consuming(self): """Start consuming.""" def consume(self, db: 'Datalayer', queue: t.Dict[str, t.List[Event]]): - """Consume the current queue and run jobs.""" + """Consume the current queue and run jobs. + :param db: Datalayer instance. + :param queue: Queue to consume. + """ keys = list(queue.keys())[:] for k in keys: consume_events(events=queue[k], table=k, db=db) diff --git a/superduper/backends/local/vector_search.py b/superduper/backends/local/vector_search.py index 9a24c1e1f..13c958338 100644 --- a/superduper/backends/local/vector_search.py +++ b/superduper/backends/local/vector_search.py @@ -11,7 +11,7 @@ ) if t.TYPE_CHECKING: - from superduper import Component + from superduper import Component, VectorIndex class LocalVectorSearchBackend(VectorSearchBackend): @@ -176,8 +176,11 @@ def find_nearest_from_array(self, h, n=100, within_ids=None): _ids = [self.index[i] for i in ix] return _ids, scores - def initialize(self, vector_index): - """Initialize the vector index.""" + def initialize(self, vector_index: 'VectorIndex'): + """Initialize the vector index. + + :param vector_index: Vector index to initialize + """ vector_index.copy_vectors() def add(self, items: t.Sequence[VectorItem] = (), cache: bool = False) -> None: diff --git a/superduper/base/datalayer.py b/superduper/base/datalayer.py index 9ffb65ccf..d2a912fbb 100644 --- a/superduper/base/datalayer.py +++ b/superduper/base/datalayer.py @@ -122,6 +122,7 @@ def drop(self, force: bool = False, data: bool = False): Drop all data, artifacts, and metadata. :param force: Force drop. + :param data: Drop data. """ if not force and not click.confirm( "!!!WARNING USE WITH CAUTION AS YOU WILL" @@ -159,6 +160,7 @@ def show( 'vector_index', 'job']. :param identifier: Identifying string to component. :param version: (Optional) Numerical version - specify for full metadata. + :param uuid: (Optional) UUID of the component. """ if uuid is not None: return self.metadata.get_component_by_uuid(uuid) @@ -349,9 +351,9 @@ def on_event(self, table: str, ids: t.List[str], event_type: 'str'): """ Trigger computation jobs after data insertion. - :param query: The select or update query object that reduces - the scope of computations. + :param table: The table to trigger computation jobs on. :param ids: IDs that further reduce the scope of computations. + :param event_type: The type of event to trigger. """ from superduper.base.event import Change @@ -424,17 +426,6 @@ def _update(self, update: Query, refresh: bool = True) -> UpdateResult: return updated_ids - @deprecated - def add(self, object: t.Any): - """ - Note: The use of `add` is deprecated, use `apply` instead. - - :param object: Object to be stored. - :param dependencies: List of jobs which should execute before component - initialization begins. - """ - return self.apply(object) - def apply( self, object: t.Union[Component, t.Sequence[t.Any], t.Any], @@ -448,10 +439,8 @@ def apply( and linked to the primary database through metadata. :param object: Object to be stored. - :param dependencies: List of jobs which should execute before component - initialization begins. + :param force: Force apply. :param wait: Wait for apply events. - :return: Tuple containing the added object(s) and the original object(s). """ result = apply.apply(db=self, object=object, force=force, wait=wait) return result @@ -473,6 +462,7 @@ def remove( :param identifier: Identifier of the component (refer to `container.base.Component`). :param version: [Optional] Numerical version to remove. + :param recursive: Toggle to remove all descendants of the component. :param force: Force skip confirmation (use with caution). """ # TODO: versions = [version] if version is not None else ... @@ -535,10 +525,10 @@ def load( :param identifier: Identifier of the component (see `container.base.Component`). :param version: [Optional] Numerical version. + :param uuid: [Optional] UUID of the component to load. + :param huuid: [Optional] human-readable UUID of the component to load. :param allow_hidden: Toggle to ``True`` to allow loading of deprecated components. - :param huuid: [Optional] human-readable UUID of the component to load. - :param uuid: [Optional] UUID of the component to load. """ if version is not None: assert type_id is not None @@ -682,9 +672,6 @@ def replace(self, object: t.Any): (Use with caution!) :param object: The object to replace. - :param upsert: Toggle to ``True`` to enable replacement even if - the object doesn't exist yet. - :param force: set to `True` to skip confirm # TODO """ old_uuid = None try: @@ -725,7 +712,10 @@ def _replace_fn(component): self.metadata.create_component(serialized) def expire(self, uuid): - """Expire a component from the cache.""" + """Expire a component from the cache. + + :param uuid: The UUID of the component to expire. + """ self.cluster.cache.expire(uuid) self.metadata.expire(uuid) parents = self.metadata.get_component_version_parents(uuid) diff --git a/superduper/base/document.py b/superduper/base/document.py index df39491fc..833ca7c22 100644 --- a/superduper/base/document.py +++ b/superduper/base/document.py @@ -46,11 +46,19 @@ def __init__(self, getters=None): self.add_getter(k, v) def add_getter(self, name: str, getter: t.Callable): - """Add a getter for a reference type.""" + """Add a getter for a reference type. + + :param name: The name of the getter. + :param getter: The getter. + """ self._getters[name].append(getter) def run(self, name, data): - """Run the getters one by one until one returns a value.""" + """Run the getters one by one until one returns a value. + + :param name: The name of the getter. + :param data: The data to get. + """ if name not in self._getters: return data for getter in self._getters[name]: @@ -151,6 +159,11 @@ def __init__( self.schema = schema def map(self, fn, condition): + """Map a function over the document. + + :param fn: The function to map. + :param condition: The condition to map over. + """ def _map(r): if isinstance(r, dict): out = {} @@ -178,7 +191,10 @@ def diff(self, other: 'Document'): return Document(out, schema=self.schema) def update(self, other: t.Union['Document', dict]): - """Update document with values from other.""" + """Update document with values from other. + + :param other: The other document to update with. + """ schema = self.schema or Schema('tmp', fields={}) if isinstance(other, Document) and other.schema: @@ -201,6 +217,9 @@ def encode( :param schema: The schema to use. :param leaves_to_keep: The types of leaves to keep. + :param metadata: Whether to include metadata. + :param defaults: Whether to include defaults. + :param keep_schema: Whether to keep the schema. """ builds: t.Dict[str, dict] = self.get(KEY_BUILDS, {}) blobs: t.Dict[str, bytes] = self.get(KEY_BLOBS, {}) @@ -259,6 +278,7 @@ def decode( :param r: The encoded data. :param schema: The schema to use. :param db: The datalayer to use. + :param getters: The getters to use. """ if '_variables' in r: variables = {**r['_variables'], 'output_prefix': CFG.output_prefix} @@ -267,7 +287,6 @@ def decode( ) schema = schema or r.get(KEY_SCHEMA) schema = get_schema(db, schema) - builds = r.get(KEY_BUILDS, {}) # TODO is this the right place for this? @@ -334,7 +353,6 @@ def variables(self) -> t.List[str]: def set_variables(self, **kwargs) -> 'Document': """Set free variables of self. - :param db: The datalayer to use. :param kwargs: The vales to set the variables to `_replace_variables`. """ from superduper.base.variables import _replace_variables @@ -346,7 +364,12 @@ def __repr__(self) -> str: return f'Document({repr(dict(self))})' @staticmethod - def decode_blobs(schema, r): + def decode_blobs(schema: 'Schema', r: t.Dict): + """Decode blobs in a document. + + :param schema: The schema to use. + :param r: The document to decode. + """ for k, v in schema.fields.items(): if k not in r: continue diff --git a/superduper/base/event.py b/superduper/base/event.py index 12f2128e4..03d4cc332 100644 --- a/superduper/base/event.py +++ b/superduper/base/event.py @@ -127,7 +127,10 @@ class Create(Event): parent: str | None = None def execute(self, db: 'Datalayer'): - """Execute the create event.""" + """Execute the create event. + + :param db: Datalayer instance. + """ # TODO decide where to assign version artifact_ids, _ = db._find_artifacts(self.component) db.metadata.create_artifact_relation(self.component['uuid'], artifact_ids) @@ -180,7 +183,10 @@ def execute( self, db: 'Datalayer', ): - """Execute the create event.""" + """Execute the create event. + + :param db: Datalayer instance. + """ # TODO decide where to assign version artifact_ids, _ = db._find_artifacts(self.component) db.metadata.create_artifact_relation(self.component['uuid'], artifact_ids) diff --git a/superduper/base/leaf.py b/superduper/base/leaf.py index 52b570c94..a17dd26cf 100644 --- a/superduper/base/leaf.py +++ b/superduper/base/leaf.py @@ -159,13 +159,15 @@ def leaves(self): if isinstance(getattr(self, f.name), Leaf) } + # TODO signature is inverted from `Component.encode` def encode(self, leaves_to_keep=(), metadata: bool = True, defaults: bool = True): """Encode itself. After encoding everything is a vanilla dictionary (JSON + bytes). - :param schema: Schema instance. :param leaves_to_keep: Leaves to keep. + :param metadata: Include metadata. + :param defaults: Include default values. """ from superduper.base.document import _deep_flat_encode @@ -233,7 +235,6 @@ def _replace_uuids_with_keys(record): def set_variables(self, **kwargs) -> 'Leaf': """Set free variables of self. - :param db: Datalayer instance. :param kwargs: Keyword arguments to pass to `_replace_variables`. """ from superduper import Document @@ -264,8 +265,13 @@ def defaults(self): out[f.name] = value return out + # TODO the signature does not agree with the `Component.dict` method def dict(self, metadata: bool = True, defaults: bool = True): - """Return dictionary representation of the object.""" + """Return dictionary representation of the object. + + :param metadata: Include metadata. + :param defaults: Include default values. + """ from superduper import Document r = asdict(self) @@ -356,6 +362,7 @@ def __call__(self, *args, **kwargs): return self.compile()(*args, **kwargs) def compile(self): + """Compile the address.""" raise NotImplementedError @@ -394,6 +401,7 @@ def __post_init__(self, db: t.Optional['Datalayer'] = None, parent=None): self.parent = parent def compile(self): + """Compile the import.""" return self.parent @@ -419,6 +427,7 @@ def __post_init__(self, db: t.Optional['Datalayer'] = None): self.parent = object(*self.args, **self.kwargs) def compile(self): + """Compile the import call.""" return self.parent @@ -434,6 +443,7 @@ class Attribute(Address): attribute: str def compile(self): + """Compile the attribute.""" parent = self.parent.compile() return getattr(parent, self.attribute) @@ -450,6 +460,7 @@ class Index(Address): index: int def compile(self): + """Compile the index.""" parent = self.parent.compile() return parent[self.index] diff --git a/superduper/base/logger.py b/superduper/base/logger.py index f96cf16a1..81fcdf298 100644 --- a/superduper/base/logger.py +++ b/superduper/base/logger.py @@ -95,7 +95,7 @@ def multikey_success(msg: str, *args): """Log a message with the SUCCESS level. :param msg: The message to log. - param args: Additional arguments to log. + :param args: Additional arguments to log. """ logger.opt(depth=1).success(" ".join(map(str, (msg, *args)))) diff --git a/superduper/components/cdc.py b/superduper/components/cdc.py index fe7dc6d9d..7a3907456 100644 --- a/superduper/components/cdc.py +++ b/superduper/components/cdc.py @@ -19,6 +19,10 @@ class CDC(Component): cdc_table: str def handle_update_or_same(self, other): + """Handle the case in which the component is update without breaking changes. + + :param other: The other component to handle.""" + super().handle_update_or_same(other) other.cdc_table = self.cdc_table diff --git a/superduper/components/component.py b/superduper/components/component.py index 0f1ddfe88..68bd0d965 100644 --- a/superduper/components/component.py +++ b/superduper/components/component.py @@ -188,10 +188,6 @@ class Component(Leaf, metaclass=ComponentMeta): build_variables: t.Dict | None = None build_template: str | None = None - # TODO what's this? - def refresh(self): - pass - @staticmethod def sort_components(components): """Sort components based on topological order. @@ -216,7 +212,11 @@ def huuid(self): """Return a human-readable uuid.""" return f'{self.type_id}:{self.identifier}:{self.uuid}' - def handle_update_or_same(self, other): + def handle_update_or_same(self, other: 'Component'): + """Handle when a component is changed without breaking changes. + + :param other: The other component to handle. + """ other.uuid = self.uuid other.version = self.version @@ -259,7 +259,10 @@ def _find_refs(r): return sorted(list(set(out))) def get_children(self, deep: bool = False) -> t.List["Component"]: - """Get all the children of the component.""" + """Get all the children of the component. + + :param deep: If set `True` get all recursively. + """ r = self.dict().encode(leaves_to_keep=(Component,)) out = [v for v in r['_builds'].values() if isinstance(v, Component)] lookup = {} @@ -322,6 +325,7 @@ def get_triggers(self, event_type, requires: t.Sequence[str] | None = None): Get all the triggers for the component. :param event_type: event_type + :param requires: the methods which should run first """ # Get all of the methods in the class which have the `@trigger` decorator # and which match the event type @@ -370,6 +374,7 @@ def create_jobs( :param event_type: The event type. :param ids: The ids of the component. :param jobs: The jobs of the component. + :param requires: The requirements of the component. """ # TODO replace this with a DAG check max_it = 100 @@ -530,7 +535,10 @@ def __post_init__(self, db): raise ValueError('identifier cannot be empty or None') def cleanup(self, db: Datalayer): - """Method to clean the component.""" + """Method to clean the component. + + :param db: The `Datalayer` to use for the operation. + """ db.cluster.cache.drop(self) def _get_metadata(self): @@ -548,7 +556,10 @@ def dependencies(self): return () def init(self, db=None): - """Method to help initiate component field dependencies.""" + """Method to help initiate component field dependencies. + + :param db: The `Datalayer` to use for the operation. + """ self.db = self.db or db self.unpack(db=db) @@ -557,6 +568,8 @@ def unpack(self, db=None): """Method to unpack the component. This method is used to initialize all the fields of the component and leaf + + :param db: The database to use for the operation. """ def _init(item): @@ -634,6 +647,7 @@ def read(path: str, db: t.Optional[Datalayer] = None): Read a `Component` instance from a directory created with `.export`. :param path: Path to the directory containing the component. + :param db: Datalayer instance to be used to read the component. Expected directory structure: ``` @@ -704,6 +718,12 @@ def export( Save `self` to a directory using super-duper protocol. :param path: Path to the directory to save the component. + :param format: Format to save the component in (json/ yaml). + :param zip: Whether to zip the directory. + :param defaults: Whether to save default values. + :param metadata: Whether to save metadata. + :param hr: Whether to save human-readable blobs. + :param component: Name of the component file. Created directory structure: ``` @@ -830,7 +850,12 @@ def _zip_export(path): def dict( self, metadata: bool = True, defaults: bool = True, refs: bool = False ) -> 'Document': - """A dictionary representation of the component.""" + """A dictionary representation of the component. + + :param metadata: If set `true` include metadata. + :param defaults: If set `true` include defaults. + :param refs: If set `true` include references. + """ from superduper import Document r = super().dict(metadata=metadata, defaults=defaults) diff --git a/superduper/components/dataset.py b/superduper/components/dataset.py index 07fb6600f..3a48a1300 100644 --- a/superduper/components/dataset.py +++ b/superduper/components/dataset.py @@ -50,7 +50,10 @@ def data(self): return self._data def init(self, db=None): - """Initialization method.""" + """Initialization method. + + :param db: The database to use for the operation. + """ db = db or self.db super().init(db=db) if self.pin: diff --git a/superduper/components/datatype.py b/superduper/components/datatype.py index a84b89ced..20ca9bbbe 100644 --- a/superduper/components/datatype.py +++ b/superduper/components/datatype.py @@ -71,7 +71,6 @@ def encode_data(self, item): """Decode the item as `bytes`. :param item: The item to decode. - :param info: The optional information dictionary. """ @abstractmethod @@ -79,7 +78,6 @@ def decode_data(self, item): """Decode the item from bytes. :param item: The item to decode. - :param info: The optional information dictionary. """ @@ -96,11 +94,17 @@ class BaseVector(BaseDataType): @abstractmethod def encode_data(self, item): - pass + """Encode the item as `bytes`. + + :param item: The item to encode. + """ @abstractmethod def decode_data(self, item): - pass + """Decode the item from `bytes`. + + :param item: The item to decode. + """ class NativeVector(BaseVector): @@ -110,12 +114,19 @@ class NativeVector(BaseVector): dtype: str = 'float' def encode_data(self, item): + """Encode the item as a list of floats. + + :param item: The item to encode. + """ if isinstance(item, numpy.ndarray): item = item.tolist() return item def decode_data(self, item): - # TODO: + """Decode the item from a list of floats. + + :param item: The item to decode. + """ return numpy.array(item).astype(self.dtype) @@ -150,9 +161,17 @@ def datatype_impl(self): return datatype def encode_data(self, item): + """Encode the item as `bytes`. + + :param item: The item to encode. + """ return self.datatype_impl.encode_data(item=item) def decode_data(self, item): + """Decode the item from `bytes`. + + :param item: The item to decode. + """ return self.datatype_impl.decode_data(item=item) @@ -169,9 +188,17 @@ def __post_init__(self, db): return super().__post_init__(db) def encode_data(self, item): + """Encode the item as a string. + + :param item: The dictionary or list (json-able object) to encode. + """ return json.dumps(item) def decode_data(self, item): + """Decode the item from string form. + + :param item: The item to decode. + """ return json.loads(item) @@ -179,6 +206,10 @@ class _Encodable: encodable: t.ClassVar[str] = 'encodable' def encode_data(self, item): + """Encode the item as `bytes`. + + :param item: The item to encode. + """ return self._encode_data(item) @@ -186,6 +217,10 @@ class _Artifact: encodable: t.ClassVar[str] = 'artifact' def encode_data(self, item): + """Encode the item as `bytes`. + + :param item: The item to encode. + """ return Blob(bytes=self._encode_data(item)) @@ -194,6 +229,10 @@ def _encode_data(self, item): return pickle.dumps(item) def decode_data(self, item): + """Decode the item from `bytes`. + + :param item: The item to decode. + """ return pickle.loads(item) @@ -210,6 +249,10 @@ def _encode_data(self, item): return dill.dumps(item, recurse=True) def decode_data(self, item): + """Decode the item from `bytes`. + + :param item: The item to decode. + """ return dill.loads(item) @@ -233,10 +276,18 @@ class File(BaseDataType): encodable: t.ClassVar[str] = 'file' def encode_data(self, item): + """Encode the item as a file path. + + :param item: The file path to encode. + """ assert os.path.exists(item) return FileItem(path=item) def decode_data(self, item): + """Decode the item placeholder. + + :param item: The file path to decode. + """ return item @@ -266,10 +317,12 @@ def reference(self): @abstractmethod def init(self): + """Initialize the object.""" pass @abstractmethod def unpack(self): + """Unpack the object to its original form.""" pass @@ -287,11 +340,13 @@ def __post_init__(self, db=None): return super().__post_init__(db) def init(self): + """Initialize the file to local disk.""" if self.path: return self.path = self.db.artifact_store.get_file(self.identifier) def unpack(self): + """Get the path out of the object.""" self.init() return self.path @@ -316,12 +371,15 @@ def __post_init__(self, db=None): self.identifier = get_hash(self.bytes) return super().__post_init__(db) + # TODO why do some of these methods have `init(self, db=None)`? def init(self): + """Initialize the blob.""" if self.bytes: return self.bytes = self.db.artifact_store.get_bytes(self.identifier) def unpack(self): + """Get the bytes out of the blob.""" self.init() return self.bytes @@ -377,11 +435,19 @@ def __post_init__(self, db): return super().__post_init__(db) def encode_data(self, item): + """Encode the data. + + :param item: The data to encode. + """ if item.dtype != self.dtype: raise TypeError(f'dtype was {item.dtype}, expected {self.dtype}') return memoryview(item).tobytes() def decode_data(self, item): + """Decode the data. + + :param item: The data to decode. + """ shape = self.shape if isinstance(shape, int): shape = (self.shape,) diff --git a/superduper/components/graph.py b/superduper/components/graph.py index 9446f552b..d0a04ad2d 100644 --- a/superduper/components/graph.py +++ b/superduper/components/graph.py @@ -173,7 +173,10 @@ def __post_init__(self, db, example): self.signature = 'singleton' def predict(self, *args): - """Single prediction.""" + """Single prediction. + + :param args: Model input + """ if self.signature == 'singleton': return args[0] return OutputWrapper( @@ -476,6 +479,9 @@ def predict(self, *args, **kwargs): Single data point prediction passes the args and kwargs to the defined node flow in the graph. + + :param args: Arguments for the model. + :param kwargs: Keyword arguments for the model. """ # Validate the node for incompletion # TODO: Move to to_graph method and validate the graph diff --git a/superduper/components/listener.py b/superduper/components/listener.py index 80428e936..542a8b5ec 100644 --- a/superduper/components/listener.py +++ b/superduper/components/listener.py @@ -61,13 +61,22 @@ def __post_init__(self, db): return super().__post_init__(db) def handle_update_or_same(self, other): + """If the component is new, but does not contain breaking changes. + + :param other: Other listener object. + """ super().handle_update_or_same(other) other.output_table = self.output_table def dict( self, metadata: bool = True, defaults: bool = True, refs: bool = False ) -> t.Dict: - """Convert to dictionary.""" + """Convert to dictionary. + + :param metadata: Include metadata. + :param defaults: Include default values. + :param refs: Include references. + """ out = super().dict(metadata=metadata, defaults=defaults, refs=refs) if not metadata: try: diff --git a/superduper/components/model.py b/superduper/components/model.py index c232a8cfc..2b7acc595 100644 --- a/superduper/components/model.py +++ b/superduper/components/model.py @@ -30,6 +30,7 @@ if t.TYPE_CHECKING: from superduper.base.datalayer import Datalayer from superduper.components.dataset import Dataset + from superduper.backends.base.cluster import Cluster EncoderArg = t.Union[BaseDataType, str, None] @@ -434,8 +435,11 @@ def _wrapper(self, data): args, kwargs = self.handle_input_type(data, self.signature) return self.predict(*args, **kwargs) - def declare_component(self, cluster): - """Declare model on compute.""" + def declare_component(self, cluster: 'Cluster'): + """Declare model on compute. + + :param cluster: Cluster instance to declare the model. + """ super().declare_component(cluster) if self.deploy or self.serve: cluster.compute.put(self) @@ -835,6 +839,7 @@ def to_vector_index( :param key: Key to be bound to the model :param select: Object for selecting which data is processed :param predict_kwargs: Keyword arguments to self.model.predict + :param identifier: Identifier for the listener :param kwargs: Additional keyword arguments """ from superduper.components.vector_index import VectorIndex @@ -860,8 +865,8 @@ def to_listener( :param key: Key to be bound to the model :param select: Object for selecting which data is processed - :param identifier: A string used to identify the model. :param predict_kwargs: Keyword arguments to self.model.predict + :param identifier: Identifier for the listener :param kwargs: Additional keyword arguments to pass to `Listener` """ from superduper.components.listener import Listener @@ -995,7 +1000,10 @@ def fit_in_db(self): ) def append_metrics(self, d: t.Dict[str, float]) -> None: - """Append metrics to the model.""" + """Append metrics to the model. + + :param d: Dictionary of metrics to append. + """ assert self.trainer is not None if self.trainer.metric_values is not None: for k, v in d.items(): @@ -1309,10 +1317,6 @@ def inputs(self) -> Inputs: """Instance of `Inputs` to represent model params.""" return self.models[0].inputs - def declare_component(self, cluster): - """Declare model on compute.""" - cluster.compute.put(self) - def on_create(self, db: Datalayer): """Post create hook. @@ -1379,23 +1383,37 @@ def _pre_create(self, db): if not self.datatype: self.datatype = self.models[self.model].datatype - @override def predict(self, *args, **kwargs) -> t.Any: + """Predict on a single data point. + + :param args: Positional arguments to predict on. + :param kwargs: Keyword arguments to predict on. + """ logging.info(f'Predicting with model {self.model}') return self.models[self.model].predict(*args, **kwargs) - @override def fit(self, *args, **kwargs) -> t.Any: + """Fit the model on the given data. + + :param args: Arguments to fit on. + :param kwargs: Keyword arguments to fit on. + """ logging.info(f'Fitting with model {self.model}') return self.models[self.model].fit(*args, **kwargs) - @override def predict_batches(self, dataset) -> t.List: + """Predict on a series of data points defined in the dataset. + + :param dataset: Series of data points to predict on. + """ logging.info(f'Predicting with model {self.model}') return self.models[self.model].predict_batches(dataset) - @override def init(self, db=None): + """Initialize the model. + + :param db: DataLayer instance. + """ if hasattr(self.models[self.model], 'shape'): self.shape = getattr(self.models[self.model], 'shape') self.example = self.models[self.model].example @@ -1425,6 +1443,10 @@ def _build_prompt(self, query, docs): return self.prompt_template.format(context=context, query=query) def predict(self, query: str): + """Predict on a single query string. + + :param query: Query string. + """ assert self.db, 'db cannot be None' select = self.select.set_variables(db=self.db, query=query) self.db.execute(select) diff --git a/superduper/components/schema.py b/superduper/components/schema.py index 041529b4f..3e4cc28ac 100644 --- a/superduper/components/schema.py +++ b/superduper/components/schema.py @@ -80,6 +80,12 @@ def __post_init__(self, db): self.fields[k] = v def update(self, other: 'Schema'): + """Update the schema with another schema. + + Note this is not in place. + + :param other: Schema to update with. + """ new_fields = self.fields.copy() new_fields.update(other.fields) return Schema(self.identifier, fields=new_fields) @@ -155,6 +161,7 @@ def encode_data(self, out, builds, blobs, files, leaves_to_keep=()): :param builds: Builds. :param blobs: Blobs. :param files: Files. + :param leaves_to_keep: `Leaf` instances to keep (don't encode) """ for k, field in self.fields.items(): if not isinstance(field, BaseDataType): diff --git a/superduper/components/template.py b/superduper/components/template.py index 3b01e3810..21a181ae8 100644 --- a/superduper/components/template.py +++ b/superduper/components/template.py @@ -201,6 +201,7 @@ def read(path: str, db: t.Optional[Datalayer] = None): Read a `Component` instance from a directory created with `.export`. :param path: Path to the directory containing the component. + :param db: Datalayer instance to be used to read the component. Expected directory structure: ``` diff --git a/superduper/components/vector_index.py b/superduper/components/vector_index.py index 277a6a075..ce8cb6801 100644 --- a/superduper/components/vector_index.py +++ b/superduper/components/vector_index.py @@ -119,10 +119,6 @@ def __post_init__(self, db): self.cdc_table = self.cdc_table or self.indexing_listener.outputs return super().__post_init__(db) - def refresh(self): - if self.cdc_table.startswith(CFG.output_prefix): - self.cdc_table = self.indexing_listener.outputs - # TODO why this? def __hash__(self): return hash((self.type_id, self.identifier)) diff --git a/superduper/misc/download.py b/superduper/misc/download.py new file mode 100644 index 000000000..20c2c7fe8 --- /dev/null +++ b/superduper/misc/download.py @@ -0,0 +1,303 @@ +import os +import re +import signal +import sys +import tempfile +import typing as t +import warnings +from contextlib import contextmanager +from io import BytesIO +from multiprocessing.pool import ThreadPool + +import boto3 +import requests +from tqdm import tqdm + +from superduper import CFG, logging + +# from superduper.components.datatype import _BaseEncodable +from superduper.components.model import Model + + +class TimeoutException(Exception): + """ + Timeout exception. + + :param args: *args of `Exception` + :param kwargs: **kwargs of `Exception` + """ + + +def timeout_handler(signum, frame): + """Timeout handler to raise an TimeoutException. + + :param signum: signal number + :param frame: frame + """ + raise TimeoutException() + + +@contextmanager +def timeout(seconds): + """Context manager to set a timeout. + + :param seconds: seconds until timeout + """ + old_handler = signal.signal(signal.SIGALRM, timeout_handler) + signal.alarm(seconds) + try: + yield + finally: + signal.alarm(0) + signal.signal(signal.SIGALRM, old_handler) + + +class Fetcher: + """Fetches data from a URI. + + :param headers: headers to be used for download + :param n_workers: number of download workers + """ + + DIALECTS: t.ClassVar = ('http', 's3', 'file') + + def __init__(self, headers: t.Optional[t.Dict] = None, n_workers: int = 0): + session = boto3.Session() + self.headers = headers + self.s3_client = session.client("s3") + self.request_session = requests.Session() + self.request_adapter = requests.adapters.HTTPAdapter( + max_retries=3, + pool_connections=n_workers if n_workers else 1, + pool_maxsize=n_workers * 10, + ) + self.request_session.mount("http://", self.request_adapter) + self.request_session.mount("https://", self.request_adapter) + + @classmethod + def is_uri(cls, uri: str): + """Helper function to check if uri is one of the dialects. + + :param uri: uri string. + """ + return uri.split('://')[0] in cls.DIALECTS + + def _download_s3_folder(self, uri): + folder_objects = [] + path = uri.split('s3://')[-1] + bucket_name = path.split('/')[0] + prefix = uri.split(bucket_name + '/')[-1] + + objects = self.s3_client.list_objects_v2(Bucket=bucket_name, Prefix=prefix) + + if 'Contents' in objects: + for obj in objects['Contents']: + s3_key = obj['Key'] + if not s3_key.endswith('/'): + data = self._download_s3_object(uri + s3_key) + info = { + 'file': s3_key, + 'data': data, + 'folder': prefix.split('/')[-1], + } + folder_objects.append(info) + return folder_objects + + def _download_s3_object(self, uri): + f = BytesIO() + path = uri.split('s3://')[-1] + bucket_name = path.split('/')[0] + file = '/'.join(path.split('/')[1:]) + self.s3_client.download_fileobj(bucket_name, file, f) + return f.getvalue() + + def _download_file(self, path): + path = re.split('^file://', path)[-1] + with open(path, 'rb') as f: + return f.read() + + def _download_from_uri(self, uri): + return self.request_session.get(uri, headers=self.headers).content + + def __call__(self, uri: str): + """Download data from a URI. + + :param uri: uri to download from + """ + if uri.startswith('file://'): + return self._download_file(uri) + elif uri.startswith('s3://'): + if uri.endswith('/'): + return self._download_s3_folder(uri) + return self._download_s3_object(uri) + elif uri.startswith('http://') or uri.startswith('https://'): + return self._download_from_uri(uri) + else: + raise NotImplementedError(f'unknown type of URI "{uri}"') + + def save(self, uri: str, contents: t.Union[t.List, bytes], file_id: str): + """Save downloaded bytes to a cached directory. + + :param uri: uri to download from + :param contents: downloaded contents + :param file_id: file id + """ + download_folder = CFG.downloads.folder + + if not download_folder: + download_folder = os.path.join( + tempfile.gettempdir(), "superduper", "ArtifactStore" + ) + + save_folder = os.path.join(download_folder, file_id) + os.makedirs(save_folder, exist_ok=True) + + if isinstance(contents, list): + folder_path = None + for content in contents: + name = content['file'] + folder = content['folder'] + data = content['data'] + folder_path = os.path.join(save_folder, folder) + os.makedirs(folder_path, exist_ok=True) + + path = os.path.join(folder_path, name) + with open(path, 'wb') as f: + f.write(data) + return folder_path + + else: + base_name = uri.split('/')[-1] + + path = os.path.join(save_folder, base_name) + with open(path, 'wb') as f: + f.write(contents) + return path + + +class BaseDownloader: + """Base class for downloading files. + + :param uris: list of uris/ file names to fetch + :param n_workers: number of multiprocessing workers + :param timeout: set seconds until request times out + :param headers: dictionary of request headers passed to``requests`` package + :param raises: raises error ``True``/``False`` + """ + + def __init__( + self, + uris: t.List[str], + n_workers: int = 0, + timeout: t.Optional[int] = None, + headers: t.Optional[t.Dict] = None, + raises: bool = True, + ): + self.timeout = timeout + self.n_workers = n_workers + self.uris = uris + self.headers = headers or {} + self.raises = raises + self.fetcher = Fetcher(headers=headers, n_workers=n_workers) + self.results: t.Dict = {} + + def go(self): + """Download all files. + + Uses a :py:class:`multiprocessing.pool.ThreadPool` to parallelize + connections. + """ + logging.info(f'number of workers {self.n_workers}') + prog = tqdm(total=len(self.uris)) + prog.prefix = 'downloading from uris' + self.failed = 0 + prog.prefix = "failed: 0" + + def f(i): + prog.update() + try: + if self.timeout is not None: + with timeout(self.timeout): + self._download(i) + else: + self._download(i) + except TimeoutException: + logging.warning(f'timed out {i}') + except Exception as e: + if self.raises: + raise e + warnings.warn(str(e)) + self.failed += 1 + prog.prefix = f"failed: {self.failed} [{e}]" + + if self.n_workers == 0: + self._sequential_go(f) + return + + self._parallel_go(f) + + def _download(self, i): + k = self.uris[i] + self.results[k] = self.fetcher(k) + + def _parallel_go(self, f): + pool = ThreadPool(self.n_workers) + try: + pool.map(f, range(len(self.uris))) + except KeyboardInterrupt: + logging.warning("--keyboard interrupt--") + pool.terminate() + pool.join() + sys.exit(1) # Kill this subprocess so it doesn't hang + + pool.close() + pool.join() + + def _sequential_go(self, f): + for i in range(len(self.uris)): + f(i) + + +class DownloadFiles(Model): + """Download files from a list of URIs. + + :param num_workers: number of multiprocessing workers + :param postprocess: postprocess function to apply to the results + :param timeout: set seconds until request times out + :param headers: dictionary of request headers passed to``requests`` package + :param raises: raises error ``True``/``False`` + :param signature: signature of the model + """ + + num_workers: int = (10,) + postprocess: t.Optional[t.Callable] = None + timeout: t.Optional[int] = None + headers: t.Optional[t.Dict] = None + raises: bool = True + signature: str = 'singleton' + + def predict(self, uri): + """Predict a single URI. + + :param uri: uri to predict + """ + return self.predict_batches([uri])[0] + + def predict_batches(self, dataset): + """Predict a batch of URIs. + + :param dataset: list of uris/ file names to fetch + """ + downloader = BaseDownloader( + uris=dataset, + n_workers=self.num_workers, + timeout=self.timeout, + headers=self.headers or {}, + raises=self.raises, + ) + downloader.go() + results = [downloader.results[uri] for uri in dataset] + results = [self.datatype.decoder(r) for r in results] + if self.postprocess: + results = [self.postprocess(r) for r in results] + return results diff --git a/superduper/misc/special_dicts.py b/superduper/misc/special_dicts.py index e28dab693..7f29cbe02 100644 --- a/superduper/misc/special_dicts.py +++ b/superduper/misc/special_dicts.py @@ -173,7 +173,10 @@ def _str2var(x, item, variable): return x def create_template(self, **kwargs): - """Convert all instances of string to variable.""" + """Convert all instances of string to variable. + + :param kwargs: mappings from values to variable names + """ r = self for k, v in kwargs.items(): r = SuperDuperFlatEncode._str2var(r, v, k) diff --git a/test/unittest/backends/local/test_artifacts.py b/test/unittest/backends/local/test_artifacts.py index 93291cba8..16bd28e0c 100644 --- a/test/unittest/backends/local/test_artifacts.py +++ b/test/unittest/backends/local/test_artifacts.py @@ -62,7 +62,7 @@ def test_save_and_load_directory( # test save and load directory test_component = TestComponent(path=random_directory, identifier="test") - db.add(test_component) + db.apply(test_component) test_component_loaded = db.load("TestComponent", "test") test_component_loaded.init() # assert that the paths are different @@ -83,7 +83,7 @@ def test_save_and_load_file(db, artifact_store: FileSystemArtifactStore): # test save and load file file = os.path.abspath(__file__) test_component = TestComponent(path=file, identifier="test") - db.add(test_component) + db.apply(test_component) test_component_loaded = db.load("TestComponent", "test") test_component_loaded.init() diff --git a/test/unittest/component/test_listener.py b/test/unittest/component/test_listener.py index c34d6ff38..25e472d90 100644 --- a/test/unittest/component/test_listener.py +++ b/test/unittest/component/test_listener.py @@ -175,7 +175,7 @@ def test_listener_cleanup(db, data): identifier="listener1", ) - db.add(listener1) + db.apply(listener1) doc = db[listener1.outputs].select().tolist()[0] result = Document(doc.unpack())[listener1.outputs] assert isinstance(result, type(data)) diff --git a/test/unittest/test_docstrings.py b/test/unittest/test_docstrings.py index 47e1ec5c3..989de9df5 100644 --- a/test/unittest/test_docstrings.py +++ b/test/unittest/test_docstrings.py @@ -95,14 +95,12 @@ def check_class_docstring(cls, line): ) -def check_method_docstring(method, cls, line): - str_ = f'{cls.__module__}.{cls.__name__}.{method.__name__}' - print(str_) - if 'builtin' in str_: - return +def check_method_docstring(method, line): + doc_string = method.__doc__ if doc_string is None: - raise MissingDocstring(method, cls.__module__, parent=cls.__name__) + msg = str(method) + raise MissingDocstring(method.__module__, msg, line=line) params = { k: v for k, v in inspect.signature(method).parameters.items() if k != 'self' @@ -111,28 +109,28 @@ def check_method_docstring(method, cls, line): if len(doc_params) != len(params): raise MismatchingDocParameters( - module=cls.__module__, - name=method.__name__, + module=method.__module__, + name=str(method), msg=f'Got {len(params)} parameters but doc-string has {len(doc_params)}.', - parent=cls, + parent=None, line=line, ) for i, (p, (dp, expl)) in enumerate(zip(params, doc_params.items())): if p != dp: raise MismatchingDocParameters( - module=cls.__module__, - name=method.__name__, + module=method.__module__, + name=str(method), msg=f'At position {i}: {p} != {dp}', - parent=cls.__name__, + parent=None, line=line, ) if not expl.strip(): raise MissingParameterExplanation( - module=cls.__module__, - name=method.__name__, + module=method.__module__, + name=str(method), msg=f'Missing explanation of parameter {dp}', - parent=cls.__name__, + parent=None, line=line, ) @@ -325,6 +323,5 @@ def test_method_docstrings(test_case): test_case = METHOD_TEST_CASES[test_case] check_method_docstring( TEST_CASES[test_case]["::object"], - TEST_CASES[test_case]["::object"].__class__, TEST_CASES[test_case]["::line"], ) diff --git a/test/utils/component/datatype.py b/test/utils/component/datatype.py index 652e75c35..f584bb6db 100644 --- a/test/utils/component/datatype.py +++ b/test/utils/component/datatype.py @@ -140,7 +140,7 @@ def check_component_with_db(data, datatype, db): x=data, child=ChildComponent("child", y=2), ) - db.add(c) + db.apply(c) pprint(c) print_sep()