Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Read a subset of metadata columns #1294

Merged
merged 10 commits into from
Feb 8, 2024
9 changes: 9 additions & 0 deletions CHANGES.md
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Real-world testing

I ran this against the metadata file produced by ncov-ingest (s3://nextstrain-data/files/ncov/open/metadata.tsv.zst) which has 8.5 million rows x 58 columns. I used the following command to sample to 10 random sequences:

augur filter \
  --metadata metadata.tsv \
  --subsample-max-sequences 10 \
  --output-metadata out.tsv

This took 8m21s to run on master, and 6m21s with changes from this PR. Here's profiling results before and after, which I visualized in Snakeviz. A summary:

  • Time spent accessing in-memory DataFrames dropped from 494s to 258s.
  • Without these changes, to_csv takes just a fraction of a second because the metadata for the 10 sequences is already loaded into memory.
  • With these changes, it takes 109s to run through the metadata file to find the lines for the 10 sequences that are wanted.

The example command benefits from a net positive improvement in run time. Although writing time increased due to 5173cb7, reading time decreased even more due to ac23e80.

This was a "best case scenario" for these changes though, since no metadata columns were used, only strain name. I should probably test with --group-by, --min-date, and other options that load additional columns to get a better picture.

I did not do any memory profiling. Memory usage is not an issue without these changes, and should be less of an issue with the changes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you get round to doing more testing/profiling with --group-by and --min-date and what not?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, not yet. Still planning to do so before merging.

Copy link
Member Author

@victorlin victorlin Feb 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tested using the ncov 100k subsample as input to an augur filter command I grabbed from an ncov build. Run time was 16s on master and 7.37s with these changes (cProfile files).

Summary:

  • Metadata reading dropped from 4.3s to 2.15s (pandas readers.py:read)
  • Accessing in-memory DataFrames dropped from 8.63s to 1.127s (pandas indexing.py)
  • Output writing increased from ~0s to 1.35s (write_metadata_based_outputs)
augur filter \
    --metadata ~/tmp/augur-filter/metadata.tsv.xz \
    --include defaults/include.txt \
    --exclude defaults/exclude.txt \
    --max-date 6M \
    --exclude-where 'region!=Asia' country=China country=India \
    --group-by country year month \
    --subsample-max-sequences 200 \
    --output-strains ~/tmp/augur-filter/strains.txt

Copy link
Member Author

@victorlin victorlin Feb 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I triggered a ncov GISAID trial run using a Docker image including these changes - it completed successfully in 6h 36m 55s. This is pretty much the same as another trial run 2 days before at 6h 40m 27s. I don't know how much variance there is between run times, and I don't want to compare against non-trial runs or older runs (those have additional Slack notifications and different input metadata sizes). So by this comparison alone, there doesn't seem to be a significant performance benefit for the ncov workflow with GISAID configuration.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm.

ISTM that last time I looked at ncov's execution profile, by far the slowest step was TreeTime. So not altogether surprising that filter's speed isn't a big impact in the context of a full build.

I grabbed the benchmarks/subsample_* files from those runs to get a little more granular insight into differences in wall clock time and max RSS for each subsample rule invocation.

avg(after - before) for wall clock time was -112s, so it shaved roughly 2 min off each subsample step on average. Equivalent for max RSS is -276 (MB, I believe).

Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,21 @@

## __NEXT__

### Features

* filter: Added a new option `--query-columns` that allows specifying what columns are used in `--query` along with the expected data types. If unspecified, automatic detection of columns and types is attempted. [#1294][] (@victorlin)
* `augur.io.read_metadata`: A new optional `columns` argument allows specifying a subset of columns to load. The default behavior still loads all columns, so this is not a breaking change. [#1294][] (@victorlin)

### Bug Fixes

* filter: The order of rows in `--output-metadata` and `--output-strains` now reflects the order in the original `--metadata`. [#1294][] (@victorlin)
* filter, frequencies, refine: Performance improvements to reading the input metadata file. [#1294][] (@victorlin)
* For filter, this comes with increased writing times for `--output-metadata` and `--output-strains`. However, net I/O speed still decreased during testing of this change.
* filter: Updated the help text of `--include` and `--include-where` to explicitly state that this can add strains that are missing an entry from `--sequences`. [#1389][] (@victorlin)
* filter: Fixed the summary messages to properly reflect force-inclusion of strains that are missing an entry from `--sequences`. [#1389][] (@victorlin)
* filter: Updated wording of summary messages. [#1389][] (@victorlin)

[#1294]: https://github.com/nextstrain/augur/pull/1294
[#1389]: https://github.com/nextstrain/augur/pull/1389

## 24.1.0 (30 January 2024)
Expand Down
6 changes: 6 additions & 0 deletions augur/filter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Filter and subsample a sequence set.
"""
from augur.dates import numeric_date_type, SUPPORTED_DATE_HELP_TEXT
from augur.filter.io import ACCEPTED_TYPES, column_type_pair
from augur.io.metadata import DEFAULT_DELIMITERS, DEFAULT_ID_COLUMNS, METADATA_DATE_COLUMN
from augur.types import EmptyOutputReportingMethod
from . import constants
Expand All @@ -28,6 +29,11 @@ def register_arguments(parser):
Uses Pandas Dataframe querying, see https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#indexing-query for syntax.
(e.g., --query "country == 'Colombia'" or --query "(country == 'USA' & (division == 'Washington'))")"""
)
metadata_filter_group.add_argument('--query-columns', type=column_type_pair, nargs="+", help=f"""
Use alongside --query to specify columns and data types in the format 'column:type', where type is one of ({','.join(ACCEPTED_TYPES)}).
Automatic type inference will be attempted on all unspecified columns used in the query.
Example: region:str coverage:float.
""")
metadata_filter_group.add_argument('--min-date', type=numeric_date_type, help=f"minimal cutoff for date, the cutoff date is inclusive; may be specified as: {SUPPORTED_DATE_HELP_TEXT}")
metadata_filter_group.add_argument('--max-date', type=numeric_date_type, help=f"maximal cutoff for date, the cutoff date is inclusive; may be specified as: {SUPPORTED_DATE_HELP_TEXT}")
metadata_filter_group.add_argument('--exclude-ambiguous-dates-by', choices=['any', 'day', 'month', 'year'],
Expand Down
81 changes: 19 additions & 62 deletions augur/filter/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
DELIMITER as SEQUENCE_INDEX_DELIMITER,
)
from augur.io.file import open_file
from augur.io.metadata import InvalidDelimiter, read_metadata
from augur.io.metadata import InvalidDelimiter, Metadata, read_metadata
from augur.io.sequences import read_sequences, write_sequences
from augur.io.print import print_err
from augur.io.vcf import is_vcf as filename_is_vcf, write_vcf
from augur.types import EmptyOutputReportingMethod
from . import include_exclude_rules
from .io import cleanup_outputs, read_priority_scores
from .io import cleanup_outputs, get_useful_metadata_columns, read_priority_scores, write_metadata_based_outputs
from .include_exclude_rules import apply_filters, construct_filters
from .subsample import PriorityQueue, TooManyGroupsError, calculate_sequences_per_group, create_queues_by_group, get_groups_for_subsampling

Expand Down Expand Up @@ -133,16 +133,6 @@ def run(args):
random_generator = np.random.default_rng(args.subsample_seed)
priorities = defaultdict(random_generator.random)

# Setup metadata output. We track whether any records have been written to
# disk yet through the following variables, to control whether we write the
# metadata's header and open a new file for writing.
metadata_header = True
metadata_mode = "w"

# Setup strain output.
if args.output_strains:
output_strains = open(args.output_strains, "w")

# Setup logging.
output_log_writer = None
if args.output_log:
Expand All @@ -168,19 +158,23 @@ def run(args):
filter_counts = defaultdict(int)

try:
metadata_reader = read_metadata(
args.metadata,
delimiters=args.metadata_delimiters,
id_columns=args.metadata_id_columns,
chunk_size=args.metadata_chunk_size,
dtype="string",
)
metadata_object = Metadata(args.metadata, args.metadata_delimiters, args.metadata_id_columns)
except InvalidDelimiter:
raise AugurError(
f"Could not determine the delimiter of {args.metadata!r}. "
f"Valid delimiters are: {args.metadata_delimiters!r}. "
"This can be changed with --metadata-delimiters."
)
useful_metadata_columns = get_useful_metadata_columns(args, metadata_object.id_column, metadata_object.columns)

metadata_reader = read_metadata(
args.metadata,
delimiters=[metadata_object.delimiter],
columns=useful_metadata_columns,
id_columns=[metadata_object.id_column],
chunk_size=args.metadata_chunk_size,
dtype="string",
)
for metadata in metadata_reader:
duplicate_strains = (
set(metadata.index[metadata.index.duplicated()]) |
Expand Down Expand Up @@ -263,30 +257,6 @@ def run(args):
priorities[strain],
)

# Always write out strains that are force-included. Additionally, if
# we are not grouping, write out metadata and strains that passed
# filters so far.
force_included_strains_to_write = distinct_force_included_strains
if not group_by:
force_included_strains_to_write = force_included_strains_to_write | seq_keep

if args.output_metadata:
# TODO: wrap logic to write metadata into its own function
metadata.loc[list(force_included_strains_to_write)].to_csv(
args.output_metadata,
sep="\t",
header=metadata_header,
mode=metadata_mode,
)
metadata_header = False
metadata_mode = "a"

if args.output_strains:
# TODO: Output strains will no longer be ordered. This is a
# small breaking change.
for strain in force_included_strains_to_write:
output_strains.write(f"{strain}\n")

# In the worst case, we need to calculate sequences per group from the
# requested maximum number of sequences and the number of sequences per
# group. Then, we need to make a second pass through the metadata to find
Expand Down Expand Up @@ -323,6 +293,7 @@ def run(args):
metadata_reader = read_metadata(
args.metadata,
delimiters=args.metadata_delimiters,
columns=useful_metadata_columns,
id_columns=args.metadata_id_columns,
chunk_size=args.metadata_chunk_size,
dtype="string",
Expand Down Expand Up @@ -367,23 +338,6 @@ def run(args):
# Construct a data frame of records to simplify metadata output.
records.append(record)

if args.output_strains:
# TODO: Output strains will no longer be ordered. This is a
# small breaking change.
output_strains.write(f"{record.name}\n")

# Write records to metadata output, if requested.
if args.output_metadata and len(records) > 0:
records = pd.DataFrame(records)
records.to_csv(
args.output_metadata,
sep="\t",
header=metadata_header,
mode=metadata_mode,
)
metadata_header = False
metadata_mode = "a"

# Count and optionally log strains that were not included due to
# subsampling.
strains_filtered_by_subsampling = valid_strains - subsampled_strains
Expand Down Expand Up @@ -442,14 +396,17 @@ def run(args):
# Update the set of available sequence strains.
sequence_strains = observed_sequence_strains

if args.output_metadata or args.output_strains:
write_metadata_based_outputs(args.metadata, args.metadata_delimiters,
args.metadata_id_columns, args.output_metadata,
args.output_strains, valid_strains)

# Calculate the number of strains that don't exist in either metadata or
# sequences.
num_excluded_by_lack_of_metadata = 0
if sequence_strains:
num_excluded_by_lack_of_metadata = len(sequence_strains - metadata_strains)

if args.output_strains:
output_strains.close()

# Calculate the number of strains passed and filtered.
total_strains_passed = len(valid_strains)
Expand Down
101 changes: 78 additions & 23 deletions augur/filter/include_exclude_rules.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import ast
import json
import operator
import re
import numpy as np
import pandas as pd
from typing import Any, Callable, Dict, List, Set, Tuple
from typing import Any, Callable, Dict, List, Optional, Set, Tuple

from augur.dates import is_date_ambiguous, get_numerical_dates
from augur.errors import AugurError
Expand Down Expand Up @@ -78,7 +79,7 @@ def filter_by_exclude(metadata, exclude_file) -> FilterFunctionReturn:
return set(metadata.index.values) - excluded_strains


def _parse_filter_query(query):
def parse_filter_query(query):
"""Parse an augur filter-style query and return the corresponding column,
operator, and value for the query.

Expand All @@ -98,9 +99,9 @@ def _parse_filter_query(query):

Examples
--------
>>> _parse_filter_query("property=value")
>>> parse_filter_query("property=value")
('property', <built-in function eq>, 'value')
>>> _parse_filter_query("property!=value")
>>> parse_filter_query("property!=value")
('property', <built-in function ne>, 'value')

"""
Expand Down Expand Up @@ -143,7 +144,7 @@ def filter_by_exclude_where(metadata, exclude_where) -> FilterFunctionReturn:
['strain1', 'strain2']

"""
column, op, value = _parse_filter_query(exclude_where)
column, op, value = parse_filter_query(exclude_where)
if column in metadata.columns:
# Apply a test operator (equality or inequality) to values from the
# column in the given query. This produces an array of boolean values we
Expand All @@ -164,7 +165,7 @@ def filter_by_exclude_where(metadata, exclude_where) -> FilterFunctionReturn:
return filtered


def filter_by_query(metadata: pd.DataFrame, query: str) -> FilterFunctionReturn:
def filter_by_query(metadata: pd.DataFrame, query: str, column_types: Optional[Dict[str, str]] = None) -> FilterFunctionReturn:
"""Filter metadata in the given pandas DataFrame with a query string and return
the strain names that pass the filter.

Expand All @@ -174,6 +175,8 @@ def filter_by_query(metadata: pd.DataFrame, query: str) -> FilterFunctionReturn:
Metadata indexed by strain name
query : str
Query string for the dataframe.
column_types : str
Dict mapping of data type

Examples
--------
Expand All @@ -187,22 +190,42 @@ def filter_by_query(metadata: pd.DataFrame, query: str) -> FilterFunctionReturn:
# Create a copy to prevent modification of the original DataFrame.
metadata_copy = metadata.copy()

# Support numeric comparisons in query strings.
#
# The built-in data type inference when loading the DataFrame does not
if column_types is None:
column_types = {}

# Set columns for type conversion.
variables = extract_variables(query)
if variables is not None:
columns = variables.intersection(metadata_copy.columns)
else:
# Column extraction failed. Apply type conversion to all columns.
columns = metadata_copy.columns

# If a type is not explicitly provided, try converting the column to numeric.
# This should cover most use cases, since one common problem is that the
# built-in data type inference when loading the DataFrame does not
# support nullable numeric columns, so numeric comparisons won't work on
# those columns. pd.to_numeric does proper conversion on those columns, and
# will not make any changes to columns with other values.
#
# TODO: Parse the query string and apply conversion only to columns used for
# numeric comparison. Pandas does not expose the API used to parse the query
# string internally, so this is non-trivial and requires a bit of
# reverse-engineering. Commit 2ead5b3e3306dc1100b49eb774287496018122d9 got
# halfway there but had issues so it was reverted.
#
# TODO: Try boolean conversion?
for column in metadata_copy.columns:
metadata_copy[column] = pd.to_numeric(metadata_copy[column], errors='ignore')
# those columns. pd.to_numeric does proper conversion on those columns,
# and will not make any changes to columns with other values.
for column in columns:
column_types.setdefault(column, 'numeric')

# Convert data types before applying the query.
for column, dtype in column_types.items():
if dtype == 'numeric':
metadata_copy[column] = pd.to_numeric(metadata_copy[column], errors='ignore')
elif dtype == 'int':
try:
metadata_copy[column] = pd.to_numeric(metadata_copy[column], errors='raise', downcast='integer')
except ValueError as e:
raise AugurError(f"Failed to convert value in column {column!r} to int. {e}")
elif dtype == 'float':
try:
metadata_copy[column] = pd.to_numeric(metadata_copy[column], errors='raise', downcast='float')
except ValueError as e:
raise AugurError(f"Failed to convert value in column {column!r} to float. {e}")
elif dtype == 'str':
metadata_copy[column] = metadata_copy[column].astype('str', errors='ignore')

try:
return set(metadata_copy.query(query).index.values)
Expand Down Expand Up @@ -492,7 +515,7 @@ def force_include_where(metadata, include_where) -> FilterFunctionReturn:
set()

"""
column, op, value = _parse_filter_query(include_where)
column, op, value = parse_filter_query(include_where)

if column in metadata.columns:
# Apply a test operator (equality or inequality) to values from the
Expand Down Expand Up @@ -578,9 +601,13 @@ def construct_filters(args, sequence_index) -> Tuple[List[FilterOption], List[Fi

# Exclude strains by metadata, using pandas querying.
if args.query:
kwargs = {"query": args.query}
if args.query_columns:
kwargs["column_types"] = {column: dtype for column, dtype in args.query_columns}

exclude_by.append((
filter_by_query,
{"query": args.query}
kwargs
))

# Filter by ambiguous dates.
Expand Down Expand Up @@ -820,3 +847,31 @@ def _filter_kwargs_to_str(kwargs: FilterFunctionKwargs):
kwarg_list.append((key, value))

return json.dumps(kwarg_list)


def extract_variables(pandas_query: str):
"""Try extracting all variable names used in a pandas query string.

If successful, return the variable names as a set. Otherwise, nothing is returned.

Examples
--------
>>> extract_variables("var1 == 'value'")
{'var1'}
>>> sorted(extract_variables("var1 == 'value' & var2 == 10"))
['var1', 'var2']
>>> extract_variables("var1.str.startswith('prefix')")
{'var1'}
>>> extract_variables("this query is invalid")
"""
# Since Pandas' query grammar should be a subset of Python's, which uses the
# ast stdlib under the hood, we can try to parse queries with that as well.
# Errors may arise from invalid query syntax or any Pandas syntax not
# covered by Python (unlikely, but I'm not sure). In those cases, don't
# return anything.
try:
return set(node.id
for node in ast.walk(ast.parse(pandas_query))
if isinstance(node, ast.Name))
victorlin marked this conversation as resolved.
Show resolved Hide resolved
except:
return None
Loading
Loading