Skip to content

Commit

Permalink
Issue #61/migrate sdfilter (#75)
Browse files Browse the repository at this point in the history
* add get_field to FastSDMol

* refactor out common logic

* add sdfilter script to pyproject.toml

* add basic tests for sdfilter

* implement sdfilter
  • Loading branch information
ggutierrez-sunbright authored Feb 3, 2024
1 parent 08a3b41 commit 960ddd0
Show file tree
Hide file tree
Showing 13 changed files with 518 additions and 27 deletions.
1 change: 1 addition & 0 deletions rdock-utils/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ sdrmsd_old = "rdock_utils.sdrmsd_original:main"
sdrmsd = "rdock_utils.sdrmsd.main:main"
sdtether = "rdock_utils.sdtether.main:main"
sdtether_old = "rdock_utils.sdtether_original:main"
sdfilter = "rdock_utils.sdfilter.main:main"

[project.urls]
Repository = "https://github.com/CBDD/rDock.git"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Standard Library
import itertools
import logging
from io import StringIO
from typing import Any, TextIO
from typing import Any, Generator, Iterable, TextIO

logger = logging.getLogger("SDParser")

Expand Down Expand Up @@ -74,3 +75,30 @@ def write(self, dest: TextIO) -> None:
for field_name, field_value in self.data.items():
dest.write(self.str_field(field_name, field_value))
dest.write("$$$$")

def get_field(self, field_name: str) -> str | None:
if field_name.startswith("_TITLE"):
line_number = int(field_name[-1]) - 1
if 0 <= line_number < min(len(self.lines), 3):
return self.lines[line_number].strip()
return None
return self.data.get(field_name, None)

@property
def title(self) -> str:
return self.lines[0].strip()


def read_molecules(file: TextIO) -> Generator[FastSDMol, None, None]:
while True:
try:
mol = FastSDMol.read(file)
if mol is None:
break
yield mol
except ValueError as e:
logger.warning(f"error reading molecule: {e}")


def read_molecules_from_all_inputs(inputs: Iterable[TextIO]) -> Iterable[FastSDMol]:
return itertools.chain.from_iterable(read_molecules(source) for source in inputs)
8 changes: 8 additions & 0 deletions rdock-utils/rdock_utils/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .files import inputs_generator
from .SDFParser import FastSDMol, read_molecules, read_molecules_from_all_inputs
from .superpose3d import MolAlignmentData, Superpose3D, update_coordinates
from .types import (
AtomsMapping,
Expand All @@ -11,6 +13,12 @@
)

