From a3389bf51e2005d073b0319c18d582f9dbd9f836 Mon Sep 17 00:00:00 2001 From: "C. Titus Brown" Date: Sun, 3 Apr 2022 09:10:05 -0700 Subject: [PATCH 1/2] add LCA database test for tricky ordering --- src/sourmash/lca/lca_db.py | 5 ++++- tests/test_lca.py | 44 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/src/sourmash/lca/lca_db.py b/src/sourmash/lca/lca_db.py index fb9119def4..8f88d0c11f 100644 --- a/src/sourmash/lca/lca_db.py +++ b/src/sourmash/lca/lca_db.py @@ -510,10 +510,13 @@ def find(self, search_fn, query, **kwargs): score = search_fn.score_fn(query_size, shared_size, subj_size, total_size) - # note to self: even with JaccardSearchBestOnly, this will + # CTB note to self: even with JaccardSearchBestOnly, this will # still iterate over & score all signatures. We should come # up with a protocol by which the JaccardSearch object can # signal that it is done, or something. + # For example, see test_lca_jaccard_ordering, where + # for containment we could be done early, but for Jaccard we + # cannot. if search_fn.passes(score): if search_fn.collect(score, subj): if passes_all_picklists(subj, self.picklists): diff --git a/tests/test_lca.py b/tests/test_lca.py index 03a4ff6650..990de48864 100644 --- a/tests/test_lca.py +++ b/tests/test_lca.py @@ -2670,3 +2670,47 @@ def test_lca_index_with_picklist_exclude(runtmp): assert len(siglist) == 9 for ss in siglist: assert 'Thermotoga' not in ss.name + + +def test_lca_jaccard_ordering(): + # this tests a tricky situation where for three sketches A, B, C, + # |A intersect B| is greater than |A intersect C| + # _but_ + # |A jaccard B| is less than |A intersect B| + a = sourmash.MinHash(ksize=31, n=0, scaled=2) + b = a.copy_and_clear() + c = a.copy_and_clear() + + a.add_many([1, 2, 3, 4]) + b.add_many([1, 2, 3] + list(range(10, 30))) + c.add_many([1, 5]) + + def _intersect(x, y): + return x.intersection_and_union_size(y)[0] + + print('a intersect b:', _intersect(a, b)) + print('a intersect c:', _intersect(a, c)) + print('a jaccard b:', a.jaccard(b)) + print('a jaccard c:', a.jaccard(c)) + assert _intersect(a, b) > _intersect(a, c) + assert a.jaccard(b) < a.jaccard(c) + + # thresholds to use: + assert a.jaccard(b) < 0.15 + assert a.jaccard(c) > 0.15 + + # now - make signatures, try out :) + ss_a = sourmash.SourmashSignature(a, name='A') + ss_b = sourmash.SourmashSignature(b, name='B') + ss_c = sourmash.SourmashSignature(c, name='C') + + db = sourmash.lca.LCA_Database(ksize=31, scaled=2) + db.insert(ss_a) + db.insert(ss_b) + db.insert(ss_c) + + sr = db.search(ss_a, threshold=0.15) + print(sr) + assert len(sr) == 2 + assert sr[0].signature == ss_a + assert sr[1].signature == ss_c From 628d72206a50a58cc5a9048b2807b721544f4bb3 Mon Sep 17 00:00:00 2001 From: "C. Titus Brown" Date: Sun, 3 Apr 2022 09:14:40 -0700 Subject: [PATCH 2/2] add test for jaccard ordering to SBTs --- tests/test_sbt.py | 47 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/tests/test_sbt.py b/tests/test_sbt.py index cb5b043c91..bbe49e85fd 100644 --- a/tests/test_sbt.py +++ b/tests/test_sbt.py @@ -921,7 +921,7 @@ def test_gather_single_return(c): sig47 = load_one_signature(sig47file, ksize=31) sig63 = load_one_signature(sig63file, ksize=31) - # construct LCA Database + # construct SBT Database factory = GraphFactory(31, 1e5, 4) tree = SBT(factory, d=2) @@ -937,6 +937,51 @@ def test_gather_single_return(c): assert results[0][0] == 1.0 +def test_sbt_jaccard_ordering(runtmp): + # this tests a tricky situation where for three sketches A, B, C, + # |A intersect B| is greater than |A intersect C| + # _but_ + # |A jaccard B| is less than |A intersect B| + a = sourmash.MinHash(ksize=31, n=0, scaled=2) + b = a.copy_and_clear() + c = a.copy_and_clear() + + a.add_many([1, 2, 3, 4]) + b.add_many([1, 2, 3] + list(range(10, 30))) + c.add_many([1, 5]) + + def _intersect(x, y): + return x.intersection_and_union_size(y)[0] + + print('a intersect b:', _intersect(a, b)) + print('a intersect c:', _intersect(a, c)) + print('a jaccard b:', a.jaccard(b)) + print('a jaccard c:', a.jaccard(c)) + assert _intersect(a, b) > _intersect(a, c) + assert a.jaccard(b) < a.jaccard(c) + + # thresholds to use: + assert a.jaccard(b) < 0.15 + assert a.jaccard(c) > 0.15 + + # now - make signatures, try out :) + ss_a = sourmash.SourmashSignature(a, name='A') + ss_b = sourmash.SourmashSignature(b, name='B') + ss_c = sourmash.SourmashSignature(c, name='C') + + factory = GraphFactory(31, 1e5, 4) + db = SBT(factory, d=2) + db.insert(ss_a) + db.insert(ss_b) + db.insert(ss_c) + + sr = db.search(ss_a, threshold=0.15) + print(sr) + assert len(sr) == 2 + assert sr[0].signature == ss_a + assert sr[1].signature == ss_c + + def test_sbt_protein_command_index(runtmp): c = runtmp