From ffca9da4d8e4c0f6457021a2eed23d3415ced770 Mon Sep 17 00:00:00 2001 From: Cas Wognum Date: Fri, 10 Jan 2025 15:09:44 -0500 Subject: [PATCH] Add custom codecs for RDKit Molecules and Biotite AtomArrays (#243) * Add custom codecs for RDKit Molecules and Biotite AtomArrays * Update polaris/dataset/zarr/_codecs.py Co-authored-by: Julien St-Laurent * Addressed feedback from PR --------- Co-authored-by: Julien St-Laurent --- polaris/__init__.py | 2 +- polaris/dataset/__init__.py | 5 +- polaris/dataset/zarr/codecs.py | 149 +++++++++++++++++++++++++++++++++ tests/test_codecs.py | 11 +++ 4 files changed, 164 insertions(+), 3 deletions(-) create mode 100644 polaris/dataset/zarr/codecs.py create mode 100644 tests/test_codecs.py diff --git a/polaris/__init__.py b/polaris/__init__.py index a3854a17..273793ba 100644 --- a/polaris/__init__.py +++ b/polaris/__init__.py @@ -4,7 +4,7 @@ from loguru import logger from ._version import __version__ -from .loader import load_benchmark, load_dataset, load_competition +from .loader import load_benchmark, load_competition, load_dataset __all__ = ["load_dataset", "load_benchmark", "load_competition", "__version__"] diff --git a/polaris/dataset/__init__.py b/polaris/dataset/__init__.py index 8084749c..253fda87 100644 --- a/polaris/dataset/__init__.py +++ b/polaris/dataset/__init__.py @@ -1,9 +1,10 @@ from polaris.dataset._column import ColumnAnnotation, KnownContentType, Modality -from polaris.dataset._dataset import DatasetV1, DatasetV1 as Dataset +from polaris.dataset._dataset import DatasetV1 +from polaris.dataset._dataset import DatasetV1 as Dataset from polaris.dataset._dataset_v2 import DatasetV2 from polaris.dataset._factory import DatasetFactory, create_dataset_from_file, create_dataset_from_files from polaris.dataset._subset import Subset - +from polaris.dataset.zarr import codecs __all__ = [ "create_dataset_from_file", diff --git a/polaris/dataset/zarr/codecs.py b/polaris/dataset/zarr/codecs.py new file mode 100644 index 00000000..2ca118c6 --- /dev/null +++ b/polaris/dataset/zarr/codecs.py @@ -0,0 +1,149 @@ +import numpy as np +from fastpdb import struc +from numcodecs import MsgPack, register_codec +from numcodecs.vlen import VLenBytes +from rdkit import Chem + + +class RDKitMolCodec(VLenBytes): + """ + Codec for RDKit's Molecules. + + Info: Binary strings for serialization + This class converts the molecules to binary strings (for ML purposes, this should be lossless). + This might not be the most storage efficient, but is fastest and easiest to maintain. + See this [Github Discussion](https://github.com/rdkit/rdkit/discussions/7235) for more info. + + """ + + codec_id = "rdkit_mol" + + def encode(self, buf: np.ndarray): + """ + Encode a chunk of RDKit Mols to byte strings + """ + to_encode = np.empty(shape=len(buf), dtype=object) + for idx, mol in enumerate(buf): + if mol is None or (isinstance(mol, bytes) and len(mol) == 0): + continue + if not isinstance(mol, Chem.Mol): + raise ValueError(f"Expected an RDKitMol, but got {type(buf)} instead.") + props = Chem.PropertyPickleOptions.AllProps + to_encode[idx] = mol.ToBinary(props) + + to_encode = np.array(to_encode, dtype=object) + return super().encode(to_encode) + + def decode(self, buf, out=None): + """Decode the variable length bytes encoded data into a RDKit Mol.""" + dec = super().decode(buf, out) + for idx, mol in enumerate(dec): + if len(mol) == 0: + continue + dec[idx] = Chem.Mol(mol) + + if out is not None: + np.copyto(out, dec) + return out + else: + return dec + + +class AtomArrayCodec(MsgPack): + """ + Codec for FastPDB (i.e. Biotite) Atom Arrays. + + Info: Only the most essential structural information of a protein is retained + This conversion saves the 3D coordinates, chain ID, residue ID, insertion code, residue name, heteroatom indicator, atom name, element, atom ID, B-factor, occupancy, and charge. + Records such as CONECT (connectivity information), ANISOU (anisotropic Temperature Factors), HETATM (heteroatoms and ligands) are handled by `fastpdb`. + We believe this makes for a good _ML-ready_ format, but let us know if you require any other information to be saved. + + + Info: PDBs as ND-arrays using `biotite` + To save PDBs in a Polaris-compatible format, we convert them to ND-arrays using `fastpdb` and `biotite`. + We then save these ND-arrays to Zarr archives. + For more info, see [fastpdb](https://github.com/biotite-dev/fastpdb) + and [biotite](https://github.com/biotite-dev/biotite/blob/main/src/biotite/structure/atoms.py) + + This codec is a subclass of the `MsgPack` codec from the `numcodecs` + """ + + codec_id = "atom_array" + + def encode(self, buf: np.ndarray): + """ + Encode a chunk of AtomArrays to a plain Python structure that MsgPack can encode + """ + + to_pack = np.empty_like(buf) + + for idx, atom_array in enumerate(buf): + # A chunk can have missing values + if atom_array is None: + continue + + if not isinstance(atom_array, struc.AtomArray): + raise ValueError(f"Expected an AtomArray, but got {type(atom_array)} instead") + + data = { + "coord": atom_array.coord, + "chain_id": atom_array.chain_id, + "res_id": atom_array.res_id, + "ins_code": atom_array.ins_code, + "res_name": atom_array.res_name, + "hetero": atom_array.hetero, + "atom_name": atom_array.atom_name, + "element": atom_array.element, + "atom_id": atom_array.atom_id, + "b_factor": atom_array.b_factor, + "occupancy": atom_array.occupancy, + "charge": atom_array.charge, + } + data = {k: v.tolist() for k, v in data.items()} + to_pack[idx] = data + + return super().encode(to_pack) + + def decode(self, buf, out=None): + """Decode the MsgPack decoded data into a `fastpdb` AtomArray.""" + + dec = super().decode(buf, out) + + structs = np.empty(shape=len(dec), dtype=object) + + for idx, data in enumerate(dec): + if data is None: + continue + + atom_array = [] + array_length = len(data["coord"]) + + for ind in range(array_length): + atom = struc.Atom( + coord=data["coord"][ind], + chain_id=data["chain_id"][ind], + res_id=data["res_id"][ind], + ins_code=data["ins_code"][ind], + res_name=data["res_name"][ind], + hetero=data["hetero"][ind], + atom_name=data["atom_name"][ind], + element=data["element"][ind], + b_factor=data["b_factor"][ind], + occupancy=data["occupancy"][ind], + charge=data["charge"][ind], + atom_id=data["atom_id"][ind], + ) + atom_array.append(atom) + + # Note that this is a `fastpdb` AtomArray, not a NumPy array. + structs[idx] = struc.array(atom_array) + + if out is not None: + np.copyto(out, structs) + return out + else: + return structs + + +register_codec(RDKitMolCodec) +register_codec(AtomArrayCodec) diff --git a/tests/test_codecs.py b/tests/test_codecs.py new file mode 100644 index 00000000..c38e4dab --- /dev/null +++ b/tests/test_codecs.py @@ -0,0 +1,11 @@ +import datamol as dm +import zarr + +from polaris.dataset.zarr.codecs import RDKitMolCodec + + +def test_rdkit_mol_codec(): + mol = dm.to_mol("C1=CC=CC=C1") + arr = zarr.array([mol, mol], chunks=(2,), dtype=object, object_codec=RDKitMolCodec()) + assert dm.same_mol(arr[0], mol) + assert dm.same_mol(arr[1], mol)