__all__ = [
# -- files --
"inputs_generator",
# -- SDFParser --
"FastSDMol",
"read_molecules",
"read_molecules_from_all_inputs",
# -- superpose3d --
"update_coordinates",
"MolAlignmentData",
Expand Down
11 changes: 11 additions & 0 deletions rdock-utils/rdock_utils/common/files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import sys
from typing import Generator, TextIO


def inputs_generator(inputs: list[str]) -> Generator[TextIO, None, None]:
if not inputs:
yield sys.stdin
else:
for infile in inputs:
with open(infile, "r") as f:
yield f
30 changes: 4 additions & 26 deletions rdock-utils/rdock_utils/sdfield.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
# Standard Library
import argparse
import sys
from logging import getLogger
from typing import Iterable, TextIO

# Local imports
from .parser import FastSDMol
from .common import inputs_generator, read_molecules_from_all_inputs

logger = getLogger("sdfield")

Expand All @@ -22,33 +20,13 @@ def get_parser() -> argparse.ArgumentParser:
return parser


def inputs_generator(inputs: list[str]) -> Iterable[TextIO]:
if not inputs:
yield sys.stdin
else:
for infile in inputs:
yield open(infile, "r")


def read_molecules(file: TextIO) -> Iterable[FastSDMol]:
while True:
try:
mol = FastSDMol.read(file)
if mol is None:
break
yield mol
except ValueError as e:
logger.warning(f"error reading molecule: {e}")


def main(argv: list[str] | None = None) -> None:
parser = get_parser()
args = parser.parse_args(argv)
inputs = inputs_generator(args.infile)
for source in inputs:
for molecule in read_molecules(source):
molecule.data[args.fieldname] = args.value
print(repr(molecule))
for molecule in read_molecules_from_all_inputs(inputs):
molecule.data[args.fieldname] = args.value
print(repr(molecule))


if __name__ == "__main__":
Expand Down
Empty file.
105 changes: 105 additions & 0 deletions rdock-utils/rdock_utils/sdfilter/filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Standard Library
import logging
import operator
import shlex
from collections import Counter
from pathlib import Path
from typing import Any, Callable, Iterable

# Local imports
from rdock_utils.common.SDFParser import FastSDMol

logger = logging.getLogger("sdfilter")


def get_casted_operands(operand1: str | int, operand2: str | int) -> tuple[str, str] | tuple[float, float]:
try:
return (float(operand1), float(operand2))
except ValueError:
return (str(operand1), str(operand2))


OPERATION_MAP: dict[str, Callable[[Any, Any], bool]] = {
"==": operator.eq,
"!=": operator.ne,
"<": operator.lt,
">": operator.gt,
"<=": operator.le,
">=": operator.ge,
"eq": operator.eq,
"ne": operator.ne,
"lt": operator.lt,
"gt": operator.gt,
"le": operator.le,
"ge": operator.ge,
"in": lambda x, y: x in y,
"not_in": lambda x, y: x not in y,
}


class ExpressionContext:
def __init__(self, summary_field: str | None = None):
self.summary_field = summary_field
self.summary_counter: Counter[Any] = Counter()
self.record = 0

def _get_symbol_as_field(self, symbol: str, molecule: FastSDMol) -> str:
value = molecule.get_field(symbol[1:])
if value is None:
logger.warning(f"field {symbol} not found in record {self.record}, assuming literal string")
return symbol
return value

def get_operand_raw_value(self, operand: str, molecule: FastSDMol) -> str | int:
if operand.startswith("$"):
return self.get_symbol_value(operand, molecule)
else:
return operand

def get_symbol_value(self, symbol: str, molecule: FastSDMol) -> str | int:
match symbol:
case "$_REC":
return self.record
case "$_COUNT":
if self.summary_field is None:
raise ValueError("summary field not provided")
summary_field_value = molecule.get_field(self.summary_field)
return self.summary_counter[summary_field_value]
case _:
return self._get_symbol_as_field(symbol, molecule)


class FilterExpression:
def __init__(self, operand1: str, operator: str, operand2: str, context: ExpressionContext):
self.operand1 = operand1
self.operator = operator
self.operand2 = operand2
self.context = context

def evaluate(self, molecule: FastSDMol) -> bool:
raw_operand1 = self.context.get_operand_raw_value(self.operand1, molecule)
raw_operand2 = self.context.get_operand_raw_value(self.operand2, molecule)
operand1, operand2 = get_casted_operands(raw_operand1, raw_operand2)
return OPERATION_MAP[self.operator](operand1, operand2)


def create_filters(filter_str: str, context: ExpressionContext) -> list[FilterExpression]:
tokens = shlex.split(filter_str)
if len(tokens) == 1 and (path := Path(filter_str)).is_file():
with open(path, "r") as f:
return [filter for line in f for filter in create_filters(line.strip(), context)]
elif len(tokens) != 3:
raise ValueError(f"invalid filter: {filter_str}")

if tokens[1] not in OPERATION_MAP:
raise ValueError(f"invalid operator: {tokens[1]}. expected: {OPERATION_MAP.keys()}")

return [FilterExpression(tokens[0], tokens[1], tokens[2], context)]


def molecules_with_context(molecules: Iterable[FastSDMol], context: ExpressionContext) -> Iterable[FastSDMol]:
for molecule in molecules:
context.record += 1
if context.summary_field is not None:
context.summary_counter[molecule.get_field(context.summary_field)] += 1
yield molecule
19 changes: 19 additions & 0 deletions rdock-utils/rdock_utils/sdfilter/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from rdock_utils.common import inputs_generator, read_molecules_from_all_inputs

from .filter import ExpressionContext, create_filters, molecules_with_context
from .parser import get_config


def main(argv: list[str] | None = None) -> None:
config = get_config(argv)
inputs = inputs_generator(config.infile)
context = ExpressionContext(config.summary_field)
filters = create_filters(config.filter, context)
molecules = molecules_with_context(read_molecules_from_all_inputs(inputs), context)
for molecule in molecules:
if any(filter.evaluate(molecule) for filter in filters):
print(repr(molecule))


if __name__ == "__main__":
main()
43 changes: 43 additions & 0 deletions rdock-utils/rdock_utils/sdfilter/parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Standard Library
import argparse
import logging
from dataclasses import dataclass

logger = logging.getLogger("sdfilter")


@dataclass
class SDFilterConfig:
filter: str
summary_field: str | None
infile: list[str]


def get_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Filters SD records by data fields")
filter_help = (
"Filters can be provided as a string or in a file, one per line. All filters are OR'd together.\n"
"Filters follow the format:\n"
"'$<DataField> <Operator> <Value>'\n"
"where valid operators are: '==', '!=', '<', '>', '<=', and '>=' for general values,\n"
"'in' and 'not_in' for strings, and 'eq', 'ne', 'lt', 'gt', 'le', and 'ge' \n"
"for strings for perl version retro-compatibility.\n"
"_REC (record number), _TITLE1, _TITLE2, and _TITLE3 are provided as a pseudo-data field\n"
"rdock-utils provides expanded functionality, where two data fields can be compared\n"
"using the following syntax:\n"
"'$<DataField1> <Operator> $<DataField2>'\n"
"also, any combination of literal filters and filter files can be provided\n"
"filter files including other filters are supported as well, so be careful with recursion\n"
)
parser.add_argument("-f", "--filter", type=str, help=filter_help, required=True)
s_help = "If -s option is used, _COUNT (#occurrences of DataField) is provided as a pseudo-data field"
parser.add_argument("-s", type=str, default=None, help=s_help)
infile_help = "input file[s] to be processed. if not provided, stdin is used."
parser.add_argument("infile", type=str, nargs="*", help=infile_help)
return parser


def get_config(argv: list[str] | None = None) -> SDFilterConfig:
parser = get_parser()
args = parser.parse_args(argv)
return SDFilterConfig(filter=args.filter, summary_field=args.s, infile=args.infile)
Loading

0 comments on commit 960ddd0

Please sign in to comment.