Skip to content

Commit

Permalink
Merge pull request #1002 from rkingsbury/groupquery
Browse files Browse the repository at this point in the history
GroupBuilder: fix query kwarg and add tests
  • Loading branch information
rkingsbury authored Oct 8, 2024
2 parents 216f1a6 + c1ef139 commit 2ab7815
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 15 deletions.
1 change: 1 addition & 0 deletions docs/getting_started/group_builder.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ Note that unlike the previous `MapBuilder` example, we didn't call the source an

`GroupBuilder` inherits from `MapBuilder` so it has the same configurational parameters.

- query: A query to apply to items in the source Store.
- projection: list of the fields you want to project. This can reduce the data transfer load if you only need certain fields or sub-documents from the source documents
- timeout: optional timeout on the process function
- store_process_timeout: adds the process time into the target document for profiling
Expand Down
19 changes: 9 additions & 10 deletions src/maggma/builders/group_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(
Args:
source: source store
target: target store
query: optional query to filter source store
query: optional query to filter items from the source store.
projection: list of keys to project from the source for
processing. Limits data transfer to improve efficiency.
delete_orphans: Whether to delete documents on target store
Expand All @@ -57,7 +57,7 @@ def __init__(
self.source = source
self.target = target
self.grouping_keys = grouping_keys
self.query = query
self.query = query if query else {}
self.projection = projection
self.kwargs = kwargs
self.timeout = timeout
Expand Down Expand Up @@ -119,8 +119,9 @@ def get_items(self):

self.total = len(groups)
for group in groups:
docs = list(self.source.query(criteria=dict(zip(self.grouping_keys, group)), properties=projection))
yield docs
group_criteria = dict(zip(self.grouping_keys, group))
group_criteria.update(self.query)
yield list(self.source.query(criteria=group_criteria, properties=projection))

def process_item(self, item: list[dict]) -> dict[tuple, dict]: # type: ignore
keys = [d[self.source.key] for d in item]
Expand Down Expand Up @@ -184,9 +185,7 @@ def get_ids_to_process(self) -> Iterable:
"""
Gets the IDs that need to be processed.
"""
query = self.query or {}

distinct_from_target = list(self.target.distinct(self._target_keys_field, criteria=query))
distinct_from_target = list(self.target.distinct(self._target_keys_field, criteria=self.query))
processed_ids = []
# Not always guaranteed that MongoDB will unpack the list so we
# have to make sure we do that
Expand All @@ -196,19 +195,19 @@ def get_ids_to_process(self) -> Iterable:
else:
processed_ids.append(d)

all_ids = set(self.source.distinct(self.source.key, criteria=query))
all_ids = set(self.source.distinct(self.source.key, criteria=self.query))
self.logger.debug(f"Found {len(all_ids)} total docs in source")

if self.retry_failed:
failed_keys = self.target.distinct(self._target_keys_field, criteria={"state": "failed", **query})
failed_keys = self.target.distinct(self._target_keys_field, criteria={"state": "failed", **self.query})
unprocessed_ids = all_ids - (set(processed_ids) - set(failed_keys))
self.logger.debug(f"Found {len(failed_keys)} failed IDs in target")
else:
unprocessed_ids = all_ids - set(processed_ids)

self.logger.info(f"Found {len(unprocessed_ids)} IDs to process")

new_ids = set(self.source.newer_in(self.target, criteria=query, exhaustive=False))
new_ids = set(self.source.newer_in(self.target, criteria=self.query, exhaustive=False))

self.logger.info(f"Found {len(new_ids)} updated IDs to process")
return list(new_ids | unprocessed_ids)
Expand Down
10 changes: 5 additions & 5 deletions tests/builders/test_group_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Tests for group builder
"""

from datetime import datetime
from datetime import datetime, timezone
from random import randint

import pytest
Expand All @@ -13,7 +13,7 @@

@pytest.fixture(scope="module")
def now():
return datetime.utcnow()
return datetime.now(timezone.utc)


@pytest.fixture()
Expand Down Expand Up @@ -62,9 +62,9 @@ def unary_function(self, items: list[dict]) -> dict:


def test_grouping(source, target, docs):
builder = DummyGrouper(source, target, grouping_keys=["a"])
builder = DummyGrouper(source, target, query={"k": {"$ne": 3}}, grouping_keys=["a"])

assert len(docs) == len(builder.get_ids_to_process())
assert len(docs) - 1 == len(builder.get_ids_to_process()), f"{len(docs) -1} != {len(builder.get_ids_to_process())}"
assert len(builder.get_groups_from_keys([d["k"] for d in docs])) == 3

to_process = list(builder.get_items())
Expand All @@ -75,4 +75,4 @@ def test_grouping(source, target, docs):

builder.update_targets(processed)

assert len(builder.get_ids_to_process()) == 0
assert len(builder.get_ids_to_process()) == 0, f"{len(builder.get_ids_to_process())} != 0"

0 comments on commit 2ab7815

Please sign in to comment.