-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add get_field to FastSDMol * refactor out common logic * add sdfilter script to pyproject.toml * add basic tests for sdfilter * implement sdfilter
- Loading branch information
1 parent
08a3b41
commit 960ddd0
Showing
13 changed files
with
518 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
import sys | ||
from typing import Generator, TextIO | ||
|
||
|
||
def inputs_generator(inputs: list[str]) -> Generator[TextIO, None, None]: | ||
if not inputs: | ||
yield sys.stdin | ||
else: | ||
for infile in inputs: | ||
with open(infile, "r") as f: | ||
yield f |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# Standard Library | ||
import logging | ||
import operator | ||
import shlex | ||
from collections import Counter | ||
from pathlib import Path | ||
from typing import Any, Callable, Iterable | ||
|
||
# Local imports | ||
from rdock_utils.common.SDFParser import FastSDMol | ||
|
||
logger = logging.getLogger("sdfilter") | ||
|
||
|
||
def get_casted_operands(operand1: str | int, operand2: str | int) -> tuple[str, str] | tuple[float, float]: | ||
try: | ||
return (float(operand1), float(operand2)) | ||
except ValueError: | ||
return (str(operand1), str(operand2)) | ||
|
||
|
||
OPERATION_MAP: dict[str, Callable[[Any, Any], bool]] = { | ||
"==": operator.eq, | ||
"!=": operator.ne, | ||
"<": operator.lt, | ||
">": operator.gt, | ||
"<=": operator.le, | ||
">=": operator.ge, | ||
"eq": operator.eq, | ||
"ne": operator.ne, | ||
"lt": operator.lt, | ||
"gt": operator.gt, | ||
"le": operator.le, | ||
"ge": operator.ge, | ||
"in": lambda x, y: x in y, | ||
"not_in": lambda x, y: x not in y, | ||
} | ||
|
||
|
||
class ExpressionContext: | ||
def __init__(self, summary_field: str | None = None): | ||
self.summary_field = summary_field | ||
self.summary_counter: Counter[Any] = Counter() | ||
self.record = 0 | ||
|
||
def _get_symbol_as_field(self, symbol: str, molecule: FastSDMol) -> str: | ||
value = molecule.get_field(symbol[1:]) | ||
if value is None: | ||
logger.warning(f"field {symbol} not found in record {self.record}, assuming literal string") | ||
return symbol | ||
return value | ||
|
||
def get_operand_raw_value(self, operand: str, molecule: FastSDMol) -> str | int: | ||
if operand.startswith("$"): | ||
return self.get_symbol_value(operand, molecule) | ||
else: | ||
return operand | ||
|
||
def get_symbol_value(self, symbol: str, molecule: FastSDMol) -> str | int: | ||
match symbol: | ||
case "$_REC": | ||
return self.record | ||
case "$_COUNT": | ||
if self.summary_field is None: | ||
raise ValueError("summary field not provided") | ||
summary_field_value = molecule.get_field(self.summary_field) | ||
return self.summary_counter[summary_field_value] | ||
case _: | ||
return self._get_symbol_as_field(symbol, molecule) | ||
|
||
|
||
class FilterExpression: | ||
def __init__(self, operand1: str, operator: str, operand2: str, context: ExpressionContext): | ||
self.operand1 = operand1 | ||
self.operator = operator | ||
self.operand2 = operand2 | ||
self.context = context | ||
|
||
def evaluate(self, molecule: FastSDMol) -> bool: | ||
raw_operand1 = self.context.get_operand_raw_value(self.operand1, molecule) | ||
raw_operand2 = self.context.get_operand_raw_value(self.operand2, molecule) | ||
operand1, operand2 = get_casted_operands(raw_operand1, raw_operand2) | ||
return OPERATION_MAP[self.operator](operand1, operand2) | ||
|
||
|
||
def create_filters(filter_str: str, context: ExpressionContext) -> list[FilterExpression]: | ||
tokens = shlex.split(filter_str) | ||
if len(tokens) == 1 and (path := Path(filter_str)).is_file(): | ||
with open(path, "r") as f: | ||
return [filter for line in f for filter in create_filters(line.strip(), context)] | ||
elif len(tokens) != 3: | ||
raise ValueError(f"invalid filter: {filter_str}") | ||
|
||
if tokens[1] not in OPERATION_MAP: | ||
raise ValueError(f"invalid operator: {tokens[1]}. expected: {OPERATION_MAP.keys()}") | ||
|
||
return [FilterExpression(tokens[0], tokens[1], tokens[2], context)] | ||
|
||
|
||
def molecules_with_context(molecules: Iterable[FastSDMol], context: ExpressionContext) -> Iterable[FastSDMol]: | ||
for molecule in molecules: | ||
context.record += 1 | ||
if context.summary_field is not None: | ||
context.summary_counter[molecule.get_field(context.summary_field)] += 1 | ||
yield molecule |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from rdock_utils.common import inputs_generator, read_molecules_from_all_inputs | ||
|
||
from .filter import ExpressionContext, create_filters, molecules_with_context | ||
from .parser import get_config | ||
|
||
|
||
def main(argv: list[str] | None = None) -> None: | ||
config = get_config(argv) | ||
inputs = inputs_generator(config.infile) | ||
context = ExpressionContext(config.summary_field) | ||
filters = create_filters(config.filter, context) | ||
molecules = molecules_with_context(read_molecules_from_all_inputs(inputs), context) | ||
for molecule in molecules: | ||
if any(filter.evaluate(molecule) for filter in filters): | ||
print(repr(molecule)) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# Standard Library | ||
import argparse | ||
import logging | ||
from dataclasses import dataclass | ||
|
||
logger = logging.getLogger("sdfilter") | ||
|
||
|
||
@dataclass | ||
class SDFilterConfig: | ||
filter: str | ||
summary_field: str | None | ||
infile: list[str] | ||
|
||
|
||
def get_parser() -> argparse.ArgumentParser: | ||
parser = argparse.ArgumentParser(description="Filters SD records by data fields") | ||
filter_help = ( | ||
"Filters can be provided as a string or in a file, one per line. All filters are OR'd together.\n" | ||
"Filters follow the format:\n" | ||
"'$<DataField> <Operator> <Value>'\n" | ||
"where valid operators are: '==', '!=', '<', '>', '<=', and '>=' for general values,\n" | ||
"'in' and 'not_in' for strings, and 'eq', 'ne', 'lt', 'gt', 'le', and 'ge' \n" | ||
"for strings for perl version retro-compatibility.\n" | ||
"_REC (record number), _TITLE1, _TITLE2, and _TITLE3 are provided as a pseudo-data field\n" | ||
"rdock-utils provides expanded functionality, where two data fields can be compared\n" | ||
"using the following syntax:\n" | ||
"'$<DataField1> <Operator> $<DataField2>'\n" | ||
"also, any combination of literal filters and filter files can be provided\n" | ||
"filter files including other filters are supported as well, so be careful with recursion\n" | ||
) | ||
parser.add_argument("-f", "--filter", type=str, help=filter_help, required=True) | ||
s_help = "If -s option is used, _COUNT (#occurrences of DataField) is provided as a pseudo-data field" | ||
parser.add_argument("-s", type=str, default=None, help=s_help) | ||
infile_help = "input file[s] to be processed. if not provided, stdin is used." | ||
parser.add_argument("infile", type=str, nargs="*", help=infile_help) | ||
return parser | ||
|
||
|
||
def get_config(argv: list[str] | None = None) -> SDFilterConfig: | ||
parser = get_parser() | ||
args = parser.parse_args(argv) | ||
return SDFilterConfig(filter=args.filter, summary_field=args.s, infile=args.infile) |
Oops, something went wrong.