diff --git a/CHANGES.md b/CHANGES.md index ca56471f3..e5e890556 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -2,12 +2,21 @@ ## __NEXT__ +### Features + +* filter: Added a new option `--query-columns` that allows specifying what columns are used in `--query` along with the expected data types. If unspecified, automatic detection of columns and types is attempted. [#1294][] (@victorlin) +* `augur.io.read_metadata`: A new optional `columns` argument allows specifying a subset of columns to load. The default behavior still loads all columns, so this is not a breaking change. [#1294][] (@victorlin) + ### Bug Fixes +* filter: The order of rows in `--output-metadata` and `--output-strains` now reflects the order in the original `--metadata`. [#1294][] (@victorlin) +* filter, frequencies, refine: Performance improvements to reading the input metadata file. [#1294][] (@victorlin) + * For filter, this comes with increased writing times for `--output-metadata` and `--output-strains`. However, net I/O speed still decreased during testing of this change. * filter: Updated the help text of `--include` and `--include-where` to explicitly state that this can add strains that are missing an entry from `--sequences`. [#1389][] (@victorlin) * filter: Fixed the summary messages to properly reflect force-inclusion of strains that are missing an entry from `--sequences`. [#1389][] (@victorlin) * filter: Updated wording of summary messages. [#1389][] (@victorlin) +[#1294]: https://github.com/nextstrain/augur/pull/1294 [#1389]: https://github.com/nextstrain/augur/pull/1389 ## 24.1.0 (30 January 2024) diff --git a/augur/filter/__init__.py b/augur/filter/__init__.py index a4d53475f..f35ff1370 100644 --- a/augur/filter/__init__.py +++ b/augur/filter/__init__.py @@ -2,6 +2,7 @@ Filter and subsample a sequence set. """ from augur.dates import numeric_date_type, SUPPORTED_DATE_HELP_TEXT +from augur.filter.io import ACCEPTED_TYPES, column_type_pair from augur.io.metadata import DEFAULT_DELIMITERS, DEFAULT_ID_COLUMNS, METADATA_DATE_COLUMN from augur.types import EmptyOutputReportingMethod from . import constants @@ -28,6 +29,11 @@ def register_arguments(parser): Uses Pandas Dataframe querying, see https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#indexing-query for syntax. (e.g., --query "country == 'Colombia'" or --query "(country == 'USA' & (division == 'Washington'))")""" ) + metadata_filter_group.add_argument('--query-columns', type=column_type_pair, nargs="+", help=f""" + Use alongside --query to specify columns and data types in the format 'column:type', where type is one of ({','.join(ACCEPTED_TYPES)}). + Automatic type inference will be attempted on all unspecified columns used in the query. + Example: region:str coverage:float. + """) metadata_filter_group.add_argument('--min-date', type=numeric_date_type, help=f"minimal cutoff for date, the cutoff date is inclusive; may be specified as: {SUPPORTED_DATE_HELP_TEXT}") metadata_filter_group.add_argument('--max-date', type=numeric_date_type, help=f"maximal cutoff for date, the cutoff date is inclusive; may be specified as: {SUPPORTED_DATE_HELP_TEXT}") metadata_filter_group.add_argument('--exclude-ambiguous-dates-by', choices=['any', 'day', 'month', 'year'], diff --git a/augur/filter/_run.py b/augur/filter/_run.py index e68bb0335..b81758dd3 100644 --- a/augur/filter/_run.py +++ b/augur/filter/_run.py @@ -15,13 +15,13 @@ DELIMITER as SEQUENCE_INDEX_DELIMITER, ) from augur.io.file import open_file -from augur.io.metadata import InvalidDelimiter, read_metadata +from augur.io.metadata import InvalidDelimiter, Metadata, read_metadata from augur.io.sequences import read_sequences, write_sequences from augur.io.print import print_err from augur.io.vcf import is_vcf as filename_is_vcf, write_vcf from augur.types import EmptyOutputReportingMethod from . import include_exclude_rules -from .io import cleanup_outputs, read_priority_scores +from .io import cleanup_outputs, get_useful_metadata_columns, read_priority_scores, write_metadata_based_outputs from .include_exclude_rules import apply_filters, construct_filters from .subsample import PriorityQueue, TooManyGroupsError, calculate_sequences_per_group, create_queues_by_group, get_groups_for_subsampling @@ -133,16 +133,6 @@ def run(args): random_generator = np.random.default_rng(args.subsample_seed) priorities = defaultdict(random_generator.random) - # Setup metadata output. We track whether any records have been written to - # disk yet through the following variables, to control whether we write the - # metadata's header and open a new file for writing. - metadata_header = True - metadata_mode = "w" - - # Setup strain output. - if args.output_strains: - output_strains = open(args.output_strains, "w") - # Setup logging. output_log_writer = None if args.output_log: @@ -168,19 +158,23 @@ def run(args): filter_counts = defaultdict(int) try: - metadata_reader = read_metadata( - args.metadata, - delimiters=args.metadata_delimiters, - id_columns=args.metadata_id_columns, - chunk_size=args.metadata_chunk_size, - dtype="string", - ) + metadata_object = Metadata(args.metadata, args.metadata_delimiters, args.metadata_id_columns) except InvalidDelimiter: raise AugurError( f"Could not determine the delimiter of {args.metadata!r}. " f"Valid delimiters are: {args.metadata_delimiters!r}. " "This can be changed with --metadata-delimiters." ) + useful_metadata_columns = get_useful_metadata_columns(args, metadata_object.id_column, metadata_object.columns) + + metadata_reader = read_metadata( + args.metadata, + delimiters=[metadata_object.delimiter], + columns=useful_metadata_columns, + id_columns=[metadata_object.id_column], + chunk_size=args.metadata_chunk_size, + dtype="string", + ) for metadata in metadata_reader: duplicate_strains = ( set(metadata.index[metadata.index.duplicated()]) | @@ -263,30 +257,6 @@ def run(args): priorities[strain], ) - # Always write out strains that are force-included. Additionally, if - # we are not grouping, write out metadata and strains that passed - # filters so far. - force_included_strains_to_write = distinct_force_included_strains - if not group_by: - force_included_strains_to_write = force_included_strains_to_write | seq_keep - - if args.output_metadata: - # TODO: wrap logic to write metadata into its own function - metadata.loc[list(force_included_strains_to_write)].to_csv( - args.output_metadata, - sep="\t", - header=metadata_header, - mode=metadata_mode, - ) - metadata_header = False - metadata_mode = "a" - - if args.output_strains: - # TODO: Output strains will no longer be ordered. This is a - # small breaking change. - for strain in force_included_strains_to_write: - output_strains.write(f"{strain}\n") - # In the worst case, we need to calculate sequences per group from the # requested maximum number of sequences and the number of sequences per # group. Then, we need to make a second pass through the metadata to find @@ -323,6 +293,7 @@ def run(args): metadata_reader = read_metadata( args.metadata, delimiters=args.metadata_delimiters, + columns=useful_metadata_columns, id_columns=args.metadata_id_columns, chunk_size=args.metadata_chunk_size, dtype="string", @@ -367,23 +338,6 @@ def run(args): # Construct a data frame of records to simplify metadata output. records.append(record) - if args.output_strains: - # TODO: Output strains will no longer be ordered. This is a - # small breaking change. - output_strains.write(f"{record.name}\n") - - # Write records to metadata output, if requested. - if args.output_metadata and len(records) > 0: - records = pd.DataFrame(records) - records.to_csv( - args.output_metadata, - sep="\t", - header=metadata_header, - mode=metadata_mode, - ) - metadata_header = False - metadata_mode = "a" - # Count and optionally log strains that were not included due to # subsampling. strains_filtered_by_subsampling = valid_strains - subsampled_strains @@ -442,14 +396,17 @@ def run(args): # Update the set of available sequence strains. sequence_strains = observed_sequence_strains + if args.output_metadata or args.output_strains: + write_metadata_based_outputs(args.metadata, args.metadata_delimiters, + args.metadata_id_columns, args.output_metadata, + args.output_strains, valid_strains) + # Calculate the number of strains that don't exist in either metadata or # sequences. num_excluded_by_lack_of_metadata = 0 if sequence_strains: num_excluded_by_lack_of_metadata = len(sequence_strains - metadata_strains) - if args.output_strains: - output_strains.close() # Calculate the number of strains passed and filtered. total_strains_passed = len(valid_strains) diff --git a/augur/filter/include_exclude_rules.py b/augur/filter/include_exclude_rules.py index 87c93c6a4..cdf1c16ac 100644 --- a/augur/filter/include_exclude_rules.py +++ b/augur/filter/include_exclude_rules.py @@ -1,9 +1,10 @@ +import ast import json import operator import re import numpy as np import pandas as pd -from typing import Any, Callable, Dict, List, Set, Tuple +from typing import Any, Callable, Dict, List, Optional, Set, Tuple from augur.dates import is_date_ambiguous, get_numerical_dates from augur.errors import AugurError @@ -78,7 +79,7 @@ def filter_by_exclude(metadata, exclude_file) -> FilterFunctionReturn: return set(metadata.index.values) - excluded_strains -def _parse_filter_query(query): +def parse_filter_query(query): """Parse an augur filter-style query and return the corresponding column, operator, and value for the query. @@ -98,9 +99,9 @@ def _parse_filter_query(query): Examples -------- - >>> _parse_filter_query("property=value") + >>> parse_filter_query("property=value") ('property', , 'value') - >>> _parse_filter_query("property!=value") + >>> parse_filter_query("property!=value") ('property', , 'value') """ @@ -143,7 +144,7 @@ def filter_by_exclude_where(metadata, exclude_where) -> FilterFunctionReturn: ['strain1', 'strain2'] """ - column, op, value = _parse_filter_query(exclude_where) + column, op, value = parse_filter_query(exclude_where) if column in metadata.columns: # Apply a test operator (equality or inequality) to values from the # column in the given query. This produces an array of boolean values we @@ -164,7 +165,7 @@ def filter_by_exclude_where(metadata, exclude_where) -> FilterFunctionReturn: return filtered -def filter_by_query(metadata: pd.DataFrame, query: str) -> FilterFunctionReturn: +def filter_by_query(metadata: pd.DataFrame, query: str, column_types: Optional[Dict[str, str]] = None) -> FilterFunctionReturn: """Filter metadata in the given pandas DataFrame with a query string and return the strain names that pass the filter. @@ -174,6 +175,8 @@ def filter_by_query(metadata: pd.DataFrame, query: str) -> FilterFunctionReturn: Metadata indexed by strain name query : str Query string for the dataframe. + column_types : str + Dict mapping of data type Examples -------- @@ -187,22 +190,42 @@ def filter_by_query(metadata: pd.DataFrame, query: str) -> FilterFunctionReturn: # Create a copy to prevent modification of the original DataFrame. metadata_copy = metadata.copy() - # Support numeric comparisons in query strings. - # - # The built-in data type inference when loading the DataFrame does not + if column_types is None: + column_types = {} + + # Set columns for type conversion. + variables = extract_variables(query) + if variables is not None: + columns = variables.intersection(metadata_copy.columns) + else: + # Column extraction failed. Apply type conversion to all columns. + columns = metadata_copy.columns + + # If a type is not explicitly provided, try converting the column to numeric. + # This should cover most use cases, since one common problem is that the + # built-in data type inference when loading the DataFrame does not # support nullable numeric columns, so numeric comparisons won't work on - # those columns. pd.to_numeric does proper conversion on those columns, and - # will not make any changes to columns with other values. - # - # TODO: Parse the query string and apply conversion only to columns used for - # numeric comparison. Pandas does not expose the API used to parse the query - # string internally, so this is non-trivial and requires a bit of - # reverse-engineering. Commit 2ead5b3e3306dc1100b49eb774287496018122d9 got - # halfway there but had issues so it was reverted. - # - # TODO: Try boolean conversion? - for column in metadata_copy.columns: - metadata_copy[column] = pd.to_numeric(metadata_copy[column], errors='ignore') + # those columns. pd.to_numeric does proper conversion on those columns, + # and will not make any changes to columns with other values. + for column in columns: + column_types.setdefault(column, 'numeric') + + # Convert data types before applying the query. + for column, dtype in column_types.items(): + if dtype == 'numeric': + metadata_copy[column] = pd.to_numeric(metadata_copy[column], errors='ignore') + elif dtype == 'int': + try: + metadata_copy[column] = pd.to_numeric(metadata_copy[column], errors='raise', downcast='integer') + except ValueError as e: + raise AugurError(f"Failed to convert value in column {column!r} to int. {e}") + elif dtype == 'float': + try: + metadata_copy[column] = pd.to_numeric(metadata_copy[column], errors='raise', downcast='float') + except ValueError as e: + raise AugurError(f"Failed to convert value in column {column!r} to float. {e}") + elif dtype == 'str': + metadata_copy[column] = metadata_copy[column].astype('str', errors='ignore') try: return set(metadata_copy.query(query).index.values) @@ -492,7 +515,7 @@ def force_include_where(metadata, include_where) -> FilterFunctionReturn: set() """ - column, op, value = _parse_filter_query(include_where) + column, op, value = parse_filter_query(include_where) if column in metadata.columns: # Apply a test operator (equality or inequality) to values from the @@ -578,9 +601,13 @@ def construct_filters(args, sequence_index) -> Tuple[List[FilterOption], List[Fi # Exclude strains by metadata, using pandas querying. if args.query: + kwargs = {"query": args.query} + if args.query_columns: + kwargs["column_types"] = {column: dtype for column, dtype in args.query_columns} + exclude_by.append(( filter_by_query, - {"query": args.query} + kwargs )) # Filter by ambiguous dates. @@ -820,3 +847,31 @@ def _filter_kwargs_to_str(kwargs: FilterFunctionKwargs): kwarg_list.append((key, value)) return json.dumps(kwarg_list) + + +def extract_variables(pandas_query: str): + """Try extracting all variable names used in a pandas query string. + + If successful, return the variable names as a set. Otherwise, nothing is returned. + + Examples + -------- + >>> extract_variables("var1 == 'value'") + {'var1'} + >>> sorted(extract_variables("var1 == 'value' & var2 == 10")) + ['var1', 'var2'] + >>> extract_variables("var1.str.startswith('prefix')") + {'var1'} + >>> extract_variables("this query is invalid") + """ + # Since Pandas' query grammar should be a subset of Python's, which uses the + # ast stdlib under the hood, we can try to parse queries with that as well. + # Errors may arise from invalid query syntax or any Pandas syntax not + # covered by Python (unlikely, but I'm not sure). In those cases, don't + # return anything. + try: + return set(node.id + for node in ast.walk(ast.parse(pandas_query)) + if isinstance(node, ast.Name)) + except: + return None diff --git a/augur/filter/io.py b/augur/filter/io.py index ee0985628..458b05511 100644 --- a/augur/filter/io.py +++ b/augur/filter/io.py @@ -1,8 +1,74 @@ +import argparse +import csv +from argparse import Namespace import os +import re +from typing import Sequence, Set import numpy as np from collections import defaultdict +from xopen import xopen from augur.errors import AugurError +from augur.io.metadata import Metadata, METADATA_DATE_COLUMN +from augur.io.print import print_err +from .constants import GROUP_BY_GENERATED_COLUMNS +from .include_exclude_rules import extract_variables, parse_filter_query + + +def get_useful_metadata_columns(args: Namespace, id_column: str, all_columns: Sequence[str]): + """Return a list of column names that are used in augur filter. + This allows reading only the necessary columns. + """ + + # Start with just the ID column. + columns = {id_column} + + # Add the date column if it is used. + if (args.exclude_ambiguous_dates_by + or args.min_date + or args.max_date + or (args.group_by and GROUP_BY_GENERATED_COLUMNS.intersection(args.group_by))): + columns.add(METADATA_DATE_COLUMN) + + if args.group_by: + group_by_set = set(args.group_by) + requested_generated_columns = group_by_set & GROUP_BY_GENERATED_COLUMNS + + # Add columns used for grouping. + columns.update(group_by_set - requested_generated_columns) + + # Show warning for ignored columns. + ignored_columns = requested_generated_columns.intersection(set(all_columns)) + for col in sorted(ignored_columns): + print_err(f"WARNING: `--group-by {col}` uses a generated {col} value from the {METADATA_DATE_COLUMN!r} column. The custom '{col}' column in the metadata is ignored for grouping purposes.") + + # Add columns used in exclude queries. + if args.exclude_where: + for query in args.exclude_where: + column, op, value = parse_filter_query(query) + columns.add(column) + + # Add columns used in include queries. + if args.include_where: + for query in args.include_where: + column, op, value = parse_filter_query(query) + columns.add(column) + + # Add columns used in Pandas queries. + if args.query: + if args.query_columns: + # Use column names explicitly specified by the user. + for column, dtype in args.query_columns: + columns.add(column) + + # Attempt to automatically extract columns from the query. + variables = extract_variables(args.query) + if variables is None and not args.query_columns: + raise AugurError("Could not infer columns from the pandas query. If the query is valid, please specify columns using --query-columns.") + else: + columns.update(variables) + + return list(columns) def read_priority_scores(fname): @@ -19,6 +85,68 @@ def constant_factory(value): raise AugurError(f"missing or malformed priority scores file {fname}") +def write_metadata_based_outputs(input_metadata_path: str, delimiters: Sequence[str], + id_columns: Sequence[str], output_metadata_path: str, + output_strains_path: str, ids_to_write: Set[str]): + """ + Write output metadata and/or strains file given input metadata information + and a set of IDs to write. + """ + input_metadata = Metadata(input_metadata_path, delimiters, id_columns) + + # Handle all outputs with one pass of metadata. This requires using + # conditionals both outside of and inside the loop through metadata rows. + + # Make these conditionally set variables available at this scope. + output_metadata_handle = None + output_metadata = None + output_strains = None + + # Set up output streams. + if output_metadata_path: + output_metadata_handle = xopen(output_metadata_path, "w") + output_metadata = csv.DictWriter(output_metadata_handle, fieldnames=input_metadata.columns, + delimiter="\t", lineterminator=os.linesep) + output_metadata.writeheader() + if output_strains_path: + output_strains = open(output_strains_path, "w") + + # Write outputs based on rows in the original metadata. + for row in input_metadata.rows(): + row_id = row[input_metadata.id_column] + if row_id in ids_to_write: + if output_metadata: + output_metadata.writerow(row) + if output_strains: + output_strains.write(row_id + '\n') + + # Close file handles. + if output_metadata_handle: + output_metadata_handle.close() + if output_strains: + output_strains.close() + + +# These are the types accepted in the following function. +ACCEPTED_TYPES = {'numeric', 'int', 'float', 'str'} + +def column_type_pair(input: str): + """Get a 2-tuple for column name to type. + + Intended to be used as the argument type converter for argparse options that + take type maps in a 'column:type' format. + """ + + match = re.match(f"^(.+?):({'|'.join(ACCEPTED_TYPES)})$", input) + if not match: + raise argparse.ArgumentTypeError(f"Column data types must be in the format 'column:type', where type is one of ({','.join(ACCEPTED_TYPES)}).") + + column = match[1] + dtype = match[2] + + return (column, dtype) + + def cleanup_outputs(args): """Remove output files. Useful when terminating midway through a loop of metadata chunks.""" if args.output: diff --git a/augur/filter/subsample.py b/augur/filter/subsample.py index 58ec9e8e4..a419f2d7b 100644 --- a/augur/filter/subsample.py +++ b/augur/filter/subsample.py @@ -106,11 +106,6 @@ def get_groups_for_subsampling(strains, metadata, group_by=None): if generated_columns_requested: - for col in sorted(generated_columns_requested): - if col in metadata.columns: - print_err(f"WARNING: `--group-by {col}` uses a generated {col} value from the {METADATA_DATE_COLUMN!r} column. The custom '{col}' column in the metadata is ignored for grouping purposes.") - metadata.drop(col, axis=1, inplace=True) - if METADATA_DATE_COLUMN not in metadata: # Set generated columns to 'unknown'. print_err(f"WARNING: A {METADATA_DATE_COLUMN!r} column could not be found to group-by {sorted(generated_columns_requested)}.") diff --git a/augur/frequencies.py b/augur/frequencies.py index 1acaccd00..3afb47860 100644 --- a/augur/frequencies.py +++ b/augur/frequencies.py @@ -10,7 +10,7 @@ from .frequency_estimators import get_pivots, alignment_frequencies, tree_frequencies from .frequency_estimators import AlignmentKdeFrequencies, TreeKdeFrequencies, TreeKdeFrequenciesError from .dates import numeric_date_type, SUPPORTED_DATE_HELP_TEXT, get_numerical_dates -from .io.metadata import DEFAULT_DELIMITERS, DEFAULT_ID_COLUMNS, InvalidDelimiter, read_metadata +from .io.metadata import DEFAULT_DELIMITERS, DEFAULT_ID_COLUMNS, METADATA_DATE_COLUMN, InvalidDelimiter, Metadata, read_metadata from .utils import write_json @@ -85,20 +85,24 @@ def format_frequencies(freq): def run(args): try: - # TODO: load only the ID, date, and --weights-attribute columns when - # read_metadata supports loading a subset of all columns. - metadata = read_metadata( - args.metadata, - delimiters=args.metadata_delimiters, - id_columns=args.metadata_id_columns, - dtype="string", - ) + metadata_object = Metadata(args.metadata, args.metadata_delimiters, args.metadata_id_columns) except InvalidDelimiter: raise AugurError( f"Could not determine the delimiter of {args.metadata!r}. " f"Valid delimiters are: {args.metadata_delimiters!r}. " "This can be changed with --metadata-delimiters." ) + + columns_to_load = [metadata_object.id_column, METADATA_DATE_COLUMN] + if args.weights_attribute: + columns_to_load.append(args.weights_attribute) + metadata = read_metadata( + args.metadata, + delimiters=[metadata_object.delimiter], + columns=columns_to_load, + id_columns=[metadata_object.id_column], + dtype="string", + ) dates = get_numerical_dates(metadata, fmt='%Y-%m-%d') stiffness = args.stiffness inertia = args.inertia diff --git a/augur/io/metadata.py b/augur/io/metadata.py index 2ae7df167..f8be2f5ad 100644 --- a/augur/io/metadata.py +++ b/augur/io/metadata.py @@ -1,6 +1,6 @@ import csv import os -from typing import Iterable +from typing import Iterable, Sequence import pandas as pd import pyfastx import sys @@ -24,7 +24,7 @@ class InvalidDelimiter(Exception): pass -def read_metadata(metadata_file, delimiters=DEFAULT_DELIMITERS, id_columns=DEFAULT_ID_COLUMNS, chunk_size=None, dtype=None): +def read_metadata(metadata_file, delimiters=DEFAULT_DELIMITERS, columns=None, id_columns=DEFAULT_ID_COLUMNS, chunk_size=None, dtype=None): r"""Read metadata from a given filename and into a pandas `DataFrame` or `TextFileReader` object. @@ -35,6 +35,8 @@ def read_metadata(metadata_file, delimiters=DEFAULT_DELIMITERS, id_columns=DEFAU delimiters : list of str List of possible delimiters to check for between columns in the metadata. Only one delimiter will be inferred. + columns : list of str + List of columns to read. If unspecified, read all columns. id_columns : list of str List of possible id column names to check for, ordered by priority. Only one id column will be inferred. @@ -112,6 +114,20 @@ def read_metadata(metadata_file, delimiters=DEFAULT_DELIMITERS, id_columns=DEFAU # If we found a valid column to index the DataFrame, specify that column. kwargs["index_col"] = index_col + if columns is not None: + # Load a subset of the columns. + for requested_column in list(columns): + if requested_column not in chunk.columns: + # Ignore missing columns. Don't error since augur filter's + # --exclude-where allows invalid columns to be specified (they + # are just ignored). + print_err(f"WARNING: Column '{requested_column}' does not exist in the metadata file. Ignoring it.") + columns.remove(requested_column) + # NOTE: list()+remove() is not very efficient, but (1) it's easy + # to understand and (2) this is unlikely to be used with large + # lists. + kwargs["usecols"] = columns + if dtype is None: dtype = {} @@ -474,6 +490,86 @@ def write_records_to_tsv(records, output_file): tsv_writer.writerow(record) +class Metadata: + """Represents a metadata file.""" + + path: str + """Path to the file on disk.""" + + delimiter: str + """Inferred delimiter of metadata.""" + + columns: Sequence[str] + """Columns extracted from the first row in the metadata file.""" + + id_column: str + """Inferred ID column.""" + + def __init__(self, path: str, delimiters: Sequence[str], id_columns: Sequence[str]): + """ + Parameters + ---------- + path + Path of the metadata file. + delimiters + Possible delimiters to use, in order of precedence. + id_columns + Possible ID columns to use, in order of precedence. + """ + self.path = path + + # Infer the dialect. + self.delimiter = _get_delimiter(self.path, delimiters) + + # Infer the column names. + with self.open() as f: + reader = csv.reader(f, delimiter=self.delimiter) + try: + self.columns = next(reader) + except StopIteration: + raise AugurError(f"{self.path}: Expected a header row but it is empty.") + + # Infer the ID column. + self.id_column = self._find_first(id_columns) + + def open(self, **kwargs): + """Open the file with auto-compression/decompression.""" + return open_file(self.path, **kwargs) + + def _find_first(self, columns: Sequence[str]): + """Return the first column in `columns` that is present in the metadata. + """ + for column in columns: + if column in self.columns: + return column + raise AugurError(f"{self.path}: None of ({columns!r}) are in the columns {tuple(self.columns)!r}.") + + def rows(self, strict: bool = True): + """Yield rows in a dictionary format. Empty lines are ignored. + + Parameters + ---------- + strict + If True, raise an error when a row contains more or less than the number of expected columns. + """ + with self.open() as f: + reader = csv.DictReader(f, delimiter=self.delimiter, fieldnames=self.columns, restkey=None, restval=None) + + # Skip the header row. + next(reader) + + # NOTE: Empty lines are ignored by csv.DictReader. + # + for row in reader: + if strict: + if None in row.keys(): + raise AugurError(f"{self.path}: Line {reader.line_num} contains at least one extra column. The inferred delimiter is {self.delimiter!r}.") + if None in row.values(): + # This is distinct from a blank value (empty string). + raise AugurError(f"{self.path}: Line {reader.line_num} is missing at least one column. The inferred delimiter is {self.delimiter!r}.") + yield row + + def _get_delimiter(path: str, valid_delimiters: Iterable[str]): """Get the delimiter of a file given a list of valid delimiters.""" diff --git a/augur/refine.py b/augur/refine.py index 03ad2ca93..95717861f 100644 --- a/augur/refine.py +++ b/augur/refine.py @@ -6,7 +6,7 @@ from Bio import Phylo from .dates import get_numerical_dates from .dates.errors import InvalidYearBounds -from .io.metadata import DEFAULT_DELIMITERS, DEFAULT_ID_COLUMNS, METADATA_DATE_COLUMN, InvalidDelimiter, read_metadata +from .io.metadata import DEFAULT_DELIMITERS, DEFAULT_ID_COLUMNS, METADATA_DATE_COLUMN, InvalidDelimiter, Metadata, read_metadata from .utils import read_tree, write_json, InvalidTreeError from .errors import AugurError from treetime.vcf_utils import read_vcf @@ -213,21 +213,24 @@ def run(args): if args.metadata is None: print("ERROR: meta data with dates is required for time tree reconstruction", file=sys.stderr) return 1 + try: - # TODO: load only the ID and date columns when read_metadata - # supports loading a subset of all columns. - metadata = read_metadata( - args.metadata, - delimiters=args.metadata_delimiters, - id_columns=args.metadata_id_columns, - dtype="string", - ) + metadata_object = Metadata(args.metadata, args.metadata_delimiters, args.metadata_id_columns) except InvalidDelimiter: raise AugurError( f"Could not determine the delimiter of {args.metadata!r}. " f"Valid delimiters are: {args.metadata_delimiters!r}. " "This can be changed with --metadata-delimiters." ) + + metadata = read_metadata( + args.metadata, + delimiters=[metadata_object.delimiter], + columns=[metadata_object.id_column, METADATA_DATE_COLUMN], + id_columns=[metadata_object.id_column], + dtype="string", + ) + try: dates = get_numerical_dates(metadata, fmt=args.date_format, min_max_year=args.year_bounds) diff --git a/tests/functional/filter/cram/filter-output-contents.t b/tests/functional/filter/cram/filter-output-contents.t index f24cdc917..4f16002a4 100644 --- a/tests/functional/filter/cram/filter-output-contents.t +++ b/tests/functional/filter/cram/filter-output-contents.t @@ -33,11 +33,11 @@ Check that the row for a strain is identical between input and output metadata. Check the order of strains in the filtered strains file. $ cat filtered_strains.txt - EcEs062_16 - ZKC2/2016 Colombia/2016/ZC204Se - BRA/2016/FC_6706 + ZKC2/2016 DOM/2016/BB_0059 + BRA/2016/FC_6706 + EcEs062_16 Check that the order of strains in the metadata is the same as above. diff --git a/tests/functional/filter/cram/filter-query-columns.t b/tests/functional/filter/cram/filter-query-columns.t new file mode 100644 index 000000000..522e8eed6 --- /dev/null +++ b/tests/functional/filter/cram/filter-query-columns.t @@ -0,0 +1,55 @@ +Setup + + $ source "$TESTDIR"/_setup.sh + +Create metadata file for testing. + + $ cat >metadata.tsv <<~~ + > strain coverage category + > SEQ_1 0.94 A + > SEQ_2 0.95 B + > SEQ_3 0.96 C + > SEQ_4 + > ~~ + +Automatic inference works. + + $ ${AUGUR} filter \ + > --metadata metadata.tsv \ + > --query "coverage >= 0.95 & category == 'B'" \ + > --output-strains filtered_strains.txt + 3 strains were dropped during filtering + 3 were filtered out by the query: "coverage >= 0.95 & category == 'B'" + 1 strain passed all filters + +Specifying coverage:float explicitly also works. + + $ ${AUGUR} filter \ + > --metadata metadata.tsv \ + > --query "coverage >= 0.95 & category == 'B'" \ + > --query-columns coverage:float \ + > --output-strains filtered_strains.txt + 3 strains were dropped during filtering + 3 were filtered out by the query: "coverage >= 0.95 & category == 'B'" + 1 strain passed all filters + +Specifying coverage:float category:str also works. + + $ ${AUGUR} filter \ + > --metadata metadata.tsv \ + > --query "coverage >= 0.95 & category == 'B'" \ + > --query-columns coverage:float category:str \ + > --output-strains filtered_strains.txt + 3 strains were dropped during filtering + \t3 were filtered out by the query: "coverage >= 0.95 & category == 'B'" (esc) + 1 strain passed all filters + +Specifying category:float does not work. + + $ ${AUGUR} filter \ + > --metadata metadata.tsv \ + > --query "coverage >= 0.95 & category == 'B'" \ + > --query-columns category:float \ + > --output-strains filtered_strains.txt + ERROR: Failed to convert value in column 'category' to float. Unable to parse string "A" at position 0 + [2] diff --git a/tests/functional/filter/cram/filter-query-errors.t b/tests/functional/filter/cram/filter-query-errors.t index 5ccd1dfad..cb9862cf5 100644 --- a/tests/functional/filter/cram/filter-query-errors.t +++ b/tests/functional/filter/cram/filter-query-errors.t @@ -8,6 +8,7 @@ Using a pandas query with a nonexistent column results in a specific error. > --metadata "$TESTDIR/../data/metadata.tsv" \ > --query "invalid == 'value'" \ > --output-strains filtered_strains.txt > /dev/null + WARNING: Column 'invalid' does not exist in the metadata file. Ignoring it. ERROR: Query contains a column that does not exist in metadata. [2] @@ -40,7 +41,5 @@ However, other Pandas errors are not so helpful, so a link is provided for users > --metadata "$TESTDIR/../data/metadata.tsv" \ > --query "some bad syntax" \ > --output-strains filtered_strains.txt > /dev/null - ERROR: Internal Pandas error when applying query: - invalid syntax (, line 1) - Ensure the syntax is valid per . + ERROR: Could not infer columns from the pandas query. If the query is valid, please specify columns using --query-columns. [2] diff --git a/tests/functional/filter/cram/filter-query-numerical.t b/tests/functional/filter/cram/filter-query-numerical.t index e0b054603..5aeb142f8 100644 --- a/tests/functional/filter/cram/filter-query-numerical.t +++ b/tests/functional/filter/cram/filter-query-numerical.t @@ -34,6 +34,29 @@ The 'category' column will fail when used with a numerical comparison. Ensure the syntax is valid per . [2] +With automatic type inference, the 'coverage' column isn't query-able with +string comparisons: + + $ ${AUGUR} filter \ + > --metadata metadata.tsv \ + > --query "coverage.str.endswith('.95')" \ + > --output-strains filtered_strains.txt > /dev/null + ERROR: Internal Pandas error when applying query: + Can only use .str accessor with string values! + Ensure the syntax is valid per . + [2] + +However, that is still possible by explicitly specifying that it is a string column. + + $ ${AUGUR} filter \ + > --metadata metadata.tsv \ + > --query "coverage.str.endswith('.95')" \ + > --query-columns coverage:str \ + > --output-strains filtered_strains.txt > /dev/null + + $ sort filtered_strains.txt + SEQ_2 + Create another metadata file for testing. $ cat >metadata.tsv <<~~ diff --git a/tests/functional/filter/cram/subsample-group-by-missing-error.t b/tests/functional/filter/cram/subsample-group-by-missing-error.t index 54ef691cc..9884c2b5f 100644 --- a/tests/functional/filter/cram/subsample-group-by-missing-error.t +++ b/tests/functional/filter/cram/subsample-group-by-missing-error.t @@ -15,6 +15,7 @@ Error on missing group-by columns. > --group-by year \ > --sequences-per-group 1 \ > --output-metadata metadata-filtered.tsv > /dev/null + WARNING: Column 'date' does not exist in the metadata file. Ignoring it. ERROR: The specified group-by categories (['year']) were not found. Note that using any of ['month', 'week', 'year'] requires a column called 'date'. [2] $ cat metadata-filtered.tsv @@ -26,6 +27,7 @@ Error on missing group-by columns. > --group-by invalid \ > --sequences-per-group 1 \ > --output-metadata metadata-filtered.tsv > /dev/null + WARNING: Column 'invalid' does not exist in the metadata file. Ignoring it. ERROR: The specified group-by categories (['invalid']) were not found. [2] $ cat metadata-filtered.tsv diff --git a/tests/io/test_metadata.py b/tests/io/test_metadata.py index 9d77067fc..28ff8e12a 100644 --- a/tests/io/test_metadata.py +++ b/tests/io/test_metadata.py @@ -4,7 +4,7 @@ from io import StringIO from augur.errors import AugurError -from augur.io.metadata import InvalidDelimiter, read_table_to_dict, read_metadata_with_sequences, write_records_to_tsv +from augur.io.metadata import InvalidDelimiter, read_table_to_dict, read_metadata_with_sequences, write_records_to_tsv, Metadata from augur.types import DataErrorMethod @@ -513,3 +513,100 @@ def test_write_records_to_tsv_with_empty_records(self, tmpdir): write_records_to_tsv(iter([]), output_file) assert str(e_info.value) == f"Unable to write records to {output_file} because provided records were empty." + + +def write_lines(tmpdir, lines): + path = str(tmpdir / "tmp") + with open(path, 'w') as f: + f.writelines(lines) + return path + + +class TestMetadataClass: + def test_attributes(self, metadata_file): + """All attributes are populated.""" + m = Metadata(metadata_file, delimiters=[',', '\t'], id_columns=['invalid', 'strain']) + assert m.path == metadata_file + assert m.delimiter == '\t' + assert m.columns == ['strain', 'country', 'date'] + assert m.id_column == 'strain' + + def test_invalid_delimiter(self, metadata_file): + """Failure to detect delimiter raises an error.""" + with pytest.raises(InvalidDelimiter): + Metadata(metadata_file, delimiters=[':'], id_columns=['strain']) + + def test_invalid_id_column(self, metadata_file): + """Failure to detect an ID column raises an error.""" + with pytest.raises(AugurError): + Metadata(metadata_file, delimiters=['\t'], id_columns=['strains']) + + def test_rows(self, metadata_file): + """Check Metadata.rows() output format.""" + m = Metadata(metadata_file, delimiters=['\t'], id_columns=['strain']) + assert list(m.rows()) == [ + {'country': 'USA', 'date': '2020-10-01', 'strain': 'SEQ_A'}, + {'country': 'USA', 'date': '2020-10-02', 'strain': 'SEQ_T'}, + {'country': 'USA', 'date': '2020-10-03', 'strain': 'SEQ_C'}, + {'country': 'USA', 'date': '2020-10-04', 'strain': 'SEQ_G'}, + ] + + def test_blank_lines(self, tmpdir): + """Check behavior of lines that are blank and have empty values. + + Blank lines are skipped. Lines with delimiters but empty values are still included when reading. + """ + path = write_lines(tmpdir, [ + 'a,b,c\n', + '1,2,3\n', + '\n', + '3,2,3\n', + ',,\n', + '5,2,3\n', + ]) + + m = Metadata(path, delimiters=',', id_columns=['a']) + assert list(m.rows()) == [ + {'a': '1', 'b': '2', 'c': '3'}, + {'a': '3', 'b': '2', 'c': '3'}, + {'a': '' , 'b': '' , 'c': '' }, + {'a': '5', 'b': '2', 'c': '3'} + ] + + def test_rows_strict_extra(self, tmpdir): + """Test behavior when reading rows with extra entries or delimiters.""" + path = write_lines(tmpdir, [ + 'a,b,c\n', + '1,2,3\n', + '2,2,3,4\n', + '3,2,3,\n', + ]) + + m = Metadata(path, delimiters=',', id_columns=['a']) + with pytest.raises(AugurError): + list(m.rows(strict=True)) + + assert list(m.rows(strict=False)) == [ + {'a': '1', 'b': '2', 'c': '3'}, + {'a': '2', 'b': '2', 'c': '3', None: ['4']}, + {'a': '3', 'b': '2', 'c': '3', None: ['']}, + ] + + def test_rows_strict_missing(self, tmpdir): + """Test behavior when reading rows with missing entries or delimiters.""" + path = write_lines(tmpdir, [ + 'a,b,c\n', + '1,2,3\n', + '2,2,\n', + '3,2\n', + ]) + + m = Metadata(path, delimiters=',', id_columns=['a']) + with pytest.raises(AugurError): + list(m.rows(strict=True)) + + assert list(m.rows(strict=False)) == [ + {'a': '1', 'b': '2', 'c': '3'}, + {'a': '2', 'b': '2', 'c': ''}, + {'a': '3', 'b': '2', 'c': None}, + ]