Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] do not report untrusted jaccard ANI #2011

Merged
merged 5 commits into from
May 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/sourmash/distance_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class jaccardANIResult(ANIResult):
"""Class for distance/ANI from jaccard (includes jaccard_error)."""
jaccard_error: float = None
je_threshold: float = 1e-4
return_ani_despite_threshold: bool = False

def __post_init__(self):
# check values
Expand All @@ -72,6 +73,13 @@ def __post_init__(self):
else:
raise ValueError("Error: jaccard_error cannot be None.")

@property
def ani(self):
# if jaccard error is too high (exceeds threshold), do not trust ANI estimate
if self.je_exceeds_threshold and not self.return_ani_despite_threshold:
return ""
return 1 - self.dist


@dataclass
class ciANIResult(ANIResult):
Expand Down
3 changes: 0 additions & 3 deletions src/sourmash/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,9 +321,6 @@ def estimate_search_ani(self):
elif self.searchtype == SearchType.JACCARD:
self.cmp.estimate_jaccard_ani(jaccard=self.similarity)
self.ani = self.cmp.jaccard_ani
# Jaccard error was too high for ANI estimation.
# Just report, or do we want to do something else?
self.ani_untrustworthy = self.cmp.jaccard_ani_untrustworthy
# this can be set from any of the above
self.potential_false_negative = self.cmp.potential_false_negative

Expand Down
23 changes: 13 additions & 10 deletions tests/test_distance_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,18 @@ def test_aniresult_bad_distance():


def test_jaccard_aniresult():
res = jaccardANIResult(0.4, 0.1, jaccard_error=0.03)
assert res.dist == 0.4
res = jaccardANIResult(0.4, 0.1, jaccard_error=0.03, return_ani_despite_threshold=True)
res2 = jaccardANIResult(0.4, 0.1, jaccard_error=0.03)
assert res.dist == res2.dist == 0.4
assert res.ani == 0.6
assert res.p_nothing_in_common == 0.1
assert res2.ani == ""
assert res.p_nothing_in_common == res2.p_nothing_in_common == 0.1
assert res.jaccard_error == 0.03
assert res.p_exceeds_threshold ==True
assert res.je_exceeds_threshold ==True
res2 = jaccardANIResult(0.4, 0.1, jaccard_error=0.03, je_threshold=0.1)
assert res2.je_exceeds_threshold ==False
res3 = jaccardANIResult(0.4, 0.1, jaccard_error=0.03, je_threshold=0.1)
assert res3.je_exceeds_threshold ==False
assert res3.ani == 0.6


def test_jaccard_aniresult_nojaccarderror():
Expand Down Expand Up @@ -260,8 +263,8 @@ def test_jaccard_to_distance_scaled():
res = jaccard_to_distance(jaccard,ksize,scaled,n_unique_kmers=nkmers)
print(res)
# check results
assert res.dist == 0.019122659390482077
assert res.ani == 0.9808773406095179
assert round(res.dist, 3) == round(0.019122659390482077, 3)
assert res.ani == ""
assert res.p_exceeds_threshold == False
assert res.jaccard_error == 0.00018351337045518042
assert res.je_exceeds_threshold ==True
Expand All @@ -282,12 +285,12 @@ def test_jaccard_to_distance_k31():
res = jaccard_to_distance(jaccard,ksize,scaled,n_unique_kmers=nkmers)
print(res)
# check results
assert res.ani == 0.9870056455892898
assert res.p_exceeds_threshold == False
assert res.je_exceeds_threshold ==True
assert res.ani == ""
assert res.p_exceeds_threshold == False
res2 = jaccard_to_distance(jaccard,ksize,scaled,n_unique_kmers=nkmers, err_threshold=0.1)
assert res2.ani == res.ani
assert res2.je_exceeds_threshold == False
assert res2.ani == 0.9870056455892898


def test_jaccard_to_distance_k31_2():
Expand Down
14 changes: 14 additions & 0 deletions tests/test_minhash.py
Original file line number Diff line number Diff line change
Expand Up @@ -2914,6 +2914,20 @@ def test_jaccard_ANI():
assert (m1_jani_m2.ani, m1_jani_m2.p_nothing_in_common, m1_jani_m2.jaccard_error) == (0.9783711630110239, 0.0, 3.891666770716877e-07)


