diff --git a/src/sourmash/cli/compare.py b/src/sourmash/cli/compare.py index 0dda9a290a..f1f0c678c4 100644 --- a/src/sourmash/cli/compare.py +++ b/src/sourmash/cli/compare.py @@ -58,6 +58,10 @@ def subparser(subparsers): '--max-containment', action='store_true', help='calculate max containment instead of similarity' ) + subparser.add_argument( + '--estimate-ani', '--estimate-ANI', action='store_true', + help='return ANI estimated from jaccard, containment, or max containment; see https://doi.org/10.1101/2022.01.11.475870' + ) subparser.add_argument( '--from-file', help='a text file containing a list of files to load signatures from' diff --git a/src/sourmash/cli/utils.py b/src/sourmash/cli/utils.py index 1725518747..2063fd09de 100644 --- a/src/sourmash/cli/utils.py +++ b/src/sourmash/cli/utils.py @@ -67,10 +67,14 @@ def range_limited_float_type(arg): return f -def add_tax_threshold_arg(parser, default=0.1): +def add_tax_threshold_arg(parser, containment_default=0.1, ani_default=None): parser.add_argument( - '--containment-threshold', default=default, type=range_limited_float_type, - help=f'minimum containment threshold for classification; default={default}' + '--containment-threshold', default=containment_default, type=range_limited_float_type, + help=f'minimum containment threshold for classification; default={containment_default}', + ) + parser.add_argument( + '--ani-threshold', '--aai-threshold', default=ani_default, type=range_limited_float_type, + help=f'minimum ANI threshold (nucleotide gather) or AAI threshold (protein gather) for classification; default={ani_default}', ) diff --git a/src/sourmash/commands.py b/src/sourmash/commands.py index 2489d13852..08ae99c771 100644 --- a/src/sourmash/commands.py +++ b/src/sourmash/commands.py @@ -16,7 +16,7 @@ from .logging import notify, error, print_results, set_quiet from .sourmash_args import (FileOutput, FileOutputCSV, SaveSignaturesToLocation) -from .search import prefetch_database, PrefetchResult, calculate_prefetch_info +from .search import SearchResult, prefetch_database, PrefetchResult, GatherResult, calculate_prefetch_info from .index import LazyLinearIndex WATERMARK_SIZE = 10000 @@ -110,11 +110,20 @@ def compare(args): error('must use scaled signatures with --containment and --max-containment') sys.exit(-1) + # complain if --ani and not is_scaled + return_ani = False + if args.estimate_ani: + return_ani = True + + if return_ani and not is_scaled: + error('must use scaled signatures with --estimate-ani') + sys.exit(-1) + # notify about implicit --ignore-abundance: - if is_containment: + if is_containment or return_ani: track_abundances = any(( s.minhash.track_abundance for s in siglist )) if track_abundances: - notify('NOTE: --containment and --max-containment ignore signature abundances.') + notify('NOTE: --containment, --max-containment, and --estimate-ani ignore signature abundances.') # if using --scaled, downsample appropriately printed_scaled_msg = False @@ -140,12 +149,12 @@ def compare(args): labeltext = [str(item) for item in siglist] if args.containment: - similarity = compare_serial_containment(siglist) + similarity = compare_serial_containment(siglist, return_ani=return_ani) elif args.max_containment: - similarity = compare_serial_max_containment(siglist) + similarity = compare_serial_max_containment(siglist, return_ani=return_ani) else: similarity = compare_all_pairs(siglist, args.ignore_abundance, - n_jobs=args.processes) + n_jobs=args.processes, return_ani=return_ani) if len(siglist) < 30: for i, E in enumerate(siglist): @@ -533,8 +542,7 @@ def search(args): notify("** reporting only one match because --best-only was set") if args.output: - fieldnames = ['similarity', 'name', 'filename', 'md5', - 'query_filename', 'query_name', 'query_md5'] + fieldnames = SearchResult._fields with FileOutputCSV(args.output) as fp: w = csv.DictWriter(fp, fieldnames=fieldnames) @@ -689,11 +697,7 @@ def gather(args): prefetch_csvout_fp = None prefetch_csvout_w = None if args.save_prefetch_csv: - fieldnames = ['intersect_bp', 'jaccard', - 'max_containment', 'f_query_match', 'f_match_query', - 'match_filename', 'match_name', 'match_md5', 'match_bp', - 'query_filename', 'query_name', 'query_md5', 'query_bp'] - + fieldnames = PrefetchResult._fields prefetch_csvout_fp = FileOutput(args.save_prefetch_csv, 'wt').open() prefetch_csvout_w = csv.DictWriter(prefetch_csvout_fp, fieldnames=fieldnames) prefetch_csvout_w.writeheader() @@ -808,13 +812,7 @@ def gather(args): # save CSV? if found and args.output: - fieldnames = ['intersect_bp', 'f_orig_query', 'f_match', - 'f_unique_to_query', 'f_unique_weighted', - 'average_abund', 'median_abund', 'std_abund', 'name', - 'filename', 'md5', 'f_match_orig', 'unique_intersect_bp', - 'gather_result_rank', 'remaining_bp', - 'query_filename', 'query_name', 'query_md5', 'query_bp'] - + fieldnames = GatherResult._fields with FileOutputCSV(args.output) as fp: w = csv.DictWriter(fp, fieldnames=fieldnames) w.writeheader() @@ -981,14 +979,7 @@ def multigather(args): output_base = os.path.basename(query_filename) output_csv = output_base + '.csv' - - fieldnames = ['intersect_bp', 'f_orig_query', 'f_match', - 'f_unique_to_query', 'f_unique_weighted', - 'average_abund', 'median_abund', 'std_abund', 'name', - 'filename', 'md5', 'f_match_orig', - 'unique_intersect_bp', 'gather_result_rank', - 'remaining_bp', 'query_filename', 'query_name', - 'query_md5', 'query_bp'] + fieldnames = GatherResult._fields with FileOutputCSV(output_csv) as fp: w = csv.DictWriter(fp, fieldnames=fieldnames) w.writeheader() @@ -1192,11 +1183,7 @@ def prefetch(args): csvout_fp = None csvout_w = None if args.output: - fieldnames = ['intersect_bp', 'jaccard', - 'max_containment', 'f_query_match', 'f_match_query', - 'match_filename', 'match_name', 'match_md5', 'match_bp', - 'query_filename', 'query_name', 'query_md5', 'query_bp'] - + fieldnames = PrefetchResult._fields csvout_fp = FileOutput(args.output, 'wt').open() csvout_w = csv.DictWriter(csvout_fp, fieldnames=fieldnames) csvout_w.writeheader() diff --git a/src/sourmash/compare.py b/src/sourmash/compare.py index fba5bbcb7b..a18ba38f06 100644 --- a/src/sourmash/compare.py +++ b/src/sourmash/compare.py @@ -9,7 +9,7 @@ from sourmash.np_utils import to_memmap -def compare_serial(siglist, ignore_abundance, downsample=False): +def compare_serial(siglist, ignore_abundance, downsample=False, return_ani=False): """Compare all combinations of signatures and return a matrix of similarities. Processes combinations serially on a single process. Best to use when there is few signatures. @@ -34,12 +34,15 @@ def compare_serial(siglist, ignore_abundance, downsample=False): similarities = np.ones((n, n)) for i, j in iterator: - similarities[i][j] = similarities[j][i] = siglist[i].similarity(siglist[j], ignore_abundance, downsample) + if return_ani: + similarities[i][j] = similarities[j][i] = siglist[i].jaccard_ani(siglist[j], downsample)[0] + else: + similarities[i][j] = similarities[j][i] = siglist[i].similarity(siglist[j], ignore_abundance, downsample) return similarities -def compare_serial_containment(siglist, downsample=False): +def compare_serial_containment(siglist, downsample=False, return_ani=False): """Compare all combinations of signatures and return a matrix of containments. Processes combinations serially on a single process. Best to only use when there are few signatures. @@ -55,13 +58,17 @@ def compare_serial_containment(siglist, downsample=False): containments = np.ones((n, n)) for i in range(n): for j in range(n): - containments[i][j] = siglist[j].contained_by(siglist[i], + if return_ani: + containments[i][j] = siglist[j].containment_ani(siglist[i], + downsample=downsample)[0] + else: + containments[i][j] = siglist[j].contained_by(siglist[i], downsample=downsample) return containments -def compare_serial_max_containment(siglist, downsample=False): +def compare_serial_max_containment(siglist, downsample=False, return_ani=False): """Compare all combinations of signatures and return a matrix of max_containments. Processes combinations serially on a single process. Best to only use when there are few signatures. @@ -77,22 +84,30 @@ def compare_serial_max_containment(siglist, downsample=False): containments = np.ones((n, n)) for i in range(n): for j in range(n): - containments[i][j] = siglist[j].max_containment(siglist[i], + if return_ani: + containments[i][j] = siglist[j].max_containment_ani(siglist[i], + downsample=downsample)[0] + else: + containments[i][j] = siglist[j].max_containment(siglist[i], downsample=downsample) return containments -def similarity_args_unpack(args, ignore_abundance, downsample): +def similarity_args_unpack(args, ignore_abundance, downsample, return_ani=False): """Helper function to unpack the arguments. Written to use in pool.imap as it can only be given one argument.""" sig1, sig2 = args - return sig1.similarity(sig2, + if return_ani: + return sig1.jaccard_ani(sig2, + downsample=downsample)[0] + else: + return sig1.similarity(sig2, ignore_abundance=ignore_abundance, downsample=downsample) -def get_similarities_at_index(index, ignore_abundance, downsample, siglist): +def get_similarities_at_index(index, ignore_abundance, downsample, siglist, return_ani=False): """Returns similarities of all the combinations of signature at index in the siglist with the rest of the indices starting at index + 1. Doesn't redundantly calculate signatures with all the other indices prior to @@ -114,14 +129,14 @@ def get_similarities_at_index(index, ignore_abundance, downsample, siglist): sig_iterator = itertools.product([siglist[index]], siglist[index + 1:]) func = partial(similarity_args_unpack, ignore_abundance=ignore_abundance, - downsample=downsample) + downsample=downsample, return_ani=return_ani) similarity_list = list(map(func, sig_iterator)) notify( f"comparison for index {index} done in {time.time() - startt:.5f} seconds", end='\r') return similarity_list -def compare_parallel(siglist, ignore_abundance, downsample, n_jobs): +def compare_parallel(siglist, ignore_abundance, downsample, n_jobs, return_ani=False): """Compare all combinations of signatures and return a matrix of similarities. Processes combinations parallely on number of processes given by n_jobs @@ -163,7 +178,8 @@ def compare_parallel(siglist, ignore_abundance, downsample, n_jobs): get_similarities_at_index, siglist=siglist, ignore_abundance=ignore_abundance, - downsample=downsample) + downsample=downsample, + return_ani=return_ani) notify("Created similarity func") # Initialize multiprocess.pool @@ -198,7 +214,7 @@ def compare_parallel(siglist, ignore_abundance, downsample, n_jobs): return np.memmap(filename, dtype=np.float64, shape=(length_siglist, length_siglist)) -def compare_all_pairs(siglist, ignore_abundance, downsample=False, n_jobs=None): +def compare_all_pairs(siglist, ignore_abundance, downsample=False, n_jobs=None, return_ani=False): """Compare all combinations of signatures and return a matrix of similarities. Processes combinations either serially or based on parallely on number of processes given by n_jobs @@ -216,7 +232,7 @@ def compare_all_pairs(siglist, ignore_abundance, downsample=False, n_jobs=None): :return: np.array similarity matrix """ if n_jobs is None or n_jobs == 1: - similarities = compare_serial(siglist, ignore_abundance, downsample) + similarities = compare_serial(siglist, ignore_abundance, downsample, return_ani) else: - similarities = compare_parallel(siglist, ignore_abundance, downsample, n_jobs) + similarities = compare_parallel(siglist, ignore_abundance, downsample, n_jobs, return_ani) return similarities diff --git a/src/sourmash/distance_utils.py b/src/sourmash/distance_utils.py deleted file mode 100644 index a6884515d7..0000000000 --- a/src/sourmash/distance_utils.py +++ /dev/null @@ -1,315 +0,0 @@ -""" -Utilities for jaccard/containment --> distance estimation -Equations from: https://github.com/KoslickiLab/mutation-rate-ci-calculator -Reference: https://doi.org/10.1101/2022.01.11.475870 -""" -from dataclasses import dataclass, field -from scipy.optimize import brentq -from scipy.stats import norm as scipy_norm -from numpy import sqrt -from math import log, exp - -from .logging import notify - -def check_distance(dist): - if not 0 <= dist <= 1: - raise ValueError(f"Error: distance value {dist :.4f} is not between 0 and 1!") - else: - return dist - -def check_prob_threshold(val, threshold=1e-3): - """ - Check likelihood of no shared hashes based on chance alone (false neg). - If too many exceed threshold, recommend user lower their scaled value. - # !! when using this, keep count and recommend user lower scaled val - """ - exceeds_threshold = False - if threshold is not None and val > threshold: - notify("WARNING: These sketches may have no hashes in common based on chance alone.") - exceeds_threshold = True - return val, exceeds_threshold - -def check_jaccard_error(val, threshold=1e-4): - exceeds_threshold = False - if threshold is not None and val > threshold: - notify(f"WARNING: Error on Jaccard distance point estimate is too high ({val :.4f}).") - exceeds_threshold = True - return val, exceeds_threshold - -@dataclass -class ANIResult: - """Base class for distance/ANI from k-mer containment.""" - dist: float - p_nothing_in_common: float - p_threshold: float = 1e-3 - p_exceeds_threshold: bool = field(init=False) - - def __post_init__(self): - # check values - self.dist = check_distance(self.dist) - self.p_nothing_in_common, self.p_exceeds_threshold = check_prob_threshold(self.p_nothing_in_common, self.p_threshold) - - @property - def ani(self): - return 1 - self.dist - - -@dataclass -class jaccardANIResult(ANIResult): - """Class for distance/ANI from jaccard (includes jaccard_error).""" - jaccard_error: float = None - je_threshold: float = 1e-4 - - def __post_init__(self): - # check values - self.dist = check_distance(self.dist) - self.p_nothing_in_common, self.p_exceeds_threshold = check_prob_threshold(self.p_nothing_in_common, self.p_threshold) - # check jaccard error - if self.jaccard_error is not None: - self.jaccard_error, self.je_exceeds_threshold = check_jaccard_error(self.jaccard_error, self.je_threshold) - else: - raise ValueError("Error: jaccard_error cannot be None.") - - -@dataclass -class ciANIResult(ANIResult): - """ - Class for distance/ANI from containment: with confidence intervals. - - Set CI defaults to None, just in case CI can't be estimated for given sample. - """ - dist_low: float = None - dist_high: float = None - - def __post_init__(self): - # check values - self.dist = check_distance(self.dist) - self.p_nothing_in_common, self.p_exceeds_threshold = check_prob_threshold(self.p_nothing_in_common, self.p_threshold) - - if self.dist_low is not None and self.dist_high is not None: - self.dist_low = check_distance(self.dist_low) - self.dist_high = check_distance(self.dist_high) - - @property - def ani_low(self): - if self.dist_high is None: - return None - return 1 - self.dist_high - - @property - def ani_high(self): - if self.dist_low is None: - return None - return 1 - self.dist_low - - -def r1_to_q(k, r1): - r1 = float(r1) - q = 1 - (1 - r1) ** k - return float(q) - - -def var_n_mutated(L, k, r1, *, q=None): - # there are computational issues in the variance formula that we solve here - # by the use of higher-precision arithmetic; the problem occurs when r is - # very small; for example, with L=10,k=2,r1=1e-6 standard precision - # gives varN<0 which is nonsense; by using the mpf type, we get the correct - # answer which is about 0.000038. - if r1 == 0: - return 0.0 - r1 = float(r1) - if q == None: # we assume that if q is provided, it is correct for r1 - q = r1_to_q(k, r1) - varN = ( - L * (1 - q) * (q * (2 * k + (2 / r1) - 1) - 2 * k) - + k * (k - 1) * (1 - q) ** 2 - + (2 * (1 - q) / (r1**2)) * ((1 + (k - 1) * (1 - q)) * r1 - q) - ) - if varN < 0.0: # this seems to happen only with super tiny test data - raise ValueError("Error: varN <0.0!") - return float(varN) - - -def exp_n_mutated(L, k, r1): - q = r1_to_q(k, r1) - return L * q - - -def exp_n_mutated_squared(L, k, p): - return var_n_mutated(L, k, p) + exp_n_mutated(L, k, p) ** 2 - - -def probit(p): - return scipy_norm.ppf(p) - - -def handle_seqlen_nkmers(ksize, *, sequence_len_bp=None, n_unique_kmers=None): - if n_unique_kmers is not None: - return n_unique_kmers - elif sequence_len_bp is None: - # both are None, raise ValueError - raise ValueError("Error: distance estimation requires input of either 'sequence_len_bp' or 'n_unique_kmers'") - else: - n_unique_kmers = sequence_len_bp - (ksize - 1) - return n_unique_kmers - - -def get_expected_log_probability(n_unique_kmers, ksize, mutation_rate, scaled_fraction): - """helper function - Note that scaled here needs to be between 0 and 1 - (e.g. scaled 1000 --> scaled_fraction 0.001) - """ - exp_nmut = exp_n_mutated(n_unique_kmers, ksize, mutation_rate) - try: - return (n_unique_kmers - exp_nmut) * log(1.0 - scaled_fraction) - except: - return float("-inf") - - -def get_exp_probability_nothing_common( - mutation_rate, ksize, scaled, *, n_unique_kmers=None, sequence_len_bp=None -): - """ - Given parameters, calculate the expected probability that nothing will be common - between a fracminhash sketch of a original sequence and a fracminhash sketch of a mutated - sequence. If this is above a threshold, we should suspect that the two sketches may have - nothing in common. The threshold needs to be set with proper insights. - - Arguments: n_unique_kmers, ksize, mutation_rate, scaled - Returns: float - expected likelihood that nothing is common between sketches - """ - n_unique_kmers = handle_seqlen_nkmers(ksize, sequence_len_bp=sequence_len_bp,n_unique_kmers=n_unique_kmers) - f_scaled = 1.0 / float(scaled) - if mutation_rate == 1.0: - return 1.0 - elif mutation_rate == 0.0: - return 0.0 - return exp( - get_expected_log_probability(n_unique_kmers, ksize, mutation_rate, f_scaled) - ) - - -def containment_to_distance( - containment, - ksize, - scaled, - *, - n_unique_kmers=None, - sequence_len_bp=None, - confidence=0.95, - estimate_ci=False, - prob_threshold=1e-3, -): - """ - Containment --> distance CI (one step) - """ - sol1, sol2, point_estimate = None, None, None - n_unique_kmers = handle_seqlen_nkmers(ksize, sequence_len_bp = sequence_len_bp, n_unique_kmers=n_unique_kmers) - if containment <= 0.0001: - point_estimate = 1.0 - elif containment >= 0.9999: - point_estimate = 0.0 - else: - point_estimate = 1.0 - containment ** (1.0 / ksize) - if estimate_ci: - try: - alpha = 1 - confidence - z_alpha = probit(1 - alpha / 2) - f_scaled = ( - 1.0 / scaled - ) # these use scaled as a fraction between 0 and 1 - - bias_factor = 1 - (1 - f_scaled) ** n_unique_kmers - - term_1 = (1.0 - f_scaled) / ( - f_scaled * n_unique_kmers**3 * bias_factor**2 - ) - term_2 = lambda pest: n_unique_kmers * exp_n_mutated( - n_unique_kmers, ksize, pest - ) - exp_n_mutated_squared(n_unique_kmers, ksize, pest) - term_3 = lambda pest: var_n_mutated(n_unique_kmers, ksize, pest) / ( - n_unique_kmers**2 - ) - - var_direct = lambda pest: term_1 * term_2(pest) + term_3(pest) - - f1 = ( - lambda pest: (1 - pest) ** ksize - + z_alpha * sqrt(var_direct(pest)) - - containment - ) - f2 = ( - lambda pest: (1 - pest) ** ksize - - z_alpha * sqrt(var_direct(pest)) - - containment - ) - - sol1 = brentq(f1, 0.0000001, 0.9999999) - sol2 = brentq(f2, 0.0000001, 0.9999999) - - except ValueError as exc: - # afaict, this only happens with extremely small test data - notify( - "WARNING: Cannot estimate ANI confidence intervals from containment. Do your sketches contain enough hashes?" - ) - notify(str(exc)) - sol1 = sol2 = None - - # Do this here, so that we don't need to reconvert distance <--> identity later. - prob_nothing_in_common = get_exp_probability_nothing_common( - point_estimate, ksize, scaled, n_unique_kmers=n_unique_kmers - ) - return ciANIResult(point_estimate, prob_nothing_in_common, dist_low=sol2, dist_high=sol1, p_threshold=prob_threshold) - - -def jaccard_to_distance( - jaccard, - ksize, - scaled, - *, - n_unique_kmers=None, - sequence_len_bp=None, - prob_threshold=1e-3, - err_threshold=1e-4, -): - """ - Given parameters, calculate point estimate for mutation rate from jaccard index. - Uses formulas derived mathematically to compute the point estimate. The formula uses - approximations, therefore a tiny error is associated with it. A lower bound of that error - is also returned. A high error indicates that the point estimate cannot be trusted. - Threshold of the error is open to interpretation, but suggested that > 10^-4 should be - handled with caution. - - Note that the error is NOT a mutation rate, and therefore cannot be considered in - something like mut.rate +/- error. - - Arguments: jaccard, ksize, scaled, n_unique_kmers - # Returns: tuple (point_estimate_of_mutation_rate, lower_bound_of_error) - - # Returns: point_estimate_of_mutation_rate - - Note: point estimate does not consider impact of scaled, but p_nothing_in_common can be - useful for determining whether scaled is sufficient for these comparisons. - """ - error_lower_bound = None - n_unique_kmers = handle_seqlen_nkmers(ksize, sequence_len_bp=sequence_len_bp, n_unique_kmers=n_unique_kmers) - if jaccard <= 0.0001: - point_estimate = 1.0 - error_lower_bound = 0.0 - elif jaccard >= 0.9999: - point_estimate = 0.0 - error_lower_bound = 0.0 - else: - point_estimate = 1.0 - (2.0 * jaccard / float(1 + jaccard)) ** ( - 1.0 / float(ksize) - ) - - exp_n_mut = exp_n_mutated(n_unique_kmers, ksize, point_estimate) - var_n_mut = var_n_mutated(n_unique_kmers, ksize, point_estimate) - error_lower_bound = ( - 1.0 * n_unique_kmers * var_n_mut / (n_unique_kmers + exp_n_mut) ** 3 - ) - prob_nothing_in_common = get_exp_probability_nothing_common( - point_estimate, ksize, scaled, n_unique_kmers=n_unique_kmers - ) - return jaccardANIResult(point_estimate, prob_nothing_in_common, jaccard_error=error_lower_bound, p_threshold=prob_threshold, je_threshold=err_threshold) diff --git a/src/sourmash/search.py b/src/sourmash/search.py index 9867f9f697..9421c02a52 100644 --- a/src/sourmash/search.py +++ b/src/sourmash/search.py @@ -5,6 +5,8 @@ from enum import Enum import numpy as np +from sourmash.distance_utils import containment_to_distance, jaccard_to_distance + from .signature import SourmashSignature @@ -160,7 +162,9 @@ def collect(self, score, match): # generic SearchResult tuple. SearchResult = namedtuple('SearchResult', - 'similarity, match, md5, filename, name, query, query_filename, query_name, query_md5') + ['similarity', 'match', 'md5', 'filename', 'name', + 'query', 'query_filename', 'query_name', 'query_md5', + 'ksize', 'search_scaled', 'estimated_ani', 'ani_ci_low', 'ani_ci_high']) def format_bp(bp): @@ -193,7 +197,31 @@ def search_databases_with_flat_query(query, databases, **kwargs): results.sort(key=lambda x: -x[0]) x = [] + query_mh = query.minhash + ksize = query_mh.ksize + scaled = query_mh.scaled + search_scaled=0 + ani, ani_low, ani_high = None, None, None for (score, match, filename) in results: + match_mh = match.minhash + if scaled: # if scaled, we can get ANI estimates + search_scaled = max(query_mh.scaled, match_mh.scaled) + if search_scaled > query_mh.scaled: + query_mh = query_mh.downsample(scaled=search_scaled) + if search_scaled > match_mh.scaled: + match_mh = match_mh.downsample(scaled=search_scaled) + if kwargs.get('do_containment'): + ani, ani_low, ani_high = containment_to_distance(score, ksize, search_scaled, + n_unique_kmers=len(query_mh), return_identity=True) + elif kwargs.get('do_max_containment'): + min_n_kmers = min(len(query_mh), len(match_mh)) + ani, ani_low, ani_high = containment_to_distance(score, ksize, search_scaled, + n_unique_kmers=min_n_kmers, return_identity=True) + else: + avg_n_kmers = round((len(query_mh) + len(match_mh))/2) + ani, ani_low, ani_high = jaccard_to_distance(score, ksize, search_scaled, + n_unique_kmers=avg_n_kmers, return_identity=True) + x.append(SearchResult(similarity=score, match=match, md5=match.md5sum(), @@ -202,7 +230,12 @@ def search_databases_with_flat_query(query, databases, **kwargs): query=query, query_filename=query.filename, query_name=query.name, - query_md5=query.md5sum()[:8] + query_md5=query.md5sum()[:8], + ksize=ksize, + search_scaled=search_scaled, + estimated_ani=ani, + ani_ci_low=ani_low, + ani_ci_high=ani_high, )) return x @@ -226,7 +259,10 @@ def search_databases_with_abund_query(query, databases, **kwargs): results.sort(key=lambda x: -x[0]) x = [] + search_scaled=0 for (score, match, filename) in results: + if query.minhash.scaled: + search_scaled = min(query.minhash.scaled, match.minhash.scaled) x.append(SearchResult(similarity=score, match=match, md5=match.md5sum(), @@ -235,7 +271,12 @@ def search_databases_with_abund_query(query, databases, **kwargs): query=query, query_filename=query.filename, query_name=query.name, - query_md5=query.md5sum()[:8] + query_md5=query.md5sum()[:8], + ksize=query.minhash.ksize, + search_scaled=search_scaled, + estimated_ani=None, + ani_ci_low=None, + ani_ci_high=None, )) return x @@ -243,8 +284,13 @@ def search_databases_with_abund_query(query, databases, **kwargs): ### gather code ### -GatherResult = namedtuple('GatherResult', - 'intersect_bp, f_orig_query, f_match, f_unique_to_query, f_unique_weighted, average_abund, median_abund, std_abund, filename, name, md5, match, f_match_orig, unique_intersect_bp, gather_result_rank, remaining_bp, query_filename, query_name, query_md5, query_bp') +GatherResult = namedtuple('GatherResult', ['intersect_bp', 'f_orig_query', 'f_match', 'f_unique_to_query', + 'f_unique_weighted','average_abund', 'median_abund', 'std_abund', 'filename', + 'name', 'md5', 'match', 'f_match_orig', 'unique_intersect_bp', 'gather_result_rank', + 'remaining_bp', 'query_filename', 'query_name', 'query_md5', 'query_bp', 'ksize', + 'moltype', 'num', 'scaled', 'query_nhashes', 'query_abundance', 'match_ani', + 'match_ani_ci_low', 'match_ani_ci_high', 'query_ani', 'query_ani_ci_low', + 'query_ani_ci_high']) def _find_best(counters, query, threshold_bp): @@ -405,6 +451,16 @@ def __next__(self): # calculate fraction of subject match with orig query f_match_orig = found_mh.contained_by(orig_query_mh) + # calculate ani using match containment by query + match_ani, match_ani_low, match_ani_high = containment_to_distance(f_match_orig, found_mh.ksize, scaled, + n_unique_kmers=(len(found_mh) * scaled), + return_identity=True) + + # calculate ani using query containment by match -- useful for genome classification + orig_query_ani, query_ani_low, query_ani_high = containment_to_distance(f_orig_query, orig_query_mh.ksize, scaled, + n_unique_kmers=(orig_query_len * scaled), + return_identity=True) + # calculate scores weighted by abundances f_unique_weighted = sum((orig_query_abunds[k] for k in intersect_mh.hashes )) f_unique_weighted /= sum_abunds @@ -453,6 +509,18 @@ def __next__(self): query_filename=self.orig_query_filename, query_name=self.orig_query_name, query_md5=self.orig_query_md5, + ksize = self.orig_query_mh.ksize, + moltype = self.orig_query_mh.moltype, + num = self.orig_query_mh.num, + scaled = scaled, + query_nhashes=len(self.orig_query_mh), + query_abundance=self.orig_query_mh.track_abundance, + match_ani=match_ani, + match_ani_ci_low=match_ani_low, + match_ani_ci_high=match_ani_high, + query_ani=orig_query_ani, + query_ani_ci_low=query_ani_low, + query_ani_ci_high=query_ani_high, ) self.result_n += 1 self.query = new_query @@ -466,7 +534,14 @@ def __next__(self): ### PrefetchResult = namedtuple('PrefetchResult', - 'intersect_bp, jaccard, max_containment, f_query_match, f_match_query, match, match_filename, match_name, match_md5, match_bp, query, query_filename, query_name, query_md5, query_bp') + ['intersect_bp', 'jaccard', 'max_containment', 'f_query_match', + 'f_match_query', 'match', 'match_filename', 'match_name', + 'match_md5', 'match_bp', 'query', 'query_filename', 'query_name', + 'query_md5', 'query_bp', 'ksize', 'moltype', 'num', 'scaled', + 'query_nhashes', 'query_abundance','jaccard_ani', 'jaccard_ani_ci_low', + 'jaccard_ani_ci_high', 'max_containment_ani', 'query_ani', + 'query_ani_ci_low', 'query_ani_ci_high', 'match_ani', + 'match_ani_ci_low', "match_ani_ci_high"]) def calculate_prefetch_info(query, match, scaled, threshold_bp): @@ -488,13 +563,20 @@ def calculate_prefetch_info(query, match, scaled, threshold_bp): f_query_match = db_mh.contained_by(query_mh) f_match_query = query_mh.contained_by(db_mh) max_containment = max(f_query_match, f_match_query) + jaccard=db_mh.jaccard(query_mh) + + # passing in jaccard/containment avoids recalc (but it better be the right one :) + jaccard_ani, jaccard_ani_low, jaccard_ani_high = query.jaccard_ani(match, jaccard=jaccard) + query_ani, query_ani_low, query_ani_high = query.containment_ani(match, containment=f_match_query) + match_ani, match_ani_low, match_ani_high = match.containment_ani(query, containment=f_query_match) + max_containment_ani = max(query_ani, match_ani) # build a result namedtuple result = PrefetchResult( intersect_bp=len(intersect_mh) * scaled, query_bp = len(query_mh) * scaled, match_bp = len(db_mh) * scaled, - jaccard=db_mh.jaccard(query_mh), + jaccard=jaccard, max_containment=max_containment, f_query_match=f_query_match, f_match_query=f_match_query, @@ -505,7 +587,23 @@ def calculate_prefetch_info(query, match, scaled, threshold_bp): query=query, query_filename=query.filename, query_name=query.name, - query_md5=query.md5sum()[:8] + query_md5=query.md5sum()[:8], + ksize = query_mh.ksize, + moltype = query_mh.moltype, + num = query_mh.num, + scaled = scaled, + query_nhashes=len(query_mh), + query_abundance=query_mh.track_abundance, + jaccard_ani=jaccard_ani, + jaccard_ani_ci_low=jaccard_ani_low, + jaccard_ani_ci_high=jaccard_ani_high, + max_containment_ani=max_containment_ani, + query_ani=query_ani, + query_ani_ci_low=query_ani_low, + query_ani_ci_high=query_ani_high, + match_ani=match_ani, + match_ani_ci_low=match_ani_low, + match_ani_ci_high=match_ani_high, ) return result diff --git a/src/sourmash/signature.py b/src/sourmash/signature.py index 386fdb1733..c8adffb145 100644 --- a/src/sourmash/signature.py +++ b/src/sourmash/signature.py @@ -10,6 +10,7 @@ from .logging import error from . import MinHash from .minhash import to_bytes, FrozenMinHash +from .distance_utils import jaccard_to_distance, containment_to_distance from ._lowlevel import ffi, lib from .utils import RustObject, rustcall, decode_str @@ -142,14 +143,73 @@ def jaccard(self, other): return self.minhash.similarity(other.minhash, ignore_abundance=True, downsample=False) + def jaccard_ani(self, other, downsample=False, jaccard=None, return_err_and_p_nothing_in_common=False)#, return_ci=False, confidence=0.95): + "Compute Jaccard similarity with the other MinHash signature." + self_mh = self.minhash + other_mh = other.minhash + scaled = self_mh.scaled + if downsample: + scaled = max(self_mh.scaled, other_mh.scaled) + self_mh = self.minhash.downsample(scaled=scaled) + other_mh = other.minhash.downsample(scaled=scaled) + if jaccard is None: + jaccard = self_mh.similarity(other_mh, ignore_abundance=True) + avg_scaled_kmers = round((len(self_mh) + len(other_mh))/2) + avg_n_kmers = avg_scaled_kmers * scaled # would be better if hll estimate + #j_ani,ani_low,ani_high = jaccard_to_distance_orig(jaccard, self_mh.ksize, + # scaled, n_unique_kmers=avg_n_kmers, + # confidence=confidence, return_identity=True) + j_ani,err,prob_nothing_in_common = jaccard_to_distance(jaccard, self_mh.ksize, + scaled, n_unique_kmers=avg_n_kmers, + return_identity=True) #confidence=confidence, + if return_err_and_p_nothing_in_common: + return j_ani,err,prob_nothing_in_common + return j_ani + def contained_by(self, other, downsample=False): "Compute containment by the other signature. Note: ignores abundance." return self.minhash.contained_by(other.minhash, downsample) + def containment_ani(self, other, downsample=False, containment=None, confidence=0.95): + "Estimate ANI from containment with the other MinHash signature." + self_mh = self.minhash + other_mh = other.minhash + scaled = self_mh.scaled + if downsample: + scaled = max(self_mh.scaled, other_mh.scaled) + self_mh = self.minhash.downsample(scaled=scaled) + other_mh = other.minhash.downsample(scaled=scaled) + if containment is None: + containment = self_mh.contained_by(other_mh) + n_kmers = len(self_mh) * scaled # would be better if hll estimate + c_ani,ani_low,ani_high = containment_to_distance(containment, self_mh.ksize, + scaled, n_unique_kmers=n_kmers, + confidence=confidence, return_identity=True) + return c_ani, ani_low, ani_high + def max_containment(self, other, downsample=False): "Compute max containment w/other signature. Note: ignores abundance." return self.minhash.max_containment(other.minhash, downsample) + def max_containment_ani(self, other, downsample=False, max_containment=None, confidence=0.95): + "Estimate ANI from max containment w/other signature. Note: ignores abundance." + self_mh = self.minhash + other_mh = other.minhash + scaled = self_mh.scaled + if downsample: + scaled = max(self_mh.scaled, other_mh.scaled) + self_mh = self.minhash.downsample(scaled=scaled) + other_mh = other.minhash.downsample(scaled=scaled) + if max_containment is None: + max_containment = self_mh.max_containment(other_mh) + # max_containment will always use smaller sig as denominator + min_n_hashes = min(len(self_mh), len(other_mh)) + n_kmers = min_n_hashes * scaled + c_ani,ani_low,ani_high = containment_to_distance(max_containment, self_mh.ksize, + scaled, n_unique_kmers=n_kmers, + confidence=confidence, return_identity=True) + return c_ani, ani_low, ani_high + def add_sequence(self, sequence, force=False): self._methodcall(lib.signature_add_sequence, to_bytes(sequence), force) diff --git a/src/sourmash/tax/__main__.py b/src/sourmash/tax/__main__.py index 01fc6bbe1c..1cee174eb1 100644 --- a/src/sourmash/tax/__main__.py +++ b/src/sourmash/tax/__main__.py @@ -184,7 +184,8 @@ def genome(args): best_at_rank, seen_perfect = tax_utils.summarize_gather_at(args.rank, tax_assign, gather_results, skip_idents=idents_missed, keep_full_identifiers=args.keep_full_identifiers, keep_identifier_versions = args.keep_identifier_versions, - best_only=True, seen_perfect=seen_perfect) + best_only=True, seen_perfect=seen_perfect, estimate_query_ani=True) + except ValueError as exc: error(f"ERROR: {str(exc)}") sys.exit(-1) @@ -194,19 +195,22 @@ def genome(args): status = 'nomatch' if sg.query_name in matched_queries: continue - if sg.fraction <= args.containment_threshold: + if args.ani_threshold and sg.query_ani_at_rank < args.ani_threshold: + status="below_threshold" + notify(f"WARNING: classifying query {sg.query_name} at desired rank {args.rank} does not meet query ANI/AAI threshold {args.ani_threshold}") + elif sg.fraction <= args.containment_threshold: # should this just be less than? status="below_threshold" notify(f"WARNING: classifying query {sg.query_name} at desired rank {args.rank} does not meet containment threshold {args.containment_threshold}") else: status="match" - classif = ClassificationResult(sg.query_name, status, sg.rank, sg.fraction, sg.lineage, sg.query_md5, sg.query_filename, sg.f_weighted_at_rank, sg.bp_match_at_rank) + classif = ClassificationResult(sg.query_name, status, sg.rank, sg.fraction, sg.lineage, sg.query_md5, sg.query_filename, sg.f_weighted_at_rank, sg.bp_match_at_rank, sg.query_ani_at_rank) classifications[args.rank].append(classif) matched_queries.add(sg.query_name) if "krona" in args.output_format: lin_list = display_lineage(sg.lineage).split(';') krona_results.append((sg.fraction, *lin_list)) else: - # classify to the match that passes the containment threshold. + # classify to the rank/match that passes the containment threshold. # To do - do we want to store anything for this match if nothing >= containment threshold? for rank in tax_utils.ascending_taxlist(include_strain=False): # gets best_at_rank for all queries in this gather_csv @@ -214,7 +218,7 @@ def genome(args): best_at_rank, seen_perfect = tax_utils.summarize_gather_at(rank, tax_assign, gather_results, skip_idents=idents_missed, keep_full_identifiers=args.keep_full_identifiers, keep_identifier_versions = args.keep_identifier_versions, - best_only=True, seen_perfect=seen_perfect) + best_only=True, seen_perfect=seen_perfect, estimate_query_ani=True) except ValueError as exc: error(f"ERROR: {str(exc)}") sys.exit(-1) @@ -223,18 +227,26 @@ def genome(args): status = 'nomatch' if sg.query_name in matched_queries: continue - if sg.fraction >= args.containment_threshold: + if args.ani_threshold and sg.query_ani_at_rank >= args.ani_threshold: + status="match" + elif sg.fraction >= args.containment_threshold: status = "match" - classif = ClassificationResult(sg.query_name, status, sg.rank, sg.fraction, sg.lineage, sg.query_md5, sg.query_filename, sg.f_weighted_at_rank, sg.bp_match_at_rank) + if status == "match": + classif = ClassificationResult(query_name=sg.query_name, status=status, rank=sg.rank, + fraction=sg.fraction, lineage=sg.lineage, + query_md5=sg.query_md5, query_filename=sg.query_filename, + f_weighted_at_rank=sg.f_weighted_at_rank, bp_match_at_rank=sg.bp_match_at_rank, + query_ani_at_rank= sg.query_ani_at_rank) classifications[sg.rank].append(classif) matched_queries.add(sg.query_name) continue - if rank == "superkingdom" and status == "nomatch": + elif rank == "superkingdom" and status == "nomatch": status="below_threshold" classif = ClassificationResult(query_name=sg.query_name, status=status, rank="", fraction=0, lineage="", query_md5=sg.query_md5, query_filename=sg.query_filename, - f_weighted_at_rank=sg.f_weighted_at_rank, bp_match_at_rank=sg.bp_match_at_rank) + f_weighted_at_rank=sg.f_weighted_at_rank, bp_match_at_rank=sg.bp_match_at_rank, + query_ani_at_rank=sg.query_ani_at_rank) classifications[sg.rank].append(classif) if not any([classifications, krona_results]): diff --git a/src/sourmash/tax/tax_utils.py b/src/sourmash/tax/tax_utils.py index 4d9bac4965..0c69eeaf55 100644 --- a/src/sourmash/tax/tax_utils.py +++ b/src/sourmash/tax/tax_utils.py @@ -5,6 +5,7 @@ import csv from collections import namedtuple, defaultdict from collections import abc +from sourmash.distance_utils import containment_to_distance __all__ = ['get_ident', 'ascending_taxlist', 'collect_gather_csvs', 'load_gather_results', 'check_and_load_gather_csvs', @@ -18,9 +19,9 @@ from sourmash.logging import notify from sourmash.sourmash_args import load_pathlist_from_file -QueryInfo = namedtuple("QueryInfo", "query_md5, query_filename, query_bp") -SummarizedGatherResult = namedtuple("SummarizedGatherResult", "query_name, rank, fraction, lineage, query_md5, query_filename, f_weighted_at_rank, bp_match_at_rank") -ClassificationResult = namedtuple("ClassificationResult", "query_name, status, rank, fraction, lineage, query_md5, query_filename, f_weighted_at_rank, bp_match_at_rank") +QueryInfo = namedtuple("QueryInfo", "query_md5, query_filename, query_bp, query_hashes") +SummarizedGatherResult = namedtuple("SummarizedGatherResult", "query_name, rank, fraction, lineage, query_md5, query_filename, f_weighted_at_rank, bp_match_at_rank, query_ani_at_rank") +ClassificationResult = namedtuple("ClassificationResult", "query_name, status, rank, fraction, lineage, query_md5, query_filename, f_weighted_at_rank, bp_match_at_rank, query_ani_at_rank") # Essential Gather column names that must be in gather_csv to allow `tax` summarization EssentialGatherColnames = ('query_name', 'name', 'f_unique_weighted', 'f_unique_to_query', 'unique_intersect_bp', 'remaining_bp', 'query_md5', 'query_filename') @@ -182,7 +183,7 @@ def find_match_lineage(match_ident, tax_assign, *, skip_idents = [], def summarize_gather_at(rank, tax_assign, gather_results, *, skip_idents = [], keep_full_identifiers=False, keep_identifier_versions=False, best_only=False, - seen_perfect=set()): + seen_perfect=set(), estimate_query_ani=False): """ Summarize gather results at specified taxonomic rank """ @@ -192,7 +193,7 @@ def summarize_gather_at(rank, tax_assign, gather_results, *, skip_idents = [], sum_uniq_to_query = defaultdict(lambda: defaultdict(float)) sum_uniq_bp = defaultdict(lambda: defaultdict(float)) query_info = {} - + ksize,scaled,query_nhashes=None,None,None for row in gather_results: # get essential gather info query_name = row['query_name'] @@ -201,13 +202,25 @@ def summarize_gather_at(rank, tax_assign, gather_results, *, skip_idents = [], unique_intersect_bp = int(row['unique_intersect_bp']) query_md5 = row['query_md5'] query_filename = row['query_filename'] - # get query_bp - if query_name not in query_info.keys(): - query_bp = unique_intersect_bp + int(row['remaining_bp']) + if query_name not in query_info.keys(): #REMOVING THIS AFFECTS GATHER RESULTS!!! BUT query bp should always be same for same query? bug? + if "query_nhashes" in row.keys(): + query_nhashes = int(row["query_nhashes"]) + if "query_bp" in row.keys(): + query_bp = int(row["query_bp"]) + else: + query_bp = unique_intersect_bp + int(row['remaining_bp']) # store query info - query_info[query_name] = QueryInfo(query_md5=query_md5, query_filename=query_filename, query_bp=query_bp) - match_ident = row['name'] + query_info[query_name] = QueryInfo(query_md5=query_md5, query_filename=query_filename, query_bp=query_bp, query_hashes = query_nhashes) + if estimate_query_ani and (not ksize or not scaled): # just need to set these once. BUT, if we have these, should we check for compatibility when loading the gather file? + if "ksize" in row.keys(): # ksize and scaled were added to gather results in same PR + ksize = int(row['ksize']) + scaled = int(row['scaled']) + else: + estimate_query_ani=False + notify("WARNING: Please run gather with sourmash >= 4.3 to estimate query ANI at rank. Continuing without ANI...") + + match_ident = row['name'] # 100% match? are we looking at something in the database? if f_unique_to_query >= 1.0 and query_name not in seen_perfect: # only want to notify once, not for each rank @@ -219,9 +232,9 @@ def summarize_gather_at(rank, tax_assign, gather_results, *, skip_idents = [], # get lineage for match lineage = find_match_lineage(match_ident, tax_assign, - skip_idents=skip_idents, - keep_full_identifiers=keep_full_identifiers, - keep_identifier_versions=keep_identifier_versions) + skip_idents=skip_idents, + keep_full_identifiers=keep_full_identifiers, + keep_identifier_versions=keep_identifier_versions) # ident was in skip_idents if not lineage: continue @@ -234,12 +247,14 @@ def summarize_gather_at(rank, tax_assign, gather_results, *, skip_idents = [], sum_uniq_weighted[query_name][lineage] += f_uniq_weighted sum_uniq_bp[query_name][lineage] += unique_intersect_bp + # sort and store each as SummarizedGatherResult sum_uniq_to_query_sorted = [] for query_name, lineage_weights in sum_uniq_to_query.items(): qInfo = query_info[query_name] sumgather_items = list(lineage_weights.items()) sumgather_items.sort(key = lambda x: -x[1]) + query_ani = None if best_only: lineage, fraction = sumgather_items[0] if fraction > 1: @@ -248,13 +263,19 @@ def summarize_gather_at(rank, tax_assign, gather_results, *, skip_idents = [], continue f_weighted_at_rank = sum_uniq_weighted[query_name][lineage] bp_intersect_at_rank = sum_uniq_bp[query_name][lineage] - sres = SummarizedGatherResult(query_name, rank, fraction, lineage, qInfo.query_md5, qInfo.query_filename, f_weighted_at_rank, bp_intersect_at_rank) + if estimate_query_ani: + query_ani = containment_to_distance(fraction, ksize, scaled, + n_unique_kmers= qInfo.query_hashes, sequence_len_bp= qInfo.query_bp, + return_identity=True)[0] + sres = SummarizedGatherResult(query_name, rank, fraction, lineage, qInfo.query_md5, + qInfo.query_filename, f_weighted_at_rank, bp_intersect_at_rank, query_ani) sum_uniq_to_query_sorted.append(sres) else: total_f_weighted= 0.0 total_f_classified = 0.0 total_bp_classified = 0 for lineage, fraction in sumgather_items: + query_ani=None if fraction > 1: raise ValueError(f"The tax summary of query '{query_name}' is {fraction}, which is > 100% of the query!! This should not be possible. Please check that your input files come directly from a single gather run per query.") elif fraction == 0: @@ -264,16 +285,23 @@ def summarize_gather_at(rank, tax_assign, gather_results, *, skip_idents = [], total_f_weighted += f_weighted_at_rank bp_intersect_at_rank = int(sum_uniq_bp[query_name][lineage]) total_bp_classified += bp_intersect_at_rank - sres = SummarizedGatherResult(query_name, rank, fraction, lineage, query_md5, query_filename, f_weighted_at_rank, bp_intersect_at_rank) + if estimate_query_ani: + query_ani = containment_to_distance(fraction, ksize, scaled, + n_unique_kmers=qInfo.query_hashes, sequence_len_bp=qInfo.query_bp, + return_identity=True)[0] + sres = SummarizedGatherResult(query_name, rank, fraction, lineage, query_md5, + query_filename, f_weighted_at_rank, bp_intersect_at_rank, query_ani) sum_uniq_to_query_sorted.append(sres) # record unclassified lineage = () + query_ani=None fraction = 1.0 - total_f_classified if fraction > 0: f_weighted_at_rank = 1.0 - total_f_weighted bp_intersect_at_rank = qInfo.query_bp - total_bp_classified - sres = SummarizedGatherResult(query_name, rank, fraction, lineage, query_md5, query_filename, f_weighted_at_rank, bp_intersect_at_rank) + sres = SummarizedGatherResult(query_name, rank, fraction, lineage, query_md5, + query_filename, f_weighted_at_rank, bp_intersect_at_rank, query_ani) sum_uniq_to_query_sorted.append(sres) return sum_uniq_to_query_sorted, seen_perfect diff --git a/tests/test-data/tax/test1.gather_ani.csv b/tests/test-data/tax/test1.gather_ani.csv new file mode 100644 index 0000000000..48a09eb199 --- /dev/null +++ b/tests/test-data/tax/test1.gather_ani.csv @@ -0,0 +1,5 @@ +intersect_bp,f_orig_query,f_match,f_unique_to_query,f_unique_weighted,average_abund,median_abund,std_abund,name,filename,md5,f_match_orig,unique_intersect_bp,gather_result_rank,remaining_bp,query_name,query_md5,query_filename,ksize,scaled,query_nhashes +442000,0.08815317112086159,0.08438335242458954,0.08815317112086159,0.05815279361459521,1.6153846153846154,1.0,1.1059438185997785,"GCF_001881345.1 Escherichia coli strain=SF-596, ASM188134v1",/group/ctbrowngrp/gtdb/databases/ctb/gtdb-rs202.genomic.k31.sbt.zip,683df1ec13872b4b98d59e98b355b52c,0.042779713511420826,442000,0,4572000,test1,md5,test1.sig,31,1000,5013970 +390000,0.07778220981252493,0.10416666666666667,0.07778220981252493,0.050496823586903404,1.5897435897435896,1.0,0.8804995294906566,"GCF_009494285.1 Prevotella copri strain=iAK1218, ASM949428v1",/group/ctbrowngrp/gtdb/databases/ctb/gtdb-rs202.genomic.k31.sbt.zip,1266c86141e3a5603da61f57dd863ed0,0.052236806857755155,390000,1,4182000,test1,md5,test1.sig,31,1000,4571970 +138000,0.027522935779816515,0.024722321748477247,0.027522935779816515,0.015637726014008795,1.391304347826087,1.0,0.5702120455914782,"GCF_013368705.1 Bacteroides vulgatus strain=B33, ASM1336870v1",/group/ctbrowngrp/gtdb/databases/ctb/gtdb-rs202.genomic.k31.sbt.zip,7d5f4ba1d01c8c3f7a520d19faded7cb,0.012648945921173235,138000,2,4044000,test1,md5,test1.sig,31,1000,4181970 +338000,0.06741124850418827,0.013789581205311542,0.010769844435580374,0.006515719172503665,1.4814814814814814,1.0,0.738886568268889,"GCF_003471795.1 Prevotella copri strain=AM16-54, ASM347179v1",/group/ctbrowngrp/gtdb/databases/ctb/gtdb-rs202.genomic.k31.sbt.zip,0ebd36ff45fc2810808789667f4aad84,0.04337782340862423,54000,3,3990000,test1,md5,test1.sig,31,1000,4327970 diff --git a/tests/test_compare.py b/tests/test_compare.py index 5c7b6eee6b..9bac10a607 100644 --- a/tests/test_compare.py +++ b/tests/test_compare.py @@ -6,7 +6,7 @@ import sourmash from sourmash.compare import (compare_all_pairs, compare_parallel, - compare_serial) + compare_serial, compare_serial_containment, compare_serial_max_containment) import sourmash_tst_utils as utils @@ -19,6 +19,16 @@ def siglist(): sigs.extend(sourmash.load_file_as_signatures(filename)) return sigs +@pytest.fixture() +def scaled_siglist(): + demo_path = utils.get_test_data("scaled") + filenames = sorted(glob.glob(os.path.join(demo_path, "*.sig"))) + sigs = [] + for filename in filenames: + these_sigs = sourmash.load_file_as_signatures(filename) + scaled_sigs = [s for s in these_sigs if s.minhash.scaled != 0] + sigs.extend(scaled_sigs) + return sigs @pytest.fixture() def ignore_abundance(track_abundance): @@ -59,3 +69,60 @@ def test_compare_all_pairs(siglist, ignore_abundance): similarities_parallel = compare_all_pairs(siglist, ignore_abundance, downsample=False, n_jobs=2) similarities_serial = compare_serial(siglist, ignore_abundance, downsample=False) np.testing.assert_array_equal(similarities_parallel, similarities_serial) + + +def test_compare_serial_jaccardANI(scaled_siglist, ignore_abundance): + similarities = compare_serial(scaled_siglist, ignore_abundance, downsample=False, return_ani=True) + + true_similarities = np.array( + [[1., 0.942, 0.988, 0.986, 0.], + [0.942, 1., 0.960, 0., 0.], + [0.988, 0.960, 1., 0., 0.], + [0.986, 0., 0., 1., 0.], + [0., 0., 0., 0., 1.]]) + + np.testing.assert_array_almost_equal(similarities, true_similarities, decimal=3) + + +def test_compare_parallel_jaccardANI(scaled_siglist, ignore_abundance): + similarities = compare_parallel(scaled_siglist, ignore_abundance, downsample=False, n_jobs=2, return_ani=True) + + true_containment = np.array( + [[1., 0.942, 0.988, 0.986, 0.], + [0.942, 1., 0.960, 0., 0.], + [0.988, 0.960, 1., 0., 0.], + [0.986, 0., 0., 1., 0.], + [0., 0., 0., 0., 1.]]) + + np.testing.assert_array_almost_equal(similarities, true_containment, decimal=3) + + +def test_compare_all_pairs_jaccardANI(scaled_siglist, ignore_abundance): + similarities_parallel = compare_all_pairs(scaled_siglist, ignore_abundance, downsample=False, n_jobs=2, return_ani=True) + similarities_serial = compare_serial(scaled_siglist, ignore_abundance, downsample=False, return_ani=True) + np.testing.assert_array_equal(similarities_parallel, similarities_serial) + + +def test_compare_serial_containmentANI(scaled_siglist, ignore_abundance): + containment = compare_serial_containment(scaled_siglist, ignore_abundance, return_ani=True) + + true_containment = np.array( + [[1., 1., 1., 1., 0.], + [0.92391599, 1., 0.94383993, 0., 0.], + [0.97889056, 1., 1., 0., 0.], + [0.97685474, 0., 0., 1., 0.], + [0., 0., 0., 0., 1.]]) + + np.testing.assert_array_almost_equal(containment, true_containment, decimal=3) + + # check max_containment ANI + max_containment = compare_serial_max_containment(scaled_siglist, ignore_abundance, return_ani=True) + + true_max_containment = np.array( + [[1., 1., 1., 1., 0.], + [1., 1., 1., 0., 0.], + [1., 1., 1., 0., 0.], + [1., 0., 0., 1., 0.], + [0., 0., 0., 0., 1.,]]) + + np.testing.assert_array_almost_equal(max_containment, true_max_containment, decimal=3) diff --git a/tests/test_signature.py b/tests/test_signature.py index 92467b5c5e..41affb9c2f 100644 --- a/tests/test_signature.py +++ b/tests/test_signature.py @@ -426,3 +426,147 @@ def test_max_containment_equal(): assert ss2.contained_by(ss1) == 1 assert ss1.max_containment(ss2) == 1 assert ss2.max_containment(ss1) == 1 + + +def test_containment_ANI(): + f1 = utils.get_test_data('2.fa.sig') + f2 = utils.get_test_data('2+63.fa.sig') + f3 = utils.get_test_data('47+63.fa.sig') + ss1 = load_one_signature(f1, ksize=31) + ss2 = load_one_signature(f2, ksize=31) + ss3 = load_one_signature(f3, ksize=31) + + print("\nss1 contained by ss2", ss1.containment_ani(ss2)) + print("ss2 contained by ss1",ss2.containment_ani(ss1)) + print("ss1 max containment", ss1.max_containment_ani(ss2)) + print("ss2 max containment", ss2.max_containment_ani(ss1)) + + print("\nss2 contained by ss3", ss2.containment_ani(ss3)) + print("ss3 contained by ss2",ss3.containment_ani(ss2)) + + print("\nss2 contained by ss3, CI 90%", ss2.containment_ani(ss3, confidence=0.9)) + print("ss3 contained by ss2, CI 99%",ss3.containment_ani(ss2, confidence=0.99)) + + assert ss1.containment_ani(ss2) == (1.0, 1.0, 1.0) + assert ss2.containment_ani(ss1) == (0.9658183324254062, 0.9648452889933389, 0.966777042966207) + assert ss1.max_containment_ani(ss2) == (1.0, 1.0, 1.0) + assert ss2.max_containment_ani(ss1) == (1.0, 1.0, 1.0) + + # containment 1 is special case. check max containment for non 0/1 values: + assert ss2.containment_ani(ss3) == (0.9866751346467802, 0.9861576758035308, 0.9871770716451368) + assert ss3.containment_ani(ss2) == (0.9868883523107224, 0.986374049720872, 0.9873870188726516) + assert ss2.max_containment_ani(ss3) == (0.9868883523107224, 0.986374049720872, 0.9873870188726516) + assert ss3.max_containment_ani(ss2) == (0.9868883523107224, 0.986374049720872, 0.9873870188726516) + assert ss2.max_containment_ani(ss3)[0] == max(ss2.containment_ani(ss3)[0], ss3.containment_ani(ss2)[0]) + + # check confidence + assert ss2.containment_ani(ss3, confidence=0.9) == (0.9866751346467802, 0.986241913154108, 0.9870974232542815) + assert ss3.containment_ani(ss2, confidence=0.99) == (0.9868883523107224, 0.9862092287269876, 0.987540474733798) + + +def test_containment_ANI_precalc_containment(): + f1 = utils.get_test_data('2.fa.sig') + f2 = utils.get_test_data('2+63.fa.sig') + f3 = utils.get_test_data('47+63.fa.sig') + ss1 = load_one_signature(f1, ksize=31) + ss2 = load_one_signature(f2, ksize=31) + ss3 = load_one_signature(f3, ksize=31) + + # precalc containments and assert same results + s1c = ss1.contained_by(ss2) + s2c = ss2.contained_by(ss1) + s3c = ss2.contained_by(ss3) + s4c = ss3.contained_by(ss2) + mc = max(s1c, s2c) + assert ss1.containment_ani(ss2, containment=s1c) == (1.0, 1.0, 1.0) + assert ss2.containment_ani(ss1, containment=s2c) == (0.9658183324254062, 0.9648452889933389, 0.966777042966207) + assert ss1.max_containment_ani(ss2, max_containment=mc) == (1.0, 1.0, 1.0) + assert ss2.max_containment_ani(ss1, max_containment=mc) == (1.0, 1.0, 1.0) + + assert ss2.containment_ani(ss3, containment=s3c) == (0.9866751346467802, 0.9861576758035308, 0.9871770716451368) + assert ss3.containment_ani(ss2, containment=s4c) == (0.9868883523107224, 0.986374049720872, 0.9873870188726516) + assert ss3.containment_ani(ss2, containment=s4c, confidence=0.99) == (0.9868883523107224, 0.9862092287269876, 0.987540474733798) + assert ss2.max_containment_ani(ss3, max_containment=s4c) == (0.9868883523107224, 0.986374049720872, 0.9873870188726516) + assert ss3.max_containment_ani(ss2, max_containment=s4c) == (0.9868883523107224, 0.986374049720872, 0.9873870188726516) + + +def test_containment_ANI_downsample(): + f2 = utils.get_test_data('2+63.fa.sig') + f3 = utils.get_test_data('47+63.fa.sig') + ss2 = load_one_signature(f2, ksize=31) + ss3 = load_one_signature(f3, ksize=31) + # check that downsampling works properly + print(ss2.minhash.scaled) + ss2.minhash = ss2.minhash.downsample(scaled=2000) + assert ss2.minhash.scaled != ss3.minhash.scaled + ds_s3c = ss2.containment_ani(ss3, downsample=True) + ds_s4c = ss3.containment_ani(ss2, downsample=True) + mc_w_ds_1 = ss2.max_containment_ani(ss3, downsample=True) + mc_w_ds_2 = ss3.max_containment_ani(ss2, downsample=True) + + with pytest.raises(ValueError) as e: + ss2.containment_ani(ss3) + assert "ValueError: mismatch in scaled; comparison fail" in e + + with pytest.raises(ValueError) as e: + ss2.max_containment_ani(ss3) + assert "ValueError: mismatch in scaled; comparison fail" in e + + ss3.minhash = ss3.minhash.downsample(scaled=2000) + assert ss2.minhash.scaled == ss3.minhash.scaled + ds_s3c_manual = ss2.containment_ani(ss3) + ds_s4c_manual = ss3.containment_ani(ss2) + ds_mc_manual = ss2.max_containment_ani(ss3) + assert ds_s3c == ds_s3c_manual + assert ds_s4c == ds_s4c_manual + assert mc_w_ds_1 == mc_w_ds_2 == ds_mc_manual + + +def test_jaccard_ANI(): + f1 = utils.get_test_data('2.fa.sig') + f2 = utils.get_test_data('2+63.fa.sig') + ss1 = load_one_signature(f1, ksize=31) + ss2 = load_one_signature(f2) + + print("\nJACCARD_ANI", ss1.jaccard_ani(ss2)) + print("\nJACCARD_ANI 90% CI", ss1.jaccard_ani(ss2, confidence=0.9)) + print("\nJACCARD_ANI 99% CI", ss1.jaccard_ani(ss2, confidence=0.99)) + + assert ss1.jaccard_ani(ss2) == (0.9783711630110239, 0.9776381521132318, 0.9790929734698974) + assert ss1.jaccard_ani(ss2, confidence=0.9) == (0.9783711630110239, 0.9777567290812516, 0.9789777082973189) + assert ss1.jaccard_ani(ss2, confidence=0.99) == (0.9783711630110239, 0.9774056164150094, 0.9793173653983231) + + +def test_jaccard_ANI_precalc_jaccard(): + f1 = utils.get_test_data('2.fa.sig') + f2 = utils.get_test_data('2+63.fa.sig') + ss1 = load_one_signature(f1, ksize=31) + ss2 = load_one_signature(f2) + # precalc jaccard and assert same result + jaccard = ss1.jaccard(ss2) + print("\nJACCARD_ANI", ss1.jaccard_ani(ss2,jaccard=jaccard)) + + assert ss1.jaccard_ani(ss2, jaccard=jaccard) == (0.9783711630110239, 0.9776381521132318, 0.9790929734698974) + assert ss1.jaccard_ani(ss2, jaccard=jaccard, confidence=0.9) == (0.9783711630110239, 0.9777567290812516, 0.9789777082973189) + + +def test_jaccard_ANI_downsample(): + f1 = utils.get_test_data('2.fa.sig') + f2 = utils.get_test_data('2+63.fa.sig') + ss1 = load_one_signature(f1, ksize=31) + ss2 = load_one_signature(f2) + # check that downsampling works properly + print(ss2.minhash.scaled) + ss1.minhash = ss1.minhash.downsample(scaled=2000) + assert ss1.minhash.scaled != ss2.minhash.scaled + ds_s1c = ss1.jaccard_ani(ss2, downsample=True) + ds_s2c = ss2.jaccard_ani(ss1, downsample=True) + + with pytest.raises(ValueError) as e: + ss1.jaccard_ani(ss2) + assert "ValueError: mismatch in scaled; comparison fail" in e + + ss2.minhash = ss2.minhash.downsample(scaled=2000) + assert ss1.minhash.scaled == ss2.minhash.scaled + ds_j_manual = ss1.jaccard_ani(ss2) + assert ds_s1c == ds_s2c == ds_j_manual diff --git a/tests/test_sourmash.py b/tests/test_sourmash.py index c395019c17..955efc7301 100644 --- a/tests/test_sourmash.py +++ b/tests/test_sourmash.py @@ -473,6 +473,86 @@ def test_compare_containment(c): assert containment == mat_val, (i, j) +@utils.in_tempdir +def test_compare_containment_ani(c): + import numpy + + testdata_glob = utils.get_test_data('scaled/*.sig') + testdata_sigs = glob.glob(testdata_glob) + + c.run_sourmash('compare', '--containment', '-k', '31', + '--estimate-ani', '--csv', 'output.csv', *testdata_sigs) + + # load the matrix output of compare --containment --estimate-ani + with open(c.output('output.csv'), 'rt') as fp: + r = iter(csv.reader(fp)) + headers = next(r) + + mat = numpy.zeros((len(headers), len(headers))) + for i, row in enumerate(r): + for j, val in enumerate(row): + mat[i][j] = float(val) + + print(mat) + + # load in all the input signatures + idx_to_sig = dict() + for idx, filename in enumerate(testdata_sigs): + ss = sourmash.load_one_signature(filename, ksize=31) + idx_to_sig[idx] = ss + + # check explicit containment against output of compare + for i in range(len(idx_to_sig)): + ss_i = idx_to_sig[i] + for j in range(len(idx_to_sig)): + ss_j = idx_to_sig[j] + containment_ani, ci_low, ci_high = ss_j.containment_ani(ss_i) + containment_ani = round(containment_ani, 3) + mat_val = round(mat[i][j], 3) + + assert containment_ani == mat_val, (i, j) + + +@utils.in_tempdir +def test_compare_jaccard_ani(c): + import numpy + + testdata_glob = utils.get_test_data('scaled/*.sig') + testdata_sigs = glob.glob(testdata_glob) + + c.run_sourmash('compare', '-k', '31', '--estimate-ani', + '--csv', 'output.csv', *testdata_sigs) + + # load the matrix output of compare --estimate-ani + with open(c.output('output.csv'), 'rt') as fp: + r = iter(csv.reader(fp)) + headers = next(r) + + mat = numpy.zeros((len(headers), len(headers))) + for i, row in enumerate(r): + for j, val in enumerate(row): + mat[i][j] = float(val) + + print(mat) + + # load in all the input signatures + idx_to_sig = dict() + for idx, filename in enumerate(testdata_sigs): + ss = sourmash.load_one_signature(filename, ksize=31) + idx_to_sig[idx] = ss + + # check explicit containment against output of compare + for i in range(len(idx_to_sig)): + ss_i = idx_to_sig[i] + for j in range(len(idx_to_sig)): + ss_j = idx_to_sig[j] + jaccard_ani, ci_low, ci_high = ss_j.jaccard_ani(ss_i) + jaccard_ani = round(jaccard_ani, 3) + mat_val = round(mat[i][j], 3) + + assert jaccard_ani == mat_val, (i, j) + + @utils.in_tempdir def test_compare_max_containment(c): import numpy @@ -513,6 +593,46 @@ def test_compare_max_containment(c): assert containment == mat_val, (i, j) +@utils.in_tempdir +def test_compare_max_containment_ani(c): + import numpy + + testdata_glob = utils.get_test_data('scaled/*.sig') + testdata_sigs = glob.glob(testdata_glob) + + c.run_sourmash('compare', '--max-containment', '-k', '31', + '--estimate-ani', '--csv', 'output.csv', *testdata_sigs) + + # load the matrix output of compare --max-containment --estimate-ani + with open(c.output('output.csv'), 'rt') as fp: + r = iter(csv.reader(fp)) + headers = next(r) + + mat = numpy.zeros((len(headers), len(headers))) + for i, row in enumerate(r): + for j, val in enumerate(row): + mat[i][j] = float(val) + + print(mat) + + # load in all the input signatures + idx_to_sig = dict() + for idx, filename in enumerate(testdata_sigs): + ss = sourmash.load_one_signature(filename, ksize=31) + idx_to_sig[idx] = ss + + # check explicit containment against output of compare + for i in range(len(idx_to_sig)): + ss_i = idx_to_sig[i] + for j in range(len(idx_to_sig)): + ss_j = idx_to_sig[j] + containment_ani, ci_low, ci_high = ss_j.max_containment_ani(ss_i) + containment_ani = round(containment_ani, 3) + mat_val = round(mat[i][j], 3) + + assert containment_ani == mat_val, (i, j) + + @utils.in_tempdir def test_compare_max_containment_and_containment(c): testdata_glob = utils.get_test_data('scaled/*.sig') @@ -536,7 +656,20 @@ def test_compare_containment_abund_flatten(c): print(c.last_result.out) print(c.last_result.err) - assert 'NOTE: --containment and --max-containment ignore signature abundances.' in \ + assert 'NOTE: --containment, --max-containment, and --estimate-ani ignore signature abundances.' in \ + c.last_result.err + + +@utils.in_tempdir +def test_compare_ani_abund_flatten(c): + s47 = utils.get_test_data('track_abund/47.fa.sig') + s63 = utils.get_test_data('track_abund/63.fa.sig') + + c.run_sourmash('compare', '--estimate-ani', '-k', '31', s47, s63) + print(c.last_result.out) + print(c.last_result.err) + + assert 'NOTE: --containment, --max-containment, and --estimate-ani ignore signature abundances.' in \ c.last_result.err @@ -554,6 +687,29 @@ def test_compare_containment_require_scaled(c): assert c.last_result.status != 0 +@utils.in_tempdir +def test_compare_ANI_require_scaled(c): + s47 = utils.get_test_data('num/47.fa.sig') + s63 = utils.get_test_data('num/63.fa.sig') + + # containment and estimate ANI will give this error + with pytest.raises(SourmashCommandFailed) as exc: + c.run_sourmash('compare', '--containment', '--estimate-ani', '-k', '31', s47, s63, + fail_ok=True) + assert 'must use scaled signatures with --containment and --max-containment' in \ + c.last_result.err + assert c.last_result.status != 0 + + # jaccard + estimate ANI will give this error + with pytest.raises(SourmashCommandFailed) as exc: + c.run_sourmash('compare', '--estimate-ani', '-k', '31', s47, s63, + fail_ok=True) + + assert 'must use scaled signatures with --estimate-ani' in \ + c.last_result.err + assert c.last_result.status != 0 + + @utils.in_tempdir def test_do_plot_comparison(c): testdata1 = utils.get_test_data('short.fa') @@ -1428,12 +1584,19 @@ def test_search_containment_abund_ignore(runtmp): r = csv.DictReader(fp) row = next(r) similarity = row['similarity'] + estimated_ani = row['estimated_ani'] print(f'search output: similarity is {similarity}') + print(f'search output: ani is {estimated_ani}') print(mh1.contained_by(mh2)) assert float(similarity) == mh1.contained_by(mh2) assert float(similarity) == 0.25 + print(runtmp.last_result.err) + assert "WARNING: Cannot estimate ANI. Are your minhashes big enough?" in runtmp.last_result.err + assert "Error: varN <0.0!" in runtmp.last_result.err + assert estimated_ani == "" + def test_search_containment_sbt(runtmp): # search with --containment in an SBT @@ -5208,7 +5371,7 @@ def test_gather_scaled_1(runtmp, linear_gather, prefetch_gather): assert "1.0 kbp 100.0% 100.0%" in runtmp.last_result.out assert "found 1 matches total;" in runtmp.last_result.out - + def test_standalone_manifest_search(runtmp): # test loading/searching a manifest file from the command line. sig47 = utils.get_test_data('47.fa.sig') @@ -5262,3 +5425,166 @@ def test_standalone_manifest_search_fail(runtmp): # ...and now use for a search! with pytest.raises(SourmashCommandFailed): runtmp.sourmash('search', sig47, mf) + + +@utils.in_tempdir +def test_search_ani_jaccard(c): + testdata1 = utils.get_test_data('short.fa') + testdata2 = utils.get_test_data('short2.fa') + c.run_sourmash('sketch', 'dna', '-p', 'k=31,scaled=1', testdata1, testdata2) + + c.run_sourmash('search', 'short.fa.sig', 'short2.fa.sig', '-o', 'xxx.csv') + print(c.last_result.status, c.last_result.out, c.last_result.err) + + csv_file = c.output('xxx.csv') + + with open(csv_file) as fp: + reader = csv.DictReader(fp) + row = next(reader) + print(row) + assert float(row['similarity']) == 0.9288577154308617 + assert row['filename'].endswith('short2.fa.sig') + assert row['md5'] == 'bf752903d635b1eb83c53fe4aae951db' + assert row['query_filename'].endswith('short.fa') + assert row['query_name'] == '' + assert row['query_md5'] == '9191284a' + assert row['estimated_ani'] == "0.9987884602947684" + + +@utils.in_tempdir +def test_search_ani_empty_abund(c): + testdata1 = utils.get_test_data('short.fa') + testdata2 = utils.get_test_data('short2.fa') + c.run_sourmash('sketch', 'dna', '-p', 'k=31,scaled=10,abund', testdata1, testdata2) + + c.run_sourmash('search', 'short.fa.sig', 'short2.fa.sig', '-o', 'xxx.csv') + print(c.last_result.status, c.last_result.out, c.last_result.err) + + csv_file = c.output('xxx.csv') + + with open(csv_file) as fp: + reader = csv.DictReader(fp) + row = next(reader) + print(row) + assert float(row['similarity']) == 0.8224046424612483 + assert row['md5'] == 'c9d5a795eeaaf58e286fb299133e1938' + assert row['filename'].endswith('short2.fa.sig') + assert row['query_filename'].endswith('short.fa') + assert row['query_name'] == '' + assert row['query_md5'] == 'b5cc464c' + assert row['estimated_ani'] == "" + + +@utils.in_tempdir +def test_search_ani_containment(c): + testdata1 = utils.get_test_data('short.fa') + testdata2 = utils.get_test_data('short2.fa') + c.run_sourmash('sketch', 'dna', '-p', 'k=31,scaled=1', testdata1, testdata2) + + c.run_sourmash('search', '--containment', 'short.fa.sig', 'short2.fa.sig', '-o', 'xxx.csv') + print(c.last_result.status, c.last_result.out, c.last_result.err) + + csv_file = c.output('xxx.csv') + + with open(csv_file) as fp: + reader = csv.DictReader(fp) + row = next(reader) + print(row) + assert float(row['similarity']) == 0.9556701030927836 + assert row['filename'].endswith('short2.fa.sig') + assert row['md5'] == 'bf752903d635b1eb83c53fe4aae951db' + assert row['query_filename'].endswith('short.fa') + assert row['query_name'] == '' + assert row['query_md5'] == '9191284a' + assert row['estimated_ani'] == "0.9985384076863009" + + # search other direction + c.run_sourmash('search', '--containment', 'short2.fa.sig', 'short.fa.sig', '-o', 'xxxx.csv') + print(c.last_result.status, c.last_result.out, c.last_result.err) + + csv_file = c.output('xxxx.csv') + + with open(csv_file) as fp: + reader = csv.DictReader(fp) + row = next(reader) + print(row) + assert float(row['similarity']) == 0.9706806282722513 + assert row['filename'].endswith('short.fa.sig') + assert row['md5'] == '9191284a3a23a913d8d410f3d53ce8f0' + assert row['query_filename'].endswith('short2.fa') + assert row['query_name'] == '' + assert row['query_md5'] == 'bf752903' + assert row['estimated_ani'] == "0.9990405323606487" + + +@utils.in_tempdir +def test_search_ani_max_containment(c): + testdata1 = utils.get_test_data('short.fa') + testdata2 = utils.get_test_data('short2.fa') + c.run_sourmash('sketch', 'dna', '-p', 'k=31,scaled=1', testdata1, testdata2) + + c.run_sourmash('search', '--max-containment', 'short.fa.sig', 'short2.fa.sig', '-o', 'xxx.csv') + print(c.last_result.status, c.last_result.out, c.last_result.err) + + csv_file = c.output('xxx.csv') + + with open(csv_file) as fp: + reader = csv.DictReader(fp) + row = next(reader) + print(row) + assert float(row['similarity']) == 0.9706806282722513 + assert row['filename'].endswith('short2.fa.sig') + assert row['md5'] == 'bf752903d635b1eb83c53fe4aae951db' + assert row['query_filename'].endswith('short.fa') + assert row['query_name'] == '' + assert row['query_md5'] == '9191284a' + assert row['estimated_ani'] == "0.9990405323606487" + + +@utils.in_tempdir +def test_search_jaccard_ani_downsample(c): + testdata1 = utils.get_test_data('short.fa') + testdata2 = utils.get_test_data('short2.fa') + sig1_out = c.output('short.fa.sig') + sig2_out = c.output('short2.fa.sig') + c.run_sourmash('sketch', 'dna', '-p', 'k=31,scaled=2', '--force', testdata1, '-o', sig1_out) + c.run_sourmash('sketch', 'dna', '-p', 'k=31,scaled=1', '--force', testdata1, '-o', sig2_out) + sig1 = sourmash.load_one_signature(sig1_out) + sig2 = sourmash.load_one_signature(sig2_out) + print(f"SCALED: sig1: {sig1.minhash.scaled}, sig2: {sig2.minhash.scaled}") # if don't change name, just reads prior sigfile!!? + + sig1F = c.output('sig1.sig') + sig2F = c.output('sig2.sig') + c.run_sourmash('sketch', 'dna', '-p', 'k=31,scaled=2', '--force', testdata1, '-o', sig1F) + c.run_sourmash('sketch', 'dna', '-p', 'k=31,scaled=1', '--force', testdata2, '-o', sig2F) + + sig1 = sourmash.load_one_signature(sig1F) + sig2 = sourmash.load_one_signature(sig2F) + print(f"SCALED: sig1: {sig1.minhash.scaled}, sig2: {sig2.minhash.scaled}") + + c.run_sourmash('search', sig1F, sig2F, '-o', 'xdx.csv') + print(c.last_result.status, c.last_result.out, c.last_result.err) + + csv_file = c.output('xdx.csv') + + with open(csv_file) as fp: + reader = csv.DictReader(fp) + row = next(reader) + print(row) + assert float(row['similarity']) == 0.9296066252587992 + assert row['md5'] == 'bf752903d635b1eb83c53fe4aae951db' + assert row['filename'].endswith('sig2.sig') + assert row['query_filename'].endswith('short.fa') + assert row['query_name'] == '' + assert row['query_md5'] == '8f74b0b8' + assert row['estimated_ani'] == "0.9988019200011651" + + #downsample manually and assert same ANI + mh1 = sig1.minhash + mh2 = sig2.minhash + mh2_sc2 = mh2.downsample(scaled=mh1.scaled) + print("SCALED:", mh1.scaled, mh2_sc2.scaled) + ani= mh1.jaccard_ani(mh2_sc2) + print(ani) + assert ani == (0.9988019200011651, 0.9980440877843673, 0.9991807844672298) + diff --git a/tests/test_tax.py b/tests/test_tax.py index 93702847bb..3d19cfd740 100644 --- a/tests/test_tax.py +++ b/tests/test_tax.py @@ -702,6 +702,18 @@ def test_genome_rank_stdout_0_db(runtmp): assert 'query_name,status,rank,fraction,lineage,query_md5,query_filename,f_weighted_at_rank,bp_match_at_rank' in c.last_result.out assert 'test1,match,species,0.089,d__Bacteria;p__Bacteroidota;c__Bacteroidia;o__Bacteroidales;f__Bacteroidaceae;g__Prevotella;s__Prevotella copri,md5,test1.sig,0.057,444000.0' in c.last_result.out + # too stringent of containment threshold: + c.run_sourmash('tax', 'genome', '--gather-csv', g_csv, '--taxonomy-csv', + tax, '--rank', 'species', '--containment-threshold', '1.0') + + print(c.last_result.status) + print(c.last_result.out) + print(c.last_result.err) + + assert c.last_result.status == 0 + assert "WARNING: classifying query test1 at desired rank species does not meet containment threshold 1.0" in c.last_result.err + assert "test1,below_threshold,species,0.089,d__Bacteria;p__Bacteroidota;c__Bacteroidia;o__Bacteroidales;f__Bacteroidaceae;g__Prevotella;s__Prevotella copri,md5,test1.sig,0.057,444000.0," in c.last_result.out + def test_genome_rank_csv_0(runtmp): # test basic genome - output csv @@ -1363,6 +1375,81 @@ def test_genome_over100percent_error(runtmp): assert "ERROR: The tax summary of query 'test1' is 1.1, which is > 100% of the query!!" in runtmp.last_result.err +def test_genome_ani_threshold_input_errors(runtmp): + c = runtmp + g_csv = utils.get_test_data('tax/test1.gather_ani.csv') + tax = utils.get_test_data('tax/test.taxonomy.csv') + below_threshold = "-1" + + with pytest.raises(SourmashCommandFailed) as exc: + c.run_sourmash('tax', 'genome', '-g', tax, '--taxonomy-csv', tax, + '--ani-threshold', below_threshold) + + print(c.last_result.status) + print(c.last_result.out) + print(c.last_result.err) + assert "ERROR: Argument must be >0 and <1" in str(exc.value) + + above_threshold = "1.1" + with pytest.raises(SourmashCommandFailed) as exc: + c.run_sourmash('tax', 'genome', '-g', g_csv, '--taxonomy-csv', tax, + '--ani-threshold', above_threshold) + + print(c.last_result.status) + print(c.last_result.out) + print(c.last_result.err) + assert "ERROR: Argument must be >0 and <1" in str(exc.value) + + not_a_float = "str" + + with pytest.raises(SourmashCommandFailed) as exc: + c.run_sourmash('tax', 'genome', '-g', g_csv, '--taxonomy-csv', tax, + '--ani-threshold', not_a_float) + + print(c.last_result.status) + print(c.last_result.out) + print(c.last_result.err) + assert "ERROR: Must be a floating point number" in str(exc.value) + + +def test_genome_ani_threshold(runtmp): + c = runtmp + g_csv = utils.get_test_data('tax/test1.gather_ani.csv') + tax = utils.get_test_data('tax/test.taxonomy.csv') + + c.run_sourmash('tax', 'genome', '-g', g_csv, '--taxonomy-csv', tax, + '--ani-threshold', "0.95") + + print(c.last_result.status) + print(c.last_result.out) + print(c.last_result.err) + + assert c.last_result.status == 0 + assert "WARNING: Please run gather with sourmash >= 4.3 to estimate query ANI at rank. Continuing without ANI..." not in c.last_result.err + assert 'query_name,status,rank,fraction,lineage,query_md5,query_filename,f_weighted_at_rank,bp_match_at_rank' in c.last_result.out + assert 'test1,match,family,0.116,d__Bacteria;p__Bacteroidota;c__Bacteroidia;o__Bacteroidales;f__Bacteroidaceae,md5,test1.sig,0.073,582000.0,0.9328896594471843' in c.last_result.out + + # more lax threshold + c.run_sourmash('tax', 'genome', '-g', g_csv, '--taxonomy-csv', tax, + '--ani-threshold', "0.9") + + print(c.last_result.status) + print(c.last_result.out) + print(c.last_result.err) + + assert c.last_result.status == 0 + assert 'test1,match,species,0.089,d__Bacteria;p__Bacteroidota;c__Bacteroidia;o__Bacteroidales;f__Bacteroidaceae;g__Prevotella;s__Prevotella copri,md5,test1.sig,0.057,444000.0' in c.last_result.out + + # too stringent of threshold (using rank) + c.run_sourmash('tax', 'genome', '-g', g_csv, '--taxonomy-csv', tax, + '--ani-threshold', "1.0", '--rank', 'species') + print(c.last_result.status) + print(c.last_result.out) + print(c.last_result.err) + assert "WARNING: classifying query test1 at desired rank species does not meet query ANI/AAI threshold 1.0" in c.last_result.err + assert "test1,below_threshold,species,0.089,d__Bacteria;p__Bacteroidota;c__Bacteroidia;o__Bacteroidales;f__Bacteroidaceae;g__Prevotella;s__Prevotella copri,md5,test1.sig,0.057,444000.0,0.9247805047263588" in c.last_result.out + + def test_annotate_0(runtmp): # test annotate c = runtmp diff --git a/tests/test_tax_utils.py b/tests/test_tax_utils.py index dce1d6d9c2..17e457b2b3 100644 --- a/tests/test_tax_utils.py +++ b/tests/test_tax_utils.py @@ -1,6 +1,7 @@ """ Tests for functions in taxonomy submodule. """ + import pytest from os.path import basename @@ -22,9 +23,11 @@ from sourmash.lca.lca_utils import LineagePair # utility functions for testing -def make_mini_gather_results(g_infolist): +def make_mini_gather_results(g_infolist, include_ksize_and_scaled=False): # make mini gather_results min_header = ["query_name", "name", "match_ident", "f_unique_to_query", "query_md5", "query_filename", "f_unique_weighted", "unique_intersect_bp", "remaining_bp"] + if include_ksize_and_scaled: + min_header.extend(['ksize', 'scaled']) gather_results = [] for g_info in g_infolist: inf = dict(zip(min_header, g_info)) @@ -351,30 +354,36 @@ def test_summarize_gather_at_0(): def test_summarize_gather_at_1(): """test two matches, diff f_unique_to_query""" # make mini gather_results - gA = ["queryA", "gA","0.5","0.6", "queryA_md5", "queryA.sig", '0.5', '60', '40'] - gB = ["queryA", "gB","0.3","0.1", "queryA_md5", "queryA.sig", '0.1', '10', '90'] - g_res = make_mini_gather_results([gA,gB]) + ksize=31 + scaled=10 + gA = ["queryA", "gA","0.5","0.6", "queryA_md5", "queryA.sig", '0.5', '60', '40', ksize, scaled] + gB = ["queryA", "gB","0.3","0.1", "queryA_md5", "queryA.sig", '0.1', '10', '90', ksize, scaled] + g_res = make_mini_gather_results([gA,gB], include_ksize_and_scaled=True) # make mini taxonomy gA_tax = ("gA", "a;b;c") gB_tax = ("gB", "a;b;d") taxD = make_mini_taxonomy([gA_tax,gB_tax]) # run summarize_gather_at and check results! - sk_sum, _ = summarize_gather_at("superkingdom", taxD, g_res) + sk_sum, _ = summarize_gather_at("superkingdom", taxD, g_res, estimate_query_ani=True) # superkingdom assert len(sk_sum) == 2 - print("superkingdom summarized gather: ", sk_sum[0]) + print("\nsuperkingdom summarized gather 0: ", sk_sum[0]) assert sk_sum[0].lineage == (LineagePair(rank='superkingdom', name='a'),) assert sk_sum[0].fraction == 0.7 assert sk_sum[0].bp_match_at_rank == 70 + print("superkingdom summarized gather 1: ", sk_sum[1]) assert sk_sum[1].lineage == () assert round(sk_sum[1].fraction, 1) == 0.3 assert sk_sum[1].bp_match_at_rank == 30 + assert sk_sum[0].query_ani_at_rank == 0.9885602934376099 + assert sk_sum[1].query_ani_at_rank == None # phylum - phy_sum, _ = summarize_gather_at("phylum", taxD, g_res) - print("phylum summarized gather: ", phy_sum[0]) + phy_sum, _ = summarize_gather_at("phylum", taxD, g_res, estimate_query_ani=False) + print("phylum summarized gather 0: ", phy_sum[0]) + print("phylum summarized gather 1: ", phy_sum[1]) assert len(phy_sum) == 2 assert phy_sum[0].lineage == (LineagePair(rank='superkingdom', name='a'),LineagePair(rank='phylum', name='b')) assert phy_sum[0].fraction == 0.7 @@ -383,8 +392,10 @@ def test_summarize_gather_at_1(): assert phy_sum[1].lineage == () assert round(phy_sum[1].fraction, 1) == 0.3 assert phy_sum[1].bp_match_at_rank == 30 + assert phy_sum[0].query_ani_at_rank == None + assert phy_sum[1].query_ani_at_rank == None # class - cl_sum, _ = summarize_gather_at("class", taxD, g_res) + cl_sum, _ = summarize_gather_at("class", taxD, g_res, estimate_query_ani=True) assert len(cl_sum) == 3 print("class summarized gather: ", cl_sum) assert cl_sum[0].lineage == (LineagePair(rank='superkingdom', name='a'), @@ -393,6 +404,7 @@ def test_summarize_gather_at_1(): assert cl_sum[0].fraction == 0.6 assert cl_sum[0].f_weighted_at_rank == 0.5 assert cl_sum[0].bp_match_at_rank == 60 + assert cl_sum[0].query_ani_at_rank == 0.9836567776983505 assert cl_sum[1].rank == 'class' assert cl_sum[1].lineage == (LineagePair(rank='superkingdom', name='a'), @@ -401,8 +413,10 @@ def test_summarize_gather_at_1(): assert cl_sum[1].fraction == 0.1 assert cl_sum[1].f_weighted_at_rank == 0.1 assert cl_sum[1].bp_match_at_rank == 10 + assert cl_sum[1].query_ani_at_rank == 0.9284145445194744 assert cl_sum[2].lineage == () assert round(cl_sum[2].fraction, 1) == 0.3 + assert cl_sum[2].query_ani_at_rank == None def test_summarize_gather_at_perfect_match(): @@ -532,32 +546,38 @@ def test_summarize_gather_at_missing_fail(): def test_summarize_gather_at_best_only_0(): """test two matches, diff f_unique_to_query""" # make mini gather_results - gA = ["queryA", "gA","0.5","0.6", "queryA_md5", "queryA.sig", '0.5', '60', '40'] - gB = ["queryA", "gB","0.3","0.1", "queryA_md5", "queryA.sig", '0.5', '10', '90'] - g_res = make_mini_gather_results([gA,gB]) + ksize =31 + scaled=10 + gA = ["queryA", "gA","0.5","0.6", "queryA_md5", "queryA.sig", '0.5', '60', '40', ksize, scaled] + gB = ["queryA", "gB","0.3","0.1", "queryA_md5", "queryA.sig", '0.5', '10', '90', ksize, scaled] + g_res = make_mini_gather_results([gA,gB],include_ksize_and_scaled=True) # make mini taxonomy gA_tax = ("gA", "a;b;c") gB_tax = ("gB", "a;b;d") taxD = make_mini_taxonomy([gA_tax,gB_tax]) # run summarize_gather_at and check results! - sk_sum, _ = summarize_gather_at("superkingdom", taxD, g_res, best_only=True) + sk_sum, _ = summarize_gather_at("superkingdom", taxD, g_res, best_only=True,estimate_query_ani=True) # superkingdom assert len(sk_sum) == 1 print("superkingdom summarized gather: ", sk_sum[0]) assert sk_sum[0].lineage == (LineagePair(rank='superkingdom', name='a'),) assert sk_sum[0].fraction == 0.7 assert sk_sum[0].bp_match_at_rank == 70 + print("superk ANI:",sk_sum[0].query_ani_at_rank) + assert sk_sum[0].query_ani_at_rank == 0.9885602934376099 # phylum - phy_sum, _ = summarize_gather_at("phylum", taxD, g_res, best_only=True) + phy_sum, _ = summarize_gather_at("phylum", taxD, g_res, best_only=True,estimate_query_ani=True) print("phylum summarized gather: ", phy_sum[0]) assert len(phy_sum) == 1 assert phy_sum[0].lineage == (LineagePair(rank='superkingdom', name='a'),LineagePair(rank='phylum', name='b')) assert phy_sum[0].fraction == 0.7 assert phy_sum[0].bp_match_at_rank == 70 + print("phy ANI:",phy_sum[0].query_ani_at_rank) + assert phy_sum[0].query_ani_at_rank == 0.9885602934376099 # class - cl_sum, _ = summarize_gather_at("class", taxD, g_res, best_only=True) + cl_sum, _ = summarize_gather_at("class", taxD, g_res, best_only=True, estimate_query_ani=True) assert len(cl_sum) == 1 print("class summarized gather: ", cl_sum) assert cl_sum[0].lineage == (LineagePair(rank='superkingdom', name='a'), @@ -565,6 +585,8 @@ def test_summarize_gather_at_best_only_0(): LineagePair(rank='class', name='c')) assert cl_sum[0].fraction == 0.6 assert cl_sum[0].bp_match_at_rank == 60 + print("cl ANI:",cl_sum[0].query_ani_at_rank) + assert cl_sum[0].query_ani_at_rank == 0.9836567776983505 def test_summarize_gather_at_best_only_equal_choose_first(): @@ -597,12 +619,14 @@ def test_write_summary_csv(runtmp): sum_gather = {'superkingdom': [SummarizedGatherResult(query_name='queryA', rank='superkingdom', fraction=1.0, query_md5='queryA_md5', query_filename='queryA.sig', f_weighted_at_rank=1.0, bp_match_at_rank=100, - lineage=(LineagePair(rank='superkingdom', name='a'),))], + lineage=(LineagePair(rank='superkingdom', name='a'),), + query_ani_at_rank=None)], 'phylum': [SummarizedGatherResult(query_name='queryA', rank='phylum', fraction=1.0, query_md5='queryA_md5', query_filename='queryA.sig', f_weighted_at_rank=1.0, bp_match_at_rank=100, lineage=(LineagePair(rank='superkingdom', name='a'), - LineagePair(rank='phylum', name='b')))]} + LineagePair(rank='phylum', name='b')), + query_ani_at_rank=None)]} outs= runtmp.output("outsum.csv") with open(outs, 'w') as out_fp: @@ -610,9 +634,9 @@ def test_write_summary_csv(runtmp): sr = [x.rstrip().split(',') for x in open(outs, 'r')] print("gather_summary_results_from_file: \n", sr) - assert ['query_name', 'rank', 'fraction', 'lineage', 'query_md5', 'query_filename', 'f_weighted_at_rank', 'bp_match_at_rank'] == sr[0] - assert ['queryA', 'superkingdom', '1.0', 'a', 'queryA_md5', 'queryA.sig', '1.0', '100'] == sr[1] - assert ['queryA', 'phylum', '1.0', 'a;b', 'queryA_md5', 'queryA.sig', '1.0', '100'] == sr[2] + assert ['query_name', 'rank', 'fraction', 'lineage', 'query_md5', 'query_filename', 'f_weighted_at_rank', 'bp_match_at_rank', 'query_ani_at_rank'] == sr[0] + assert ['queryA', 'superkingdom', '1.0', 'a', 'queryA_md5', 'queryA.sig', '1.0', '100', ''] == sr[1] + assert ['queryA', 'phylum', '1.0', 'a;b', 'queryA_md5', 'queryA.sig', '1.0', '100',''] == sr[2] def test_write_classification(runtmp): @@ -620,7 +644,8 @@ def test_write_classification(runtmp): classif = ClassificationResult('queryA', 'match', 'phylum', 1.0, (LineagePair(rank='superkingdom', name='a'), LineagePair(rank='phylum', name='b')), - 'queryA_md5', 'queryA.sig', 1.0, 100) + 'queryA_md5', 'queryA.sig', 1.0, 100, + query_ani_at_rank=None) classification = {'phylum': [classif]} @@ -630,8 +655,8 @@ def test_write_classification(runtmp): sr = [x.rstrip().split(',') for x in open(outs, 'r')] print("gather_classification_results_from_file: \n", sr) - assert ['query_name', 'status', 'rank', 'fraction', 'lineage', 'query_md5', 'query_filename', 'f_weighted_at_rank', 'bp_match_at_rank'] == sr[0] - assert ['queryA', 'match', 'phylum', '1.0', 'a;b', 'queryA_md5', 'queryA.sig', '1.0', '100'] == sr[1] + assert ['query_name', 'status', 'rank', 'fraction', 'lineage', 'query_md5', 'query_filename', 'f_weighted_at_rank', 'bp_match_at_rank', 'query_ani_at_rank'] == sr[0] + assert ['queryA', 'match', 'phylum', '1.0', 'a;b', 'queryA_md5', 'queryA.sig', '1.0', '100', ''] == sr[1] def test_make_krona_header_0(): @@ -816,21 +841,25 @@ def test_combine_sumgather_csvs_by_lineage(runtmp): sum_gather1 = {'superkingdom': [SummarizedGatherResult(query_name='queryA', rank='superkingdom', fraction=0.5, query_md5='queryA_md5', query_filename='queryA.sig', f_weighted_at_rank=1.0, bp_match_at_rank=100, - lineage=(LineagePair(rank='superkingdom', name='a'),))], + lineage=(LineagePair(rank='superkingdom', name='a'),), + query_ani_at_rank=None)], 'phylum': [SummarizedGatherResult(query_name='queryA', rank='phylum', fraction=0.5, query_md5='queryA_md5', query_filename='queryA.sig', f_weighted_at_rank=0.5, bp_match_at_rank=50, lineage=(LineagePair(rank='superkingdom', name='a'), - LineagePair(rank='phylum', name='b')))]} + LineagePair(rank='phylum', name='b')), + query_ani_at_rank=None)]} sum_gather2 = {'superkingdom': [SummarizedGatherResult(query_name='queryB', rank='superkingdom', fraction=0.7, query_md5='queryB_md5', query_filename='queryB.sig', f_weighted_at_rank=0.7, bp_match_at_rank=70, - lineage=(LineagePair(rank='superkingdom', name='a'),))], + lineage=(LineagePair(rank='superkingdom', name='a'),), + query_ani_at_rank=None)], 'phylum': [SummarizedGatherResult(query_name='queryB', rank='phylum', fraction=0.7, query_md5='queryB_md5', query_filename='queryB.sig', f_weighted_at_rank=0.7, bp_match_at_rank=70, lineage=(LineagePair(rank='superkingdom', name='a'), - LineagePair(rank='phylum', name='c')))]} + LineagePair(rank='phylum', name='c')), + query_ani_at_rank=None)]} # write summarized gather results csvs sg1= runtmp.output("sample1.csv") @@ -903,21 +932,25 @@ def test_combine_sumgather_csvs_by_lineage_improper_rank(runtmp): sum_gather1 = {'superkingdom': [SummarizedGatherResult(query_name='queryA', rank='superkingdom', fraction=0.5, query_md5='queryA_md5', query_filename='queryA.sig', f_weighted_at_rank=0.5, bp_match_at_rank=50, - lineage=(LineagePair(rank='superkingdom', name='a'),))], + lineage=(LineagePair(rank='superkingdom', name='a'),), + query_ani_at_rank=None)], 'phylum': [SummarizedGatherResult(query_name='queryA', rank='phylum', fraction=0.5, query_md5='queryA_md5', query_filename='queryA.sig', f_weighted_at_rank=0.5, bp_match_at_rank=50, lineage=(LineagePair(rank='superkingdom', name='a'), - LineagePair(rank='phylum', name='b')))]} + LineagePair(rank='phylum', name='b')), + query_ani_at_rank=None)]} sum_gather2 = {'superkingdom': [SummarizedGatherResult(query_name='queryB', rank='superkingdom', fraction=0.7, query_md5='queryB_md5', query_filename='queryB.sig', f_weighted_at_rank=0.7, bp_match_at_rank=70, - lineage=(LineagePair(rank='superkingdom', name='a'),))], + lineage=(LineagePair(rank='superkingdom', name='a'),), + query_ani_at_rank=None)], 'phylum': [SummarizedGatherResult(query_name='queryB', rank='phylum', fraction=0.7, query_md5='queryB_md5', query_filename='queryB.sig', f_weighted_at_rank=0.7, bp_match_at_rank=70, lineage=(LineagePair(rank='superkingdom', name='a'), - LineagePair(rank='phylum', name='c')))]} + LineagePair(rank='phylum', name='c')), + query_ani_at_rank=None)]} # write summarized gather results csvs sg1= runtmp.output("sample1.csv")