From cf46aae2b23ca8df5a450e396e07fe5b9f4b378b Mon Sep 17 00:00:00 2001 From: Aric Coady Date: Sun, 17 Dec 2023 09:03:07 -0800 Subject: [PATCH] Switched from `black` to `ruff`. Line-length set to 100. Dict union operator. --- .github/workflows/build.yml | 2 +- Makefile | 2 +- docs/examples.ipynb | 19 ++++-- lupyne/engine/analyzers.py | 4 +- lupyne/engine/documents.py | 56 ++++++++++------ lupyne/engine/indexers.py | 66 ++++++++++++++----- lupyne/engine/queries.py | 12 +++- lupyne/engine/utils.py | 4 +- lupyne/services/base.py | 8 ++- lupyne/services/graphql.py | 12 ++-- lupyne/services/rest.py | 6 +- pyproject.toml | 15 +++-- tests/conftest.py | 14 ++-- tests/test_engine.py | 123 +++++++++++++++++++++++++++++------- tests/test_graphql.py | 11 +++- tests/test_rest.py | 7 +- 16 files changed, 273 insertions(+), 88 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 67ef46b..2f4e5d1 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -25,7 +25,7 @@ jobs: - uses: actions/setup-python@v5 with: python-version: 3.x - - run: pip install black[jupyter] ruff mypy + - run: pip install ruff mypy - run: make lint docs: diff --git a/Makefile b/Makefile index f8d83a7..f2c8ab1 100644 --- a/Makefile +++ b/Makefile @@ -3,8 +3,8 @@ check: python -m pytest -s --cov-append --cov=lupyne.services tests/test_rest.py tests/test_graphql.py lint: - black --check . ruff . + ruff format --check . mypy -p lupyne.engine html: diff --git a/docs/examples.ipynb b/docs/examples.ipynb index 22503ca..8968077 100644 --- a/docs/examples.ipynb +++ b/docs/examples.ipynb @@ -66,8 +66,10 @@ "metadata": {}, "outputs": [], "source": [ - "indexer = engine.Indexer('tempIndex') # Indexer combines Writer and Searcher; StandardAnalyzer is the default\n", - "indexer.set('fieldname', engine.Field.Text, stored=True) # default indexed text settings for documents\n", + "# Indexer combines Writer and Searcher; StandardAnalyzer is the default\n", + "indexer = engine.Indexer('tempIndex')\n", + "# default indexed text settings for documents\n", + "indexer.set('fieldname', engine.Field.Text, stored=True)\n", "indexer.add(fieldname=text) # add document\n", "indexer.commit() # commit changes and refresh searcher\n", "\n", @@ -99,8 +101,15 @@ "from org.apache.lucene.queries import spans\n", "\n", "q1 = search.TermQuery(index.Term('text', 'lucene'))\n", - "q2 = search.PhraseQuery.Builder().add(index.Term('text', 'search')).add(index.Term('text', 'engine')).build()\n", - "search.BooleanQuery.Builder().add(q1, search.BooleanClause.Occur.MUST).add(q2, search.BooleanClause.Occur.MUST).build()" + "q2 = (\n", + " search.PhraseQuery.Builder()\n", + " .add(index.Term('text', 'search'))\n", + " .add(index.Term('text', 'engine'))\n", + " .build()\n", + ")\n", + "search.BooleanQuery.Builder().add(q1, search.BooleanClause.Occur.MUST).add(\n", + " q2, search.BooleanClause.Occur.MUST\n", + ").build()" ] }, { @@ -535,7 +544,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.0" + "version": "3.12.1" }, "vscode": { "interpreter": { diff --git a/lupyne/engine/analyzers.py b/lupyne/engine/analyzers.py index 3b0625f..210e489 100644 --- a/lupyne/engine/analyzers.py +++ b/lupyne/engine/analyzers.py @@ -139,7 +139,9 @@ def parse(self, query: str, field='', op='', parser=None, **attrs) -> search.Que **attrs: additional attributes to set on the parser """ # parsers aren't thread-safe (nor slow), so create one each time - cls = queryparser.classic.QueryParser if isinstance(field, str) else queryparser.classic.MultiFieldQueryParser + cls = queryparser.classic.MultiFieldQueryParser + if isinstance(field, str): + cls = queryparser.classic.QueryParser args = field, self if isinstance(field, Mapping): boosts = HashMap() diff --git a/lupyne/engine/documents.py b/lupyne/engine/documents.py index f74172f..b67de25 100644 --- a/lupyne/engine/documents.py +++ b/lupyne/engine/documents.py @@ -27,17 +27,25 @@ class Field(FieldType): # type: ignore indexOptions = property(FieldType.indexOptions, FieldType.setIndexOptions) omitNorms = property(FieldType.omitNorms, FieldType.setOmitNorms) stored = property(FieldType.stored, FieldType.setStored) - storeTermVectorOffsets = property(FieldType.storeTermVectorOffsets, FieldType.setStoreTermVectorOffsets) - storeTermVectorPayloads = property(FieldType.storeTermVectorPayloads, FieldType.setStoreTermVectorPayloads) - storeTermVectorPositions = property(FieldType.storeTermVectorPositions, FieldType.setStoreTermVectorPositions) + storeTermVectorOffsets = property( + FieldType.storeTermVectorOffsets, FieldType.setStoreTermVectorOffsets + ) + storeTermVectorPayloads = property( + FieldType.storeTermVectorPayloads, FieldType.setStoreTermVectorPayloads + ) + storeTermVectorPositions = property( + FieldType.storeTermVectorPositions, FieldType.setStoreTermVectorPositions + ) storeTermVectors = property(FieldType.storeTermVectors, FieldType.setStoreTermVectors) tokenized = property(FieldType.tokenized, FieldType.setTokenized) properties = {name for name in locals() if not name.startswith('__')} types = {int: 'long', float: 'double', str: 'string'} - types.update(NUMERIC='long', BINARY='string', SORTED='string', SORTED_NUMERIC='long', SORTED_SET='string') + types.update( + NUMERIC='long', BINARY='string', SORTED='string', SORTED_NUMERIC='long', SORTED_SET='string' + ) dimensions = property( - getattr(FieldType, 'pointDataDimensionCount', getattr(FieldType, 'pointDimensionCount', None)), + FieldType.pointDimensionCount, lambda self, count: self.setDimensions(count, Long.BYTES), ) @@ -54,7 +62,8 @@ def __init__(self, name: str, docValuesType='', indexOptions='', dimensions=0, * self.indexOptions = getattr(index.IndexOptions, indexOptions.upper()) if docValuesType: self.docValuesType = getattr(index.DocValuesType, docValuesType.upper()) - self.docValueClass = getattr(document, docValuesType.title().replace('_', '') + 'DocValuesField') + name = docValuesType.title().replace('_', '') + self.docValueClass = getattr(document, name + 'DocValuesField') if self.stored or self.indexed or self.dimensions: settings = self.settings del settings['docValuesType'] @@ -62,9 +71,12 @@ def __init__(self, name: str, docValuesType='', indexOptions='', dimensions=0, * assert self.stored or self.indexed or self.docvalues or self.dimensions @classmethod - def String(cls, name: str, tokenized=False, omitNorms=True, indexOptions='DOCS', **settings) -> 'Field': + def String( + cls, name: str, tokenized=False, omitNorms=True, indexOptions='DOCS', **settings + ) -> 'Field': """Return Field with default settings for strings.""" - return cls(name, tokenized=tokenized, omitNorms=omitNorms, indexOptions=indexOptions, **settings) + settings.update(tokenized=tokenized, omitNorms=omitNorms, indexOptions=indexOptions) + return cls(name, **settings) @classmethod def Text(cls, name: str, indexOptions='DOCS_AND_FREQS_AND_POSITIONS', **settings) -> 'Field': @@ -128,8 +140,8 @@ def __init__(self, name: str, sep: str = '.', **settings): def values(self, value: str) -> Iterator[str]: """Generate component field values in order.""" value = value.split(self.sep) # type: ignore - for index in range(1, len(value) + 1): - yield self.sep.join(value[:index]) + for stop in range(1, len(value) + 1): + yield self.sep.join(value[:stop]) def items(self, *values: str) -> Iterator[document.Field]: """Generate indexed component fields.""" @@ -326,9 +338,8 @@ def dict(self, *names: str, **defaults) -> dict: *names: names of multi-valued fields to return as a list **defaults: include only given fields, using default values as necessary """ - defaults.update((name, self[name]) for name in (defaults or self) if name in self) - defaults.update((name, self.getlist(name)) for name in names) - return defaults + defaults |= {name: self[name] for name in (defaults or self) if name in self} + return defaults | {name: self.getlist(name) for name in names} class Hit(Document): @@ -415,7 +426,9 @@ def highlights(self, query: search.Query, **fields: int) -> Iterator[dict]: query: lucene Query **fields: mapping of fields to maxinum number of passages """ - mapping = self.searcher.highlighter.highlightFields(list(fields), query, list(self.ids), list(fields.values())) + mapping = self.searcher.highlighter.highlightFields( + list(fields), query, list(self.ids), list(fields.values()) + ) mapping = {field: lucene.JArray_string.cast_(mapping.get(field)) for field in fields} return (dict(zip(mapping, values)) for values in zip(*mapping.values())) @@ -423,7 +436,9 @@ def docvalues(self, field: str, type=None) -> dict: """Return mapping of docs to docvalues.""" return self.searcher.docvalues(field, type).select(self.ids) - def groupby(self, func: Callable, count: Optional[int] = None, docs: Optional[int] = None) -> 'Groups': + def groupby( + self, func: Callable, count: Optional[int] = None, docs: Optional[int] = None + ) -> 'Groups': """Return ordered list of [Hits][lupyne.engine.documents.Hits] grouped by value of function applied to doc ids. Optionally limit the number of groups and docs per group. @@ -507,9 +522,14 @@ def __len__(self): def __iter__(self): return map(convert, self.allMatchingGroups) - def search(self, searcher, query: search.Query, count: Optional[int] = None, start: int = 0) -> Groups: + def search( + self, searcher, query: search.Query, count: Optional[int] = None, start: int = 0 + ) -> Groups: """Run query and return [Groups][lupyne.engine.documents.Groups].""" if count is None: - count = sum(index.DocValues.getSorted(reader, self.field).valueCount for reader in searcher.readers) or 1 - topgroups = super().search(searcher, query, start, count - start) + count = sum( + index.DocValues.getSorted(reader, self.field).valueCount + for reader in searcher.readers + ) + topgroups = super().search(searcher, query, start, max(count - start, 1)) return Groups(searcher, topgroups.groups, topgroups.totalHitCount) diff --git a/lupyne/engine/indexers.py b/lupyne/engine/indexers.py index 540d2f8..20f9c42 100644 --- a/lupyne/engine/indexers.py +++ b/lupyne/engine/indexers.py @@ -148,7 +148,8 @@ def suggest(self, name: str, value, count: int = 1, **attrs) -> list: checker = spell.DirectSpellChecker() for attr in attrs: setattr(checker, attr, attrs[attr]) - return [word.string for word in checker.suggestSimilar(index.Term(name, value), count, self.indexReader)] + words = checker.suggestSimilar(index.Term(name, value), count, self.indexReader) + return [word.string for word in words] def sortfield(self, name: str, type=None, reverse=False) -> search.SortField: """Return lucene SortField, deriving the the type from FieldInfos if necessary. @@ -172,12 +173,15 @@ def docvalues(self, name: str, type=None) -> DocValues.Sorted: name: field name type: int or float for converting values """ - type = {int: int, float: util.NumericUtils.sortableLongToDouble}.get(type, util.BytesRef.utf8ToString) + types = {int: int, float: util.NumericUtils.sortableLongToDouble} + type = types.get(type, util.BytesRef.utf8ToString) docValuesType = self.fieldinfos[name].docValuesType.toString().title().replace('_', '') method = getattr(index.MultiDocValues, f'get{docValuesType}Values') return getattr(DocValues, docValuesType)(method(self.indexReader, name), len(self), type) - def copy(self, dest, query: search.Query = None, exclude: search.Query = None, merge: int = 0) -> int: + def copy( + self, dest, query: search.Query = None, exclude: search.Query = None, merge: int = 0 + ) -> int: """Copy the index to the destination directory. Optimized to use hard links if the destination is a file system path. @@ -221,14 +225,18 @@ def terms(self, name: str, value='', stop='', counts=False, distance=0, prefix=0 termsenum.seekCeil(util.BytesRef(value)) terms = itertools.chain([termsenum.term()], util.BytesRefIterator.cast_(termsenum)) terms = map(operator.methodcaller('utf8ToString'), terms) - predicate = partial(operator.gt, stop) if stop else operator.methodcaller('startswith', value) + predicate = ( + partial(operator.gt, stop) if stop else operator.methodcaller('startswith', value) + ) if not distance: terms = itertools.takewhile(predicate, terms) # type: ignore return ((term, termsenum.docFreq()) for term in terms) if counts else terms def docs(self, name: str, value, counts=False) -> Iterator: """Generate doc ids which contain given term, optionally with frequency counts.""" - docsenum = index.MultiTerms.getTermPostingsEnum(self.indexReader, name, util.BytesRef(value)) + docsenum = index.MultiTerms.getTermPostingsEnum( + self.indexReader, name, util.BytesRef(value) + ) docs = iter(docsenum.nextDoc, index.PostingsEnum.NO_MORE_DOCS) if docsenum else () return ((doc, docsenum.freq()) for doc in docs) if counts else iter(docs) # type: ignore @@ -236,19 +244,28 @@ def positions(self, name: str, value, payloads=False, offsets=False) -> Iterator """Generate doc ids and positions which contain given term. Optionally with offsets, or only ones with payloads.""" - docsenum = index.MultiTerms.getTermPostingsEnum(self.indexReader, name, util.BytesRef(value)) + docsenum = index.MultiTerms.getTermPostingsEnum( + self.indexReader, name, util.BytesRef(value) + ) for doc in iter(docsenum.nextDoc, index.PostingsEnum.NO_MORE_DOCS) if docsenum else (): # type: ignore positions = (docsenum.nextPosition() for _ in range(docsenum.freq())) if payloads: - positions = ((position, docsenum.payload.utf8ToString()) for position in positions if docsenum.payload) + positions = ( + (position, docsenum.payload.utf8ToString()) + for position in positions + if docsenum.payload + ) elif offsets: - positions = ((docsenum.startOffset(), docsenum.endOffset()) for position in positions) + positions = ( + (docsenum.startOffset(), docsenum.endOffset()) for position in positions + ) yield doc, list(positions) def vector(self, id, field): terms = self.getTermVector(id, field) termsenum = terms.iterator() if terms else index.TermsEnum.EMPTY - return termsenum, map(operator.methodcaller('utf8ToString'), util.BytesRefIterator.cast_(termsenum)) + terms = map(operator.methodcaller('utf8ToString'), util.BytesRefIterator.cast_(termsenum)) + return termsenum, terms def termvector(self, id: int, field: str, counts=False) -> Iterator: """Generate terms for given doc id and field, optionally with frequency counts.""" @@ -390,7 +407,15 @@ def collector(self, count=None, sort=None, reverse=False, scores=False, mincount return search.TopFieldCollector.create(sort, count, mincount) def search( - self, query=None, count=None, sort=None, reverse=False, scores=False, mincount=1000, timeout=None, **parser + self, + query=None, + count=None, + sort=None, + reverse=False, + scores=False, + mincount=1000, + timeout=None, + **parser, ) -> Hits: """Run query and return [Hits][lupyne.engine.documents.Hits]. @@ -440,7 +465,9 @@ def facets(self, query, *fields: str, **query_map: dict) -> dict: counts[facet] = {key: self.count(Query.all(query, queries[key])) for key in queries} return counts - def groupby(self, field: str, query, count: Optional[int] = None, start: int = 0, **attrs) -> Groups: + def groupby( + self, field: str, query, count: Optional[int] = None, start: int = 0, **attrs + ) -> Groups: """Return [Hits][lupyne.engine.documents.Hits] grouped by field using a [GroupingSearch][lupyne.engine.documents.GroupingSearch].""" return GroupingSearch(field, **attrs).search(self, self.parse(query), count, start) @@ -450,7 +477,8 @@ def spellchecker(self, field: str) -> SpellChecker: try: return self.spellcheckers[field] except KeyError: - return self.spellcheckers.setdefault(field, SpellChecker(self.terms(field, counts=True))) + spellchecker = SpellChecker(self.terms(field, counts=True)) + return self.spellcheckers.setdefault(field, spellchecker) def complete(self, field: str, prefix: str, count: Optional[int] = None) -> list: """Return ordered suggested words for prefix.""" @@ -475,7 +503,9 @@ class MultiSearcher(IndexSearcher): def __init__(self, reader, analyzer=None): super().__init__(reader, analyzer) - self.indexReaders = [index.DirectoryReader.cast_(context.reader()) for context in self.context.children()] + self.indexReaders = [ + index.DirectoryReader.cast_(context.reader()) for context in self.context.children() + ] self.version = sum(reader.version for reader in self.indexReaders) def __getattr__(self, name): @@ -484,7 +514,8 @@ def __getattr__(self, name): def openIfChanged(self): readers = list(map(index.DirectoryReader.openIfChanged, self.indexReaders)) if any(readers): - return index.MultiReader([new or old.incRef() or old for new, old in zip(readers, self.indexReaders)]) + readers = [new or old.incRef() or old for new, old in zip(readers, self.indexReaders)] + return index.MultiReader(readers) @property def timestamp(self): @@ -513,7 +544,8 @@ def __init__(self, directory, mode: str = 'a', analyzer=None, version=None, **at config.openMode = index.IndexWriterConfig.OpenMode.values()['wra'.index(mode)] for name, value in attrs.items(): setattr(config, name, value) - self.policy = config.indexDeletionPolicy = index.SnapshotDeletionPolicy(config.indexDeletionPolicy) + self.policy = index.SnapshotDeletionPolicy(config.indexDeletionPolicy) + config.indexDeletionPolicy = self.policy super().__init__(self.shared.directory(directory), config) self.fields = {} # type: dict @@ -568,7 +600,9 @@ def update(self, name: str, value='', document=(), **terms): fields = list(doc.iterator()) types = [Field.cast_(field.fieldType()) for field in fields] noindex = index.IndexOptions.NONE - if any(ft.stored() or ft.indexOptions() != noindex or Field.dimensions.fget(ft) for ft in types): + if any( + ft.stored() or ft.indexOptions() != noindex or Field.dimensions.fget(ft) for ft in types + ): self.updateDocument(term, doc) elif fields: self.updateDocValues(term, *fields) diff --git a/lupyne/engine/queries.py b/lupyne/engine/queries.py index 7f6016b..20cd696 100644 --- a/lupyne/engine/queries.py +++ b/lupyne/engine/queries.py @@ -59,7 +59,9 @@ def disjunct(cls, multiplier, *queries, **terms): """Return lucene DisjunctionMaxQuery from queries and terms.""" queries = list(queries) for name, values in terms.items(): - queries += (cls.term(name, value) for value in ([values] if isinstance(values, str) else values)) + if isinstance(values, str): + values = [values] + queries += (cls.term(name, value) for value in values) return cls(search.DisjunctionMaxQuery, Arrays.asList(queries), multiplier) @classmethod @@ -73,7 +75,10 @@ def span(cls, *term) -> 'SpanQuery': def near(cls, name: str, *values, **kwargs) -> 'SpanQuery': """Return [SpanNearQuery][lupyne.engine.queries.SpanQuery.near] from terms. Term values which supply another field name will be masked.""" - spans = (cls.span(name, value) if isinstance(value, str) else cls.span(*value).mask(name) for value in values) + spans = ( + cls.span(name, value) if isinstance(value, str) else cls.span(*value).mask(name) + for value in values + ) return SpanQuery.near(*spans, **kwargs) @classmethod @@ -265,7 +270,8 @@ def __getitem__(self, id: int): class SortedNumeric(Sorted): def __getitem__(self, id: int): if self.docvalues.advanceExact(id): - return tuple(self.type(self.docvalues.nextValue()) for _ in range(self.docvalues.docValueCount())) + indices = range(self.docvalues.docValueCount()) + return tuple(self.type(self.docvalues.nextValue()) for _ in indices) class SortedSet(Sorted): def __getitem__(self, id: int): diff --git a/lupyne/engine/utils.py b/lupyne/engine/utils.py index 619d2db..2c1f8fc 100644 --- a/lupyne/engine/utils.py +++ b/lupyne/engine/utils.py @@ -62,4 +62,6 @@ def convert(value): if not Number.instance_(value): return value.toString() if Object.instance_(value) else value value = Number.cast_(value) - return value.doubleValue() if Float.instance_(value) or Double.instance_(value) else int(value.longValue()) + if Float.instance_(value) or Double.instance_(value): + return value.doubleValue() + return int(value.longValue()) diff --git a/lupyne/services/base.py b/lupyne/services/base.py index 0751fa2..66b69a0 100644 --- a/lupyne/services/base.py +++ b/lupyne/services/base.py @@ -36,7 +36,9 @@ def multi_valued(annotations): class Document: """stored fields""" - __annotations__ = {field.name.value: Optional[convert(field.type)] for field in schema.get('Document', [])} + __annotations__ = { + field.name.value: Optional[convert(field.type)] for field in schema.get('Document', []) + } locals().update(dict.fromkeys(__annotations__)) locals().update(dict.fromkeys(multi_valued(__annotations__), ())) @@ -91,7 +93,9 @@ def index(self) -> dict: """index information""" searcher = self.searcher if isinstance(searcher, engine.MultiSearcher): # pragma: no cover - return {reader.directory().toString(): reader.numDocs() for reader in searcher.indexReaders} + return { + reader.directory().toString(): reader.numDocs() for reader in searcher.indexReaders + } return {searcher.directory.toString(): len(searcher)} def refresh(self, spellcheckers: bool = False) -> dict: diff --git a/lupyne/services/graphql.py b/lupyne/services/graphql.py index 279e128..69eb3cb 100644 --- a/lupyne/services/graphql.py +++ b/lupyne/services/graphql.py @@ -26,7 +26,9 @@ async def lifespan(app: Starlette): # pragma: no cover def selections(*fields) -> dict: """Return tree of field name selections.""" - return {selection.name: selections(selection) for field in fields for selection in field.selections} + return { + selection.name: selections(selection) for field in fields for selection in field.selections + } def doc_type(cls): @@ -109,9 +111,11 @@ def terms(self, info: Info) -> IndexedFields: """indexed field names""" fields = {} for name, selected in selections(*info.selected_fields).items(): - counts = 'counts' in selected - terms = root.searcher.terms(name, counts=counts) - fields[name] = Terms(**dict(zip(['values', 'counts'], zip(*terms)))) if counts else Terms(values=terms) + if 'counts' in selected: + values, counts = zip(*root.searcher.terms(name, counts=True)) + fields[name] = Terms(values=values, counts=counts) + else: + fields[name] = Terms(values=root.searcher.terms(name)) return IndexedFields(**fields) @doc_field( diff --git a/lupyne/services/rest.py b/lupyne/services/rest.py index 8ab8146..954a069 100644 --- a/lupyne/services/rest.py +++ b/lupyne/services/rest.py @@ -47,5 +47,9 @@ def search(q: str, count: int = None, sort: str = '') -> dict: async def headers(request, call_next): start = time.time() response = await call_next(request) - response.headers.update({'x-response-time': str(time.time() - start), 'age': str(int(root.age)), 'etag': root.etag}) + response.headers.update({ + 'x-response-time': str(time.time() - start), + 'age': str(int(root.age)), + 'etag': root.etag, + }) return response diff --git a/pyproject.toml b/pyproject.toml index c257cd5..9540ad9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,12 +48,17 @@ version = {attr = "lupyne.__version__"} [tool.setuptools.package-data] lupyne = ["py.typed"] -[tool.black] -line-length = 120 -skip-string-normalization = true - [tool.ruff] -ignore = ["E501", "F402"] +line-length = 100 +ignore = ["F402"] +extend-include = ["*.ipynb"] + +[tool.ruff.format] +preview = true +quote-style = "preserve" + +[tool.ruff.lint.per-file-ignores] +"*.ipynb" = ["F821"] [[tool.mypy.overrides]] module = ["lucene", "jcc", "java.*", "org.apache.*"] diff --git a/tests/conftest.py b/tests/conftest.py index 7c6eb57..e074bf4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -37,7 +37,12 @@ def fixture(gen): @pytest.fixture def fields(): return [ - engine.Field.Text('text', storeTermVectors=True, storeTermVectorPositions=True, storeTermVectorOffsets=True), + engine.Field.Text( + 'text', + storeTermVectors=True, + storeTermVectorPositions=True, + storeTermVectorOffsets=True, + ), engine.Field.String('article', stored=True), engine.Field.String('amendment', stored=True), engine.Field.String('date', stored=True, docValuesType='sorted'), @@ -48,7 +53,8 @@ def fields(): @fixture def constitution(): lines = open(fixtures / 'constitution.txt') - items = itertools.groupby(lines, lambda l: l.startswith('Article ') or l.startswith('Amendment ')) # noqa: E741 + key = lambda l: l.startswith('Article ') or l.startswith('Amendment ') # noqa + items = itertools.groupby(lines, key) for _, (header,) in items: _, lines = next(items) header, num = header.rstrip('.\n').split(None, 1) @@ -56,7 +62,7 @@ def constitution(): if header == 'Amendment': num, date = num.split() date = datetime.strptime(date, '%m/%d/%Y').date() - fields.update({header.lower(): num, 'date': str(date), 'year': date.year}) + fields |= {header.lower(): num, 'date': str(date), 'year': date.year} yield fields @@ -77,7 +83,7 @@ def zipcodes(): @pytest.fixture def index(tempdir, fields, constitution): with engine.IndexWriter(tempdir) as writer: - writer.fields.update({field.name: field for field in fields}) + writer.fields |= {field.name: field for field in fields} for doc in constitution: writer.add(doc) return tempdir diff --git a/tests/test_engine.py b/tests/test_engine.py index d25b5df..90d5b23 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -40,7 +40,10 @@ def test_analyzers(tempdir): token.charTerm = token.type = '' token.offset, token.positionIncrement = (0, 0), 0 assert str(stemmer.parse('searches', field=['body', 'title'])) == 'body:search title:search' - assert str(stemmer.parse('searches', field={'body': 1.0, 'title': 2.0})) == '(body:search)^1.0 (title:search)^2.0' + assert ( + str(stemmer.parse('searches', field={'body': 1.0, 'title': 2.0})) + == '(body:search)^1.0 (title:search)^2.0' + ) indexer = engine.Indexer(tempdir, analyzer=stemmer) indexer.set('text', engine.Field.Text) indexer.add(text='searches') @@ -141,7 +144,11 @@ def test_searcher(tempdir, fields, constitution): assert list(indexer.terms('text', 'right', 'right_')) == ['right'] assert dict(indexer.terms('text', 'right', 'right_', counts=True)) == {'right': 13} assert list(indexer.terms('text', 'right', distance=1)) == ['eight', 'right', 'rights'] - assert dict(indexer.terms('text', 'right', distance=1, counts=True)) == {'eight': 3, 'right': 13, 'rights': 1} + assert dict(indexer.terms('text', 'right', distance=1, counts=True)) == { + 'eight': 3, + 'right': 13, + 'rights': 1, + } assert list(indexer.terms('text', 'senite', distance=2)) == ['senate', 'sent'] word, count = next(indexer.terms('text', 'people', counts=True)) assert word == 'people' and count == 8 @@ -184,14 +191,20 @@ def test_searcher(tempdir, fields, constitution): assert 'persons' in indexer.termvector(id, 'text') assert dict(indexer.termvector(id, 'text', counts=True))['persons'] == 2 assert dict(indexer.positionvector(id, 'text'))['persons'] in ([3, 26], [10, 48]) - assert dict(indexer.positionvector(id, 'text', offsets=True))['persons'] == [(46, 53), (301, 308)] + assert dict(indexer.positionvector(id, 'text', offsets=True))['persons'] == [ + (46, 53), + (301, 308), + ] analyzer = analysis.core.WhitespaceAnalyzer() query = indexer.morelikethis(0, analyzer=analyzer) assert {'text:united', 'text:states'} <= set(str(query).split()) assert str(indexer.morelikethis(0, 'article', analyzer=analyzer)) == '' query = indexer.morelikethis(0, minDocFreq=3, analyzer=analyzer) assert {'text:establish', 'text:united', 'text:states'} <= set(str(query).split()) - assert str(indexer.morelikethis('jury', 'text', minDocFreq=4, minTermFreq=1, analyzer=analyzer)) == 'text:jury' + assert ( + str(indexer.morelikethis('jury', 'text', minDocFreq=4, minTermFreq=1, analyzer=analyzer)) + == 'text:jury' + ) assert str(indexer.morelikethis('jury', 'article', analyzer=analyzer)) == '' @@ -204,7 +217,11 @@ def test_spellcheck(tempdir, fields, constitution): assert indexer.complete('missing', '') == [] assert {'shall', 'states'} <= set(indexer.complete('text', '')[:8]) assert indexer.complete('text', 'con')[:2] == ['congress', 'constitution'] - assert indexer.complete('text', 'congress') == indexer.complete('text', 'con', count=1) == ['congress'] + assert ( + indexer.complete('text', 'congress') + == indexer.complete('text', 'con', count=1) + == ['congress'] + ) assert indexer.complete('text', 'congresses') == [] assert indexer.suggest('text', 'write') == ['writs'] assert indexer.suggest('text', 'write', 3) == ['writs', 'writ', 'written'] @@ -316,9 +333,13 @@ def test_queries(): near = Q.near('text', 'lucene', ('alias', 'search'), slop=-1, inOrder=False) assert str(near) == 'spanNear([text:lucene, mask(alias:search) as text], -1, false)' assert ( - str(span - near) == 'spanNot(text:lucene, spanNear([text:lucene, mask(alias:search) as text], -1, false), 0, 0)' + str(span - near) + == 'spanNot(text:lucene, spanNear([text:lucene, mask(alias:search) as text], -1, false), 0, 0)' + ) + assert ( + str(span | near) + == 'spanOr([text:lucene, spanNear([text:lucene, mask(alias:search) as text], -1, false)])' ) - assert str(span | near) == 'spanOr([text:lucene, spanNear([text:lucene, mask(alias:search) as text], -1, false)])' assert str(span.mask('alias')) == 'mask(text:lucene) as alias' assert str(span.boost(2.0)) == '(text:lucene)^2.0' assert str(span.containing(span)) == 'SpanContaining(text:lucene, text:lucene)' @@ -328,7 +349,10 @@ def test_queries(): assert str(Q.points('point', 0.0, 1.0)) == 'point:{0.0 1.0}' assert str(Q.points('point', 0)) == 'point:{0}' assert str(Q.points('point', 0, 1)) == 'point:{0 1}' - assert str(Q.ranges('point', (0.0, 1.0), (2.0, 3.0), upper=True)) == 'point:[0.0 TO 1.0],[2.0 TO 3.0]' + assert ( + str(Q.ranges('point', (0.0, 1.0), (2.0, 3.0), upper=True)) + == 'point:[0.0 TO 1.0],[2.0 TO 3.0]' + ) assert str(Q.ranges('point', (0.0, 1.0), lower=False)).startswith('point:[4.9E-324 TO 0.9999') assert str(Q.ranges('point', (None, 0.0), upper=True)) == 'point:[-Infinity TO 0.0]' assert str(Q.ranges('point', (0.0, None))) == 'point:[0.0 TO Infinity]' @@ -339,7 +363,9 @@ def test_queries(): def test_grouping(tempdir, indexer, zipcodes): - field = indexer.fields['location'] = engine.NestedField('state.county.city', docValuesType='sorted') + field = indexer.fields['location'] = engine.NestedField( + 'state.county.city', docValuesType='sorted' + ) for doc in zipcodes: if doc['state'] in ('CA', 'AK', 'WY', 'PR'): lat, lng = ('{0:08.3f}'.format(doc.pop(lt)) for lt in ['latitude', 'longitude']) @@ -444,7 +470,10 @@ def test_shape(indexer, zipcodes): assert indexer.count(latlon.within(geo.Circle(*circle))) == 1 lat, lon = point (pg,) = geo.Polygon.fromGeoJSON( - json.dumps({'type': 'Polygon', 'coordinates': [[(lon, lat), (lon + 1, lat), (lon, lat + 1), (lon, lat)]]}) + json.dumps({ + 'type': 'Polygon', + 'coordinates': [[(lon, lat), (lon + 1, lat), (lon, lat + 1), (lon, lat)]], + }) ) assert indexer.count(latlon.contains(pg)) == 0 assert indexer.count(latlon.disjoint(pg)) == 2639 @@ -468,15 +497,30 @@ def test_fields(indexer, constitution): with pytest.raises(lucene.JavaError): with engine.utils.suppress(search.TimeLimitingCollector.TimeExceededException): document.Field('name', 'value', document.FieldType()) - assert str(engine.Field.String('')) == str(document.StringField('', '', document.Field.Store.NO).fieldType()) - assert str(engine.Field.Text('')) == str(document.TextField('', '', document.Field.Store.NO).fieldType()) + assert str(engine.Field.String('')) == str( + document.StringField('', '', document.Field.Store.NO).fieldType() + ) + assert str(engine.Field.Text('')) == str( + document.TextField('', '', document.Field.Store.NO).fieldType() + ) assert str(engine.DateTimeField('')) == str(document.DoublePoint('', 0.0).fieldType()) settings = {'docValuesType': 'NUMERIC', 'indexOptions': 'DOCS'} field = engine.Field('', **settings) assert field.settings == engine.Field('', **field.settings).settings == settings field = engine.NestedField('', stored=True) - assert field.settings == {'stored': True, 'tokenized': False, 'omitNorms': True, 'indexOptions': 'DOCS'} - attrs = 'stored', 'omitNorms', 'storeTermVectors', 'storeTermVectorPositions', 'storeTermVectorOffsets' + assert field.settings == { + 'stored': True, + 'tokenized': False, + 'omitNorms': True, + 'indexOptions': 'DOCS', + } + attrs = ( + 'stored', + 'omitNorms', + 'storeTermVectors', + 'storeTermVectorPositions', + 'storeTermVectorOffsets', + ) field = engine.Field('', indexOptions='docs', **dict.fromkeys(attrs, True)) (field,) = field.items(' ') assert all(getattr(field.fieldType(), attr)() for attr in attrs) @@ -486,7 +530,9 @@ def test_fields(indexer, constitution): for doc in constitution: if 'amendment' in doc: indexer.add( - amendment='{:02}'.format(int(doc['amendment'])), date=doc['date'], size='{:04}'.format(len(doc['text'])) + amendment='{:02}'.format(int(doc['amendment'])), + date=doc['date'], + size='{:04}'.format(len(doc['text'])), ) indexer.commit() assert set(indexer.fieldinfos) == {'amendment', 'Y', 'Y-m', 'Y-m-d', 'size'} @@ -527,7 +573,9 @@ def test_numeric(indexer, constitution): for doc in constitution: if 'amendment' in doc: indexer.add( - amendment=int(doc['amendment']), date=[tuple(map(int, doc['date'].split('-')))], size=len(doc['text']) + amendment=int(doc['amendment']), + date=[tuple(map(int, doc['date'].split('-')))], + size=len(doc['text']), ) indexer.commit() query = field.prefix((1791, 12)) @@ -539,7 +587,10 @@ def test_numeric(indexer, constitution): query = field.range(datetime.date(1919, 1, 1), datetime.date(1921, 12, 31)) hits = indexer.search(query) assert [hit['amendment'] for hit in hits] == [18, 19] - assert [datetime.datetime.utcfromtimestamp(float(hit['date'])).year for hit in hits] == [1919, 1920] + assert [datetime.datetime.utcfromtimestamp(float(hit['date'])).year for hit in hits] == [ + 1919, + 1920, + ] assert indexer.count(field.within(seconds=100)) == indexer.count(field.within(weeks=1)) == 0 query = field.duration([2009], days=-100 * 365) assert indexer.count(query) == 12 @@ -571,7 +622,10 @@ def test_highlighting(tempdir, constitution): indexer.add(text=doc['text']) indexer.commit() query = Q.term('text', 'right') - assert engine.Analyzer.highlight(indexer.analyzer, query, 'text', "word right word") == "word right word" + assert ( + engine.Analyzer.highlight(indexer.analyzer, query, 'text', "word right word") + == "word right word" + ) hits = indexer.search(query) highlights = list(hits.highlights(query, text=1)) assert len(hits) == len(highlights) @@ -618,15 +672,32 @@ def test_docvalues(tempdir): indexer.set('tags', docValuesType='sorted_set') indexer.set('sizes', docValuesType='sorted_numeric') indexer.set('points', docValuesType='sorted_numeric') - indexer.add(id='0', title='zero', size=0, point=0.5, priority='low', tags=['red'], sizes=[0], points=[0.5]) + indexer.add( + id='0', + title='zero', + size=0, + point=0.5, + priority='low', + tags=['red'], + sizes=[0], + points=[0.5], + ) indexer.commit() with pytest.raises(AttributeError): indexer.sortfield('id') sortfield = indexer.sortfield('id', type='string', reverse=True) - assert sortfield.field == 'id' and sortfield.reverse and sortfield.type == search.SortField.Type.STRING + assert ( + sortfield.field == 'id' + and sortfield.reverse + and sortfield.type == search.SortField.Type.STRING + ) sortfield = indexer.sortfield('title') - assert sortfield.field == 'title' and not sortfield.reverse and sortfield.type == search.SortField.Type.STRING + assert ( + sortfield.field == 'title' + and not sortfield.reverse + and sortfield.type == search.SortField.Type.STRING + ) assert indexer.sortfield('size', type=int).type == search.SortField.Type.LONG assert indexer.sortfield('point', type=float).type == search.SortField.Type.DOUBLE assert indexer.sortfield('priority').type == search.SortField.Type.STRING @@ -636,7 +707,15 @@ def test_docvalues(tempdir): segments = indexer.segments indexer.update( - 'id', id='0', title='one', size=1, point=1.5, priority='high', tags=['blue'], sizes=[1], points=[1.5] + 'id', + id='0', + title='one', + size=1, + point=1.5, + priority='high', + tags=['blue'], + sizes=[1], + points=[1.5], ) indexer.commit() assert indexer.segments != segments diff --git a/tests/test_graphql.py b/tests/test_graphql.py index aef36e8..42726e0 100644 --- a/tests/test_graphql.py +++ b/tests/test_graphql.py @@ -54,12 +54,17 @@ def test_search(client): assert hit['sortkeys'] == {'year': None} assert hit['doc'] == {'amendment': '2'} data = client.execute( - '''{ search(q: "text:right", count: 1, sort: ["-year"]) - { count hits { id score sortkeys { year } doc { amendment } } } }''' + """{ search(q: "text:right", count: 1, sort: ["-year"]) + { count hits { id score sortkeys { year } doc { amendment } } } }""" ) assert data['search']['count'] == 13 (hit,) = data['search']['hits'] - assert hit == {'id': 33, 'score': pytest.approx(0.648349), 'sortkeys': {'year': 1971}, 'doc': {'amendment': '26'}} + assert hit == { + 'id': 33, + 'score': pytest.approx(0.648349), + 'sortkeys': {'year': 1971}, + 'doc': {'amendment': '26'}, + } def test_count(client): diff --git a/tests/test_rest.py b/tests/test_rest.py index fd0f242..0329064 100644 --- a/tests/test_rest.py +++ b/tests/test_rest.py @@ -47,5 +47,10 @@ def test_search(client): result = client.get('/search', params={'q': "text:right", 'count': 1, 'sort': '-year'}).json() assert result['count'] == 13 assert result['hits'] == [ - {'id': 33, 'score': None, 'sortkeys': {'year': 1971}, 'doc': {'amendment': '26', 'date': '1971-07-01'}}, + { + 'id': 33, + 'score': None, + 'sortkeys': {'year': 1971}, + 'doc': {'amendment': '26', 'date': '1971-07-01'}, + }, ]