Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[EXP] minhash unique covered bp #2027

Open
wants to merge 7 commits into
base: latest
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/sourmash/minhash.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,10 +918,10 @@ def std_abundance(self):
return None

@property
def covered_bp(self):
def unique_covered_bp(self):
if not self.scaled:
raise TypeError("can only calculate bp for scaled MinHashes")
return len(self.hashes) * self.scaled
return len(self.hashes) * self.scaled + (self.ksize - 1)


class FrozenMinHash(MinHash):
Expand Down
10 changes: 5 additions & 5 deletions src/sourmash/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,8 +355,8 @@ def init_sigcomparison(self):
self.get_cmpinfo() # grab comparison metadata
self.intersect_bp = self.cmp.intersect_bp
self.max_containment = self.cmp.max_containment
self.query_bp = self.mh1.covered_bp
self.match_bp = self.mh2.covered_bp
self.query_bp = self.mh1.unique_covered_bp
self.match_bp = self.mh2.unique_covered_bp
self.threshold = self.threshold_bp
self.estimate_containment_ani()

Expand Down Expand Up @@ -454,7 +454,7 @@ def build_gather_result(self):
# this affects estimation of original query information, and requires us to pass in orig_query_len and orig_query_abunds.
# we also need to overwrite self.query_bp, self.query_n_hashes, and self.query_abundance
# todo: find a better solution?
self.query_bp = self.orig_query_len * self.query.minhash.scaled
self.query_bp = self.orig_query_len * self.query.minhash.scaled + self.ksize - 1
self.query_n_hashes = self.orig_query_len

# calculate intersection with query hashes:
Expand All @@ -473,7 +473,7 @@ def build_gather_result(self):
self.f_unique_to_query = len(self.gather_comparison.intersect_mh)/self.orig_query_len

# here, need to make sure to use the mh1_cmp (bc was downsampled to cmp_scaled)
self.remaining_bp = (self.gather_comparison.mh1_cmp.covered_bp - self.gather_comparison.intersect_bp)
self.remaining_bp = (self.gather_comparison.mh1_cmp.unique_covered_bp - self.gather_comparison.intersect_bp)

