diff --git a/src/sourmash/minhash.py b/src/sourmash/minhash.py index a1ac06c5c3..da0a06704d 100644 --- a/src/sourmash/minhash.py +++ b/src/sourmash/minhash.py @@ -918,10 +918,10 @@ def std_abundance(self): return None @property - def covered_bp(self): + def unique_covered_bp(self): if not self.scaled: raise TypeError("can only calculate bp for scaled MinHashes") - return len(self.hashes) * self.scaled + return len(self.hashes) * self.scaled + (self.ksize - 1) class FrozenMinHash(MinHash): diff --git a/src/sourmash/search.py b/src/sourmash/search.py index 9535b1f68c..0f56d6b473 100644 --- a/src/sourmash/search.py +++ b/src/sourmash/search.py @@ -355,8 +355,8 @@ def init_sigcomparison(self): self.get_cmpinfo() # grab comparison metadata self.intersect_bp = self.cmp.intersect_bp self.max_containment = self.cmp.max_containment - self.query_bp = self.mh1.covered_bp - self.match_bp = self.mh2.covered_bp + self.query_bp = self.mh1.unique_covered_bp + self.match_bp = self.mh2.unique_covered_bp self.threshold = self.threshold_bp self.estimate_containment_ani() @@ -454,7 +454,7 @@ def build_gather_result(self): # this affects estimation of original query information, and requires us to pass in orig_query_len and orig_query_abunds. # we also need to overwrite self.query_bp, self.query_n_hashes, and self.query_abundance # todo: find a better solution? - self.query_bp = self.orig_query_len * self.query.minhash.scaled + self.query_bp = self.orig_query_len * self.query.minhash.scaled + self.ksize - 1 self.query_n_hashes = self.orig_query_len # calculate intersection with query hashes: @@ -473,7 +473,7 @@ def build_gather_result(self): self.f_unique_to_query = len(self.gather_comparison.intersect_mh)/self.orig_query_len # here, need to make sure to use the mh1_cmp (bc was downsampled to cmp_scaled) - self.remaining_bp = (self.gather_comparison.mh1_cmp.covered_bp - self.gather_comparison.intersect_bp) + self.remaining_bp = (self.gather_comparison.mh1_cmp.unique_covered_bp - self.gather_comparison.intersect_bp) # calculate stats on abundances, if desired. self.average_abund, self.median_abund, self.std_abund = None, None, None @@ -643,7 +643,7 @@ def __init__(self, query, counters, *, # track original query information for later usage? track_abundance = query.minhash.track_abundance and not ignore_abundance self.orig_query = query - self.orig_query_bp = len(query.minhash) * query.minhash.scaled + self.orig_query_bp = query.minhash.unique_covered_bp #len(query.minhash) * query.minhash.scaled self.orig_query_filename = query.filename self.orig_query_name = query.name self.orig_query_md5 = query.md5sum()[:8] diff --git a/src/sourmash/sketchcomparison.py b/src/sourmash/sketchcomparison.py index 1d50f7833d..a910468007 100644 --- a/src/sourmash/sketchcomparison.py +++ b/src/sourmash/sketchcomparison.py @@ -104,7 +104,7 @@ def pass_threshold(self): @property def intersect_bp(self): - return len(self.intersect_mh) * self.cmp_scaled + return (len(self.intersect_mh) * self.cmp_scaled) + (self.ksize - 1) @property def mh1_containment(self): diff --git a/tests/test_minhash.py b/tests/test_minhash.py index fdcaea60c9..0a37ba953b 100644 --- a/tests/test_minhash.py +++ b/tests/test_minhash.py @@ -2804,7 +2804,7 @@ def test_std_abundance(track_abundance): assert not mh2.std_abundance -def test_covered_bp(track_abundance): +def test_unique_covered_bp(track_abundance): "test covered_bp" mh1 = MinHash(0, 21, scaled=1, track_abundance=track_abundance) mh2 = MinHash(4, 21, track_abundance=track_abundance) @@ -2813,9 +2813,9 @@ def test_covered_bp(track_abundance): mh1.add_many((1, 2)) mh2.add_many((1, 5)) - assert mh1.covered_bp == 4 # hmmm... + assert mh1.unique_covered_bp == 24 with pytest.raises(TypeError) as exc: - mh2.covered_bp + mh2.unique_covered_bp assert "can only calculate bp for scaled MinHashes" in str(exc) diff --git a/tests/test_prefetch.py b/tests/test_prefetch.py index 254b76d177..5efe2ed362 100644 --- a/tests/test_prefetch.py +++ b/tests/test_prefetch.py @@ -193,7 +193,7 @@ def test_prefetch_csv_out(runtmp, linear_gather): assert c.last_result.status == 0 assert os.path.exists(csvout) - expected_intersect_bp = [2529000, 5177000] + expected_intersect_bp = [2529030, 5177030] with open(csvout, 'rt', newline="") as fp: r = csv.DictReader(fp) for (row, expected) in zip(r, expected_intersect_bp): diff --git a/tests/test_search.py b/tests/test_search.py index 0d765e0f96..eea527e24c 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -270,8 +270,6 @@ def test_scaledSearchResult(): assert res.cmp_scaled == 1000 assert res.query_abundance == ss47.minhash.track_abundance assert res.match_abundance == ss4763.minhash.track_abundance -# assert res.query_bp == len(ss47.minhash) * scaled -# assert res.match_bp == len(ss4763.minhash) * scaled assert res.ksize == 31 assert res.moltype == 'DNA' assert res.query_filename == '47.fa' @@ -376,7 +374,7 @@ def test_PrefetchResult(): scaled = ss47.minhash.scaled intersect_mh = ss47.minhash.intersection(ss4763.minhash) - intersect_bp = len(intersect_mh) * scaled + intersect_bp = len(intersect_mh) * scaled + ss47.minhash.ksize + 1 jaccard=ss4763.jaccard(ss47) max_containment=ss4763.max_containment(ss47) f_match_query=ss47.contained_by(ss4763) @@ -393,8 +391,8 @@ def test_PrefetchResult(): assert res.cmp_scaled == 1000 assert res.query_abundance == ss47.minhash.track_abundance assert res.match_abundance == ss4763.minhash.track_abundance - assert res.query_bp == len(ss47.minhash) * scaled - assert res.match_bp == len(ss4763.minhash) * scaled + assert res.query_bp == len(ss47.minhash) * scaled + ss47.minhash.ksize + 1 + assert res.match_bp == len(ss4763.minhash) * scaled + ss4763.minhash.ksize + 1 assert res.ksize == 31 assert res.moltype == 'DNA' assert res.query_filename == '47.fa' @@ -447,9 +445,8 @@ def test_GatherResult(): remaining_mh = ss4763.minhash.to_mutable() remaining_mh.remove_many(intersect_mh) - intersect_bp = len(intersect_mh) * scaled + intersect_bp = (len(intersect_mh) * scaled) + ss47.minhash.ksize + 1 max_containment=ss4763.max_containment(ss47) - f_match_query = ss47.contained_by(ss4763) orig_query_abunds = ss47.minhash.hashes queryc_ani = ss47.containment_ani(ss4763) matchc_ani = ss4763.containment_ani(ss47) @@ -472,8 +469,8 @@ def test_GatherResult(): assert res.cmp_scaled == 1000 assert res.query_abundance == ss47.minhash.track_abundance assert res.match_abundance == ss4763.minhash.track_abundance - assert res.query_bp == len(ss47.minhash) * scaled - assert res.match_bp == len(ss4763.minhash) * scaled + assert res.query_bp == ss47.minhash.unique_covered_bp + assert res.match_bp == ss4763.minhash.unique_covered_bp assert res.ksize == 31 assert res.moltype == 'DNA' assert res.query_filename == 'podar-ref/47.fa' diff --git a/tests/test_sketchcomparison.py b/tests/test_sketchcomparison.py index e84926993b..753233fe0a 100644 --- a/tests/test_sketchcomparison.py +++ b/tests/test_sketchcomparison.py @@ -42,7 +42,7 @@ def test_FracMinHashComparison(track_abundance): assert cmp.jaccard == a.jaccard(b) == b.jaccard(a) intersect_mh = a.flatten().intersection(b.flatten()) assert cmp.intersect_mh == intersect_mh == b.flatten().intersection(a.flatten()) - assert cmp.intersect_bp == 4 + assert cmp.intersect_bp == 24 assert cmp.pass_threshold # default threshold is 0; this should pass if track_abundance: assert cmp.angular_similarity == a.angular_similarity(b) == b.angular_similarity(a) @@ -100,7 +100,7 @@ def test_FracMinHashComparison_downsample(track_abundance): assert cmp.jaccard == ds_a.jaccard(ds_b) == ds_b.jaccard(ds_a) intersect_mh = ds_a.flatten().intersection(ds_b.flatten()) assert cmp.intersect_mh == intersect_mh == ds_b.flatten().intersection(ds_a.flatten()) - assert cmp.intersect_bp == 8 + assert cmp.intersect_bp == 28 assert cmp.pass_threshold # default threshold is 0; this should pass if track_abundance: assert cmp.angular_similarity == ds_a.angular_similarity(ds_b) == ds_b.angular_similarity(ds_a) @@ -158,7 +158,7 @@ def test_FracMinHashComparison_autodownsample(track_abundance): assert cmp.jaccard == ds_a.jaccard(ds_b) == ds_b.jaccard(ds_a) intersect_mh = ds_a.flatten().intersection(ds_b.flatten()) assert cmp.intersect_mh == intersect_mh == ds_b.flatten().intersection(ds_a.flatten()) - assert cmp.intersect_bp == 8 + assert cmp.intersect_bp == 28 assert cmp.pass_threshold # default threshold is 0; this should pass if track_abundance: assert cmp.angular_similarity == ds_a.angular_similarity(ds_b) == ds_b.angular_similarity(ds_a) @@ -215,7 +215,7 @@ def test_FracMinHashComparison_ignore_abundance(track_abundance): assert cmp.jaccard == a.jaccard(b) == b.jaccard(a) intersect_mh = ds_a.flatten().intersection(ds_b.flatten()) assert cmp.intersect_mh == intersect_mh == ds_b.flatten().intersection(ds_a.flatten()) - assert cmp.intersect_bp == 8 + assert cmp.intersect_bp == 28 assert cmp.pass_threshold # default threshold is 0; this should pass # with ignore_abundance = True, all of these should not be usable. Do we want errors, or ""/None? with pytest.raises(TypeError) as exc: @@ -252,7 +252,7 @@ def test_FracMinHashComparison_fail_threshold(track_abundance): ds_b = b.flatten().downsample(scaled=cmp_scaled) # build FracMinHashComparison - cmp = FracMinHashComparison(a, b, cmp_scaled = cmp_scaled, threshold_bp=10) + cmp = FracMinHashComparison(a, b, cmp_scaled = cmp_scaled, threshold_bp=40) assert cmp.mh1 == a assert cmp.mh2 == b assert cmp.ignore_abundance == False @@ -266,8 +266,8 @@ def test_FracMinHashComparison_fail_threshold(track_abundance): assert cmp.jaccard == a.jaccard(b) == b.jaccard(a) intersect_mh = ds_a.flatten().intersection(ds_b.flatten()) assert cmp.intersect_mh == intersect_mh == ds_b.flatten().intersection(ds_a.flatten()) - assert cmp.intersect_bp == 8 - assert not cmp.pass_threshold # threshold is 10; this should fail + assert cmp.intersect_bp == 28 + assert not cmp.pass_threshold # threshold is 40; this should fail def test_FracMinHashComparison_potential_false_negative(track_abundance): diff --git a/tests/test_sourmash.py b/tests/test_sourmash.py index a647530c63..59dfa66ed7 100644 --- a/tests/test_sourmash.py +++ b/tests/test_sourmash.py @@ -2872,8 +2872,8 @@ def test_gather_csv(runtmp, linear_gather, prefetch_gather): reader = csv.DictReader(fp) row = next(reader) print(row) - assert float(row['intersect_bp']) == 910 - assert float(row['unique_intersect_bp']) == 910 + assert float(row['intersect_bp']) == 940 + assert float(row['unique_intersect_bp']) == 940 assert float(row['remaining_bp']) == 0 assert float(row['f_orig_query']) == 1.0 assert float(row['f_unique_to_query']) == 1.0 @@ -2885,7 +2885,7 @@ def test_gather_csv(runtmp, linear_gather, prefetch_gather): assert row['query_filename'].endswith('short2.fa') assert row['query_name'] == 'tr1 4' assert row['query_md5'] == 'c9d5a795' - assert row['query_bp'] == '910' + assert row['query_bp'] == '940' def test_gather_abund_x_abund(runtmp, prefetch_gather, linear_gather): @@ -2974,7 +2974,7 @@ def test_gather_multiple_sbts_save_prefetch_csv(runtmp, linear_gather): with open(runtmp.output('prefetch.csv')) as f: output = f.read() print((output,)) - assert '870,0.925531914893617,0.9666666666666667' in output + assert '900,0.925531914893617,0.9666666666666667' in output def test_gather_multiple_sbts_save_prefetch_and_prefetch_csv(runtmp, linear_gather): @@ -3004,7 +3004,7 @@ def test_gather_multiple_sbts_save_prefetch_and_prefetch_csv(runtmp, linear_gath with open(runtmp.output('prefetch.csv')) as f: output = f.read() print((output,)) - assert '870,0.925531914893617,0.9666666666666667' in output + assert '900,0.925531914893617,0.9666666666666667' in output assert os.path.exists(runtmp.output('out.zip')) @@ -3048,7 +3048,7 @@ def test_gather_file_output(runtmp, linear_gather, prefetch_gather): with open(runtmp.output('foo.out')) as f: output = f.read() print((output,)) - assert '910,1.0,1.0' in output + assert '940,1.0,1.0' in output def test_gather_f_match_orig(runtmp, linear_gather, prefetch_gather): @@ -3984,7 +3984,7 @@ def test_gather_with_picklist_exclude(runtmp, linear_gather, prefetch_gather): assert "found 9 matches total;" in out assert "4.9 Mbp 33.2% 100.0% NC_003198.1 Salmonella enterica subsp..." in out assert "1.6 Mbp 10.7% 100.0% NC_002163.1 Campylobacter jejuni subs..." in out - assert "4.8 Mbp 10.4% 31.3% NC_003197.2 Salmonella enterica subsp..." in out + assert "4.9 Mbp 10.4% 31.3% NC_003197.2 Salmonella enterica subsp..." in out assert "4.7 Mbp 5.2% 16.1% NC_006905.1 Salmonella enterica subsp..." in out assert "4.7 Mbp 4.0% 12.6% NC_011080.1 Salmonella enterica subsp..." in out assert "4.6 Mbp 2.9% 9.2% NC_011274.1 Salmonella enterica subsp..." in out @@ -4030,7 +4030,7 @@ def test_gather_with_pattern_exclude(runtmp, linear_gather, prefetch_gather): assert "found 9 matches total;" in out assert "4.9 Mbp 33.2% 100.0% NC_003198.1 Salmonella enterica subsp..." in out assert "1.6 Mbp 10.7% 100.0% NC_002163.1 Campylobacter jejuni subs..." in out - assert "4.8 Mbp 10.4% 31.3% NC_003197.2 Salmonella enterica subsp..." in out + assert "4.9 Mbp 10.4% 31.3% NC_003197.2 Salmonella enterica subsp..." in out assert "4.7 Mbp 5.2% 16.1% NC_006905.1 Salmonella enterica subsp..." in out assert "4.7 Mbp 4.0% 12.6% NC_011080.1 Salmonella enterica subsp..." in out assert "4.6 Mbp 2.9% 9.2% NC_011274.1 Salmonella enterica subsp..." in out @@ -4170,7 +4170,7 @@ def test_gather_deduce_moltype(runtmp, linear_gather, prefetch_gather): print(runtmp.last_result.out) print(runtmp.last_result.err) - assert '1.9 kbp 100.0% 100.0%' in runtmp.last_result.out + assert '2.0 kbp 100.0% 100.0%' in runtmp.last_result.out def test_gather_abund_1_1(runtmp, linear_gather, prefetch_gather): @@ -4265,6 +4265,7 @@ def test_gather_abund_10_1(runtmp, prefetch_gather, linear_gather): remaining_bps = [] for n, row in enumerate(r): + print(row) assert int(row['gather_result_rank']) == n overlap = float(row['intersect_bp']) remaining_bp = float(row['remaining_bp']) @@ -4278,6 +4279,9 @@ def test_gather_abund_10_1(runtmp, prefetch_gather, linear_gather): average_abunds.append(average_abund) remaining_bps.append(remaining_bp) + query_sig = sourmash.load_one_signature(query) + query_mh = query_sig.minhash + weighted_calc = [] for (overlap, average_abund) in zip(overlaps, average_abunds): prod = overlap*average_abund @@ -4285,13 +4289,12 @@ def test_gather_abund_10_1(runtmp, prefetch_gather, linear_gather): total_weighted = sum(weighted_calc) for prod, f_weighted in zip(weighted_calc, f_weighted_list): + fw_calc = prod/total_weighted + print(f"prod: {prod}, total_weighted: {total_weighted}, fw_calc: {fw_calc}, f_weighted: {f_weighted}") assert prod / total_weighted == f_weighted, (prod, f_weighted) - query_sig = sourmash.load_one_signature(query) - query_mh = query_sig.minhash - total_bp_analyzed = sum(unique_overlaps) + remaining_bps[-1] - total_query_bp = len(query_mh) * query_mh.scaled + total_query_bp = query_mh.unique_covered_bp # len(query_mh) * query_mh.scaled assert total_bp_analyzed == total_query_bp @@ -5625,8 +5628,8 @@ def test_gather_ani_csv(runtmp, linear_gather, prefetch_gather): print(row) assert gather_result_names == list(row.keys()) assert gather_result_names_ci != list(row.keys()) - assert float(row['intersect_bp']) == 910 - assert float(row['unique_intersect_bp']) == 910 + assert float(row['intersect_bp']) == 940 + assert float(row['unique_intersect_bp']) == 940 assert float(row['remaining_bp']) == 0 assert float(row['f_orig_query']) == 1.0 assert float(row['f_unique_to_query']) == 1.0 @@ -5638,7 +5641,7 @@ def test_gather_ani_csv(runtmp, linear_gather, prefetch_gather): assert row['query_filename'].endswith('short2.fa') assert row['query_name'] == 'tr1 4' assert row['query_md5'] == 'c9d5a795' - assert row['query_bp'] == '910' + assert row['query_bp'] == '940' assert row['query_containment_ani']== '1.0' assert row['match_containment_ani'] == '1.0' assert row['average_containment_ani'] == '1.0' @@ -5672,8 +5675,8 @@ def test_gather_ani_csv_estimate_ci(runtmp, linear_gather, prefetch_gather): row = next(reader) print(row) assert gather_result_names == list(row.keys()) - assert float(row['intersect_bp']) == 910 - assert float(row['unique_intersect_bp']) == 910 + assert float(row['intersect_bp']) == 940 + assert float(row['unique_intersect_bp']) == 940 assert float(row['remaining_bp']) == 0 assert float(row['f_orig_query']) == 1.0 assert float(row['f_unique_to_query']) == 1.0 @@ -5685,7 +5688,7 @@ def test_gather_ani_csv_estimate_ci(runtmp, linear_gather, prefetch_gather): assert row['query_filename'].endswith('short2.fa') assert row['query_name'] == 'tr1 4' assert row['query_md5'] == 'c9d5a795' - assert row['query_bp'] == '910' + assert row['query_bp'] == '940' assert row['query_containment_ani']== '1.0' assert row['query_containment_ani_low']== '' assert row['query_containment_ani_high']== ''