def test_jaccard_ANI_untrustworthy():
f1 = utils.get_test_data('2.fa.sig')
f2 = utils.get_test_data('2+63.fa.sig')
mh1 = sourmash.load_one_signature(f1, ksize=31).minhash
mh2 = sourmash.load_one_signature(f2).minhash

print("\nJACCARD_ANI", mh1.jaccard_ani(mh2))

m1_jani_m2 = mh1.jaccard_ani(mh2, err_threshold=1e-7)
assert m1_jani_m2.ani == ""
assert m1_jani_m2.je_exceeds_threshold==True
assert m1_jani_m2.je_threshold == 1e-7


def test_jaccard_ANI_precalc_jaccard():
f1 = utils.get_test_data('2.fa.sig')
f2 = utils.get_test_data('2+63.fa.sig')
Expand Down
14 changes: 14 additions & 0 deletions tests/test_signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,20 @@ def test_jaccard_ANI():
assert (s1_jani_s2.ani, s1_jani_s2.p_nothing_in_common, s1_jani_s2.jaccard_error) == (0.9783711630110239, 0.0, 3.891666770716877e-07)


def test_jaccard_ANI_untrustworthy():
f1 = utils.get_test_data('2.fa.sig')
f2 = utils.get_test_data('2+63.fa.sig')
ss1 = sourmash.load_one_signature(f1, ksize=31)
ss2 = sourmash.load_one_signature(f2)

print("\nJACCARD_ANI", ss1.jaccard_ani(ss2))

s1_jani_s2 = ss1.jaccard_ani(ss2, err_threshold=1e-7)
assert s1_jani_s2.ani == ""
assert s1_jani_s2.je_exceeds_threshold==True
assert s1_jani_s2.je_threshold == 1e-7


def test_jaccard_ANI_precalc_jaccard():
f1 = utils.get_test_data('2.fa.sig')
f2 = utils.get_test_data('2+63.fa.sig')
Expand Down
96 changes: 61 additions & 35 deletions tests/test_sourmash.py
Original file line number Diff line number Diff line change
Expand Up @@ -5321,6 +5321,32 @@ def test_standalone_manifest_search_fail(runtmp):

@utils.in_tempdir
def test_search_ani_jaccard(c):
sig47 = utils.get_test_data('47.fa.sig')
sig4763 = utils.get_test_data('47+63.fa.sig')

c.run_sourmash('search', sig47, sig4763, '-o', 'xxx.csv')
print(c.last_result.status, c.last_result.out, c.last_result.err)

search_result_names = SearchResult.search_write_cols

csv_file = c.output('xxx.csv')

with open(csv_file) as fp:
reader = csv.DictReader(fp)
row = next(reader)
print(row)
assert search_result_names == list(row.keys())
assert float(row['similarity']) == 0.6564798376870403
assert row['filename'].endswith('47+63.fa.sig')
assert row['md5'] == '491c0a81b2cfb0188c0d3b46837c2f42'
assert row['query_filename'].endswith('47.fa')
assert row['query_name'] == 'NC_009665.1 Shewanella baltica OS185, complete genome'
assert row['query_md5'] == '09a08691'
assert row['ani'] == "0.992530907924384"


@utils.in_tempdir
def test_search_ani_jaccard_error_too_high(c):
testdata1 = utils.get_test_data('short.fa')
testdata2 = utils.get_test_data('short2.fa')
c.run_sourmash('sketch', 'dna', '-p', 'k=31,scaled=1', testdata1, testdata2)
Expand All @@ -5343,7 +5369,8 @@ def test_search_ani_jaccard(c):
assert row['query_filename'].endswith('short.fa')
assert row['query_name'] == ''
assert row['query_md5'] == '9191284a'
assert row['ani'] == "0.9987884602947684"
#assert row['ani'] == "0.9987884602947684"
assert row['ani'] == ""


@utils.in_tempdir
Expand Down Expand Up @@ -5522,54 +5549,53 @@ def test_search_ani_max_containment_estimate_ci(c):

