From 940866df85eb9a64cf235486c34461594cbd8d29 Mon Sep 17 00:00:00 2001 From: "N. Tessa Pierce" Date: Wed, 27 Apr 2022 11:14:42 -0700 Subject: [PATCH] init with tax code from #1788 --- src/sourmash/tax/__main__.py | 30 +++-- src/sourmash/tax/tax_utils.py | 150 ++++++++++------------- tests/test-data/tax/test1.gather_ani.csv | 5 + tests/test_tax.py | 118 +++++++++++++----- tests/test_tax_utils.py | 106 +++++++++------- 5 files changed, 242 insertions(+), 167 deletions(-) create mode 100644 tests/test-data/tax/test1.gather_ani.csv diff --git a/src/sourmash/tax/__main__.py b/src/sourmash/tax/__main__.py index 01fc6bbe1c..1cee174eb1 100644 --- a/src/sourmash/tax/__main__.py +++ b/src/sourmash/tax/__main__.py @@ -184,7 +184,8 @@ def genome(args): best_at_rank, seen_perfect = tax_utils.summarize_gather_at(args.rank, tax_assign, gather_results, skip_idents=idents_missed, keep_full_identifiers=args.keep_full_identifiers, keep_identifier_versions = args.keep_identifier_versions, - best_only=True, seen_perfect=seen_perfect) + best_only=True, seen_perfect=seen_perfect, estimate_query_ani=True) + except ValueError as exc: error(f"ERROR: {str(exc)}") sys.exit(-1) @@ -194,19 +195,22 @@ def genome(args): status = 'nomatch' if sg.query_name in matched_queries: continue - if sg.fraction <= args.containment_threshold: + if args.ani_threshold and sg.query_ani_at_rank < args.ani_threshold: + status="below_threshold" + notify(f"WARNING: classifying query {sg.query_name} at desired rank {args.rank} does not meet query ANI/AAI threshold {args.ani_threshold}") + elif sg.fraction <= args.containment_threshold: # should this just be less than? status="below_threshold" notify(f"WARNING: classifying query {sg.query_name} at desired rank {args.rank} does not meet containment threshold {args.containment_threshold}") else: status="match" - classif = ClassificationResult(sg.query_name, status, sg.rank, sg.fraction, sg.lineage, sg.query_md5, sg.query_filename, sg.f_weighted_at_rank, sg.bp_match_at_rank) + classif = ClassificationResult(sg.query_name, status, sg.rank, sg.fraction, sg.lineage, sg.query_md5, sg.query_filename, sg.f_weighted_at_rank, sg.bp_match_at_rank, sg.query_ani_at_rank) classifications[args.rank].append(classif) matched_queries.add(sg.query_name) if "krona" in args.output_format: lin_list = display_lineage(sg.lineage).split(';') krona_results.append((sg.fraction, *lin_list)) else: - # classify to the match that passes the containment threshold. + # classify to the rank/match that passes the containment threshold. # To do - do we want to store anything for this match if nothing >= containment threshold? for rank in tax_utils.ascending_taxlist(include_strain=False): # gets best_at_rank for all queries in this gather_csv @@ -214,7 +218,7 @@ def genome(args): best_at_rank, seen_perfect = tax_utils.summarize_gather_at(rank, tax_assign, gather_results, skip_idents=idents_missed, keep_full_identifiers=args.keep_full_identifiers, keep_identifier_versions = args.keep_identifier_versions, - best_only=True, seen_perfect=seen_perfect) + best_only=True, seen_perfect=seen_perfect, estimate_query_ani=True) except ValueError as exc: error(f"ERROR: {str(exc)}") sys.exit(-1) @@ -223,18 +227,26 @@ def genome(args): status = 'nomatch' if sg.query_name in matched_queries: continue - if sg.fraction >= args.containment_threshold: + if args.ani_threshold and sg.query_ani_at_rank >= args.ani_threshold: + status="match" + elif sg.fraction >= args.containment_threshold: status = "match" - classif = ClassificationResult(sg.query_name, status, sg.rank, sg.fraction, sg.lineage, sg.query_md5, sg.query_filename, sg.f_weighted_at_rank, sg.bp_match_at_rank) + if status == "match": + classif = ClassificationResult(query_name=sg.query_name, status=status, rank=sg.rank, + fraction=sg.fraction, lineage=sg.lineage, + query_md5=sg.query_md5, query_filename=sg.query_filename, + f_weighted_at_rank=sg.f_weighted_at_rank, bp_match_at_rank=sg.bp_match_at_rank, + query_ani_at_rank= sg.query_ani_at_rank) classifications[sg.rank].append(classif) matched_queries.add(sg.query_name) continue - if rank == "superkingdom" and status == "nomatch": + elif rank == "superkingdom" and status == "nomatch": status="below_threshold" classif = ClassificationResult(query_name=sg.query_name, status=status, rank="", fraction=0, lineage="", query_md5=sg.query_md5, query_filename=sg.query_filename, - f_weighted_at_rank=sg.f_weighted_at_rank, bp_match_at_rank=sg.bp_match_at_rank) + f_weighted_at_rank=sg.f_weighted_at_rank, bp_match_at_rank=sg.bp_match_at_rank, + query_ani_at_rank=sg.query_ani_at_rank) classifications[sg.rank].append(classif) if not any([classifications, krona_results]): diff --git a/src/sourmash/tax/tax_utils.py b/src/sourmash/tax/tax_utils.py index d5a7161afc..0c69eeaf55 100644 --- a/src/sourmash/tax/tax_utils.py +++ b/src/sourmash/tax/tax_utils.py @@ -5,12 +5,7 @@ import csv from collections import namedtuple, defaultdict from collections import abc - -from sourmash import sqlite_utils -from sourmash.exceptions import IndexNotSupported - -import sqlite3 - +from sourmash.distance_utils import containment_to_distance __all__ = ['get_ident', 'ascending_taxlist', 'collect_gather_csvs', 'load_gather_results', 'check_and_load_gather_csvs', @@ -24,9 +19,9 @@ from sourmash.logging import notify from sourmash.sourmash_args import load_pathlist_from_file -QueryInfo = namedtuple("QueryInfo", "query_md5, query_filename, query_bp") -SummarizedGatherResult = namedtuple("SummarizedGatherResult", "query_name, rank, fraction, lineage, query_md5, query_filename, f_weighted_at_rank, bp_match_at_rank") -ClassificationResult = namedtuple("ClassificationResult", "query_name, status, rank, fraction, lineage, query_md5, query_filename, f_weighted_at_rank, bp_match_at_rank") +QueryInfo = namedtuple("QueryInfo", "query_md5, query_filename, query_bp, query_hashes") +SummarizedGatherResult = namedtuple("SummarizedGatherResult", "query_name, rank, fraction, lineage, query_md5, query_filename, f_weighted_at_rank, bp_match_at_rank, query_ani_at_rank") +ClassificationResult = namedtuple("ClassificationResult", "query_name, status, rank, fraction, lineage, query_md5, query_filename, f_weighted_at_rank, bp_match_at_rank, query_ani_at_rank") # Essential Gather column names that must be in gather_csv to allow `tax` summarization EssentialGatherColnames = ('query_name', 'name', 'f_unique_weighted', 'f_unique_to_query', 'unique_intersect_bp', 'remaining_bp', 'query_md5', 'query_filename') @@ -188,7 +183,7 @@ def find_match_lineage(match_ident, tax_assign, *, skip_idents = [], def summarize_gather_at(rank, tax_assign, gather_results, *, skip_idents = [], keep_full_identifiers=False, keep_identifier_versions=False, best_only=False, - seen_perfect=set()): + seen_perfect=set(), estimate_query_ani=False): """ Summarize gather results at specified taxonomic rank """ @@ -198,7 +193,7 @@ def summarize_gather_at(rank, tax_assign, gather_results, *, skip_idents = [], sum_uniq_to_query = defaultdict(lambda: defaultdict(float)) sum_uniq_bp = defaultdict(lambda: defaultdict(float)) query_info = {} - + ksize,scaled,query_nhashes=None,None,None for row in gather_results: # get essential gather info query_name = row['query_name'] @@ -207,13 +202,25 @@ def summarize_gather_at(rank, tax_assign, gather_results, *, skip_idents = [], unique_intersect_bp = int(row['unique_intersect_bp']) query_md5 = row['query_md5'] query_filename = row['query_filename'] - # get query_bp - if query_name not in query_info.keys(): - query_bp = unique_intersect_bp + int(row['remaining_bp']) + if query_name not in query_info.keys(): #REMOVING THIS AFFECTS GATHER RESULTS!!! BUT query bp should always be same for same query? bug? + if "query_nhashes" in row.keys(): + query_nhashes = int(row["query_nhashes"]) + if "query_bp" in row.keys(): + query_bp = int(row["query_bp"]) + else: + query_bp = unique_intersect_bp + int(row['remaining_bp']) # store query info - query_info[query_name] = QueryInfo(query_md5=query_md5, query_filename=query_filename, query_bp=query_bp) - match_ident = row['name'] + query_info[query_name] = QueryInfo(query_md5=query_md5, query_filename=query_filename, query_bp=query_bp, query_hashes = query_nhashes) + + if estimate_query_ani and (not ksize or not scaled): # just need to set these once. BUT, if we have these, should we check for compatibility when loading the gather file? + if "ksize" in row.keys(): # ksize and scaled were added to gather results in same PR + ksize = int(row['ksize']) + scaled = int(row['scaled']) + else: + estimate_query_ani=False + notify("WARNING: Please run gather with sourmash >= 4.3 to estimate query ANI at rank. Continuing without ANI...") + match_ident = row['name'] # 100% match? are we looking at something in the database? if f_unique_to_query >= 1.0 and query_name not in seen_perfect: # only want to notify once, not for each rank @@ -225,27 +232,29 @@ def summarize_gather_at(rank, tax_assign, gather_results, *, skip_idents = [], # get lineage for match lineage = find_match_lineage(match_ident, tax_assign, - skip_idents=skip_idents, - keep_full_identifiers=keep_full_identifiers, - keep_identifier_versions=keep_identifier_versions) + skip_idents=skip_idents, + keep_full_identifiers=keep_full_identifiers, + keep_identifier_versions=keep_identifier_versions) # ident was in skip_idents if not lineage: continue # summarize at rank! lineage = pop_to_rank(lineage, rank) - assert lineage[-1].rank == rank, (rank, lineage[-1]) + assert lineage[-1].rank == rank, lineage[-1] # record info sum_uniq_to_query[query_name][lineage] += f_unique_to_query sum_uniq_weighted[query_name][lineage] += f_uniq_weighted sum_uniq_bp[query_name][lineage] += unique_intersect_bp + # sort and store each as SummarizedGatherResult sum_uniq_to_query_sorted = [] for query_name, lineage_weights in sum_uniq_to_query.items(): qInfo = query_info[query_name] sumgather_items = list(lineage_weights.items()) sumgather_items.sort(key = lambda x: -x[1]) + query_ani = None if best_only: lineage, fraction = sumgather_items[0] if fraction > 1: @@ -254,13 +263,19 @@ def summarize_gather_at(rank, tax_assign, gather_results, *, skip_idents = [], continue f_weighted_at_rank = sum_uniq_weighted[query_name][lineage] bp_intersect_at_rank = sum_uniq_bp[query_name][lineage] - sres = SummarizedGatherResult(query_name, rank, fraction, lineage, qInfo.query_md5, qInfo.query_filename, f_weighted_at_rank, bp_intersect_at_rank) + if estimate_query_ani: + query_ani = containment_to_distance(fraction, ksize, scaled, + n_unique_kmers= qInfo.query_hashes, sequence_len_bp= qInfo.query_bp, + return_identity=True)[0] + sres = SummarizedGatherResult(query_name, rank, fraction, lineage, qInfo.query_md5, + qInfo.query_filename, f_weighted_at_rank, bp_intersect_at_rank, query_ani) sum_uniq_to_query_sorted.append(sres) else: total_f_weighted= 0.0 total_f_classified = 0.0 total_bp_classified = 0 for lineage, fraction in sumgather_items: + query_ani=None if fraction > 1: raise ValueError(f"The tax summary of query '{query_name}' is {fraction}, which is > 100% of the query!! This should not be possible. Please check that your input files come directly from a single gather run per query.") elif fraction == 0: @@ -270,16 +285,23 @@ def summarize_gather_at(rank, tax_assign, gather_results, *, skip_idents = [], total_f_weighted += f_weighted_at_rank bp_intersect_at_rank = int(sum_uniq_bp[query_name][lineage]) total_bp_classified += bp_intersect_at_rank - sres = SummarizedGatherResult(query_name, rank, fraction, lineage, query_md5, query_filename, f_weighted_at_rank, bp_intersect_at_rank) + if estimate_query_ani: + query_ani = containment_to_distance(fraction, ksize, scaled, + n_unique_kmers=qInfo.query_hashes, sequence_len_bp=qInfo.query_bp, + return_identity=True)[0] + sres = SummarizedGatherResult(query_name, rank, fraction, lineage, query_md5, + query_filename, f_weighted_at_rank, bp_intersect_at_rank, query_ani) sum_uniq_to_query_sorted.append(sres) # record unclassified lineage = () + query_ani=None fraction = 1.0 - total_f_classified if fraction > 0: f_weighted_at_rank = 1.0 - total_f_weighted bp_intersect_at_rank = qInfo.query_bp - total_bp_classified - sres = SummarizedGatherResult(query_name, rank, fraction, lineage, query_md5, query_filename, f_weighted_at_rank, bp_intersect_at_rank) + sres = SummarizedGatherResult(query_name, rank, fraction, lineage, query_md5, + query_filename, f_weighted_at_rank, bp_intersect_at_rank, query_ani) sum_uniq_to_query_sorted.append(sres) return sum_uniq_to_query_sorted, seen_perfect @@ -628,27 +650,15 @@ def load(cls, filename, *, delimiter=',', force=False, class LineageDB_Sqlite(abc.Mapping): """ - A LineageDB based on a sqlite3 database with a 'sourmash_taxonomy' table. + A LineageDB based on a sqlite3 database with a 'taxonomy' table. """ # NOTE: 'order' is a reserved name in sql, so we have to use 'order_'. columns = ('superkingdom', 'phylum', 'order_', 'class', 'family', 'genus', 'species', 'strain') - table_name = 'sourmash_taxonomy' - def __init__(self, conn, *, table_name=None): + def __init__(self, conn): self.conn = conn - # provide for legacy support for pre-sourmash_internal days... - if table_name is not None: - self.table_name = table_name - - # check that the right table is there. - c = conn.cursor() - try: - c.execute(f'SELECT * FROM {self.table_name} LIMIT 1') - except (sqlite3.DatabaseError, sqlite3.OperationalError): - raise ValueError("not a taxonomy database") - # check: can we do a 'select' on the right table? self.__len__() c = conn.cursor() @@ -656,7 +666,7 @@ def __init__(self, conn, *, table_name=None): # get available ranks... ranks = set() for column, rank in zip(self.columns, taxlist(include_strain=True)): - query = f'SELECT COUNT({column}) FROM {self.table_name} WHERE {column} IS NOT NULL AND {column} != ""' + query = f'SELECT COUNT({column}) FROM taxonomy WHERE {column} IS NOT NULL AND {column} != ""' c.execute(query) cnt, = c.fetchone() if cnt: @@ -667,35 +677,14 @@ def __init__(self, conn, *, table_name=None): @classmethod def load(cls, location): - "load taxonomy information from an existing sqlite3 database" - conn = sqlite_utils.open_sqlite_db(location) - if not conn: - raise ValueError("not a sqlite taxonomy database") - - table_name = None - c = conn.cursor() + "load taxonomy information from a sqlite3 database" + import sqlite3 try: - info = sqlite_utils.get_sourmash_internal(c) - except sqlite3.OperationalError: - info = {} - - if 'SqliteLineage' in info: - if info['SqliteLineage'] != '1.0': - raise IndexNotSupported - - table_name = 'sourmash_taxonomy' - else: - # legacy support for old taxonomy DB, pre sourmash_internal. - try: - c.execute('SELECT * FROM taxonomy LIMIT 1') - table_name = 'taxonomy' - except sqlite3.OperationalError: - pass - - if table_name is None: - raise ValueError("not a sqlite taxonomy database") - - return cls(conn, table_name=table_name) + conn = sqlite3.connect(location) + db = cls(conn) + except sqlite3.DatabaseError: + raise ValueError("not a sqlite database") + return db def _make_tup(self, row): "build a tuple of LineagePairs for this sqlite row" @@ -705,7 +694,7 @@ def _make_tup(self, row): def __getitem__(self, ident): "Retrieve lineage for identifer" c = self.cursor - c.execute(f'SELECT superkingdom, phylum, class, order_, family, genus, species, strain FROM {self.table_name} WHERE ident=?', (ident,)) + c.execute('SELECT superkingdom, phylum, class, order_, family, genus, species, strain FROM taxonomy WHERE ident=?', (ident,)) # retrieve names list... names = c.fetchone() @@ -726,7 +715,7 @@ def __bool__(self): def __len__(self): "Return number of rows" c = self.conn.cursor() - c.execute(f'SELECT COUNT(DISTINCT ident) FROM {self.table_name}') + c.execute('SELECT COUNT(DISTINCT ident) FROM taxonomy') nrows, = c.fetchone() return nrows @@ -734,7 +723,7 @@ def __iter__(self): "Return all identifiers" # create new cursor so as to allow other operations c = self.conn.cursor() - c.execute(f'SELECT DISTINCT ident FROM {self.table_name}') + c.execute('SELECT DISTINCT ident FROM taxonomy') for ident, in c: yield ident @@ -743,12 +732,11 @@ def items(self): "return all items in the sqlite database" c = self.conn.cursor() - c.execute(f'SELECT DISTINCT ident, superkingdom, phylum, class, order_, family, genus, species, strain FROM {self.table_name}') + c.execute('SELECT DISTINCT ident, superkingdom, phylum, class, order_, family, genus, species, strain FROM taxonomy') for ident, *names in c: yield ident, self._make_tup(names) - class MultiLineageDB(abc.Mapping): "A wrapper for (dynamically) combining multiple lineage databases." @@ -845,23 +833,15 @@ def save(self, filename_or_fp, file_format): if is_filename: fp.close() - def _save_sqlite(self, filename, *, conn=None): - from sourmash import sqlite_utils - - if conn is None: - db = sqlite3.connect(filename) - else: - assert not filename - db = conn + def _save_sqlite(self, filename): + import sqlite3 + db = sqlite3.connect(filename) cursor = db.cursor() try: - sqlite_utils.add_sourmash_internal(cursor, 'SqliteLineage', '1.0') - - # CTB: could add 'IF NOT EXIST' here; would need tests, too. cursor.execute(""" - CREATE TABLE sourmash_taxonomy ( + CREATE TABLE taxonomy ( ident TEXT NOT NULL, superkingdom TEXT, phylum TEXT, @@ -879,7 +859,7 @@ class TEXT, raise ValueError(f"taxonomy table already exists in '{filename}'") # follow up and create index - cursor.execute("CREATE UNIQUE INDEX sourmash_taxonomy_ident ON sourmash_taxonomy(ident);") + cursor.execute("CREATE UNIQUE INDEX taxonomy_ident ON taxonomy(ident);") for ident, tax in self.items(): x = [ident, *[ t.name for t in tax ]] @@ -888,7 +868,7 @@ class TEXT, while len(x) < 9: x.append('') - cursor.execute('INSERT INTO sourmash_taxonomy (ident, superkingdom, phylum, class, order_, family, genus, species, strain) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', x) + cursor.execute('INSERT INTO taxonomy (ident, superkingdom, phylum, class, order_, family, genus, species, strain) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', x) db.commit() diff --git a/tests/test-data/tax/test1.gather_ani.csv b/tests/test-data/tax/test1.gather_ani.csv new file mode 100644 index 0000000000..48a09eb199 --- /dev/null +++ b/tests/test-data/tax/test1.gather_ani.csv @@ -0,0 +1,5 @@ +intersect_bp,f_orig_query,f_match,f_unique_to_query,f_unique_weighted,average_abund,median_abund,std_abund,name,filename,md5,f_match_orig,unique_intersect_bp,gather_result_rank,remaining_bp,query_name,query_md5,query_filename,ksize,scaled,query_nhashes +442000,0.08815317112086159,0.08438335242458954,0.08815317112086159,0.05815279361459521,1.6153846153846154,1.0,1.1059438185997785,"GCF_001881345.1 Escherichia coli strain=SF-596, ASM188134v1",/group/ctbrowngrp/gtdb/databases/ctb/gtdb-rs202.genomic.k31.sbt.zip,683df1ec13872b4b98d59e98b355b52c,0.042779713511420826,442000,0,4572000,test1,md5,test1.sig,31,1000,5013970 +390000,0.07778220981252493,0.10416666666666667,0.07778220981252493,0.050496823586903404,1.5897435897435896,1.0,0.8804995294906566,"GCF_009494285.1 Prevotella copri strain=iAK1218, ASM949428v1",/group/ctbrowngrp/gtdb/databases/ctb/gtdb-rs202.genomic.k31.sbt.zip,1266c86141e3a5603da61f57dd863ed0,0.052236806857755155,390000,1,4182000,test1,md5,test1.sig,31,1000,4571970 +138000,0.027522935779816515,0.024722321748477247,0.027522935779816515,0.015637726014008795,1.391304347826087,1.0,0.5702120455914782,"GCF_013368705.1 Bacteroides vulgatus strain=B33, ASM1336870v1",/group/ctbrowngrp/gtdb/databases/ctb/gtdb-rs202.genomic.k31.sbt.zip,7d5f4ba1d01c8c3f7a520d19faded7cb,0.012648945921173235,138000,2,4044000,test1,md5,test1.sig,31,1000,4181970 +338000,0.06741124850418827,0.013789581205311542,0.010769844435580374,0.006515719172503665,1.4814814814814814,1.0,0.738886568268889,"GCF_003471795.1 Prevotella copri strain=AM16-54, ASM347179v1",/group/ctbrowngrp/gtdb/databases/ctb/gtdb-rs202.genomic.k31.sbt.zip,0ebd36ff45fc2810808789667f4aad84,0.04337782340862423,54000,3,3990000,test1,md5,test1.sig,31,1000,4327970 diff --git a/tests/test_tax.py b/tests/test_tax.py index 1faf6ce19a..3d19cfd740 100644 --- a/tests/test_tax.py +++ b/tests/test_tax.py @@ -9,9 +9,6 @@ from sourmash.tax import tax_utils from sourmash_tst_utils import SourmashCommandFailed -from sourmash import sqlite_utils -from sourmash.exceptions import IndexNotSupported - ## command line tests def test_run_sourmash_tax(): status, out, err = utils.runscript('sourmash', ['tax'], fail_ok=True) @@ -705,6 +702,18 @@ def test_genome_rank_stdout_0_db(runtmp): assert 'query_name,status,rank,fraction,lineage,query_md5,query_filename,f_weighted_at_rank,bp_match_at_rank' in c.last_result.out assert 'test1,match,species,0.089,d__Bacteria;p__Bacteroidota;c__Bacteroidia;o__Bacteroidales;f__Bacteroidaceae;g__Prevotella;s__Prevotella copri,md5,test1.sig,0.057,444000.0' in c.last_result.out + # too stringent of containment threshold: + c.run_sourmash('tax', 'genome', '--gather-csv', g_csv, '--taxonomy-csv', + tax, '--rank', 'species', '--containment-threshold', '1.0') + + print(c.last_result.status) + print(c.last_result.out) + print(c.last_result.err) + + assert c.last_result.status == 0 + assert "WARNING: classifying query test1 at desired rank species does not meet containment threshold 1.0" in c.last_result.err + assert "test1,below_threshold,species,0.089,d__Bacteria;p__Bacteroidota;c__Bacteroidia;o__Bacteroidales;f__Bacteroidaceae;g__Prevotella;s__Prevotella copri,md5,test1.sig,0.057,444000.0," in c.last_result.out + def test_genome_rank_csv_0(runtmp): # test basic genome - output csv @@ -1366,6 +1375,81 @@ def test_genome_over100percent_error(runtmp): assert "ERROR: The tax summary of query 'test1' is 1.1, which is > 100% of the query!!" in runtmp.last_result.err +def test_genome_ani_threshold_input_errors(runtmp): + c = runtmp + g_csv = utils.get_test_data('tax/test1.gather_ani.csv') + tax = utils.get_test_data('tax/test.taxonomy.csv') + below_threshold = "-1" + + with pytest.raises(SourmashCommandFailed) as exc: + c.run_sourmash('tax', 'genome', '-g', tax, '--taxonomy-csv', tax, + '--ani-threshold', below_threshold) + + print(c.last_result.status) + print(c.last_result.out) + print(c.last_result.err) + assert "ERROR: Argument must be >0 and <1" in str(exc.value) + + above_threshold = "1.1" + with pytest.raises(SourmashCommandFailed) as exc: + c.run_sourmash('tax', 'genome', '-g', g_csv, '--taxonomy-csv', tax, + '--ani-threshold', above_threshold) + + print(c.last_result.status) + print(c.last_result.out) + print(c.last_result.err) + assert "ERROR: Argument must be >0 and <1" in str(exc.value) + + not_a_float = "str" + + with pytest.raises(SourmashCommandFailed) as exc: + c.run_sourmash('tax', 'genome', '-g', g_csv, '--taxonomy-csv', tax, + '--ani-threshold', not_a_float) + + print(c.last_result.status) + print(c.last_result.out) + print(c.last_result.err) + assert "ERROR: Must be a floating point number" in str(exc.value) + + +def test_genome_ani_threshold(runtmp): + c = runtmp + g_csv = utils.get_test_data('tax/test1.gather_ani.csv') + tax = utils.get_test_data('tax/test.taxonomy.csv') + + c.run_sourmash('tax', 'genome', '-g', g_csv, '--taxonomy-csv', tax, + '--ani-threshold', "0.95") + + print(c.last_result.status) + print(c.last_result.out) + print(c.last_result.err) + + assert c.last_result.status == 0 + assert "WARNING: Please run gather with sourmash >= 4.3 to estimate query ANI at rank. Continuing without ANI..." not in c.last_result.err + assert 'query_name,status,rank,fraction,lineage,query_md5,query_filename,f_weighted_at_rank,bp_match_at_rank' in c.last_result.out + assert 'test1,match,family,0.116,d__Bacteria;p__Bacteroidota;c__Bacteroidia;o__Bacteroidales;f__Bacteroidaceae,md5,test1.sig,0.073,582000.0,0.9328896594471843' in c.last_result.out + + # more lax threshold + c.run_sourmash('tax', 'genome', '-g', g_csv, '--taxonomy-csv', tax, + '--ani-threshold', "0.9") + + print(c.last_result.status) + print(c.last_result.out) + print(c.last_result.err) + + assert c.last_result.status == 0 + assert 'test1,match,species,0.089,d__Bacteria;p__Bacteroidota;c__Bacteroidia;o__Bacteroidales;f__Bacteroidaceae;g__Prevotella;s__Prevotella copri,md5,test1.sig,0.057,444000.0' in c.last_result.out + + # too stringent of threshold (using rank) + c.run_sourmash('tax', 'genome', '-g', g_csv, '--taxonomy-csv', tax, + '--ani-threshold', "1.0", '--rank', 'species') + print(c.last_result.status) + print(c.last_result.out) + print(c.last_result.err) + assert "WARNING: classifying query test1 at desired rank species does not meet query ANI/AAI threshold 1.0" in c.last_result.err + assert "test1,below_threshold,species,0.089,d__Bacteria;p__Bacteroidota;c__Bacteroidia;o__Bacteroidales;f__Bacteroidaceae;g__Prevotella;s__Prevotella copri,md5,test1.sig,0.057,444000.0,0.9247805047263588" in c.last_result.out + + def test_annotate_0(runtmp): # test annotate c = runtmp @@ -1822,31 +1906,3 @@ def test_tax_prepare_3_db_to_csv_empty_ranks_3(runtmp): keep_identifier_versions=False) assert set(db1) == set(db2) assert set(db1) == set(db3) - - -def test_tax_prepare_sqlite_lineage_version(runtmp): - # test bad sourmash_internals version for SqliteLineage - taxcsv = utils.get_test_data('tax/test.taxonomy.csv') - taxout = runtmp.output('out.db') - - runtmp.run_sourmash('tax', 'prepare', '-t', taxcsv, - '-o', taxout, '-F', 'sql') - assert os.path.exists(taxout) - - # set bad version - conn = sqlite_utils.open_sqlite_db(taxout) - c = conn.cursor() - c.execute("UPDATE sourmash_internal SET value='0.9' WHERE key='SqliteLineage'") - - conn.commit() - conn.close() - - with pytest.raises(IndexNotSupported): - db = tax_utils.MultiLineageDB.load([taxout]) - -def test_tax_prepare_sqlite_no_lineage(): - # no lineage table at all - sqldb = utils.get_test_data('sqlite/index.sqldb') - - with pytest.raises(ValueError): - db = tax_utils.MultiLineageDB.load([sqldb]) diff --git a/tests/test_tax_utils.py b/tests/test_tax_utils.py index 449e26f972..17e457b2b3 100644 --- a/tests/test_tax_utils.py +++ b/tests/test_tax_utils.py @@ -1,6 +1,7 @@ """ Tests for functions in taxonomy submodule. """ + import pytest from os.path import basename @@ -22,9 +23,11 @@ from sourmash.lca.lca_utils import LineagePair # utility functions for testing -def make_mini_gather_results(g_infolist): +def make_mini_gather_results(g_infolist, include_ksize_and_scaled=False): # make mini gather_results min_header = ["query_name", "name", "match_ident", "f_unique_to_query", "query_md5", "query_filename", "f_unique_weighted", "unique_intersect_bp", "remaining_bp"] + if include_ksize_and_scaled: + min_header.extend(['ksize', 'scaled']) gather_results = [] for g_info in g_infolist: inf = dict(zip(min_header, g_info)) @@ -351,30 +354,36 @@ def test_summarize_gather_at_0(): def test_summarize_gather_at_1(): """test two matches, diff f_unique_to_query""" # make mini gather_results - gA = ["queryA", "gA","0.5","0.6", "queryA_md5", "queryA.sig", '0.5', '60', '40'] - gB = ["queryA", "gB","0.3","0.1", "queryA_md5", "queryA.sig", '0.1', '10', '90'] - g_res = make_mini_gather_results([gA,gB]) + ksize=31 + scaled=10 + gA = ["queryA", "gA","0.5","0.6", "queryA_md5", "queryA.sig", '0.5', '60', '40', ksize, scaled] + gB = ["queryA", "gB","0.3","0.1", "queryA_md5", "queryA.sig", '0.1', '10', '90', ksize, scaled] + g_res = make_mini_gather_results([gA,gB], include_ksize_and_scaled=True) # make mini taxonomy gA_tax = ("gA", "a;b;c") gB_tax = ("gB", "a;b;d") taxD = make_mini_taxonomy([gA_tax,gB_tax]) # run summarize_gather_at and check results! - sk_sum, _ = summarize_gather_at("superkingdom", taxD, g_res) + sk_sum, _ = summarize_gather_at("superkingdom", taxD, g_res, estimate_query_ani=True) # superkingdom assert len(sk_sum) == 2 - print("superkingdom summarized gather: ", sk_sum[0]) + print("\nsuperkingdom summarized gather 0: ", sk_sum[0]) assert sk_sum[0].lineage == (LineagePair(rank='superkingdom', name='a'),) assert sk_sum[0].fraction == 0.7 assert sk_sum[0].bp_match_at_rank == 70 + print("superkingdom summarized gather 1: ", sk_sum[1]) assert sk_sum[1].lineage == () assert round(sk_sum[1].fraction, 1) == 0.3 assert sk_sum[1].bp_match_at_rank == 30 + assert sk_sum[0].query_ani_at_rank == 0.9885602934376099 + assert sk_sum[1].query_ani_at_rank == None # phylum - phy_sum, _ = summarize_gather_at("phylum", taxD, g_res) - print("phylum summarized gather: ", phy_sum[0]) + phy_sum, _ = summarize_gather_at("phylum", taxD, g_res, estimate_query_ani=False) + print("phylum summarized gather 0: ", phy_sum[0]) + print("phylum summarized gather 1: ", phy_sum[1]) assert len(phy_sum) == 2 assert phy_sum[0].lineage == (LineagePair(rank='superkingdom', name='a'),LineagePair(rank='phylum', name='b')) assert phy_sum[0].fraction == 0.7 @@ -383,8 +392,10 @@ def test_summarize_gather_at_1(): assert phy_sum[1].lineage == () assert round(phy_sum[1].fraction, 1) == 0.3 assert phy_sum[1].bp_match_at_rank == 30 + assert phy_sum[0].query_ani_at_rank == None + assert phy_sum[1].query_ani_at_rank == None # class - cl_sum, _ = summarize_gather_at("class", taxD, g_res) + cl_sum, _ = summarize_gather_at("class", taxD, g_res, estimate_query_ani=True) assert len(cl_sum) == 3 print("class summarized gather: ", cl_sum) assert cl_sum[0].lineage == (LineagePair(rank='superkingdom', name='a'), @@ -393,6 +404,7 @@ def test_summarize_gather_at_1(): assert cl_sum[0].fraction == 0.6 assert cl_sum[0].f_weighted_at_rank == 0.5 assert cl_sum[0].bp_match_at_rank == 60 + assert cl_sum[0].query_ani_at_rank == 0.9836567776983505 assert cl_sum[1].rank == 'class' assert cl_sum[1].lineage == (LineagePair(rank='superkingdom', name='a'), @@ -401,8 +413,10 @@ def test_summarize_gather_at_1(): assert cl_sum[1].fraction == 0.1 assert cl_sum[1].f_weighted_at_rank == 0.1 assert cl_sum[1].bp_match_at_rank == 10 + assert cl_sum[1].query_ani_at_rank == 0.9284145445194744 assert cl_sum[2].lineage == () assert round(cl_sum[2].fraction, 1) == 0.3 + assert cl_sum[2].query_ani_at_rank == None def test_summarize_gather_at_perfect_match(): @@ -532,32 +546,38 @@ def test_summarize_gather_at_missing_fail(): def test_summarize_gather_at_best_only_0(): """test two matches, diff f_unique_to_query""" # make mini gather_results - gA = ["queryA", "gA","0.5","0.6", "queryA_md5", "queryA.sig", '0.5', '60', '40'] - gB = ["queryA", "gB","0.3","0.1", "queryA_md5", "queryA.sig", '0.5', '10', '90'] - g_res = make_mini_gather_results([gA,gB]) + ksize =31 + scaled=10 + gA = ["queryA", "gA","0.5","0.6", "queryA_md5", "queryA.sig", '0.5', '60', '40', ksize, scaled] + gB = ["queryA", "gB","0.3","0.1", "queryA_md5", "queryA.sig", '0.5', '10', '90', ksize, scaled] + g_res = make_mini_gather_results([gA,gB],include_ksize_and_scaled=True) # make mini taxonomy gA_tax = ("gA", "a;b;c") gB_tax = ("gB", "a;b;d") taxD = make_mini_taxonomy([gA_tax,gB_tax]) # run summarize_gather_at and check results! - sk_sum, _ = summarize_gather_at("superkingdom", taxD, g_res, best_only=True) + sk_sum, _ = summarize_gather_at("superkingdom", taxD, g_res, best_only=True,estimate_query_ani=True) # superkingdom assert len(sk_sum) == 1 print("superkingdom summarized gather: ", sk_sum[0]) assert sk_sum[0].lineage == (LineagePair(rank='superkingdom', name='a'),) assert sk_sum[0].fraction == 0.7 assert sk_sum[0].bp_match_at_rank == 70 + print("superk ANI:",sk_sum[0].query_ani_at_rank) + assert sk_sum[0].query_ani_at_rank == 0.9885602934376099 # phylum - phy_sum, _ = summarize_gather_at("phylum", taxD, g_res, best_only=True) + phy_sum, _ = summarize_gather_at("phylum", taxD, g_res, best_only=True,estimate_query_ani=True) print("phylum summarized gather: ", phy_sum[0]) assert len(phy_sum) == 1 assert phy_sum[0].lineage == (LineagePair(rank='superkingdom', name='a'),LineagePair(rank='phylum', name='b')) assert phy_sum[0].fraction == 0.7 assert phy_sum[0].bp_match_at_rank == 70 + print("phy ANI:",phy_sum[0].query_ani_at_rank) + assert phy_sum[0].query_ani_at_rank == 0.9885602934376099 # class - cl_sum, _ = summarize_gather_at("class", taxD, g_res, best_only=True) + cl_sum, _ = summarize_gather_at("class", taxD, g_res, best_only=True, estimate_query_ani=True) assert len(cl_sum) == 1 print("class summarized gather: ", cl_sum) assert cl_sum[0].lineage == (LineagePair(rank='superkingdom', name='a'), @@ -565,6 +585,8 @@ def test_summarize_gather_at_best_only_0(): LineagePair(rank='class', name='c')) assert cl_sum[0].fraction == 0.6 assert cl_sum[0].bp_match_at_rank == 60 + print("cl ANI:",cl_sum[0].query_ani_at_rank) + assert cl_sum[0].query_ani_at_rank == 0.9836567776983505 def test_summarize_gather_at_best_only_equal_choose_first(): @@ -597,12 +619,14 @@ def test_write_summary_csv(runtmp): sum_gather = {'superkingdom': [SummarizedGatherResult(query_name='queryA', rank='superkingdom', fraction=1.0, query_md5='queryA_md5', query_filename='queryA.sig', f_weighted_at_rank=1.0, bp_match_at_rank=100, - lineage=(LineagePair(rank='superkingdom', name='a'),))], + lineage=(LineagePair(rank='superkingdom', name='a'),), + query_ani_at_rank=None)], 'phylum': [SummarizedGatherResult(query_name='queryA', rank='phylum', fraction=1.0, query_md5='queryA_md5', query_filename='queryA.sig', f_weighted_at_rank=1.0, bp_match_at_rank=100, lineage=(LineagePair(rank='superkingdom', name='a'), - LineagePair(rank='phylum', name='b')))]} + LineagePair(rank='phylum', name='b')), + query_ani_at_rank=None)]} outs= runtmp.output("outsum.csv") with open(outs, 'w') as out_fp: @@ -610,9 +634,9 @@ def test_write_summary_csv(runtmp): sr = [x.rstrip().split(',') for x in open(outs, 'r')] print("gather_summary_results_from_file: \n", sr) - assert ['query_name', 'rank', 'fraction', 'lineage', 'query_md5', 'query_filename', 'f_weighted_at_rank', 'bp_match_at_rank'] == sr[0] - assert ['queryA', 'superkingdom', '1.0', 'a', 'queryA_md5', 'queryA.sig', '1.0', '100'] == sr[1] - assert ['queryA', 'phylum', '1.0', 'a;b', 'queryA_md5', 'queryA.sig', '1.0', '100'] == sr[2] + assert ['query_name', 'rank', 'fraction', 'lineage', 'query_md5', 'query_filename', 'f_weighted_at_rank', 'bp_match_at_rank', 'query_ani_at_rank'] == sr[0] + assert ['queryA', 'superkingdom', '1.0', 'a', 'queryA_md5', 'queryA.sig', '1.0', '100', ''] == sr[1] + assert ['queryA', 'phylum', '1.0', 'a;b', 'queryA_md5', 'queryA.sig', '1.0', '100',''] == sr[2] def test_write_classification(runtmp): @@ -620,7 +644,8 @@ def test_write_classification(runtmp): classif = ClassificationResult('queryA', 'match', 'phylum', 1.0, (LineagePair(rank='superkingdom', name='a'), LineagePair(rank='phylum', name='b')), - 'queryA_md5', 'queryA.sig', 1.0, 100) + 'queryA_md5', 'queryA.sig', 1.0, 100, + query_ani_at_rank=None) classification = {'phylum': [classif]} @@ -630,8 +655,8 @@ def test_write_classification(runtmp): sr = [x.rstrip().split(',') for x in open(outs, 'r')] print("gather_classification_results_from_file: \n", sr) - assert ['query_name', 'status', 'rank', 'fraction', 'lineage', 'query_md5', 'query_filename', 'f_weighted_at_rank', 'bp_match_at_rank'] == sr[0] - assert ['queryA', 'match', 'phylum', '1.0', 'a;b', 'queryA_md5', 'queryA.sig', '1.0', '100'] == sr[1] + assert ['query_name', 'status', 'rank', 'fraction', 'lineage', 'query_md5', 'query_filename', 'f_weighted_at_rank', 'bp_match_at_rank', 'query_ani_at_rank'] == sr[0] + assert ['queryA', 'match', 'phylum', '1.0', 'a;b', 'queryA_md5', 'queryA.sig', '1.0', '100', ''] == sr[1] def test_make_krona_header_0(): @@ -816,21 +841,25 @@ def test_combine_sumgather_csvs_by_lineage(runtmp): sum_gather1 = {'superkingdom': [SummarizedGatherResult(query_name='queryA', rank='superkingdom', fraction=0.5, query_md5='queryA_md5', query_filename='queryA.sig', f_weighted_at_rank=1.0, bp_match_at_rank=100, - lineage=(LineagePair(rank='superkingdom', name='a'),))], + lineage=(LineagePair(rank='superkingdom', name='a'),), + query_ani_at_rank=None)], 'phylum': [SummarizedGatherResult(query_name='queryA', rank='phylum', fraction=0.5, query_md5='queryA_md5', query_filename='queryA.sig', f_weighted_at_rank=0.5, bp_match_at_rank=50, lineage=(LineagePair(rank='superkingdom', name='a'), - LineagePair(rank='phylum', name='b')))]} + LineagePair(rank='phylum', name='b')), + query_ani_at_rank=None)]} sum_gather2 = {'superkingdom': [SummarizedGatherResult(query_name='queryB', rank='superkingdom', fraction=0.7, query_md5='queryB_md5', query_filename='queryB.sig', f_weighted_at_rank=0.7, bp_match_at_rank=70, - lineage=(LineagePair(rank='superkingdom', name='a'),))], + lineage=(LineagePair(rank='superkingdom', name='a'),), + query_ani_at_rank=None)], 'phylum': [SummarizedGatherResult(query_name='queryB', rank='phylum', fraction=0.7, query_md5='queryB_md5', query_filename='queryB.sig', f_weighted_at_rank=0.7, bp_match_at_rank=70, lineage=(LineagePair(rank='superkingdom', name='a'), - LineagePair(rank='phylum', name='c')))]} + LineagePair(rank='phylum', name='c')), + query_ani_at_rank=None)]} # write summarized gather results csvs sg1= runtmp.output("sample1.csv") @@ -903,21 +932,25 @@ def test_combine_sumgather_csvs_by_lineage_improper_rank(runtmp): sum_gather1 = {'superkingdom': [SummarizedGatherResult(query_name='queryA', rank='superkingdom', fraction=0.5, query_md5='queryA_md5', query_filename='queryA.sig', f_weighted_at_rank=0.5, bp_match_at_rank=50, - lineage=(LineagePair(rank='superkingdom', name='a'),))], + lineage=(LineagePair(rank='superkingdom', name='a'),), + query_ani_at_rank=None)], 'phylum': [SummarizedGatherResult(query_name='queryA', rank='phylum', fraction=0.5, query_md5='queryA_md5', query_filename='queryA.sig', f_weighted_at_rank=0.5, bp_match_at_rank=50, lineage=(LineagePair(rank='superkingdom', name='a'), - LineagePair(rank='phylum', name='b')))]} + LineagePair(rank='phylum', name='b')), + query_ani_at_rank=None)]} sum_gather2 = {'superkingdom': [SummarizedGatherResult(query_name='queryB', rank='superkingdom', fraction=0.7, query_md5='queryB_md5', query_filename='queryB.sig', f_weighted_at_rank=0.7, bp_match_at_rank=70, - lineage=(LineagePair(rank='superkingdom', name='a'),))], + lineage=(LineagePair(rank='superkingdom', name='a'),), + query_ani_at_rank=None)], 'phylum': [SummarizedGatherResult(query_name='queryB', rank='phylum', fraction=0.7, query_md5='queryB_md5', query_filename='queryB.sig', f_weighted_at_rank=0.7, bp_match_at_rank=70, lineage=(LineagePair(rank='superkingdom', name='a'), - LineagePair(rank='phylum', name='c')))]} + LineagePair(rank='phylum', name='c')), + query_ani_at_rank=None)]} # write summarized gather results csvs sg1= runtmp.output("sample1.csv") @@ -967,17 +1000,6 @@ def test_tax_multi_load_files(runtmp): MultiLineageDB.load([runtmp.output('no-such-file')]) -def test_tax_sql_load_new_file(runtmp): - # test loading a newer-format sql file with sourmash_internals table - taxonomy_db = utils.get_test_data('sqlite/test.taxonomy.db') - - db = MultiLineageDB.load([taxonomy_db]) - print(list(db.keys())) - assert len(db) == 6 - assert 'strain' not in db.available_ranks - assert db['GCF_001881345'][0].rank == 'superkingdom' - - def test_tax_multi_load_files_shadowed(runtmp): # test loading various good and bad files taxonomy_csv = utils.get_test_data('tax/test.taxonomy.csv')