Skip to content

Commit

Permalink
[MRG] Add ani utils to signature, sketchcomparison (#1966)
Browse files Browse the repository at this point in the history
* add sig ani functions

* add ani utils to sketchcomparison

* fix jaccard_ani options
  • Loading branch information
bluegenes committed Apr 20, 2022
1 parent 53a8fce commit 2d1fc69
Show file tree
Hide file tree
Showing 6 changed files with 355 additions and 9 deletions.
11 changes: 6 additions & 5 deletions src/sourmash/distance_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,14 @@ class ANIResult:
p_threshold: float = 1e-3
p_exceeds_threshold: bool = field(init=False)

def __post_init__(self):
def check_dist_and_p_threshold(self):
# check values
self.dist = check_distance(self.dist)
self.p_nothing_in_common, self.p_exceeds_threshold = check_prob_threshold(self.p_nothing_in_common, self.p_threshold)

def __post_init__(self):
self.check_dist_and_p_threshold()

@property
def ani(self):
return 1 - self.dist
Expand All @@ -62,8 +65,7 @@ class jaccardANIResult(ANIResult):

def __post_init__(self):
# check values
self.dist = check_distance(self.dist)
self.p_nothing_in_common, self.p_exceeds_threshold = check_prob_threshold(self.p_nothing_in_common, self.p_threshold)
self.check_dist_and_p_threshold()
# check jaccard error
if self.jaccard_error is not None:
self.jaccard_error, self.je_exceeds_threshold = check_jaccard_error(self.jaccard_error, self.je_threshold)
Expand All @@ -83,8 +85,7 @@ class ciANIResult(ANIResult):

def __post_init__(self):
# check values
self.dist = check_distance(self.dist)
self.p_nothing_in_common, self.p_exceeds_threshold = check_prob_threshold(self.p_nothing_in_common, self.p_threshold)
self.check_dist_and_p_threshold()

if self.dist_low is not None and self.dist_high is not None:
self.dist_low = check_distance(self.dist_low)
Expand Down
6 changes: 3 additions & 3 deletions src/sourmash/minhash.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@ def jaccard(self, other, downsample=False):
return self._methodcall(lib.kmerminhash_similarity, other._get_objptr(), True, downsample)

def jaccard_ani(self, other, *, downsample=False, jaccard=None, prob_threshold=1e-3, err_threshold=1e-4):
"Calculate Jaccard --> ANI of two MinHash objects."
"Use jaccard to estimate ANI between two MinHash objects."
self_mh = self
other_mh = other
scaled = self.scaled
Expand Down Expand Up @@ -709,7 +709,7 @@ def contained_by(self, other, downsample=False):
return self.count_common(other, downsample) / len(self)

def containment_ani(self, other, *, downsample=False, containment=None, confidence=0.95, estimate_ci = False):
"Estimate ANI from containment with the other MinHash."
"Use containment to estimate ANI between two MinHash objects."
self_mh = self
other_mh = other
scaled = self.scaled
Expand Down Expand Up @@ -740,7 +740,7 @@ def max_containment(self, other, downsample=False):
return self.count_common(other, downsample) / min_denom

def max_containment_ani(self, other, *, downsample=False, max_containment=None, confidence=0.95, estimate_ci=False):
"Estimate ANI from containment with the other MinHash."
"Use max_containment to estimate ANI between two MinHash objects."
self_mh = self
other_mh = other
scaled = self.scaled
Expand Down
18 changes: 18 additions & 0 deletions src/sourmash/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,32 @@ def jaccard(self, other):
return self.minhash.similarity(other.minhash, ignore_abundance=True,
downsample=False)

def jaccard_ani(self, other, *, downsample=False, jaccard=None, prob_threshold=1e-3, err_threshold=1e-4):
"Use jaccard to estimate ANI between two FracMinHash signatures."
return self.minhash.jaccard_ani(other.minhash, downsample=downsample,
jaccard=jaccard, prob_threshold=prob_threshold,
err_threshold=err_threshold)

def contained_by(self, other, downsample=False):
"Compute containment by the other signature. Note: ignores abundance."
return self.minhash.contained_by(other.minhash, downsample)

def containment_ani(self, other, *, downsample=False, containment=None, confidence=0.95, estimate_ci=False):
"Use containment to estimate ANI between two FracMinHash signatures."
return self.minhash.containment_ani(other.minhash, downsample=downsample,
containment=containment, confidence=confidence,
estimate_ci=estimate_ci)

def max_containment(self, other, downsample=False):
"Compute max containment w/other signature. Note: ignores abundance."
return self.minhash.max_containment(other.minhash, downsample)

def max_containment_ani(self, other, *, downsample=False, max_containment=None, confidence=0.95, estimate_ci=False):
"Use max containment to estimate ANI between two FracMinHash signatures."
return self.minhash.max_containment_ani(other.minhash, downsample=downsample,
max_containment=max_containment, confidence=confidence,
estimate_ci=estimate_ci)

def add_sequence(self, sequence, force=False):
self._methodcall(lib.signature_add_sequence, to_bytes(sequence), force)

Expand Down
29 changes: 29 additions & 0 deletions src/sourmash/sketchcomparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ def intersect_mh(self):
def jaccard(self):
return self.mh1_cmp.jaccard(self.mh2_cmp)

@property
def jaccard_ani(self):
return self.mh1_cmp.jaccard_ani(self.mh2_cmp)

@property
def angular_similarity(self):
# Note: this currently throws TypeError if self.ignore_abundance.
Expand All @@ -79,6 +83,8 @@ class FracMinHashComparison(BaseMinHashComparison):
"""Class for standard comparison between two scaled minhashes"""
cmp_scaled: int = None # optionally force scaled value for this comparison
threshold_bp: int = 0
estimate_ani_ci: bool = False
ani_confidence: int = 0.95

def __post_init__(self):
"Initialize ScaledComparison using values from provided FracMinHashes"
Expand All @@ -99,18 +105,41 @@ def intersect_bp(self):
def mh1_containment(self):
return self.mh1_cmp.contained_by(self.mh2_cmp)

@property
def mh1_containment_ani(self):
return self.mh1_cmp.containment_ani(self.mh2_cmp,
confidence=self.ani_confidence,
estimate_ci=self.estimate_ani_ci)

@property
def mh2_containment(self):
return self.mh2_cmp.contained_by(self.mh1_cmp)

@property
def mh2_containment_ani(self):
return self.mh2_cmp.containment_ani(self.mh1_cmp,
confidence=self.ani_confidence,
estimate_ci=self.estimate_ani_ci)

@property
def max_containment(self):
return self.mh1_cmp.max_containment(self.mh2_cmp)

@property
def max_containment_ani(self):
return self.mh1_cmp.max_containment_ani(self.mh2_cmp,
confidence=self.ani_confidence,
estimate_ci=self.estimate_ani_ci)

@property
def avg_containment(self):
return np.mean([self.mh1_containment, self.mh2_containment])

@property
def avg_containment_ani(self):
"Returns single average_containment_ani value."
return np.mean([self.mh1_containment_ani.ani, self.mh2_containment_ani.ani])

def weighted_intersection(self, from_mh=None, from_abundD={}):
# map abundances to all intersection hashes.
abund_mh = self.intersect_mh.copy_and_clear()
Expand Down
121 changes: 121 additions & 0 deletions tests/test_signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,3 +426,124 @@ def test_max_containment_equal():
assert ss2.contained_by(ss1) == 1
assert ss1.max_containment(ss2) == 1
assert ss2.max_containment(ss1) == 1


def test_containment_ANI():
f1 = utils.get_test_data('2.fa.sig')
f2 = utils.get_test_data('2+63.fa.sig')
ss1 = sourmash.load_one_signature(f1, ksize=31)
ss2 = sourmash.load_one_signature(f2, ksize=31)

s1_cont_s2 = ss1.containment_ani(ss2, estimate_ci =True)
s2_cont_s1 = ss2.containment_ani(ss1, estimate_ci =True)
print("\nss1 contained by ss2", s1_cont_s2)
print("ss2 contained by ss1", s2_cont_s1)

assert (s1_cont_s2.ani, s1_cont_s2.ani_low, s1_cont_s2.ani_high, s1_cont_s2.p_nothing_in_common) == (1.0, None, None, 0.0)
assert (round(s2_cont_s1.ani,3), round(s2_cont_s1.ani_low,3), round(s2_cont_s1.ani_high,3)) == (0.966, 0.965, 0.967)

s1_mc_s2 = ss1.max_containment_ani(ss2, estimate_ci =True)
s2_mc_s1 = ss2.max_containment_ani(ss1, estimate_ci =True)
print("ss1 max containment", s1_mc_s2)
print("ss2 max containment", s2_mc_s1)
assert s1_mc_s2 == s2_mc_s1
assert (s1_mc_s2.ani, s1_mc_s2.ani_low, s1_mc_s2.ani_high) == (1.0,None,None)


def test_containment_ANI_precalc_containment():
f1 = utils.get_test_data('2.fa.sig')
f2 = utils.get_test_data('2+63.fa.sig')
ss1 = sourmash.load_one_signature(f1, ksize=31)
ss2 = sourmash.load_one_signature(f2, ksize=31)
# precalc containments and assert same results
s1c = ss1.contained_by(ss2)
s2c = ss2.contained_by(ss1)
mc = max(s1c, s2c)

assert ss1.containment_ani(ss2, estimate_ci=True) == ss1.containment_ani(ss2, containment=s1c, estimate_ci=True)
assert ss2.containment_ani(ss1) == ss2.containment_ani(ss1, containment=s2c)
assert ss1.max_containment_ani(ss2) == ss1.max_containment_ani(ss2, max_containment=mc)
assert ss1.max_containment_ani(ss2) == ss2.max_containment_ani(ss1, max_containment=mc)


def test_containment_ANI_downsample():
f2 = utils.get_test_data('2+63.fa.sig')
f3 = utils.get_test_data('47+63.fa.sig')
ss2 = sourmash.load_one_signature(f2, ksize=31)
ss3 = sourmash.load_one_signature(f3, ksize=31)
# check that downsampling works properly
print(ss2.minhash.scaled)
ss2.minhash = ss2.minhash.downsample(scaled=2000)
assert ss2.minhash.scaled != ss3.minhash.scaled
ds_s3c = ss2.containment_ani(ss3, downsample=True)
ds_s4c = ss3.containment_ani(ss2, downsample=True)
mc_w_ds_1 = ss2.max_containment_ani(ss3, downsample=True)
mc_w_ds_2 = ss3.max_containment_ani(ss2, downsample=True)

with pytest.raises(ValueError) as e:
ss2.containment_ani(ss3)
assert "ValueError: mismatch in scaled; comparison fail" in e

with pytest.raises(ValueError) as e:
ss2.max_containment_ani(ss3)
assert "ValueError: mismatch in scaled; comparison fail" in e

ss3.minhash = ss3.minhash.downsample(scaled=2000)
assert ss2.minhash.scaled == ss3.minhash.scaled
ds_s3c_manual = ss2.containment_ani(ss3)
ds_s4c_manual = ss3.containment_ani(ss2)
ds_mc_manual = ss2.max_containment_ani(ss3)
assert ds_s3c == ds_s3c_manual
assert ds_s4c == ds_s4c_manual
assert mc_w_ds_1 == mc_w_ds_2 == ds_mc_manual


def test_jaccard_ANI():
f1 = utils.get_test_data('2.fa.sig')
f2 = utils.get_test_data('2+63.fa.sig')
ss1 = sourmash.load_one_signature(f1, ksize=31)
ss2 = sourmash.load_one_signature(f2)

print("\nJACCARD_ANI", ss1.jaccard_ani(ss2))

s1_jani_s2 = ss1.jaccard_ani(ss2)
s2_jani_s1 = ss2.jaccard_ani(ss1)

assert s1_jani_s2 == s2_jani_s1
assert (s1_jani_s2.ani, s1_jani_s2.p_nothing_in_common, s1_jani_s2.jaccard_error) == (0.9783711630110239, 0.0, 3.891666770716877e-07)


def test_jaccard_ANI_precalc_jaccard():
f1 = utils.get_test_data('2.fa.sig')
f2 = utils.get_test_data('2+63.fa.sig')
ss1 = sourmash.load_one_signature(f1, ksize=31)
ss2 = sourmash.load_one_signature(f2)
# precalc jaccard and assert same result
jaccard = ss1.jaccard(ss2)
print("\nJACCARD_ANI", ss1.jaccard_ani(ss2,jaccard=jaccard))

assert ss1.jaccard_ani(ss2) == ss1.jaccard_ani(ss2, jaccard=jaccard) == ss2.jaccard_ani(ss1, jaccard=jaccard)
wrong_jaccard = jaccard - 0.1
assert ss1.jaccard_ani(ss2) != ss1.jaccard_ani(ss2, jaccard=wrong_jaccard)


def test_jaccard_ANI_downsample():
f1 = utils.get_test_data('2.fa.sig')
f2 = utils.get_test_data('2+63.fa.sig')
ss1 = sourmash.load_one_signature(f1, ksize=31)
ss2 = sourmash.load_one_signature(f2)

print(ss1.minhash.scaled)
ss1.minhash = ss1.minhash.downsample(scaled=2000)
assert ss1.minhash.scaled != ss2.minhash.scaled
with pytest.raises(ValueError) as e:
ss1.jaccard_ani(ss2)
assert "ValueError: mismatch in scaled; comparison fail" in e

ds_s1c = ss1.jaccard_ani(ss2, downsample=True)
ds_s2c = ss2.jaccard_ani(ss1, downsample=True)

ss2.minhash = ss2.minhash.downsample(scaled=2000)
assert ss1.minhash.scaled == ss2.minhash.scaled
ds_j_manual = ss1.jaccard_ani(ss2)
assert ds_s1c == ds_s2c == ds_j_manual
Loading

0 comments on commit 2d1fc69

Please sign in to comment.