Skip to content

Commit

Permalink
Switched from black to ruff.
Browse files Browse the repository at this point in the history
Line-length set to 100. Dict union operator.
  • Loading branch information
coady committed Dec 17, 2023
1 parent 0b677ad commit cf46aae
Show file tree
Hide file tree
Showing 16 changed files with 273 additions and 88 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
- uses: actions/setup-python@v5
with:
python-version: 3.x
- run: pip install black[jupyter] ruff mypy
- run: pip install ruff mypy
- run: make lint

docs:
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ check:
python -m pytest -s --cov-append --cov=lupyne.services tests/test_rest.py tests/test_graphql.py

lint:
black --check .
ruff .
ruff format --check .
mypy -p lupyne.engine

html:
Expand Down
19 changes: 14 additions & 5 deletions docs/examples.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,10 @@
"metadata": {},
"outputs": [],
"source": [
"indexer = engine.Indexer('tempIndex') # Indexer combines Writer and Searcher; StandardAnalyzer is the default\n",
"indexer.set('fieldname', engine.Field.Text, stored=True) # default indexed text settings for documents\n",
"# Indexer combines Writer and Searcher; StandardAnalyzer is the default\n",
"indexer = engine.Indexer('tempIndex')\n",
"# default indexed text settings for documents\n",
"indexer.set('fieldname', engine.Field.Text, stored=True)\n",
"indexer.add(fieldname=text) # add document\n",
"indexer.commit() # commit changes and refresh searcher\n",
"\n",
Expand Down Expand Up @@ -99,8 +101,15 @@
"from org.apache.lucene.queries import spans\n",
"\n",
"q1 = search.TermQuery(index.Term('text', 'lucene'))\n",
"q2 = search.PhraseQuery.Builder().add(index.Term('text', 'search')).add(index.Term('text', 'engine')).build()\n",
"search.BooleanQuery.Builder().add(q1, search.BooleanClause.Occur.MUST).add(q2, search.BooleanClause.Occur.MUST).build()"
"q2 = (\n",
" search.PhraseQuery.Builder()\n",
" .add(index.Term('text', 'search'))\n",
" .add(index.Term('text', 'engine'))\n",
" .build()\n",
")\n",
"search.BooleanQuery.Builder().add(q1, search.BooleanClause.Occur.MUST).add(\n",
" q2, search.BooleanClause.Occur.MUST\n",
").build()"
]
},
{
Expand Down Expand Up @@ -535,7 +544,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.0"
"version": "3.12.1"
},
"vscode": {
"interpreter": {
Expand Down
4 changes: 3 additions & 1 deletion lupyne/engine/analyzers.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,9 @@ def parse(self, query: str, field='', op='', parser=None, **attrs) -> search.Que
**attrs: additional attributes to set on the parser
"""
# parsers aren't thread-safe (nor slow), so create one each time
cls = queryparser.classic.QueryParser if isinstance(field, str) else queryparser.classic.MultiFieldQueryParser
cls = queryparser.classic.MultiFieldQueryParser
if isinstance(field, str):
cls = queryparser.classic.QueryParser
args = field, self
if isinstance(field, Mapping):
boosts = HashMap()
Expand Down
56 changes: 38 additions & 18 deletions lupyne/engine/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,25 @@ class Field(FieldType): # type: ignore
indexOptions = property(FieldType.indexOptions, FieldType.setIndexOptions)
omitNorms = property(FieldType.omitNorms, FieldType.setOmitNorms)
stored = property(FieldType.stored, FieldType.setStored)
storeTermVectorOffsets = property(FieldType.storeTermVectorOffsets, FieldType.setStoreTermVectorOffsets)
storeTermVectorPayloads = property(FieldType.storeTermVectorPayloads, FieldType.setStoreTermVectorPayloads)
storeTermVectorPositions = property(FieldType.storeTermVectorPositions, FieldType.setStoreTermVectorPositions)
storeTermVectorOffsets = property(
FieldType.storeTermVectorOffsets, FieldType.setStoreTermVectorOffsets
)
storeTermVectorPayloads = property(
FieldType.storeTermVectorPayloads, FieldType.setStoreTermVectorPayloads
)
storeTermVectorPositions = property(
FieldType.storeTermVectorPositions, FieldType.setStoreTermVectorPositions
)
storeTermVectors = property(FieldType.storeTermVectors, FieldType.setStoreTermVectors)
tokenized = property(FieldType.tokenized, FieldType.setTokenized)

properties = {name for name in locals() if not name.startswith('__')}
types = {int: 'long', float: 'double', str: 'string'}
types.update(NUMERIC='long', BINARY='string', SORTED='string', SORTED_NUMERIC='long', SORTED_SET='string')
types.update(
NUMERIC='long', BINARY='string', SORTED='string', SORTED_NUMERIC='long', SORTED_SET='string'
)
dimensions = property(
getattr(FieldType, 'pointDataDimensionCount', getattr(FieldType, 'pointDimensionCount', None)),
FieldType.pointDimensionCount,
lambda self, count: self.setDimensions(count, Long.BYTES),
)

Expand All @@ -54,17 +62,21 @@ def __init__(self, name: str, docValuesType='', indexOptions='', dimensions=0, *
self.indexOptions = getattr(index.IndexOptions, indexOptions.upper())
if docValuesType:
self.docValuesType = getattr(index.DocValuesType, docValuesType.upper())
self.docValueClass = getattr(document, docValuesType.title().replace('_', '') + 'DocValuesField')
name = docValuesType.title().replace('_', '')
self.docValueClass = getattr(document, name + 'DocValuesField')
if self.stored or self.indexed or self.dimensions:
settings = self.settings
del settings['docValuesType']
self.docValueLess = Field(self.name, **settings)
assert self.stored or self.indexed or self.docvalues or self.dimensions

@classmethod
def String(cls, name: str, tokenized=False, omitNorms=True, indexOptions='DOCS', **settings) -> 'Field':
def String(
cls, name: str, tokenized=False, omitNorms=True, indexOptions='DOCS', **settings
) -> 'Field':
"""Return Field with default settings for strings."""
return cls(name, tokenized=tokenized, omitNorms=omitNorms, indexOptions=indexOptions, **settings)
settings.update(tokenized=tokenized, omitNorms=omitNorms, indexOptions=indexOptions)
return cls(name, **settings)

@classmethod
def Text(cls, name: str, indexOptions='DOCS_AND_FREQS_AND_POSITIONS', **settings) -> 'Field':
Expand Down Expand Up @@ -128,8 +140,8 @@ def __init__(self, name: str, sep: str = '.', **settings):
def values(self, value: str) -> Iterator[str]:
"""Generate component field values in order."""
value = value.split(self.sep) # type: ignore
for index in range(1, len(value) + 1):
yield self.sep.join(value[:index])
for stop in range(1, len(value) + 1):
yield self.sep.join(value[:stop])

def items(self, *values: str) -> Iterator[document.Field]:
"""Generate indexed component fields."""
Expand Down Expand Up @@ -326,9 +338,8 @@ def dict(self, *names: str, **defaults) -> dict:
*names: names of multi-valued fields to return as a list
**defaults: include only given fields, using default values as necessary
"""
defaults.update((name, self[name]) for name in (defaults or self) if name in self)
defaults.update((name, self.getlist(name)) for name in names)
return defaults
defaults |= {name: self[name] for name in (defaults or self) if name in self}
return defaults | {name: self.getlist(name) for name in names}


class Hit(Document):
Expand Down Expand Up @@ -415,15 +426,19 @@ def highlights(self, query: search.Query, **fields: int) -> Iterator[dict]:
query: lucene Query
**fields: mapping of fields to maxinum number of passages
"""
mapping = self.searcher.highlighter.highlightFields(list(fields), query, list(self.ids), list(fields.values()))
mapping = self.searcher.highlighter.highlightFields(
list(fields), query, list(self.ids), list(fields.values())
)
mapping = {field: lucene.JArray_string.cast_(mapping.get(field)) for field in fields}
return (dict(zip(mapping, values)) for values in zip(*mapping.values()))

def docvalues(self, field: str, type=None) -> dict:
"""Return mapping of docs to docvalues."""
return self.searcher.docvalues(field, type).select(self.ids)

def groupby(self, func: Callable, count: Optional[int] = None, docs: Optional[int] = None) -> 'Groups':
def groupby(
self, func: Callable, count: Optional[int] = None, docs: Optional[int] = None
) -> 'Groups':
"""Return ordered list of [Hits][lupyne.engine.documents.Hits] grouped by value of function applied to doc ids.
Optionally limit the number of groups and docs per group.
Expand Down Expand Up @@ -507,9 +522,14 @@ def __len__(self):
def __iter__(self):
return map(convert, self.allMatchingGroups)

def search(self, searcher, query: search.Query, count: Optional[int] = None, start: int = 0) -> Groups:
def search(
self, searcher, query: search.Query, count: Optional[int] = None, start: int = 0
) -> Groups:
"""Run query and return [Groups][lupyne.engine.documents.Groups]."""
if count is None:
count = sum(index.DocValues.getSorted(reader, self.field).valueCount for reader in searcher.readers) or 1
topgroups = super().search(searcher, query, start, count - start)
count = sum(
index.DocValues.getSorted(reader, self.field).valueCount
for reader in searcher.readers
)
topgroups = super().search(searcher, query, start, max(count - start, 1))
return Groups(searcher, topgroups.groups, topgroups.totalHitCount)
66 changes: 50 additions & 16 deletions lupyne/engine/indexers.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ def suggest(self, name: str, value, count: int = 1, **attrs) -> list:
checker = spell.DirectSpellChecker()
for attr in attrs:
setattr(checker, attr, attrs[attr])
return [word.string for word in checker.suggestSimilar(index.Term(name, value), count, self.indexReader)]
words = checker.suggestSimilar(index.Term(name, value), count, self.indexReader)
return [word.string for word in words]

def sortfield(self, name: str, type=None, reverse=False) -> search.SortField:
"""Return lucene SortField, deriving the the type from FieldInfos if necessary.
Expand All @@ -172,12 +173,15 @@ def docvalues(self, name: str, type=None) -> DocValues.Sorted:
name: field name
type: int or float for converting values
"""
type = {int: int, float: util.NumericUtils.sortableLongToDouble}.get(type, util.BytesRef.utf8ToString)
types = {int: int, float: util.NumericUtils.sortableLongToDouble}
type = types.get(type, util.BytesRef.utf8ToString)
docValuesType = self.fieldinfos[name].docValuesType.toString().title().replace('_', '')
method = getattr(index.MultiDocValues, f'get{docValuesType}Values')
return getattr(DocValues, docValuesType)(method(self.indexReader, name), len(self), type)

def copy(self, dest, query: search.Query = None, exclude: search.Query = None, merge: int = 0) -> int:
def copy(
self, dest, query: search.Query = None, exclude: search.Query = None, merge: int = 0
) -> int:
"""Copy the index to the destination directory.
Optimized to use hard links if the destination is a file system path.
Expand Down Expand Up @@ -221,34 +225,47 @@ def terms(self, name: str, value='', stop='', counts=False, distance=0, prefix=0
termsenum.seekCeil(util.BytesRef(value))
terms = itertools.chain([termsenum.term()], util.BytesRefIterator.cast_(termsenum))
terms = map(operator.methodcaller('utf8ToString'), terms)
predicate = partial(operator.gt, stop) if stop else operator.methodcaller('startswith', value)
predicate = (
partial(operator.gt, stop) if stop else operator.methodcaller('startswith', value)
)
if not distance:
terms = itertools.takewhile(predicate, terms) # type: ignore
return ((term, termsenum.docFreq()) for term in terms) if counts else terms

def docs(self, name: str, value, counts=False) -> Iterator:
"""Generate doc ids which contain given term, optionally with frequency counts."""
docsenum = index.MultiTerms.getTermPostingsEnum(self.indexReader, name, util.BytesRef(value))
docsenum = index.MultiTerms.getTermPostingsEnum(
self.indexReader, name, util.BytesRef(value)
)
docs = iter(docsenum.nextDoc, index.PostingsEnum.NO_MORE_DOCS) if docsenum else ()
return ((doc, docsenum.freq()) for doc in docs) if counts else iter(docs) # type: ignore

def positions(self, name: str, value, payloads=False, offsets=False) -> Iterator[tuple]:
"""Generate doc ids and positions which contain given term.
Optionally with offsets, or only ones with payloads."""
docsenum = index.MultiTerms.getTermPostingsEnum(self.indexReader, name, util.BytesRef(value))
docsenum = index.MultiTerms.getTermPostingsEnum(
self.indexReader, name, util.BytesRef(value)
)
for doc in iter(docsenum.nextDoc, index.PostingsEnum.NO_MORE_DOCS) if docsenum else (): # type: ignore
positions = (docsenum.nextPosition() for _ in range(docsenum.freq()))
if payloads:
positions = ((position, docsenum.payload.utf8ToString()) for position in positions if docsenum.payload)
positions = (
(position, docsenum.payload.utf8ToString())
for position in positions
if docsenum.payload
)
elif offsets:
positions = ((docsenum.startOffset(), docsenum.endOffset()) for position in positions)
positions = (
(docsenum.startOffset(), docsenum.endOffset()) for position in positions
)
yield doc, list(positions)

def vector(self, id, field):
terms = self.getTermVector(id, field)
termsenum = terms.iterator() if terms else index.TermsEnum.EMPTY
return termsenum, map(operator.methodcaller('utf8ToString'), util.BytesRefIterator.cast_(termsenum))
terms = map(operator.methodcaller('utf8ToString'), util.BytesRefIterator.cast_(termsenum))
return termsenum, terms

def termvector(self, id: int, field: str, counts=False) -> Iterator:
"""Generate terms for given doc id and field, optionally with frequency counts."""
Expand Down Expand Up @@ -390,7 +407,15 @@ def collector(self, count=None, sort=None, reverse=False, scores=False, mincount
return search.TopFieldCollector.create(sort, count, mincount)

def search(
self, query=None, count=None, sort=None, reverse=False, scores=False, mincount=1000, timeout=None, **parser
self,
query=None,
count=None,
sort=None,
reverse=False,
scores=False,
mincount=1000,
timeout=None,
**parser,
) -> Hits:
"""Run query and return [Hits][lupyne.engine.documents.Hits].
Expand Down Expand Up @@ -440,7 +465,9 @@ def facets(self, query, *fields: str, **query_map: dict) -> dict:
counts[facet] = {key: self.count(Query.all(query, queries[key])) for key in queries}
return counts

def groupby(self, field: str, query, count: Optional[int] = None, start: int = 0, **attrs) -> Groups:
def groupby(
self, field: str, query, count: Optional[int] = None, start: int = 0, **attrs
) -> Groups:
"""Return [Hits][lupyne.engine.documents.Hits] grouped by field
using a [GroupingSearch][lupyne.engine.documents.GroupingSearch]."""
return GroupingSearch(field, **attrs).search(self, self.parse(query), count, start)
Expand All @@ -450,7 +477,8 @@ def spellchecker(self, field: str) -> SpellChecker:
try:
return self.spellcheckers[field]
except KeyError:
return self.spellcheckers.setdefault(field, SpellChecker(self.terms(field, counts=True)))
spellchecker = SpellChecker(self.terms(field, counts=True))
return self.spellcheckers.setdefault(field, spellchecker)

def complete(self, field: str, prefix: str, count: Optional[int] = None) -> list:
"""Return ordered suggested words for prefix."""
Expand All @@ -475,7 +503,9 @@ class MultiSearcher(IndexSearcher):

def __init__(self, reader, analyzer=None):
super().__init__(reader, analyzer)
self.indexReaders = [index.DirectoryReader.cast_(context.reader()) for context in self.context.children()]
self.indexReaders = [
index.DirectoryReader.cast_(context.reader()) for context in self.context.children()
]
self.version = sum(reader.version for reader in self.indexReaders)

def __getattr__(self, name):
Expand All @@ -484,7 +514,8 @@ def __getattr__(self, name):
def openIfChanged(self):
readers = list(map(index.DirectoryReader.openIfChanged, self.indexReaders))
if any(readers):
return index.MultiReader([new or old.incRef() or old for new, old in zip(readers, self.indexReaders)])
readers = [new or old.incRef() or old for new, old in zip(readers, self.indexReaders)]
return index.MultiReader(readers)

@property
def timestamp(self):
Expand Down Expand Up @@ -513,7 +544,8 @@ def __init__(self, directory, mode: str = 'a', analyzer=None, version=None, **at
config.openMode = index.IndexWriterConfig.OpenMode.values()['wra'.index(mode)]
for name, value in attrs.items():
setattr(config, name, value)
self.policy = config.indexDeletionPolicy = index.SnapshotDeletionPolicy(config.indexDeletionPolicy)
self.policy = index.SnapshotDeletionPolicy(config.indexDeletionPolicy)
config.indexDeletionPolicy = self.policy
super().__init__(self.shared.directory(directory), config)
self.fields = {} # type: dict

Expand Down Expand Up @@ -568,7 +600,9 @@ def update(self, name: str, value='', document=(), **terms):
fields = list(doc.iterator())
types = [Field.cast_(field.fieldType()) for field in fields]
noindex = index.IndexOptions.NONE
if any(ft.stored() or ft.indexOptions() != noindex or Field.dimensions.fget(ft) for ft in types):
if any(
ft.stored() or ft.indexOptions() != noindex or Field.dimensions.fget(ft) for ft in types
):
self.updateDocument(term, doc)
elif fields:
self.updateDocValues(term, *fields)
Expand Down
12 changes: 9 additions & 3 deletions lupyne/engine/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def disjunct(cls, multiplier, *queries, **terms):
"""Return lucene DisjunctionMaxQuery from queries and terms."""
queries = list(queries)
for name, values in terms.items():
queries += (cls.term(name, value) for value in ([values] if isinstance(values, str) else values))
if isinstance(values, str):
values = [values]
queries += (cls.term(name, value) for value in values)
return cls(search.DisjunctionMaxQuery, Arrays.asList(queries), multiplier)

@classmethod
Expand All @@ -73,7 +75,10 @@ def span(cls, *term) -> 'SpanQuery':
def near(cls, name: str, *values, **kwargs) -> 'SpanQuery':
"""Return [SpanNearQuery][lupyne.engine.queries.SpanQuery.near] from terms.
Term values which supply another field name will be masked."""
spans = (cls.span(name, value) if isinstance(value, str) else cls.span(*value).mask(name) for value in values)
spans = (
cls.span(name, value) if isinstance(value, str) else cls.span(*value).mask(name)
for value in values
)
return SpanQuery.near(*spans, **kwargs)

@classmethod
Expand Down Expand Up @@ -265,7 +270,8 @@ def __getitem__(self, id: int):
class SortedNumeric(Sorted):
def __getitem__(self, id: int):
if self.docvalues.advanceExact(id):
return tuple(self.type(self.docvalues.nextValue()) for _ in range(self.docvalues.docValueCount()))
indices = range(self.docvalues.docValueCount())
return tuple(self.type(self.docvalues.nextValue()) for _ in indices)

class SortedSet(Sorted):
def __getitem__(self, id: int):
Expand Down
Loading

0 comments on commit cf46aae

Please sign in to comment.