@utils.in_tempdir
def test_search_jaccard_ani_downsample(c):
testdata1 = utils.get_test_data('short.fa')
testdata2 = utils.get_test_data('short2.fa')
sig1_out = c.output('short.fa.sig')
sig2_out = c.output('short2.fa.sig')
c.run_sourmash('sketch', 'dna', '-p', 'k=31,scaled=2', '--force', testdata1, '-o', sig1_out)
c.run_sourmash('sketch', 'dna', '-p', 'k=31,scaled=1', '--force', testdata1, '-o', sig2_out)
sig1 = sourmash.load_one_signature(sig1_out)
sig2 = sourmash.load_one_signature(sig2_out)
print(f"SCALED: sig1: {sig1.minhash.scaled}, sig2: {sig2.minhash.scaled}") # if don't change name, just reads prior sigfile!!?

sig1F = c.output('sig1.sig')
sig2F = c.output('sig2.sig')
c.run_sourmash('sketch', 'dna', '-p', 'k=31,scaled=2', '--force', testdata1, '-o', sig1F)
c.run_sourmash('sketch', 'dna', '-p', 'k=31,scaled=1', '--force', testdata2, '-o', sig2F)

sig1 = sourmash.load_one_signature(sig1F)
sig2 = sourmash.load_one_signature(sig2F)
print(f"SCALED: sig1: {sig1.minhash.scaled}, sig2: {sig2.minhash.scaled}")

c.run_sourmash('search', sig1F, sig2F, '-o', 'xdx.csv')
sig47 = utils.get_test_data('47.fa.sig')
sig4763 = utils.get_test_data('47+63.fa.sig')
ss47 = sourmash.load_one_signature(sig47)
ss4763 = sourmash.load_one_signature(sig4763)
print(f"SCALED: sig1: {ss47.minhash.scaled}, sig2: {ss4763.minhash.scaled}")

c.run_sourmash('search', sig47, sig4763, '-o', 'xxx.csv')
print(c.last_result.status, c.last_result.out, c.last_result.err)

csv_file = c.output('xdx.csv')
search_result_names = SearchResult.search_write_cols
search_result_names_ci = SearchResult.search_write_cols_ci

csv_file = c.output('xxx.csv')

with open(csv_file) as fp:
reader = csv.DictReader(fp)
row = next(reader)
print(row)
assert search_result_names == list(row.keys())
assert search_result_names_ci != list(row.keys())
assert float(row['similarity']) == 0.9296066252587992
assert row['md5'] == 'bf752903d635b1eb83c53fe4aae951db'
assert row['filename'].endswith('sig2.sig')
assert row['query_filename'].endswith('short.fa')
assert row['query_name'] == ''
assert row['query_md5'] == '8f74b0b8'
assert row['ani'] == "0.9988019200011651"
assert float(row['similarity']) == 0.6564798376870403
assert row['filename'].endswith('47+63.fa.sig')
assert row['md5'] == '491c0a81b2cfb0188c0d3b46837c2f42'
assert row['query_filename'].endswith('47.fa')
assert row['query_name'] == 'NC_009665.1 Shewanella baltica OS185, complete genome'
assert row['query_md5'] == '09a08691'
assert row['ani'] == "0.992530907924384"

# downsample one and check similarity and ANI
ds_sig47 = c.output("ds_sig47.sig")
c.run_sourmash('sig', "downsample", sig47, "--scaled", "2000", '-o', ds_sig47)
c.run_sourmash('search', ds_sig47, sig4763, '-o', 'xxx.csv')

csv_file = c.output('xxx.csv')
with open(csv_file) as fp:
reader = csv.DictReader(fp)
row = next(reader)
print(row)
assert round(float(row['similarity']), 3) == round(0.6634517766497462, 3)
assert round(float(row['ani']), 3) == round(0.992530907924384, 3)

#downsample manually and assert same ANI
mh1 = sig1.minhash
mh2 = sig2.minhash
mh2_sc2 = mh2.downsample(scaled=mh1.scaled)
print("SCALED:", mh1.scaled, mh2_sc2.scaled)
ani_info = mh1.jaccard_ani(mh2_sc2)
ss47_ds = signature.load_one_signature(ds_sig47)
print("SCALED:", ss47_ds.minhash.scaled, ss4763.minhash.scaled)
ani_info = ss47_ds.jaccard_ani(ss4763, downsample=True)
print(ani_info)
assert ani_info.ani == 0.9988019200011651
assert round(ani_info.ani, 3) == round(0.992530907924384, 3)


def test_gather_ani_csv(runtmp, linear_gather, prefetch_gather):
Expand Down