Skip to content

Commit

Permalink
[MRG] standardize and simplify search, prefetch, gather results by us…
Browse files Browse the repository at this point in the history
…ing dataclasses (#1955)

* add some search/gather/prefetch columns to enable ANI estimation

* fix introduced err

* init SearchResult dataclass

* define generic write_cols

* add prefetchresult class, clean up post_init repetitiveness later

* clean up

* add gatherresult dataclass

* rm unused line

* upd

* init searchresult tests

* use query_n_hashes; remove num

* add basic gatherresult test

* save in progress changes

* closer...

* closer still...

* handle num sketches; clean up unnecessary sig comparison cls

* use base classes properly to simplify

* add tests for multiple rounds of downsampling in prefetch and gather (#1956)

* split sketchcomparison to new file; clean up *Result

* add minhash tests

* init sketchcomparison tests

* test incompatible sketch comparisons

* test failing *Results

* test num SearchResult

* fix calcs for gather

* upd with suggestions from code review

Co-authored-by: C. Titus Brown <titus@idyll.org>
  • Loading branch information
bluegenes and ctb authored Apr 20, 2022
1 parent efc700b commit 53a8fce
Show file tree
Hide file tree
Showing 9 changed files with 1,615 additions and 186 deletions.
39 changes: 12 additions & 27 deletions src/sourmash/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .logging import notify, error, print_results, set_quiet
from .sourmash_args import (FileOutput, FileOutputCSV,
SaveSignaturesToLocation)
from .search import SearchResult, prefetch_database, PrefetchResult, GatherResult, calculate_prefetch_info
from .search import prefetch_database, SearchResult, PrefetchResult, GatherResult
from .index import LazyLinearIndex

WATERMARK_SIZE = 10000
Expand Down Expand Up @@ -533,17 +533,13 @@ def search(args):
notify("** reporting only one match because --best-only was set")

if args.output:
fieldnames = SearchResult._fields

fieldnames = SearchResult.search_write_cols
with FileOutputCSV(args.output) as fp:
w = csv.DictWriter(fp, fieldnames=fieldnames)

w.writeheader()
for sr in results:
d = dict(sr._asdict())
del d['match']
del d['query']
w.writerow(d)
w.writerow(sr.writedict)

# save matching signatures upon request
if args.save_matches:
Expand Down Expand Up @@ -688,7 +684,7 @@ def gather(args):
prefetch_csvout_fp = None
prefetch_csvout_w = None
if args.save_prefetch_csv:
fieldnames = PrefetchResult._fields
fieldnames = PrefetchResult.prefetch_write_cols
prefetch_csvout_fp = FileOutput(args.save_prefetch_csv, 'wt').open()
prefetch_csvout_w = csv.DictWriter(prefetch_csvout_fp, fieldnames=fieldnames)
prefetch_csvout_w.writeheader()
Expand Down Expand Up @@ -717,12 +713,8 @@ def gather(args):
if prefetch_csvout_fp:
assert scaled
# calculate intersection stats and info
prefetch_result = calculate_prefetch_info(prefetch_query, found_sig, scaled, args.threshold_bp)
# remove match and query signatures; write result to prefetch csv
d = dict(prefetch_result._asdict())
del d['match']
del d['query']
prefetch_csvout_w.writerow(d)
prefetch_result = PrefetchResult(prefetch_query, found_sig, cmp_scaled=scaled, threshold_bp=args.threshold_bp)
prefetch_csvout_w.writerow(prefetch_result.writedict)

counters.append(counter)

Expand Down Expand Up @@ -803,14 +795,12 @@ def gather(args):

# save CSV?
if found and args.output:
fieldnames = GatherResult._fields
fieldnames = GatherResult.gather_write_cols
with FileOutputCSV(args.output) as fp:
w = csv.DictWriter(fp, fieldnames=fieldnames)
w.writeheader()
for result in found:
d = dict(result._asdict())
del d['match'] # actual signature not in CSV.
w.writerow(d)
w.writerow(result.writedict)

# save matching signatures?
if found and args.save_matches:
Expand Down Expand Up @@ -970,14 +960,12 @@ def multigather(args):

output_base = os.path.basename(query_filename)
output_csv = output_base + '.csv'
fieldnames = GatherResult._fields
fieldnames = GatherResult.gather_write_cols
with FileOutputCSV(output_csv) as fp:
w = csv.DictWriter(fp, fieldnames=fieldnames)
w.writeheader()
for result in found:
d = dict(result._asdict())
del d['match'] # actual signature not output to CSV!
w.writerow(d)
w.writerow(result.writedict)

output_matches = output_base + '.matches.sig'
with open(output_matches, 'wt') as fp:
Expand Down Expand Up @@ -1174,7 +1162,7 @@ def prefetch(args):
csvout_fp = None
csvout_w = None
if args.output:
fieldnames = PrefetchResult._fields
fieldnames = PrefetchResult.prefetch_write_cols
csvout_fp = FileOutput(args.output, 'wt').open()
csvout_w = csv.DictWriter(csvout_fp, fieldnames=fieldnames)
csvout_w.writeheader()
Expand Down Expand Up @@ -1231,10 +1219,7 @@ def prefetch(args):

# output match info as we go
if csvout_fp:
d = dict(result._asdict())
del d['match'] # actual signatures not in CSV.
del d['query']
csvout_w.writerow(d)
csvout_w.writerow(result.writedict)

# output match signatures as we go (maybe)
matches_out.add(match)
Expand Down
37 changes: 36 additions & 1 deletion src/sourmash/minhash.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ class FrozenMinHash - read-only MinHash class.
from __future__ import unicode_literals, division
from .distance_utils import jaccard_to_distance, containment_to_distance

import numpy as np


__all__ = ['get_minhash_default_seed',
'get_minhash_max_hash',
'hash_murmur',
Expand Down Expand Up @@ -686,6 +689,8 @@ def similarity(self, other, ignore_abundance=False, downsample=False):

def angular_similarity(self, other):
"Calculate the angular similarity."
if not (self.track_abundance and other.track_abundance):
raise TypeError("Error: Angular (cosine) similarity requires both sketches to track hash abundance.")
return self._methodcall(lib.kmerminhash_angular_similarity,
other._get_objptr())

Expand Down Expand Up @@ -854,7 +859,37 @@ def inflate(self, from_mh):

return abund_mh
else:
raise ValueError("inflate operates on a flat MinHash and takes a MinHash object with track_abundance=True")
raise ValueError("inflate operates on a flat MinHash and takes a MinHash object with track_abundance=True")

@property
def sum_abundances(self):
if self.track_abundance:
return sum(v for v in self.hashes.values())
return None

@property
def mean_abundance(self):
if self.track_abundance:
return np.mean(list(self.hashes.values()))
return None

@property
def median_abundance(self):
if self.track_abundance:
return np.median(list(self.hashes.values()))
return None

@property
def std_abundance(self):
if self.track_abundance:
return np.std(list(self.hashes.values()))
return None

@property
def covered_bp(self):
if not self.scaled:
raise TypeError("can only calculate bp for scaled MinHashes")
return len(self.hashes) * self.scaled


class FrozenMinHash(MinHash):
Expand Down
Loading

0 comments on commit 53a8fce

Please sign in to comment.