Skip to content

Commit

Permalink
[MRG] distance utility functions to support ANI estimation (#1934)
Browse files Browse the repository at this point in the history
* split distance utils and start updating tests for function return changes

* upd containment tests

* upd jaccard tests

* add p_nothing_in_common test; cleanup unused jaccard formulas

* pyflakes, formatting

* raise err with bad input to distance_to_identity

* add minhash fns and tests

* upd

* Apply suggestions from code review

Co-authored-by: C. Titus Brown <titus@idyll.org>

* try a dataclass

* better dataclasses; init revamp tests for aniresult classes

* mod all tests to for aniresult dataclass

* init minhash changes for aniresult

* upd tests

* tiny testdata tests for var0.0

* test var_n_mutated directly

* test test_handle_seqlen_nkmers directly

* clarify comment

* dont populate fake CI for dist=0,dist=1 cases

* fix newline

* Update src/sourmash/distance_utils.py

Co-authored-by: C. Titus Brown <titus@idyll.org>

* fix corresponding test

Co-authored-by: C. Titus Brown <titus@idyll.org>
  • Loading branch information
bluegenes and ctb committed Apr 15, 2022
1 parent 0d04cca commit 6f7eb06
Show file tree
Hide file tree
Showing 4 changed files with 907 additions and 0 deletions.
315 changes: 315 additions & 0 deletions src/sourmash/distance_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,315 @@
"""
Utilities for jaccard/containment --> distance estimation
Equations from: https://github.com/KoslickiLab/mutation-rate-ci-calculator
Reference: https://doi.org/10.1101/2022.01.11.475870
"""
from dataclasses import dataclass, field
from scipy.optimize import brentq
from scipy.stats import norm as scipy_norm
from numpy import sqrt
from math import log, exp

from .logging import notify

def check_distance(dist):
if not 0 <= dist <= 1:
raise ValueError(f"Error: distance value {dist :.4f} is not between 0 and 1!")
else:
return dist

def check_prob_threshold(val, threshold=1e-3):
"""
Check likelihood of no shared hashes based on chance alone (false neg).
If too many exceed threshold, recommend user lower their scaled value.
# !! when using this, keep count and recommend user lower scaled val
"""
exceeds_threshold = False
if threshold is not None and val > threshold:
notify("WARNING: These sketches may have no hashes in common based on chance alone.")
exceeds_threshold = True
return val, exceeds_threshold

def check_jaccard_error(val, threshold=1e-4):
exceeds_threshold = False
if threshold is not None and val > threshold:
notify(f"WARNING: Error on Jaccard distance point estimate is too high ({val :.4f}).")
exceeds_threshold = True
return val, exceeds_threshold

@dataclass
class ANIResult:
"""Base class for distance/ANI from k-mer containment."""
dist: float
p_nothing_in_common: float
p_threshold: float = 1e-3
p_exceeds_threshold: bool = field(init=False)

def __post_init__(self):
# check values
self.dist = check_distance(self.dist)
self.p_nothing_in_common, self.p_exceeds_threshold = check_prob_threshold(self.p_nothing_in_common, self.p_threshold)

@property
def ani(self):
return 1 - self.dist


@dataclass
class jaccardANIResult(ANIResult):
"""Class for distance/ANI from jaccard (includes jaccard_error)."""
jaccard_error: float = None
je_threshold: float = 1e-4

def __post_init__(self):
# check values
self.dist = check_distance(self.dist)
self.p_nothing_in_common, self.p_exceeds_threshold = check_prob_threshold(self.p_nothing_in_common, self.p_threshold)
# check jaccard error
if self.jaccard_error is not None:
self.jaccard_error, self.je_exceeds_threshold = check_jaccard_error(self.jaccard_error, self.je_threshold)
else:
raise ValueError("Error: jaccard_error cannot be None.")


@dataclass
class ciANIResult(ANIResult):
"""
Class for distance/ANI from containment: with confidence intervals.
Set CI defaults to None, just in case CI can't be estimated for given sample.
"""
dist_low: float = None
dist_high: float = None

def __post_init__(self):
# check values
self.dist = check_distance(self.dist)
self.p_nothing_in_common, self.p_exceeds_threshold = check_prob_threshold(self.p_nothing_in_common, self.p_threshold)

if self.dist_low is not None and self.dist_high is not None:
self.dist_low = check_distance(self.dist_low)
self.dist_high = check_distance(self.dist_high)

@property
def ani_low(self):
if self.dist_high is None:
return None
return 1 - self.dist_high

@property
def ani_high(self):
if self.dist_low is None:
return None
return 1 - self.dist_low


def r1_to_q(k, r1):
r1 = float(r1)
q = 1 - (1 - r1) ** k
return float(q)


def var_n_mutated(L, k, r1, *, q=None):
# there are computational issues in the variance formula that we solve here
# by the use of higher-precision arithmetic; the problem occurs when r is
# very small; for example, with L=10,k=2,r1=1e-6 standard precision
# gives varN<0 which is nonsense; by using the mpf type, we get the correct
# answer which is about 0.000038.
if r1 == 0:
return 0.0
r1 = float(r1)
if q == None: # we assume that if q is provided, it is correct for r1
q = r1_to_q(k, r1)
varN = (
L * (1 - q) * (q * (2 * k + (2 / r1) - 1) - 2 * k)
+ k * (k - 1) * (1 - q) ** 2
+ (2 * (1 - q) / (r1**2)) * ((1 + (k - 1) * (1 - q)) * r1 - q)
)
if varN < 0.0: # this seems to happen only with super tiny test data
raise ValueError("Error: varN <0.0!")
return float(varN)


def exp_n_mutated(L, k, r1):
q = r1_to_q(k, r1)
return L * q


def exp_n_mutated_squared(L, k, p):
return var_n_mutated(L, k, p) + exp_n_mutated(L, k, p) ** 2


def probit(p):
return scipy_norm.ppf(p)


def handle_seqlen_nkmers(ksize, *, sequence_len_bp=None, n_unique_kmers=None):
if n_unique_kmers is not None:
return n_unique_kmers
elif sequence_len_bp is None:
# both are None, raise ValueError
raise ValueError("Error: distance estimation requires input of either 'sequence_len_bp' or 'n_unique_kmers'")
else:
n_unique_kmers = sequence_len_bp - (ksize - 1)
return n_unique_kmers


def get_expected_log_probability(n_unique_kmers, ksize, mutation_rate, scaled_fraction):
"""helper function
Note that scaled here needs to be between 0 and 1
(e.g. scaled 1000 --> scaled_fraction 0.001)
"""
exp_nmut = exp_n_mutated(n_unique_kmers, ksize, mutation_rate)
try:
return (n_unique_kmers - exp_nmut) * log(1.0 - scaled_fraction)
except:
return float("-inf")


def get_exp_probability_nothing_common(
mutation_rate, ksize, scaled, *, n_unique_kmers=None, sequence_len_bp=None
):
"""
Given parameters, calculate the expected probability that nothing will be common
between a fracminhash sketch of a original sequence and a fracminhash sketch of a mutated
sequence. If this is above a threshold, we should suspect that the two sketches may have
nothing in common. The threshold needs to be set with proper insights.
Arguments: n_unique_kmers, ksize, mutation_rate, scaled
Returns: float - expected likelihood that nothing is common between sketches
"""
n_unique_kmers = handle_seqlen_nkmers(ksize, sequence_len_bp=sequence_len_bp,n_unique_kmers=n_unique_kmers)
f_scaled = 1.0 / float(scaled)
if mutation_rate == 1.0:
return 1.0
elif mutation_rate == 0.0:
return 0.0
return exp(
get_expected_log_probability(n_unique_kmers, ksize, mutation_rate, f_scaled)
)


def containment_to_distance(
containment,
ksize,
scaled,
*,
n_unique_kmers=None,
sequence_len_bp=None,
confidence=0.95,
estimate_ci=False,
prob_threshold=1e-3,
):
"""
Containment --> distance CI (one step)
"""
sol1, sol2, point_estimate = None, None, None
n_unique_kmers = handle_seqlen_nkmers(ksize, sequence_len_bp = sequence_len_bp, n_unique_kmers=n_unique_kmers)
if containment <= 0.0001:
point_estimate = 1.0
elif containment >= 0.9999:
point_estimate = 0.0
else:
point_estimate = 1.0 - containment ** (1.0 / ksize)
if estimate_ci:
try:
alpha = 1 - confidence
z_alpha = probit(1 - alpha / 2)
f_scaled = (
1.0 / scaled
) # these use scaled as a fraction between 0 and 1

bias_factor = 1 - (1 - f_scaled) ** n_unique_kmers

term_1 = (1.0 - f_scaled) / (
f_scaled * n_unique_kmers**3 * bias_factor**2
)
term_2 = lambda pest: n_unique_kmers * exp_n_mutated(
n_unique_kmers, ksize, pest
) - exp_n_mutated_squared(n_unique_kmers, ksize, pest)
term_3 = lambda pest: var_n_mutated(n_unique_kmers, ksize, pest) / (
n_unique_kmers**2
)

var_direct = lambda pest: term_1 * term_2(pest) + term_3(pest)

f1 = (
lambda pest: (1 - pest) ** ksize
+ z_alpha * sqrt(var_direct(pest))
- containment
)
f2 = (
lambda pest: (1 - pest) ** ksize
- z_alpha * sqrt(var_direct(pest))
- containment
)

sol1 = brentq(f1, 0.0000001, 0.9999999)
sol2 = brentq(f2, 0.0000001, 0.9999999)

except ValueError as exc:
# afaict, this only happens with extremely small test data
notify(
"WARNING: Cannot estimate ANI confidence intervals from containment. Do your sketches contain enough hashes?"
)
notify(str(exc))
sol1 = sol2 = None

# Do this here, so that we don't need to reconvert distance <--> identity later.
prob_nothing_in_common = get_exp_probability_nothing_common(
point_estimate, ksize, scaled, n_unique_kmers=n_unique_kmers
)
return ciANIResult(point_estimate, prob_nothing_in_common, dist_low=sol2, dist_high=sol1, p_threshold=prob_threshold)


def jaccard_to_distance(
jaccard,
ksize,
scaled,
*,
n_unique_kmers=None,
sequence_len_bp=None,
prob_threshold=1e-3,
err_threshold=1e-4,
):
"""
Given parameters, calculate point estimate for mutation rate from jaccard index.
Uses formulas derived mathematically to compute the point estimate. The formula uses
approximations, therefore a tiny error is associated with it. A lower bound of that error
is also returned. A high error indicates that the point estimate cannot be trusted.
Threshold of the error is open to interpretation, but suggested that > 10^-4 should be
handled with caution.
Note that the error is NOT a mutation rate, and therefore cannot be considered in
something like mut.rate +/- error.
Arguments: jaccard, ksize, scaled, n_unique_kmers
# Returns: tuple (point_estimate_of_mutation_rate, lower_bound_of_error)
# Returns: point_estimate_of_mutation_rate
Note: point estimate does not consider impact of scaled, but p_nothing_in_common can be
useful for determining whether scaled is sufficient for these comparisons.
"""
error_lower_bound = None
n_unique_kmers = handle_seqlen_nkmers(ksize, sequence_len_bp=sequence_len_bp, n_unique_kmers=n_unique_kmers)
if jaccard <= 0.0001:
point_estimate = 1.0
error_lower_bound = 0.0
elif jaccard >= 0.9999:
point_estimate = 0.0
error_lower_bound = 0.0
else:
point_estimate = 1.0 - (2.0 * jaccard / float(1 + jaccard)) ** (
1.0 / float(ksize)
)

exp_n_mut = exp_n_mutated(n_unique_kmers, ksize, point_estimate)
var_n_mut = var_n_mutated(n_unique_kmers, ksize, point_estimate)
error_lower_bound = (
1.0 * n_unique_kmers * var_n_mut / (n_unique_kmers + exp_n_mut) ** 3
)
prob_nothing_in_common = get_exp_probability_nothing_common(
point_estimate, ksize, scaled, n_unique_kmers=n_unique_kmers
)
return jaccardANIResult(point_estimate, prob_nothing_in_common, jaccard_error=error_lower_bound, p_threshold=prob_threshold, je_threshold=err_threshold)
Loading

0 comments on commit 6f7eb06

Please sign in to comment.