diff --git a/src/sourmash/distance_utils.py b/src/sourmash/distance_utils.py new file mode 100644 index 0000000000..a6884515d7 --- /dev/null +++ b/src/sourmash/distance_utils.py @@ -0,0 +1,315 @@ +""" +Utilities for jaccard/containment --> distance estimation +Equations from: https://github.com/KoslickiLab/mutation-rate-ci-calculator +Reference: https://doi.org/10.1101/2022.01.11.475870 +""" +from dataclasses import dataclass, field +from scipy.optimize import brentq +from scipy.stats import norm as scipy_norm +from numpy import sqrt +from math import log, exp + +from .logging import notify + +def check_distance(dist): + if not 0 <= dist <= 1: + raise ValueError(f"Error: distance value {dist :.4f} is not between 0 and 1!") + else: + return dist + +def check_prob_threshold(val, threshold=1e-3): + """ + Check likelihood of no shared hashes based on chance alone (false neg). + If too many exceed threshold, recommend user lower their scaled value. + # !! when using this, keep count and recommend user lower scaled val + """ + exceeds_threshold = False + if threshold is not None and val > threshold: + notify("WARNING: These sketches may have no hashes in common based on chance alone.") + exceeds_threshold = True + return val, exceeds_threshold + +def check_jaccard_error(val, threshold=1e-4): + exceeds_threshold = False + if threshold is not None and val > threshold: + notify(f"WARNING: Error on Jaccard distance point estimate is too high ({val :.4f}).") + exceeds_threshold = True + return val, exceeds_threshold + +@dataclass +class ANIResult: + """Base class for distance/ANI from k-mer containment.""" + dist: float + p_nothing_in_common: float + p_threshold: float = 1e-3 + p_exceeds_threshold: bool = field(init=False) + + def __post_init__(self): + # check values + self.dist = check_distance(self.dist) + self.p_nothing_in_common, self.p_exceeds_threshold = check_prob_threshold(self.p_nothing_in_common, self.p_threshold) + + @property + def ani(self): + return 1 - self.dist + + +@dataclass +class jaccardANIResult(ANIResult): + """Class for distance/ANI from jaccard (includes jaccard_error).""" + jaccard_error: float = None + je_threshold: float = 1e-4 + + def __post_init__(self): + # check values + self.dist = check_distance(self.dist) + self.p_nothing_in_common, self.p_exceeds_threshold = check_prob_threshold(self.p_nothing_in_common, self.p_threshold) + # check jaccard error + if self.jaccard_error is not None: + self.jaccard_error, self.je_exceeds_threshold = check_jaccard_error(self.jaccard_error, self.je_threshold) + else: + raise ValueError("Error: jaccard_error cannot be None.") + + +@dataclass +class ciANIResult(ANIResult): + """ + Class for distance/ANI from containment: with confidence intervals. + + Set CI defaults to None, just in case CI can't be estimated for given sample. + """ + dist_low: float = None + dist_high: float = None + + def __post_init__(self): + # check values + self.dist = check_distance(self.dist) + self.p_nothing_in_common, self.p_exceeds_threshold = check_prob_threshold(self.p_nothing_in_common, self.p_threshold) + + if self.dist_low is not None and self.dist_high is not None: + self.dist_low = check_distance(self.dist_low) + self.dist_high = check_distance(self.dist_high) + + @property + def ani_low(self): + if self.dist_high is None: + return None + return 1 - self.dist_high + + @property + def ani_high(self): + if self.dist_low is None: + return None + return 1 - self.dist_low + + +def r1_to_q(k, r1): + r1 = float(r1) + q = 1 - (1 - r1) ** k + return float(q) + + +def var_n_mutated(L, k, r1, *, q=None): + # there are computational issues in the variance formula that we solve here + # by the use of higher-precision arithmetic; the problem occurs when r is + # very small; for example, with L=10,k=2,r1=1e-6 standard precision + # gives varN<0 which is nonsense; by using the mpf type, we get the correct + # answer which is about 0.000038. + if r1 == 0: + return 0.0 + r1 = float(r1) + if q == None: # we assume that if q is provided, it is correct for r1 + q = r1_to_q(k, r1) + varN = ( + L * (1 - q) * (q * (2 * k + (2 / r1) - 1) - 2 * k) + + k * (k - 1) * (1 - q) ** 2 + + (2 * (1 - q) / (r1**2)) * ((1 + (k - 1) * (1 - q)) * r1 - q) + ) + if varN < 0.0: # this seems to happen only with super tiny test data + raise ValueError("Error: varN <0.0!") + return float(varN) + + +def exp_n_mutated(L, k, r1): + q = r1_to_q(k, r1) + return L * q + + +def exp_n_mutated_squared(L, k, p): + return var_n_mutated(L, k, p) + exp_n_mutated(L, k, p) ** 2 + + +def probit(p): + return scipy_norm.ppf(p) + + +def handle_seqlen_nkmers(ksize, *, sequence_len_bp=None, n_unique_kmers=None): + if n_unique_kmers is not None: + return n_unique_kmers + elif sequence_len_bp is None: + # both are None, raise ValueError + raise ValueError("Error: distance estimation requires input of either 'sequence_len_bp' or 'n_unique_kmers'") + else: + n_unique_kmers = sequence_len_bp - (ksize - 1) + return n_unique_kmers + + +def get_expected_log_probability(n_unique_kmers, ksize, mutation_rate, scaled_fraction): + """helper function + Note that scaled here needs to be between 0 and 1 + (e.g. scaled 1000 --> scaled_fraction 0.001) + """ + exp_nmut = exp_n_mutated(n_unique_kmers, ksize, mutation_rate) + try: + return (n_unique_kmers - exp_nmut) * log(1.0 - scaled_fraction) + except: + return float("-inf") + + +def get_exp_probability_nothing_common( + mutation_rate, ksize, scaled, *, n_unique_kmers=None, sequence_len_bp=None +): + """ + Given parameters, calculate the expected probability that nothing will be common + between a fracminhash sketch of a original sequence and a fracminhash sketch of a mutated + sequence. If this is above a threshold, we should suspect that the two sketches may have + nothing in common. The threshold needs to be set with proper insights. + + Arguments: n_unique_kmers, ksize, mutation_rate, scaled + Returns: float - expected likelihood that nothing is common between sketches + """ + n_unique_kmers = handle_seqlen_nkmers(ksize, sequence_len_bp=sequence_len_bp,n_unique_kmers=n_unique_kmers) + f_scaled = 1.0 / float(scaled) + if mutation_rate == 1.0: + return 1.0 + elif mutation_rate == 0.0: + return 0.0 + return exp( + get_expected_log_probability(n_unique_kmers, ksize, mutation_rate, f_scaled) + ) + + +def containment_to_distance( + containment, + ksize, + scaled, + *, + n_unique_kmers=None, + sequence_len_bp=None, + confidence=0.95, + estimate_ci=False, + prob_threshold=1e-3, +): + """ + Containment --> distance CI (one step) + """ + sol1, sol2, point_estimate = None, None, None + n_unique_kmers = handle_seqlen_nkmers(ksize, sequence_len_bp = sequence_len_bp, n_unique_kmers=n_unique_kmers) + if containment <= 0.0001: + point_estimate = 1.0 + elif containment >= 0.9999: + point_estimate = 0.0 + else: + point_estimate = 1.0 - containment ** (1.0 / ksize) + if estimate_ci: + try: + alpha = 1 - confidence + z_alpha = probit(1 - alpha / 2) + f_scaled = ( + 1.0 / scaled + ) # these use scaled as a fraction between 0 and 1 + + bias_factor = 1 - (1 - f_scaled) ** n_unique_kmers + + term_1 = (1.0 - f_scaled) / ( + f_scaled * n_unique_kmers**3 * bias_factor**2 + ) + term_2 = lambda pest: n_unique_kmers * exp_n_mutated( + n_unique_kmers, ksize, pest + ) - exp_n_mutated_squared(n_unique_kmers, ksize, pest) + term_3 = lambda pest: var_n_mutated(n_unique_kmers, ksize, pest) / ( + n_unique_kmers**2 + ) + + var_direct = lambda pest: term_1 * term_2(pest) + term_3(pest) + + f1 = ( + lambda pest: (1 - pest) ** ksize + + z_alpha * sqrt(var_direct(pest)) + - containment + ) + f2 = ( + lambda pest: (1 - pest) ** ksize + - z_alpha * sqrt(var_direct(pest)) + - containment + ) + + sol1 = brentq(f1, 0.0000001, 0.9999999) + sol2 = brentq(f2, 0.0000001, 0.9999999) + + except ValueError as exc: + # afaict, this only happens with extremely small test data + notify( + "WARNING: Cannot estimate ANI confidence intervals from containment. Do your sketches contain enough hashes?" + ) + notify(str(exc)) + sol1 = sol2 = None + + # Do this here, so that we don't need to reconvert distance <--> identity later. + prob_nothing_in_common = get_exp_probability_nothing_common( + point_estimate, ksize, scaled, n_unique_kmers=n_unique_kmers + ) + return ciANIResult(point_estimate, prob_nothing_in_common, dist_low=sol2, dist_high=sol1, p_threshold=prob_threshold) + + +def jaccard_to_distance( + jaccard, + ksize, + scaled, + *, + n_unique_kmers=None, + sequence_len_bp=None, + prob_threshold=1e-3, + err_threshold=1e-4, +): + """ + Given parameters, calculate point estimate for mutation rate from jaccard index. + Uses formulas derived mathematically to compute the point estimate. The formula uses + approximations, therefore a tiny error is associated with it. A lower bound of that error + is also returned. A high error indicates that the point estimate cannot be trusted. + Threshold of the error is open to interpretation, but suggested that > 10^-4 should be + handled with caution. + + Note that the error is NOT a mutation rate, and therefore cannot be considered in + something like mut.rate +/- error. + + Arguments: jaccard, ksize, scaled, n_unique_kmers + # Returns: tuple (point_estimate_of_mutation_rate, lower_bound_of_error) + + # Returns: point_estimate_of_mutation_rate + + Note: point estimate does not consider impact of scaled, but p_nothing_in_common can be + useful for determining whether scaled is sufficient for these comparisons. + """ + error_lower_bound = None + n_unique_kmers = handle_seqlen_nkmers(ksize, sequence_len_bp=sequence_len_bp, n_unique_kmers=n_unique_kmers) + if jaccard <= 0.0001: + point_estimate = 1.0 + error_lower_bound = 0.0 + elif jaccard >= 0.9999: + point_estimate = 0.0 + error_lower_bound = 0.0 + else: + point_estimate = 1.0 - (2.0 * jaccard / float(1 + jaccard)) ** ( + 1.0 / float(ksize) + ) + + exp_n_mut = exp_n_mutated(n_unique_kmers, ksize, point_estimate) + var_n_mut = var_n_mutated(n_unique_kmers, ksize, point_estimate) + error_lower_bound = ( + 1.0 * n_unique_kmers * var_n_mut / (n_unique_kmers + exp_n_mut) ** 3 + ) + prob_nothing_in_common = get_exp_probability_nothing_common( + point_estimate, ksize, scaled, n_unique_kmers=n_unique_kmers + ) + return jaccardANIResult(point_estimate, prob_nothing_in_common, jaccard_error=error_lower_bound, p_threshold=prob_threshold, je_threshold=err_threshold) diff --git a/src/sourmash/minhash.py b/src/sourmash/minhash.py index 87819fac00..d2323ae8ab 100644 --- a/src/sourmash/minhash.py +++ b/src/sourmash/minhash.py @@ -6,6 +6,7 @@ class MinHash - core MinHash class. class FrozenMinHash - read-only MinHash class. """ from __future__ import unicode_literals, division +from .distance_utils import jaccard_to_distance, containment_to_distance __all__ = ['get_minhash_default_seed', 'get_minhash_max_hash', @@ -646,6 +647,25 @@ def jaccard(self, other, downsample=False): raise TypeError(err) return self._methodcall(lib.kmerminhash_similarity, other._get_objptr(), True, downsample) + def jaccard_ani(self, other, *, downsample=False, jaccard=None, prob_threshold=1e-3, err_threshold=1e-4): + "Calculate Jaccard --> ANI of two MinHash objects." + self_mh = self + other_mh = other + scaled = self.scaled + if downsample: + scaled = max(self_mh.scaled, other_mh.scaled) + self_mh = self.downsample(scaled=scaled) + other_mh = other.downsample(scaled=scaled) + if jaccard is None: + jaccard = self_mh.similarity(other_mh, ignore_abundance=True) + avg_scaled_kmers = round((len(self_mh) + len(other_mh))/2) + avg_n_kmers = avg_scaled_kmers * scaled # would be better if hll estimate - see #1798 + j_aniresult = jaccard_to_distance(jaccard, self_mh.ksize, scaled, + n_unique_kmers=avg_n_kmers, + prob_threshold = prob_threshold, + err_threshold = err_threshold) + return j_aniresult + def similarity(self, other, ignore_abundance=False, downsample=False): """Calculate similarity of two sketches. @@ -683,6 +703,25 @@ def contained_by(self, other, downsample=False): return self.count_common(other, downsample) / len(self) + def containment_ani(self, other, *, downsample=False, containment=None, confidence=0.95, estimate_ci = False): + "Estimate ANI from containment with the other MinHash." + self_mh = self + other_mh = other + scaled = self.scaled + if downsample: + scaled = max(self_mh.scaled, other_mh.scaled) + self_mh = self.downsample(scaled=scaled) + other_mh = other.downsample(scaled=scaled) + if containment is None: + containment = self_mh.contained_by(other_mh) + n_kmers = len(self_mh) * scaled # would be better if hll estimate - see #1798 + + c_aniresult = containment_to_distance(containment, self_mh.ksize, self_mh.scaled, + n_unique_kmers=n_kmers, confidence=confidence, + estimate_ci = estimate_ci) + return c_aniresult + + def max_containment(self, other, downsample=False): """ Calculate maximum containment. @@ -695,6 +734,25 @@ def max_containment(self, other, downsample=False): return self.count_common(other, downsample) / min_denom + def max_containment_ani(self, other, *, downsample=False, max_containment=None, confidence=0.95, estimate_ci=False): + "Estimate ANI from containment with the other MinHash." + self_mh = self + other_mh = other + scaled = self.scaled + if downsample: + scaled = max(self_mh.scaled, other_mh.scaled) + self_mh = self.downsample(scaled=scaled) + other_mh = other.downsample(scaled=scaled) + if max_containment is None: + max_containment = self_mh.max_containment(other_mh) + min_n_kmers = min(len(self_mh), len(other_mh)) + n_kmers = min_n_kmers * scaled # would be better if hll estimate - see #1798 + + c_aniresult = containment_to_distance(max_containment, self_mh.ksize, scaled, + n_unique_kmers=n_kmers,confidence=confidence, + estimate_ci = estimate_ci) + return c_aniresult + def __add__(self, other): if not isinstance(other, MinHash): raise TypeError("can only add MinHash objects to MinHash objects!") diff --git a/tests/test_distance_utils.py b/tests/test_distance_utils.py new file mode 100644 index 0000000000..ebc4dd56d8 --- /dev/null +++ b/tests/test_distance_utils.py @@ -0,0 +1,395 @@ +""" +Tests for distance utils. +""" +import pytest +from sourmash.distance_utils import (containment_to_distance, get_exp_probability_nothing_common, + handle_seqlen_nkmers, jaccard_to_distance, + ANIResult, ciANIResult, jaccardANIResult, var_n_mutated) + +def test_aniresult(): + res = ANIResult(0.4, 0.1) + assert res.dist == 0.4 + assert res.ani == 0.6 + assert res.p_nothing_in_common == 0.1 + assert res.p_exceeds_threshold ==True + # check that they're equivalent + res2 = ANIResult(0.4, 0.1) + assert res == res2 + res3 = ANIResult(0.5, 0) + assert res != res3 + assert res3.p_exceeds_threshold ==False + +def test_aniresult_bad_distance(): + """ + Fail if distance is not between 0 and 1. + """ + with pytest.raises(Exception) as exc: + ANIResult(1.1, 0.1) + print("\n", str(exc.value)) + assert "distance value 1.1000 is not between 0 and 1!" in str(exc.value) + with pytest.raises(Exception) as exc: + ANIResult(-0.1, 0.1) + print("\n", str(exc.value)) + assert "distance value -0.1000 is not between 0 and 1!" in str(exc.value) + + +def test_jaccard_aniresult(): + res = jaccardANIResult(0.4, 0.1, jaccard_error=0.03) + assert res.dist == 0.4 + assert res.ani == 0.6 + assert res.p_nothing_in_common == 0.1 + assert res.jaccard_error == 0.03 + assert res.p_exceeds_threshold ==True + assert res.je_exceeds_threshold ==True + res2 = jaccardANIResult(0.4, 0.1, jaccard_error=0.03, je_threshold=0.1) + assert res2.je_exceeds_threshold ==False + + +def test_jaccard_aniresult_nojaccarderror(): + #jaccard error is None + with pytest.raises(Exception) as exc: + jaccardANIResult(0.4, 0.1, None) + print("\n", str(exc.value)) + assert "Error: jaccard_error cannot be None." in str(exc.value) + + +def test_ci_aniresult(): + res = ciANIResult(0.4, 0.1, dist_low=0.3,dist_high=0.5) + print(res) + assert res.dist == 0.4 + assert res.ani == 0.6 + assert res.p_nothing_in_common == 0.1 + assert res.ani_low == 0.5 + assert res.ani_high == 0.7 + res2 = ciANIResult(0.4, 0.1, dist_low=0.3,dist_high=0.5) + assert res == res2 + res3 = ciANIResult(0.4, 0.2, dist_low=0.3, dist_high=0.5) + assert res != res3 + + +def test_containment_to_distance_zero(): + contain = 0 + scaled = 1 + nkmers = 10000 + ksize=21 + res = containment_to_distance(contain,ksize,scaled, n_unique_kmers=nkmers, estimate_ci=True) + print(res) + # check results + exp_dist,exp_low,exp_high,pnc = 1.0,None,None,1.0 + exp_id, exp_idlow,exp_idhigh,pnc = 0.0,None,None,1.0 + assert res.dist == exp_dist + assert res.dist_low == exp_low + assert res.dist_high == exp_high + assert res.p_nothing_in_common == pnc + assert res.ani == exp_id + assert res.ani_low == exp_idlow + assert res.ani_high == exp_idhigh + # check without returning ci + res2 = containment_to_distance(contain,ksize,scaled,n_unique_kmers=nkmers) + print(res2) + exp_res = ciANIResult(dist=1.0, p_nothing_in_common=1.0, p_threshold=0.001) + assert res2 == exp_res + + +def test_containment_to_distance_one(): + contain = 1 + scaled = 1 + nkmers = 10000 + ksize=21 + res = containment_to_distance(contain,ksize,scaled,n_unique_kmers=nkmers,estimate_ci=True) + print(res) + exp_dist, exp_low,exp_high,pnc = 0.0,None,None,0.0 + exp_id, exp_idlow,exp_idhigh,pnc = 1.0,None,None,0.0 + assert res.dist == exp_dist + assert res.dist_low == exp_low + assert res.dist_high == exp_high + assert res.p_nothing_in_common == pnc + assert res.ani == exp_id + assert res.ani_low == exp_idlow + assert res.ani_high == exp_idhigh + + # check without returning ci + res = containment_to_distance(contain,ksize,scaled,n_unique_kmers=nkmers) + assert res.dist == exp_dist + assert res.ani == exp_id + assert res.p_nothing_in_common == pnc + assert res.ani_low == None + assert res.ani_high == None + + +def test_containment_to_distance_scaled1(): + contain = 0.5 + scaled = 1 + nkmers = 10000 + ksize=21 + res = containment_to_distance(contain,ksize,scaled,n_unique_kmers=nkmers,estimate_ci=True) + print(res) + # check results + assert res.dist == 0.032468221476108394 + assert res.ani == 0.9675317785238916 + assert res.dist_low == 0.028709912966405623 + assert res.ani_high == 0.9712900870335944 + assert res.dist_high == 0.03647860197289783 + assert res.ani_low == 0.9635213980271021 + assert res.p_nothing_in_common == 0.0 + # without returning ci + res2 = containment_to_distance(contain,ksize,scaled,n_unique_kmers=nkmers) + assert (res2.dist,res2.ani,res2.p_nothing_in_common) == (0.032468221476108394, 0.9675317785238916, 0.0) + assert (res2.dist,res2.ani,res2.p_nothing_in_common) == (res.dist, res.ani, res.p_nothing_in_common) + + +def test_containment_to_distance_scaled100(): + contain = 0.1 + scaled = 100 + nkmers = 10000 + ksize=31 + res = containment_to_distance(contain,ksize,scaled,n_unique_kmers=nkmers,estimate_ci=True) + print(res) + # check results + assert res.dist == 0.07158545548052564 + assert res.dist_low == 0.05320779238601372 + assert res.dist_high == 0.09055547672455365 + assert res.p_nothing_in_common == 4.3171247410658655e-05 + assert res.p_exceeds_threshold == False + + +def test_containment_to_distance_scaled100_2(): + contain = 0.5 + scaled = 100 + nkmers = 10000 + ksize=21 + res= containment_to_distance(contain,ksize,scaled,n_unique_kmers=nkmers,estimate_ci=True) + print(res) + # check results + assert res.dist == 0.032468221476108394 + assert res.dist_low == 0.023712063916639017 + assert res.dist_high == 0.04309960543965866 + assert res.p_exceeds_threshold == False + + +def test_containment_to_distance_k10(): + contain = 0.5 + scaled = 100 + nkmers = 10000 + ksize=10 + res = containment_to_distance(contain,ksize,scaled,n_unique_kmers=nkmers,estimate_ci=True) + print(res) + # check results + assert res.dist == 0.06696700846319259 + assert res.dist_low == 0.04982777541057476 + assert res.dist_high == 0.08745108232411622 + assert res.p_exceeds_threshold == False + + +def test_containment_to_distance_confidence(): + contain = 0.1 + scaled = 100 + nkmers = 10000 + ksize=31 + confidence=0.99 + res = containment_to_distance(contain,ksize,scaled,confidence=confidence,n_unique_kmers=nkmers, estimate_ci=True) + print(res) + # check results + assert res.dist == 0.07158545548052564 + assert res.dist_low == 0.04802880300938562 + assert res.dist_high == 0.09619930040790341 + assert res.p_exceeds_threshold == False + confidence=0.90 + res2 = containment_to_distance(contain,ksize,scaled,n_unique_kmers=nkmers,confidence=confidence, estimate_ci=True) + print(res2) + # check results + assert res2.dist == res.dist + assert res2.dist_low == 0.05599435479247415 + assert res2.dist_high == 0.08758718871990222 + assert res.p_exceeds_threshold == False + + +def test_nkmers_to_bp_containment(): + containment = 0.1 + scaled = 100 + bp_len = 10030 + ksize=31 + nkmers = handle_seqlen_nkmers(ksize, sequence_len_bp= bp_len) + print("nkmers_from_bp:", nkmers) + confidence=0.99 + kmer_res = containment_to_distance(containment,ksize,scaled,confidence=confidence,n_unique_kmers=nkmers,estimate_ci=True) + bp_res = containment_to_distance(containment,ksize,scaled,confidence=confidence,sequence_len_bp=bp_len,estimate_ci=True) + print(f"\nkDIST: {kmer_res}") + print(f"\nbpDIST:,{bp_res}") + # check results + assert kmer_res==bp_res + assert kmer_res.dist == 0.07158545548052564 + assert kmer_res.dist_low == 0.04802880300938562 + assert kmer_res.dist_high == 0.09619930040790341 + + +def test_jaccard_to_distance_zero(): + jaccard = 0 + scaled = 1 + nkmers = 10000 + ksize=21 + res= jaccard_to_distance(jaccard,ksize,scaled,n_unique_kmers=nkmers) + print(res) + # check results + assert res.dist == 1.0 + assert res.ani == 0.0 + assert res.p_nothing_in_common == 1.0 + assert res.jaccard_error == 0.0 + + +def test_jaccard_to_distance_one(): + jaccard = 1 + scaled = 1 + nkmers = 10000 + ksize=21 + res= jaccard_to_distance(jaccard,ksize,scaled,n_unique_kmers=nkmers) + print(res) + # check results + assert res.dist == 0.0 + assert res.ani == 1.0 + assert res.p_nothing_in_common == 0.0 + assert res.jaccard_error == 0.0 + + +def test_jaccard_to_distance_scaled(): + # scaled value doesn't impact point estimate or jaccard error, just p_nothing_in_common + jaccard = 0.5 + scaled = 1 + nkmers = 10000 + ksize=21 + res = jaccard_to_distance(jaccard,ksize,scaled,n_unique_kmers=nkmers) + print(res) + # check results + assert res.dist == 0.019122659390482077 + assert res.ani == 0.9808773406095179 + assert res.p_exceeds_threshold == False + assert res.jaccard_error == 0.00018351337045518042 + assert res.je_exceeds_threshold ==True + scaled = 100 + res2 = jaccard_to_distance(jaccard,ksize,scaled,n_unique_kmers=nkmers) + print(res2) + assert res2.dist == res.dist + assert res2.jaccard_error == res.jaccard_error + assert res2.p_nothing_in_common != res.p_nothing_in_common + assert res2.p_exceeds_threshold ==False + + +def test_jaccard_to_distance_k31(): + jaccard = 0.5 + scaled = 100 + nkmers = 10000 + ksize=31 + res = jaccard_to_distance(jaccard,ksize,scaled,n_unique_kmers=nkmers) + print(res) + # check results + assert res.ani == 0.9870056455892898 + assert res.p_exceeds_threshold == False + assert res.je_exceeds_threshold ==True + res2 = jaccard_to_distance(jaccard,ksize,scaled,n_unique_kmers=nkmers, err_threshold=0.1) + assert res2.ani == res.ani + assert res2.je_exceeds_threshold == False + + +def test_jaccard_to_distance_k31_2(): + jaccard = 0.1 + scaled = 100 + nkmers = 10000 + ksize=31 + res = jaccard_to_distance(jaccard,ksize,scaled,n_unique_kmers=nkmers) + print(res) + # check results + assert res.ani == 0.9464928391768298 + assert res.p_exceeds_threshold == False + assert res.je_exceeds_threshold == False + + +def test_nkmers_to_bp_jaccard(): + jaccard = 0.1 + scaled = 100 + bp_len = 10030 + ksize=31 + nkmers = handle_seqlen_nkmers(ksize, sequence_len_bp= bp_len) + print("nkmers_from_bp:", nkmers) + kmer_res = jaccard_to_distance(jaccard,ksize,scaled,n_unique_kmers=nkmers) + bp_res = jaccard_to_distance(jaccard,ksize,scaled,sequence_len_bp=bp_len) + print(f"\nkmer_res: {kmer_res}") + print(f"\nbp_res: {bp_res}") + # check results + assert kmer_res == bp_res + assert kmer_res.dist == 0.0535071608231702 + assert kmer_res.p_exceeds_threshold == False + assert kmer_res.je_exceeds_threshold == False + + +def test_exp_prob_nothing_common(): + dist = 0.25 + ksize = 31 + scaled = 10 + bp_len = 1000030 + nkmers = handle_seqlen_nkmers(ksize, sequence_len_bp= bp_len) + print("nkmers_from_bp:", nkmers) + + nkmers_pnc = get_exp_probability_nothing_common(dist,ksize,scaled,n_unique_kmers=nkmers) + print(f"prob nothing in common: {nkmers_pnc}") + bp_pnc = get_exp_probability_nothing_common(dist,ksize,scaled,sequence_len_bp=bp_len) + assert nkmers_pnc == bp_pnc == 7.437016945722123e-07 + + +def test_containment_to_distance_tinytestdata_var0(): + """ + tiny test data to trigger the following: + WARNING: Cannot estimate ANI confidence intervals from containment. Do your sketches contain enough hashes? + Error: varN <0.0! + """ + contain = 0.9 + scaled = 1 + nkmers = 4 + ksize=31 + res = containment_to_distance(contain,ksize,scaled,n_unique_kmers=nkmers, estimate_ci=True) + print(res) + # check results + assert res.dist == 0.003392957179023992 + assert res.dist_low == None + assert res.dist_high == None + assert res.ani_low == None + assert res.ani_high == None + assert res.p_exceeds_threshold == False + + +def test_var_n_mutated(): + # check 0 + r = 0 + ksize = 31 + nkmers = 200 + var_n_mut = var_n_mutated(nkmers,ksize,r) + print(f"var_n_mutated: {var_n_mut}") + assert var_n_mut == 0 + # check var 0.0 valuerror + r = 10 + ksize = 31 + nkmers = 200 + with pytest.raises(ValueError) as exc: + var_n_mut = var_n_mutated(nkmers,ksize,r) + assert "Error: varN <0.0!" in str(exc) + # check successful + r = 0.4 + ksize = 31 + nkmers = 200000 + var_n_mut = var_n_mutated(nkmers,ksize,r) + print(f"var_n_mutated: {var_n_mut}") + assert var_n_mut == 0.10611425440741508 + + +def test_handle_seqlen_nkmers(): + bp_len = 10030 + ksize=31 + # convert seqlen to nkmers + nkmers = handle_seqlen_nkmers(ksize, sequence_len_bp= bp_len) + assert nkmers == 10000 + # if nkmers is provided, just use that + nkmers = handle_seqlen_nkmers(ksize, sequence_len_bp= bp_len, n_unique_kmers= bp_len) + assert nkmers == 10030 + # if neither seqlen or nkmers provided, complain + with pytest.raises(ValueError) as exc: + nkmers = handle_seqlen_nkmers(ksize) + assert("Error: distance estimation requires input of either 'sequence_len_bp' or 'n_unique_kmers'") in str(exc) diff --git a/tests/test_minhash.py b/tests/test_minhash.py index be9adf171b..4d9124c501 100644 --- a/tests/test_minhash.py +++ b/tests/test_minhash.py @@ -2650,3 +2650,142 @@ def test_containment(track_abundance): assert mh1.contained_by(mh2) == 1/4 assert mh2.contained_by(mh1) == 1/2 + + +def test_containment_ANI(): + f1 = utils.get_test_data('2.fa.sig') + f2 = utils.get_test_data('2+63.fa.sig') + mh1 = sourmash.load_one_signature(f1, ksize=31).minhash + mh2 = sourmash.load_one_signature(f2, ksize=31).minhash + + m1_cont_m2 = mh1.containment_ani(mh2, estimate_ci =True) + m2_cont_m1 = mh2.containment_ani(mh1, estimate_ci =True) + print("\nmh1 contained by mh2", m1_cont_m2) + print("mh2 contained by mh1", m2_cont_m1) + + assert (m1_cont_m2.ani, m1_cont_m2.ani_low, m1_cont_m2.ani_high, m1_cont_m2.p_nothing_in_common) == (1.0, None, None, 0.0) + assert (round(m2_cont_m1.ani,3), round(m2_cont_m1.ani_low,3), round(m2_cont_m1.ani_high,3)) == (0.966, 0.965, 0.967) + + m1_mc_m2 = mh1.max_containment_ani(mh2, estimate_ci =True) + m2_mc_m1 = mh2.max_containment_ani(mh1, estimate_ci =True) + print("mh1 max containment", m1_mc_m2) + print("mh2 max containment", m2_mc_m1) + assert m1_mc_m2 == m2_mc_m1 + assert (m1_mc_m2.ani, m1_mc_m2.ani_low, m1_mc_m2.ani_high) == (1.0,None,None) + + +def test_containment_ANI_precalc_containment(): + f1 = utils.get_test_data('2.fa.sig') + f2 = utils.get_test_data('2+63.fa.sig') + mh1 = sourmash.load_one_signature(f1, ksize=31).minhash + mh2 = sourmash.load_one_signature(f2, ksize=31).minhash + # precalc containments and assert same results + s1c = mh1.contained_by(mh2) + s2c = mh2.contained_by(mh1) + mc = max(s1c, s2c) + + assert mh1.containment_ani(mh2, estimate_ci=True) == mh1.containment_ani(mh2, containment=s1c, estimate_ci=True) + assert mh2.containment_ani(mh1) == mh2.containment_ani(mh1, containment=s2c) + assert mh1.max_containment_ani(mh2) == mh1.max_containment_ani(mh2, max_containment=mc) + assert mh1.max_containment_ani(mh2) == mh2.max_containment_ani(mh1, max_containment=mc) + + +def test_containment_ANI_downsample(): + f2 = utils.get_test_data('2+63.fa.sig') + f3 = utils.get_test_data('47+63.fa.sig') + mh2 = sourmash.load_one_signature(f2, ksize=31).minhash + mh3 = sourmash.load_one_signature(f3, ksize=31).minhash + # check that downsampling works properly + print(mh2.scaled) + mh2 = mh2.downsample(scaled=2000) + assert mh2.scaled != mh3.scaled + ds_s3c = mh2.containment_ani(mh3, downsample=True) + ds_s4c = mh3.containment_ani(mh2, downsample=True) + mc_w_ds_1 = mh2.max_containment_ani(mh3, downsample=True) + mc_w_ds_2 = mh3.max_containment_ani(mh2, downsample=True) + + with pytest.raises(ValueError) as e: + mh2.containment_ani(mh3) + assert "ValueError: mismatch in scaled; comparison fail" in e + + with pytest.raises(ValueError) as e: + mh2.max_containment_ani(mh3) + assert "ValueError: mismatch in scaled; comparison fail" in e + + mh3 = mh3.downsample(scaled=2000) + assert mh2.scaled == mh3.scaled + ds_s3c_manual = mh2.containment_ani(mh3) + ds_s4c_manual = mh3.containment_ani(mh2) + ds_mc_manual = mh2.max_containment_ani(mh3) + assert ds_s3c == ds_s3c_manual + assert ds_s4c == ds_s4c_manual + assert mc_w_ds_1 == mc_w_ds_2 == ds_mc_manual + + +def test_jaccard_ANI(): + f1 = utils.get_test_data('2.fa.sig') + f2 = utils.get_test_data('2+63.fa.sig') + mh1 = sourmash.load_one_signature(f1, ksize=31).minhash + mh2 = sourmash.load_one_signature(f2).minhash + + print("\nJACCARD_ANI", mh1.jaccard_ani(mh2)) + + m1_jani_m2 = mh1.jaccard_ani(mh2) + m2_jani_m1 = mh2.jaccard_ani(mh1) + + assert m1_jani_m2 == m2_jani_m1 + assert (m1_jani_m2.ani, m1_jani_m2.p_nothing_in_common, m1_jani_m2.jaccard_error) == (0.9783711630110239, 0.0, 3.891666770716877e-07) + + +def test_jaccard_ANI_precalc_jaccard(): + f1 = utils.get_test_data('2.fa.sig') + f2 = utils.get_test_data('2+63.fa.sig') + mh1 = sourmash.load_one_signature(f1, ksize=31).minhash + mh2 = sourmash.load_one_signature(f2).minhash + # precalc jaccard and assert same result + jaccard = mh1.jaccard(mh2) + print("\nJACCARD_ANI", mh1.jaccard_ani(mh2,jaccard=jaccard)) + + assert mh1.jaccard_ani(mh2) == mh1.jaccard_ani(mh2, jaccard=jaccard) == mh2.jaccard_ani(mh1, jaccard=jaccard) + wrong_jaccard = jaccard - 0.1 + assert mh1.jaccard_ani(mh2) != mh1.jaccard_ani(mh2, jaccard=wrong_jaccard) + + +def test_jaccard_ANI_downsample(): + f1 = utils.get_test_data('2.fa.sig') + f2 = utils.get_test_data('2+63.fa.sig') + mh1 = sourmash.load_one_signature(f1, ksize=31).minhash + mh2 = sourmash.load_one_signature(f2).minhash + + print(mh1.scaled) + mh1 = mh1.downsample(scaled=2000) + assert mh1.scaled != mh2.scaled + with pytest.raises(ValueError) as e: + mh1.jaccard_ani(mh2) + assert "ValueError: mismatch in scaled; comparison fail" in e + + ds_s1c = mh1.jaccard_ani(mh2, downsample=True) + ds_s2c = mh2.jaccard_ani(mh1, downsample=True) + + mh2 = mh2.downsample(scaled=2000) + assert mh1.scaled == mh2.scaled + ds_j_manual = mh1.jaccard_ani(mh2) + assert ds_s1c == ds_s2c == ds_j_manual + +def test_containment_ani_ci_tiny_testdata(): + """ + tiny test data to trigger the following: + WARNING: Cannot estimate ANI confidence intervals from containment. Do your sketches contain enough hashes? + Error: varN <0.0! + """ + mh1 = MinHash(0, 21, scaled=1, track_abundance=False) + mh2 = MinHash(0, 21, scaled=1, track_abundance=False) + + mh1.add_many((1, 3, 4)) + mh2.add_many((1, 2, 3, 4)) + + m2_cani_m1 = mh2.containment_ani(mh1, estimate_ci=True) + print(m2_cani_m1) + assert m2_cani_m1.ani == 0.986394259982259 + assert m2_cani_m1.ani_low == None + assert m2_cani_m1.ani_high == None