# calculate stats on abundances, if desired.
self.average_abund, self.median_abund, self.std_abund = None, None, None
Expand Down Expand Up @@ -643,7 +643,7 @@ def __init__(self, query, counters, *,
# track original query information for later usage?
track_abundance = query.minhash.track_abundance and not ignore_abundance
self.orig_query = query
self.orig_query_bp = len(query.minhash) * query.minhash.scaled
self.orig_query_bp = query.minhash.unique_covered_bp #len(query.minhash) * query.minhash.scaled
self.orig_query_filename = query.filename
self.orig_query_name = query.name
self.orig_query_md5 = query.md5sum()[:8]
Expand Down
2 changes: 1 addition & 1 deletion src/sourmash/sketchcomparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def pass_threshold(self):

@property
def intersect_bp(self):
return len(self.intersect_mh) * self.cmp_scaled
return (len(self.intersect_mh) * self.cmp_scaled) + (self.ksize - 1)

@property
def mh1_containment(self):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_minhash.py
Original file line number Diff line number Diff line change
Expand Up @@ -2804,7 +2804,7 @@ def test_std_abundance(track_abundance):
assert not mh2.std_abundance


def test_covered_bp(track_abundance):
def test_unique_covered_bp(track_abundance):
"test covered_bp"
mh1 = MinHash(0, 21, scaled=1, track_abundance=track_abundance)
mh2 = MinHash(4, 21, track_abundance=track_abundance)
Expand All @@ -2813,9 +2813,9 @@ def test_covered_bp(track_abundance):
mh1.add_many((1, 2))
mh2.add_many((1, 5))

assert mh1.covered_bp == 4 # hmmm...
assert mh1.unique_covered_bp == 24
with pytest.raises(TypeError) as exc:
mh2.covered_bp
mh2.unique_covered_bp
assert "can only calculate bp for scaled MinHashes" in str(exc)


Expand Down
2 changes: 1 addition & 1 deletion tests/test_prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def test_prefetch_csv_out(runtmp, linear_gather):
assert c.last_result.status == 0
assert os.path.exists(csvout)

expected_intersect_bp = [2529000, 5177000]
expected_intersect_bp = [2529030, 5177030]
with open(csvout, 'rt', newline="") as fp:
r = csv.DictReader(fp)
for (row, expected) in zip(r, expected_intersect_bp):
Expand Down
15 changes: 6 additions & 9 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,6 @@ def test_scaledSearchResult():
assert res.cmp_scaled == 1000
assert res.query_abundance == ss47.minhash.track_abundance
assert res.match_abundance == ss4763.minhash.track_abundance
# assert res.query_bp == len(ss47.minhash) * scaled
# assert res.match_bp == len(ss4763.minhash) * scaled
assert res.ksize == 31
assert res.moltype == 'DNA'
assert res.query_filename == '47.fa'
Expand Down Expand Up @@ -376,7 +374,7 @@ def test_PrefetchResult():
scaled = ss47.minhash.scaled

intersect_mh = ss47.minhash.intersection(ss4763.minhash)
intersect_bp = len(intersect_mh) * scaled
intersect_bp = len(intersect_mh) * scaled + ss47.minhash.ksize + 1
jaccard=ss4763.jaccard(ss47)
max_containment=ss4763.max_containment(ss47)
f_match_query=ss47.contained_by(ss4763)
Expand All @@ -393,8 +391,8 @@ def test_PrefetchResult():
assert res.cmp_scaled == 1000
assert res.query_abundance == ss47.minhash.track_abundance
assert res.match_abundance == ss4763.minhash.track_abundance
assert res.query_bp == len(ss47.minhash) * scaled
assert res.match_bp == len(ss4763.minhash) * scaled
assert res.query_bp == len(ss47.minhash) * scaled + ss47.minhash.ksize + 1
assert res.match_bp == len(ss4763.minhash) * scaled + ss4763.minhash.ksize + 1
assert res.ksize == 31
assert res.moltype == 'DNA'
assert res.query_filename == '47.fa'
Expand Down Expand Up @@ -447,9 +445,8 @@ def test_GatherResult():
remaining_mh = ss4763.minhash.to_mutable()
remaining_mh.remove_many(intersect_mh)

intersect_bp = len(intersect_mh) * scaled
intersect_bp = (len(intersect_mh) * scaled) + ss47.minhash.ksize + 1
max_containment=ss4763.max_containment(ss47)
f_match_query = ss47.contained_by(ss4763)
orig_query_abunds = ss47.minhash.hashes
queryc_ani = ss47.containment_ani(ss4763)
matchc_ani = ss4763.containment_ani(ss47)
Expand All @@ -472,8 +469,8 @@ def test_GatherResult():
assert res.cmp_scaled == 1000
assert res.query_abundance == ss47.minhash.track_abundance
assert res.match_abundance == ss4763.minhash.track_abundance
assert res.query_bp == len(ss47.minhash) * scaled
assert res.match_bp == len(ss4763.minhash) * scaled
assert res.query_bp == ss47.minhash.unique_covered_bp
assert res.match_bp == ss4763.minhash.unique_covered_bp
assert res.ksize == 31
assert res.moltype == 'DNA'
assert res.query_filename == 'podar-ref/47.fa'
Expand Down
14 changes: 7 additions & 7 deletions tests/test_sketchcomparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_FracMinHashComparison(track_abundance):
assert cmp.jaccard == a.jaccard(b) == b.jaccard(a)
intersect_mh = a.flatten().intersection(b.flatten())
assert cmp.intersect_mh == intersect_mh == b.flatten().intersection(a.flatten())
assert cmp.intersect_bp == 4
assert cmp.intersect_bp == 24
assert cmp.pass_threshold # default threshold is 0; this should pass
if track_abundance:
assert cmp.angular_similarity == a.angular_similarity(b) == b.angular_similarity(a)
Expand Down Expand Up @@ -100,7 +100,7 @@ def test_FracMinHashComparison_downsample(track_abundance):
assert cmp.jaccard == ds_a.jaccard(ds_b) == ds_b.jaccard(ds_a)
intersect_mh = ds_a.flatten().intersection(ds_b.flatten())
assert cmp.intersect_mh == intersect_mh == ds_b.flatten().intersection(ds_a.flatten())
assert cmp.intersect_bp == 8
assert cmp.intersect_bp == 28
assert cmp.pass_threshold # default threshold is 0; this should pass
if track_abundance:
assert cmp.angular_similarity == ds_a.angular_similarity(ds_b) == ds_b.angular_similarity(ds_a)
Expand Down Expand Up @@ -158,7 +158,7 @@ def test_FracMinHashComparison_autodownsample(track_abundance):
assert cmp.jaccard == ds_a.jaccard(ds_b) == ds_b.jaccard(ds_a)
intersect_mh = ds_a.flatten().intersection(ds_b.flatten())
assert cmp.intersect_mh == intersect_mh == ds_b.flatten().intersection(ds_a.flatten())
assert cmp.intersect_bp == 8
assert cmp.intersect_bp == 28
assert cmp.pass_threshold # default threshold is 0; this should pass
if track_abundance:
assert cmp.angular_similarity == ds_a.angular_similarity(ds_b) == ds_b.angular_similarity(ds_a)
Expand Down Expand Up @@ -215,7 +215,7 @@ def test_FracMinHashComparison_ignore_abundance(track_abundance):
assert cmp.jaccard == a.jaccard(b) == b.jaccard(a)
intersect_mh = ds_a.flatten().intersection(ds_b.flatten())
assert cmp.intersect_mh == intersect_mh == ds_b.flatten().intersection(ds_a.flatten())
assert cmp.intersect_bp == 8
assert cmp.intersect_bp == 28
assert cmp.pass_threshold # default threshold is 0; this should pass
# with ignore_abundance = True, all of these should not be usable. Do we want errors, or ""/None?
with pytest.raises(TypeError) as exc:
Expand Down Expand Up @@ -252,7 +252,7 @@ def test_FracMinHashComparison_fail_threshold(track_abundance):
ds_b = b.flatten().downsample(scaled=cmp_scaled)

# build FracMinHashComparison
cmp = FracMinHashComparison(a, b, cmp_scaled = cmp_scaled, threshold_bp=10)
cmp = FracMinHashComparison(a, b, cmp_scaled = cmp_scaled, threshold_bp=40)
assert cmp.mh1 == a
assert cmp.mh2 == b
assert cmp.ignore_abundance == False
Expand All @@ -266,8 +266,8 @@ def test_FracMinHashComparison_fail_threshold(track_abundance):
assert cmp.jaccard == a.jaccard(b) == b.jaccard(a)
intersect_mh = ds_a.flatten().intersection(ds_b.flatten())
assert cmp.intersect_mh == intersect_mh == ds_b.flatten().intersection(ds_a.flatten())
assert cmp.intersect_bp == 8
assert not cmp.pass_threshold # threshold is 10; this should fail
assert cmp.intersect_bp == 28
assert not cmp.pass_threshold # threshold is 40; this should fail


def test_FracMinHashComparison_potential_false_negative(track_abundance):
Expand Down
41 changes: 22 additions & 19 deletions tests/test_sourmash.py
Original file line number Diff line number Diff line change
Expand Up @@ -2872,8 +2872,8 @@ def test_gather_csv(runtmp, linear_gather, prefetch_gather):
reader = csv.DictReader(fp)
row = next(reader)
print(row)
assert float(row['intersect_bp']) == 910
assert float(row['unique_intersect_bp']) == 910
assert float(row['intersect_bp']) == 940
assert float(row['unique_intersect_bp']) == 940
assert float(row['remaining_bp']) == 0
assert float(row['f_orig_query']) == 1.0
assert float(row['f_unique_to_query']) == 1.0
Expand All @@ -2885,7 +2885,7 @@ def test_gather_csv(runtmp, linear_gather, prefetch_gather):
assert row['query_filename'].endswith('short2.fa')
assert row['query_name'] == 'tr1 4'
assert row['query_md5'] == 'c9d5a795'
assert row['query_bp'] == '910'
assert row['query_bp'] == '940'


def test_gather_abund_x_abund(runtmp, prefetch_gather, linear_gather):
Expand Down Expand Up @@ -2974,7 +2974,7 @@ def test_gather_multiple_sbts_save_prefetch_csv(runtmp, linear_gather):
with open(runtmp.output('prefetch.csv')) as f:
output = f.read()
print((output,))
assert '870,0.925531914893617,0.9666666666666667' in output
assert '900,0.925531914893617,0.9666666666666667' in output


def test_gather_multiple_sbts_save_prefetch_and_prefetch_csv(runtmp, linear_gather):
Expand Down Expand Up @@ -3004,7 +3004,7 @@ def test_gather_multiple_sbts_save_prefetch_and_prefetch_csv(runtmp, linear_gath
with open(runtmp.output('prefetch.csv')) as f:
output = f.read()
print((output,))
assert '870,0.925531914893617,0.9666666666666667' in output
assert '900,0.925531914893617,0.9666666666666667' in output
assert os.path.exists(runtmp.output('out.zip'))


Expand Down Expand Up @@ -3048,7 +3048,7 @@ def test_gather_file_output(runtmp, linear_gather, prefetch_gather):
with open(runtmp.output('foo.out')) as f:
output = f.read()
print((output,))
assert '910,1.0,1.0' in output
assert '940,1.0,1.0' in output


def test_gather_f_match_orig(runtmp, linear_gather, prefetch_gather):
Expand Down Expand Up @@ -3984,7 +3984,7 @@ def test_gather_with_picklist_exclude(runtmp, linear_gather, prefetch_gather):
assert "found 9 matches total;" in out
assert "4.9 Mbp 33.2% 100.0% NC_003198.1 Salmonella enterica subsp..." in out
assert "1.6 Mbp 10.7% 100.0% NC_002163.1 Campylobacter jejuni subs..." in out
assert "4.8 Mbp 10.4% 31.3% NC_003197.2 Salmonella enterica subsp..." in out
assert "4.9 Mbp 10.4% 31.3% NC_003197.2 Salmonella enterica subsp..." in out
assert "4.7 Mbp 5.2% 16.1% NC_006905.1 Salmonella enterica subsp..." in out
assert "4.7 Mbp 4.0% 12.6% NC_011080.1 Salmonella enterica subsp..." in out
assert "4.6 Mbp 2.9% 9.2% NC_011274.1 Salmonella enterica subsp..." in out
Expand Down Expand Up @@ -4030,7 +4030,7 @@ def test_gather_with_pattern_exclude(runtmp, linear_gather, prefetch_gather):
assert "found 9 matches total;" in out
assert "4.9 Mbp 33.2% 100.0% NC_003198.1 Salmonella enterica subsp..." in out
assert "1.6 Mbp 10.7% 100.0% NC_002163.1 Campylobacter jejuni subs..." in out
assert "4.8 Mbp 10.4% 31.3% NC_003197.2 Salmonella enterica subsp..." in out
assert "4.9 Mbp 10.4% 31.3% NC_003197.2 Salmonella enterica subsp..." in out
assert "4.7 Mbp 5.2% 16.1% NC_006905.1 Salmonella enterica subsp..." in out
assert "4.7 Mbp 4.0% 12.6% NC_011080.1 Salmonella enterica subsp..." in out
assert "4.6 Mbp 2.9% 9.2% NC_011274.1 Salmonella enterica subsp..." in out
Expand Down Expand Up @@ -4170,7 +4170,7 @@ def test_gather_deduce_moltype(runtmp, linear_gather, prefetch_gather):
print(runtmp.last_result.out)
print(runtmp.last_result.err)

assert '1.9 kbp 100.0% 100.0%' in runtmp.last_result.out
assert '2.0 kbp 100.0% 100.0%' in runtmp.last_result.out


def test_gather_abund_1_1(runtmp, linear_gather, prefetch_gather):
Expand Down Expand Up @@ -4265,6 +4265,7 @@ def test_gather_abund_10_1(runtmp, prefetch_gather, linear_gather):
remaining_bps = []

for n, row in enumerate(r):
print(row)
assert int(row['gather_result_rank']) == n
overlap = float(row['intersect_bp'])
remaining_bp = float(row['remaining_bp'])
Expand All @@ -4278,20 +4279,22 @@ def test_gather_abund_10_1(runtmp, prefetch_gather, linear_gather):
average_abunds.append(average_abund)
remaining_bps.append(remaining_bp)

query_sig = sourmash.load_one_signature(query)
query_mh = query_sig.minhash

weighted_calc = []
for (overlap, average_abund) in zip(overlaps, average_abunds):
prod = overlap*average_abund
weighted_calc.append(prod)

total_weighted = sum(weighted_calc)
for prod, f_weighted in zip(weighted_calc, f_weighted_list):
fw_calc = prod/total_weighted
print(f"prod: {prod}, total_weighted: {total_weighted}, fw_calc: {fw_calc}, f_weighted: {f_weighted}")
assert prod / total_weighted == f_weighted, (prod, f_weighted)

query_sig = sourmash.load_one_signature(query)
query_mh = query_sig.minhash

total_bp_analyzed = sum(unique_overlaps) + remaining_bps[-1]
total_query_bp = len(query_mh) * query_mh.scaled
total_query_bp = query_mh.unique_covered_bp # len(query_mh) * query_mh.scaled
assert total_bp_analyzed == total_query_bp


Expand Down Expand Up @@ -5625,8 +5628,8 @@ def test_gather_ani_csv(runtmp, linear_gather, prefetch_gather):
print(row)
assert gather_result_names == list(row.keys())
assert gather_result_names_ci != list(row.keys())
assert float(row['intersect_bp']) == 910
assert float(row['unique_intersect_bp']) == 910
assert float(row['intersect_bp']) == 940
assert float(row['unique_intersect_bp']) == 940
assert float(row['remaining_bp']) == 0
assert float(row['f_orig_query']) == 1.0
assert float(row['f_unique_to_query']) == 1.0
Expand All @@ -5638,7 +5641,7 @@ def test_gather_ani_csv(runtmp, linear_gather, prefetch_gather):
assert row['query_filename'].endswith('short2.fa')
assert row['query_name'] == 'tr1 4'
assert row['query_md5'] == 'c9d5a795'
assert row['query_bp'] == '910'
assert row['query_bp'] == '940'
assert row['query_containment_ani']== '1.0'
assert row['match_containment_ani'] == '1.0'
assert row['average_containment_ani'] == '1.0'
Expand Down Expand Up @@ -5672,8 +5675,8 @@ def test_gather_ani_csv_estimate_ci(runtmp, linear_gather, prefetch_gather):
row = next(reader)
print(row)
assert gather_result_names == list(row.keys())
assert float(row['intersect_bp']) == 910
assert float(row['unique_intersect_bp']) == 910
assert float(row['intersect_bp']) == 940
assert float(row['unique_intersect_bp']) == 940
assert float(row['remaining_bp']) == 0
assert float(row['f_orig_query']) == 1.0
assert float(row['f_unique_to_query']) == 1.0
Expand All @@ -5685,7 +5688,7 @@ def test_gather_ani_csv_estimate_ci(runtmp, linear_gather, prefetch_gather):
assert row['query_filename'].endswith('short2.fa')
assert row['query_name'] == 'tr1 4'
assert row['query_md5'] == 'c9d5a795'
assert row['query_bp'] == '910'
assert row['query_bp'] == '940'
assert row['query_containment_ani']== '1.0'
assert row['query_containment_ani_low']== ''
assert row['query_containment_ani_high']== ''
Expand Down