diff --git a/src/sourmash/tax/__main__.py b/src/sourmash/tax/__main__.py index b97c84e39b..fe34ff9884 100644 --- a/src/sourmash/tax/__main__.py +++ b/src/sourmash/tax/__main__.py @@ -9,11 +9,11 @@ import re import sourmash -from ..sourmash_args import FileOutputCSV, FileOutput +from ..sourmash_args import FileOutputCSV, FileInputCSV, FileOutput from sourmash.logging import set_quiet, error, notify, print_results from . import tax_utils -from .tax_utils import MultiLineageDB, GatherRow, RankLineageInfo, LINLineageInfo +from .tax_utils import MultiLineageDB, RankLineageInfo, LINLineageInfo, AnnotateTaxResult usage=''' sourmash taxonomy [] - manipulate/work with taxonomy information. @@ -302,12 +302,13 @@ def annotate(args): set_quiet(args.quiet) - # first, load taxonomic_assignments try: + # first, load taxonomic_assignments tax_assign = MultiLineageDB.load(args.taxonomy_csv, keep_full_identifiers=args.keep_full_identifiers, keep_identifier_versions=args.keep_identifier_versions, force=args.force, lins=args.lins) + except ValueError as exc: error(f"ERROR: {str(exc)}") sys.exit(-1) @@ -316,36 +317,68 @@ def annotate(args): error(f'ERROR: No taxonomic assignments loaded from {",".join(args.taxonomy_csv)}. Exiting.') sys.exit(-1) - # get gather_csvs from args - gather_csvs = tax_utils.collect_gather_csvs(args.gather_csv, from_file=args.from_file) + # get csv from args + input_csvs = tax_utils.collect_gather_csvs(args.gather_csv, from_file=args.from_file) # handle each gather csv separately - for n, g_csv in enumerate(gather_csvs): - query_gather_results = tax_utils.check_and_load_gather_csvs(gather_csvs, tax_assign, force=args.force, - fail_on_missing_taxonomy=args.fail_on_missing_taxonomy, - keep_full_identifiers=args.keep_full_identifiers, - keep_identifier_versions = args.keep_identifier_versions, - lins=args.lins) - - if not query_gather_results: - continue - - out_base = os.path.basename(g_csv.rsplit('.csv')[0]) - this_outfile, limit_float = make_outfile(out_base, "annotate", output_dir=args.output_dir) - - header = [field.name for field in fields(GatherRow)] - with FileOutputCSV(this_outfile) as out_fp: - header.append("lineage") - w = csv.DictWriter(out_fp, header, delimiter=',') - w.writeheader() - - for gather_res in query_gather_results: - for taxres in gather_res.raw_taxresults: - gr = asdict(taxres.raw) - write_gr = {key: gr[key] for key in gr if key in header} - write_gr['lineage'] = taxres.lineageInfo.display_lineage(truncate_empty=True) - w.writerow(write_gr) + for n, in_csv in enumerate(input_csvs): + try: + # Check for a column we can use to find lineage information: + with FileInputCSV(in_csv) as r: + header = r.fieldnames + # check for empty file + if not header: + raise ValueError(f"Cannot read from '{in_csv}'. Is file empty?") + + # look for the column to match with taxonomic identifier + id_col = None + col_options = ['name', 'match_name', 'ident', 'accession'] + for colname in col_options: + if colname in header: + id_col = colname + break + + if not id_col: + raise ValueError(f"Cannot find taxonomic identifier column in '{in_csv}'. Tried: {', '.join(col_options)}") + + notify(f"Starting annotation on '{in_csv}'. Using ID column: '{id_col}'") + + # make output file for this input + out_base = os.path.basename(in_csv.rsplit('.csv')[0]) + this_outfile, _ = make_outfile(out_base, "annotate", output_dir=args.output_dir) + + out_header = header + ['lineage'] + + with FileOutputCSV(this_outfile) as out_fp: + w = csv.DictWriter(out_fp, out_header) + w.writeheader() + + n = 0 + n_missed = 0 + for n, row in enumerate(r): + # find lineage and write annotated row + taxres = AnnotateTaxResult(raw=row, id_col=id_col, lins=args.lins, + keep_full_identifiers=args.keep_full_identifiers, + keep_identifier_versions=args.keep_identifier_versions) + taxres.get_match_lineage(tax_assignments=tax_assign, fail_on_missing_taxonomy=args.fail_on_missing_taxonomy) + + if taxres.missed_ident: # could not assign taxonomy + n_missed+=1 + w.writerow(taxres.row_with_lineages()) + + rows_annotated = (n+1) - n_missed + if not rows_annotated: + raise ValueError(f"Could not annotate any rows from '{in_csv}'.") + else: + notify(f"Annotated {rows_annotated} of {n+1} total rows from '{in_csv}'.") + except ValueError as exc: + if args.force: + notify(str(exc)) + notify('--force is set. Attempting to continue to next file.') + else: + error(f"ERROR: {str(exc)}") + sys.exit(-1) def prepare(args): diff --git a/src/sourmash/tax/tax_utils.py b/src/sourmash/tax/tax_utils.py index 36aca3049c..4bd7ddd8d9 100644 --- a/src/sourmash/tax/tax_utils.py +++ b/src/sourmash/tax/tax_utils.py @@ -1566,8 +1566,80 @@ def __post_init__(self): def total_weighted_bp(self): return self.total_weighted_hashes * self.scaled + +@dataclass +class BaseTaxResult: + """ + Base class for sourmash taxonomic annotation. + """ + raw: dict # csv row + keep_full_identifiers: bool = False + keep_identifier_versions: bool = False + match_ident: str = field(init=False) + skipped_ident: bool = False + missed_ident: bool = False + match_lineage_attempted: bool = False + lins: bool = False + + def get_ident(self, id_col=None): + # split identifiers = split on whitespace + # keep identifiers = don't split .[12] from assembly accessions + "Hack and slash identifiers." + if id_col: + self.match_ident = self.raw[id_col] + else: + self.match_ident = self.raw.name + if not self.keep_full_identifiers: + self.match_ident = self.match_ident.split(' ')[0] + else: + #overrides version bc can't keep full without keeping version + self.keep_identifier_versions = True + if not self.keep_identifier_versions: + self.match_ident = self.match_ident.split('.')[0] + + + def get_match_lineage(self, tax_assignments, skip_idents=None, fail_on_missing_taxonomy=False): + if skip_idents and self.match_ident in skip_idents: + self.skipped_ident = True + else: + lin = tax_assignments.get(self.match_ident) + if lin: + if self.lins: + self.lineageInfo = LINLineageInfo(lineage = lin) + else: + self.lineageInfo = RankLineageInfo(lineage = lin) + else: + self.missed_ident=True + self.match_lineage_attempted = True + if self.missed_ident and fail_on_missing_taxonomy: + raise ValueError(f"Error: ident '{self.match_ident}' is not in the taxonomy database. Failing, as requested via --fail-on-missing-taxonomy") + + @dataclass -class TaxResult: +class AnnotateTaxResult(BaseTaxResult): + """ + Class to enable taxonomic annotation of any sourmash CSV. + """ + id_col: str = 'name' + + def __post_init__(self): + if self.id_col not in self.raw.keys(): + raise ValueError(f"ID column '{self.id_col}' not found.") + self.get_ident(id_col=self.id_col) + if self.lins: + self.lineageInfo = LINLineageInfo() + else: + self.lineageInfo = RankLineageInfo() + + def row_with_lineages(self): + lineage = self.lineageInfo.display_lineage(truncate_empty=True) + rl = {"lineage": lineage} + rl.update(self.raw) + return rl + + +@dataclass +class TaxResult(BaseTaxResult): """ Class to store taxonomic result of a single row from a gather CSV, including accessible query information (QueryInfo) and matched taxonomic lineage. TaxResult tracks whether @@ -1589,19 +1661,11 @@ class TaxResult: # get match lineage tax_res.get_match_lineage(taxD=taxonomic_assignments) - Uses RankLineageInfo to store lineage information; this may need to be modified in the future. + Use RankLineageInfo or LINLineageInfo to store lineage information. """ raw: GatherRow - keep_full_identifiers: bool = False - keep_identifier_versions: bool = False - query_name: str = field(init=False) query_info: QueryInfo = field(init=False) - match_ident: str = field(init=False) - skipped_ident: bool = False - missed_ident: bool = False - match_lineage_attempted: bool = False - lins: bool = False def __post_init__(self): self.get_ident() @@ -1624,35 +1688,6 @@ def __post_init__(self): else: self.lineageInfo = RankLineageInfo() - def get_ident(self): - # split identifiers = split on whitespace - # keep identifiers = don't split .[12] from assembly accessions - "Hack and slash identifiers." - self.match_ident = self.raw.name - if not self.keep_full_identifiers: - self.match_ident = self.raw.name.split(' ')[0] - else: - #overrides version bc can't keep full without keeping version - self.keep_identifier_versions = True - if not self.keep_identifier_versions: - self.match_ident = self.match_ident.split('.')[0] - - - def get_match_lineage(self, tax_assignments, skip_idents=None, fail_on_missing_taxonomy=False): - if skip_idents and self.match_ident in skip_idents: - self.skipped_ident = True - else: - lin = tax_assignments.get(self.match_ident) - if lin: - if self.lins: - self.lineageInfo = LINLineageInfo(lineage = lin) - else: - self.lineageInfo = RankLineageInfo(lineage = lin) - else: - self.missed_ident=True - self.match_lineage_attempted = True - if self.missed_ident and fail_on_missing_taxonomy: - raise ValueError(f"Error: ident '{self.match_ident}' is not in the taxonomy database. Failing, as requested via --fail-on-missing-taxonomy") @dataclass class SummarizedGatherResult: diff --git a/tests/test_tax.py b/tests/test_tax.py index 48d47664a7..56013528dd 100644 --- a/tests/test_tax.py +++ b/tests/test_tax.py @@ -2458,28 +2458,104 @@ def test_annotate_empty_gather_results(runtmp): with pytest.raises(SourmashCommandFailed) as exc: runtmp.run_sourmash('tax', 'annotate', '-g', g_csv, '--taxonomy-csv', tax) - assert f"Cannot read gather results from '{g_csv}'. Is file empty?" in str(exc.value) + assert f"Cannot read from '{g_csv}'. Is file empty?" in str(exc.value) assert runtmp.last_result.status == -1 -def test_annotate_bad_gather_header(runtmp): +def test_annotate_prefetch_or_other_header(runtmp): + tax = utils.get_test_data('tax/test.taxonomy.csv') + g_csv = utils.get_test_data('tax/test1.gather.csv') + + alt_csv = runtmp.output('g.csv') + for alt_col in ['match_name', 'ident', 'accession']: + #modify 'name' to other acceptable id_columns result + alt_g = [x.replace("name", alt_col) for x in open(g_csv, 'r')] + with open(alt_csv, 'w') as fp: + for line in alt_g: + fp.write(line) + + runtmp.run_sourmash('tax', 'annotate', '-g', alt_csv, '--taxonomy-csv', tax) + + assert runtmp.last_result.status == 0 + print(runtmp.last_result.out) + print(runtmp.last_result.err) + assert f"Starting annotation on '{alt_csv}'. Using ID column: '{alt_col}'" in runtmp.last_result.err + assert f"Annotated 4 of 4 total rows from '{alt_csv}'" in runtmp.last_result.err + + +def test_annotate_bad_header(runtmp): tax = utils.get_test_data('tax/test.taxonomy.csv') g_csv = utils.get_test_data('tax/test1.gather.csv') bad_g_csv = runtmp.output('g.csv') #creates bad gather result - bad_g = [x.replace("query_name", "nope") for x in open(g_csv, 'r')] + bad_g = [x.replace("name", "nope") for x in open(g_csv, 'r')] with open(bad_g_csv, 'w') as fp: for line in bad_g: fp.write(line) - print("bad_gather_results: \n", bad_g) + # print("bad_gather_results: \n", bad_g) with pytest.raises(SourmashCommandFailed) as exc: runtmp.run_sourmash('tax', 'annotate', '-g', bad_g_csv, '--taxonomy-csv', tax) - assert 'is missing columns needed for taxonomic summarization.' in str(exc.value) + assert f"ERROR: Cannot find taxonomic identifier column in '{bad_g_csv}'. Tried: name, match_name, ident, accession" in str(exc.value) + assert runtmp.last_result.status == -1 + print(runtmp.last_result.out) + print(runtmp.last_result.err) + + +def test_annotate_no_tax_matches(runtmp): + tax = utils.get_test_data('tax/test.taxonomy.csv') + g_csv = utils.get_test_data('tax/test1.gather.csv') + + bad_g_csv = runtmp.output('g.csv') + + #mess up tax idents + bad_g = [x.replace("GCF_", "GGG_") for x in open(g_csv, 'r')] + with open(bad_g_csv, 'w') as fp: + for line in bad_g: + fp.write(line) + # print("bad_gather_results: \n", bad_g) + + with pytest.raises(SourmashCommandFailed) as exc: + runtmp.run_sourmash('tax', 'annotate', '-g', bad_g_csv, '--taxonomy-csv', tax) + + assert f"ERROR: Could not annotate any rows from '{bad_g_csv}'" in str(exc.value) assert runtmp.last_result.status == -1 + print(runtmp.last_result.out) + print(runtmp.last_result.err) + + runtmp.run_sourmash('tax', 'annotate', '-g', bad_g_csv, '--taxonomy-csv', tax, '--force') + + assert runtmp.last_result.status == 0 + assert f"Could not annotate any rows from '{bad_g_csv}'" in runtmp.last_result.err + assert f"--force is set. Attempting to continue to next file." in runtmp.last_result.err + print(runtmp.last_result.out) + print(runtmp.last_result.err) + + +def test_annotate_missed_tax_matches(runtmp): + tax = utils.get_test_data('tax/test.taxonomy.csv') + g_csv = utils.get_test_data('tax/test1.gather.csv') + + bad_g_csv = runtmp.output('g.csv') + + with open(g_csv, 'r') as gather_lines, open(bad_g_csv, 'w') as fp: + for n, line in enumerate(gather_lines): + if n > 2: + # mess up tax idents of lines 3, 4 + line = line.replace("GCF_", "GGG_") + fp.write(line) + # print("bad_gather_results: \n", bad_g) + + runtmp.run_sourmash('tax', 'annotate', '-g', bad_g_csv, '--taxonomy-csv', tax) + + print(runtmp.last_result.out) + print(runtmp.last_result.err) + + assert runtmp.last_result.status == 0 + assert f"Annotated 2 of 4 total rows from '{bad_g_csv}'." in runtmp.last_result.err def test_annotate_empty_tax_lineage_input(runtmp): diff --git a/tests/test_tax_utils.py b/tests/test_tax_utils.py index 5c72cdb5e0..d97672bf14 100644 --- a/tests/test_tax_utils.py +++ b/tests/test_tax_utils.py @@ -13,7 +13,7 @@ from sourmash.tax.tax_utils import (ascending_taxlist, get_ident, load_gather_results, collect_gather_csvs, check_and_load_gather_csvs, LineagePair, QueryInfo, GatherRow, TaxResult, QueryTaxResult, - SummarizedGatherResult, ClassificationResult, + SummarizedGatherResult, ClassificationResult, AnnotateTaxResult, BaseLineageInfo, RankLineageInfo, LINLineageInfo, aggregate_by_lineage_at_rank, format_for_krona, write_krona, write_lineage_sample_frac, read_lingroups, @@ -424,6 +424,37 @@ def test_TaxResult_get_ident_default(): assert taxres.match_ident == "GCF_001881345" +def test_AnnotateTaxResult_get_ident_default(): + gA = {"name": "GCF_001881345.1"} # gather result with match name as GCF_001881345.1 + taxres = AnnotateTaxResult(raw=gA) + print(taxres.match_ident) + assert taxres.match_ident == "GCF_001881345" + + +def test_AnnotateTaxResult_get_ident_idcol(): + gA = {"name": "n1", "match_name": "n2", "ident": "n3", "accession": "n4"} # gather result with match name as GCF_001881345.1 + taxres = AnnotateTaxResult(raw=gA) + print(taxres.match_ident) + assert taxres.match_ident == "n1" + taxres = AnnotateTaxResult(raw=gA, id_col="match_name") + print(taxres.match_ident) + assert taxres.match_ident == "n2" + taxres = AnnotateTaxResult(raw=gA, id_col="ident") + print(taxres.match_ident) + assert taxres.match_ident == "n3" + taxres = AnnotateTaxResult(raw=gA, id_col="accession") + print(taxres.match_ident) + assert taxres.match_ident == "n4" + + +def test_AnnotateTaxResult_get_ident_idcol_fail(): + gA = {"name": "n1", "match_name": "n2", "ident": "n3", "accession": "n4"} # gather result with match name as GCF_001881345.1 + with pytest.raises(ValueError) as exc: + AnnotateTaxResult(raw=gA, id_col="NotACol") + print(str(exc)) + assert "ID column 'NotACol' not found." in str(exc) + + def test_get_ident_split_but_keep_version(): ident = "GCF_001881345.1 secondname" n_id = get_ident(ident, keep_identifier_versions=True) @@ -440,6 +471,16 @@ def test_TaxResult_get_ident_split_but_keep_version(): assert taxres.match_ident == "GCF_001881345.1" +def test_AnnotateTaxResult_get_ident_split_but_keep_version(): + gA = {"name": "GCF_001881345.1 secondname"} + taxres = AnnotateTaxResult(gA, keep_identifier_versions=True) + print("raw ident: ", taxres.raw['name']) + print("keep_full?: ", taxres.keep_full_identifiers) + print("keep_version?: ",taxres.keep_identifier_versions) + print("final ident: ", taxres.match_ident) + assert taxres.match_ident == "GCF_001881345.1" + + def test_get_ident_no_split(): ident = "GCF_001881345.1 secondname" n_id = get_ident(ident, keep_full_identifiers=True) @@ -456,6 +497,16 @@ def test_TaxResult_get_ident_keep_full(): assert taxres.match_ident == "GCF_001881345.1 secondname" +def test_AnnotateTaxResult_get_ident_keep_full(): + gA = {"name": "GCF_001881345.1 secondname"} + taxres = AnnotateTaxResult(gA, keep_full_identifiers=True) + print("raw ident: ", taxres.raw['name']) + print("keep_full?: ", taxres.keep_full_identifiers) + print("keep_version?: ",taxres.keep_identifier_versions) + print("final ident: ", taxres.match_ident) + assert taxres.match_ident == "GCF_001881345.1 secondname" + + def test_collect_gather_csvs(runtmp): g_csv = utils.get_test_data('tax/test1.gather.csv') from_file = runtmp.output("tmp-from-file.txt") @@ -1759,6 +1810,17 @@ def test_TaxResult_get_match_lineage_1(): assert taxres.lineageInfo.display_lineage() == "a;b;c" +def test_AnnotateTaxResult_get_match_lineage_1(): + gA_tax = ("gA", "a;b;c") + taxD = make_mini_taxonomy([gA_tax]) + + gA = {"name": "gA.1 name"} + taxres = AnnotateTaxResult(gA) + taxres.get_match_lineage(tax_assignments=taxD) + assert taxres.lineageInfo.display_lineage() == "a;b;c" + assert taxres.row_with_lineages() == {"name": "gA.1 name", "lineage": "a;b;c"} + + def test_TaxResult_get_match_lineage_skip_ident(): gA_tax = ("gA", "a;b;c") taxD = make_mini_taxonomy([gA_tax])