Skip to content

Commit

Permalink
[MRG] update Index protocol tests to include tests for peek and `…
Browse files Browse the repository at this point in the history
…consume` (#2111)

* move most CounterGather tests over to index protocol tests

* add LinearIndex wrapper

* getting closer

* fix a bunch of the tests

* fix call to 'peek'

* adjust 'counter.add' call signature

* add CounterGather_LCA

* move CounterGather.calc_threshold into search.py

* minor refactoring

* resolve downsampling for linear index wrapper

* fix downsampling for LCA-based CounterGather

* fix location foo

* fix remaining test

* minor cleanup

* add doc

* test multiple identical matches

* adjust LinearIndex implementation to skip identical matches

* cleanup protocol tests

* revert LCA_Database fix

* cleanup
  • Loading branch information
ctb committed Jul 16, 2022
1 parent f34bd17 commit 0bc9dbd
Show file tree
Hide file tree
Showing 4 changed files with 735 additions and 593 deletions.
84 changes: 43 additions & 41 deletions src/sourmash/index/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@
import sourmash
from abc import abstractmethod, ABC
from collections import namedtuple, Counter
from collections import defaultdict

from ..search import make_jaccard_search_query, make_gather_query
from ..manifest import CollectionManifest
from ..logging import debug_literal
from ..signature import load_signatures, save_signatures
from sourmash.search import (make_jaccard_search_query, make_gather_query,
calc_threshold_from_bp)
from sourmash.manifest import CollectionManifest
from sourmash.logging import debug_literal
from sourmash.signature import load_signatures, save_signatures

# generic return tuple for Index.search and Index.gather
IndexSearchResult = namedtuple('Result', 'score, signature, location')
Expand Down Expand Up @@ -277,8 +277,13 @@ def gather(self, query, threshold_bp=None, **kwargs):

return results[:1]

def peek(self, query_mh, threshold_bp=0):
"Mimic CounterGather.peek() on top of Index. Yes, this is backwards."
def peek(self, query_mh, *, threshold_bp=0):
"""Mimic CounterGather.peek() on top of Index.
This is implemented for situations where we don't want to use
'prefetch' functionality. It is a light wrapper around the
'gather'/search-by-containment method.
"""
from sourmash import SourmashSignature

# build a signature to use with self.gather...
Expand Down Expand Up @@ -323,7 +328,7 @@ def counter_gather(self, query, threshold_bp, **kwargs):
# find all matches and construct a CounterGather object.
counter = CounterGather(prefetch_query.minhash)
for result in self.prefetch(prefetch_query, threshold_bp, **kwargs):
counter.add(result.signature, result.location)
counter.add(result.signature, location=result.location)

# tada!
return counter
Expand Down Expand Up @@ -701,31 +706,42 @@ def select(self, **kwargs):


class CounterGather:
"""
Track and summarize matches for efficient 'gather' protocol. This
could be used downstream of prefetch (for example).
"""This is an ancillary class that is used to implement "fast
gather", post-prefetch. It tracks and summarize matches for
efficient min-set-cov/'gather'.
The class constructor takes a query MinHash that must be scaled, and
then takes signatures that have overlaps with the query (via 'add').
After all overlapping signatures have been loaded, the 'peek'
method is then used at each stage of the 'gather' procedure to
find the best match, and the 'consume' method is used to remove
a match from this counter.
The public interface is `peek(...)` and `consume(...)` only.
This particular implementation maintains a collections.Counter that
is used to quickly find the best match when 'peek' is called, but
other implementations are possible ;).
"""
def __init__(self, query_mh):
"Constructor - takes a query FracMinHash."
if not query_mh.scaled:
raise ValueError('gather requires scaled signatures')

# track query
self.orig_query_mh = query_mh.copy().flatten()
self.scaled = query_mh.scaled

# track matching signatures & their locations
# use these to track loaded matches & their locations
self.siglist = []
self.locations = []

# ...and overlaps with query
# ...and also track overlaps with the progressive query
self.counter = Counter()

# cannot add matches once query has started.
# fence to make sure we do add matches once query has started.
self.query_started = 0

def add(self, ss, location=None, require_overlap=True):
def add(self, ss, *, location=None, require_overlap=True):
"Add this signature in as a potential match."
if self.query_started:
raise ValueError("cannot add more signatures to counter after peek/consume")
Expand All @@ -748,26 +764,11 @@ def downsample(self, scaled):
"Track highest scaled across all possible matches."
if scaled > self.scaled:
self.scaled = scaled
return self.scaled

def calc_threshold(self, threshold_bp, scaled, query_size):
# CTB: this code doesn't need to be in this class.
threshold = 0.0
n_threshold_hashes = 0

if threshold_bp:
# if we have a threshold_bp of N, then that amounts to N/scaled
# hashes:
n_threshold_hashes = float(threshold_bp) / scaled

# that then requires the following containment:
threshold = n_threshold_hashes / query_size

return threshold, n_threshold_hashes

def peek(self, cur_query_mh, threshold_bp=0):
def peek(self, cur_query_mh, *, threshold_bp=0):
"Get next 'gather' result for this database, w/o changing counters."
self.query_started = 1
scaled = cur_query_mh.scaled

# empty? nothing to search.
counter = self.counter
Expand All @@ -777,38 +778,39 @@ def peek(self, cur_query_mh, threshold_bp=0):
siglist = self.siglist
assert siglist

self.downsample(scaled)
scaled = self.scaled
scaled = self.downsample(cur_query_mh.scaled)
cur_query_mh = cur_query_mh.downsample(scaled=scaled)

if not cur_query_mh: # empty query? quit.
return []

# CTB: could probably remove this check unless debug requested.
if cur_query_mh.contained_by(self.orig_query_mh, downsample=True) < 1:
raise ValueError("current query not a subset of original query")

# are we setting a threshold?
threshold, n_threshold_hashes = self.calc_threshold(threshold_bp,
scaled,
len(cur_query_mh))
threshold, n_threshold_hashes = calc_threshold_from_bp(threshold_bp,
scaled,
len(cur_query_mh))
# is it too high to ever match? if so, exit.
if threshold > 1.0:
return []

# Find the best match -
# Find the best match using the internal Counter.
most_common = counter.most_common()
dataset_id, match_size = most_common[0]

# below threshold? no match!
if match_size < n_threshold_hashes:
return []

## at this point, we must have a legitimate match above threshold!
## at this point, we have a legitimate match above threshold!

# pull match and location.
match = siglist[dataset_id]

# calculate containment
# CTB: this check is probably redundant with intersect_mh calc, below.
cont = cur_query_mh.contained_by(match.minhash, downsample=True)
assert cont
assert cont >= threshold
Expand All @@ -822,7 +824,7 @@ def peek(self, cur_query_mh, threshold_bp=0):
return (IndexSearchResult(cont, match, location), intersect_mh)

def consume(self, intersect_mh):
"Remove the given hashes from this counter."
"Maintain the internal counter by removing the given hashes."
self.query_started = 1

if not intersect_mh:
Expand Down
21 changes: 20 additions & 1 deletion src/sourmash/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,25 @@
from .sketchcomparison import FracMinHashComparison, NumMinHashComparison


def calc_threshold_from_bp(threshold_bp, scaled, query_size):
"""
Convert threshold_bp (threshold in estimated bp) to
fraction of query & minimum number of hashes needed.
"""
threshold = 0.0
n_threshold_hashes = 0

if threshold_bp:
# if we have a threshold_bp of N, then that amounts to N/scaled
# hashes:
n_threshold_hashes = float(threshold_bp) / scaled

# that then requires the following containment:
threshold = n_threshold_hashes / query_size

return threshold, n_threshold_hashes


class SearchType(Enum):
JACCARD = 1
CONTAINMENT = 2
Expand Down Expand Up @@ -621,7 +640,7 @@ def _find_best(counters, query, threshold_bp):

# find the best score across multiple counters, without consuming
for counter in counters:
result = counter.peek(query.minhash, threshold_bp)
result = counter.peek(query.minhash, threshold_bp=threshold_bp)
if result:
(sr, intersect_mh) = result

Expand Down
Loading

0 comments on commit 0bc9dbd

Please sign in to comment.