Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MRG: re-establish tax gather reading flexibility #2986

Merged
merged 6 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 24 additions & 27 deletions src/sourmash/tax/tax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections import abc, defaultdict
from itertools import zip_longest
from typing import NamedTuple
from dataclasses import dataclass, field, replace, asdict
from dataclasses import dataclass, field, replace, asdict, fields
import gzip

from sourmash import sqlite_utils, sourmash_args
Expand Down Expand Up @@ -742,7 +742,10 @@ def load_gather_results(
for n, row in enumerate(r):
# try reading each gather row into a TaxResult
try:
gatherRow = GatherRow(**row)
filt_row = filter_row(
row, GatherRow
) # filter row first to allow extra (unused) columns in csv
gatherRow = GatherRow(**filt_row)
except TypeError as exc:
raise ValueError(
f"'{gather_csv}' is missing columns needed for taxonomic summarization. Please run gather with sourmash >= 4.4."
Expand Down Expand Up @@ -1675,6 +1678,20 @@ def load(cls, locations, **kwargs):
return tax_assign


def filter_row(row, dataclass_type):
"""
Filter the row to only include keys that exist in the dataclass fields.
This allows extra columns to be passed in with the gather csv while still
taking advantage of the checks for required columns that come with dataclass
initialization.
"""
valid_keys = {field.name for field in fields(dataclass_type)}
# 'match_name' and 'name' should be interchangeable (sourmash 4.x)
if "match_name" in row.keys() and "name" not in row.keys():
row["name"] = row.pop("match_name")
return {k: v for k, v in row.items() if k in valid_keys}


@dataclass
class GatherRow:
"""
Expand All @@ -1689,7 +1706,8 @@ class GatherRow:

with sourmash_args.FileInputCSV(gather_csv) as r:
for row in enumerate(r):
gatherRow = GatherRow(**row)
filt_row = filter_row(row, GatherRow) # filter first to allow extra columns
gatherRow = GatherRow(**filt_row)
"""

# essential columns
Expand All @@ -1706,32 +1724,10 @@ class GatherRow:
ksize: int
scaled: int

# non-essential
intersect_bp: int = None
f_orig_query: float = None
f_match: float = None
average_abund: float = None
median_abund: float = None
std_abund: float = None
filename: str = None
md5: str = None
f_match_orig: float = None
gather_result_rank: str = None
moltype: str = None
# non-essential, but used if available
query_n_hashes: int = None
query_abundance: int = None
query_containment_ani: float = None
match_containment_ani: float = None
average_containment_ani: float = None
max_containment_ani: float = None
potential_false_negative: bool = None
n_unique_weighted_found: int = None
sum_weighted_found: int = None
total_weighted_hashes: int = None
query_containment_ani_low: float = None
query_containment_ani_high: float = None
match_containment_ani_low: float = None
match_containment_ani_high: float = None


@dataclass
Expand Down Expand Up @@ -1854,7 +1850,8 @@ class TaxResult(BaseTaxResult):

with sourmash_args.FileInputCSV(gather_csv) as r:
for row in enumerate(r):
gatherRow = GatherRow(**row)
filt_row = filter_row(row, GatherRow) # this filters any extra columns
gatherRow = GatherRow(**filt_row) # this checks for required columns and raises TypeError for any missing
# initialize TaxResult
tax_res = TaxResult(raw=gatherRow)

Expand Down
19 changes: 18 additions & 1 deletion tests/test_tax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
LineageDB,
LineageDB_Sqlite,
MultiLineageDB,
filter_row,
)


Expand Down Expand Up @@ -93,7 +94,8 @@ def make_GatherRow(gather_dict=None, exclude_cols=[]):
gatherD.update(gather_dict)
for col in exclude_cols:
gatherD.pop(col)
gatherRaw = GatherRow(**gatherD)
fgatherD = filter_row(gatherD, GatherRow)
gatherRaw = GatherRow(**fgatherD)
return gatherRaw


Expand Down Expand Up @@ -807,6 +809,21 @@ def test_GatherRow_old_gather():
assert "__init__() missing 1 required positional argument: 'query_bp'" in str(exc)


def test_GatherRow_match_name_not_name():
# gather contains match_name but not name column
gA = {"match_name": "gA.1 name"}
grow = make_GatherRow(gA, exclude_cols=["name"])
print(grow)
assert grow.name == "gA.1 name"


def test_GatherRow_extra_cols():
# gather contains extra columns
gA = {"not-a-col": "nope"}
grow = make_GatherRow(gA)
assert isinstance(grow, GatherRow)


def test_get_ident_default():
ident = "GCF_001881345.1"
n_id = get_ident(ident)
Expand Down
Loading