Skip to content

Commit

Permalink
Add custom codecs for RDKit Molecules and Biotite AtomArrays (#243)
Browse files Browse the repository at this point in the history
* Add custom codecs for RDKit Molecules and Biotite AtomArrays

* Update polaris/dataset/zarr/_codecs.py

Co-authored-by: Julien St-Laurent <jstlaurent@users.noreply.github.com>

* Addressed feedback from PR

---------

Co-authored-by: Julien St-Laurent <jstlaurent@users.noreply.github.com>
  • Loading branch information
cwognum and jstlaurent authored Jan 10, 2025
1 parent 0df06a0 commit ffca9da
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 3 deletions.
2 changes: 1 addition & 1 deletion polaris/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from loguru import logger

from ._version import __version__
from .loader import load_benchmark, load_dataset, load_competition
from .loader import load_benchmark, load_competition, load_dataset

__all__ = ["load_dataset", "load_benchmark", "load_competition", "__version__"]

Expand Down
5 changes: 3 additions & 2 deletions polaris/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from polaris.dataset._column import ColumnAnnotation, KnownContentType, Modality
from polaris.dataset._dataset import DatasetV1, DatasetV1 as Dataset
from polaris.dataset._dataset import DatasetV1
from polaris.dataset._dataset import DatasetV1 as Dataset
from polaris.dataset._dataset_v2 import DatasetV2
from polaris.dataset._factory import DatasetFactory, create_dataset_from_file, create_dataset_from_files
from polaris.dataset._subset import Subset

from polaris.dataset.zarr import codecs

__all__ = [
"create_dataset_from_file",
Expand Down
149 changes: 149 additions & 0 deletions polaris/dataset/zarr/codecs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import numpy as np
from fastpdb import struc
from numcodecs import MsgPack, register_codec
from numcodecs.vlen import VLenBytes
from rdkit import Chem


class RDKitMolCodec(VLenBytes):
"""
Codec for RDKit's Molecules.
Info: Binary strings for serialization
This class converts the molecules to binary strings (for ML purposes, this should be lossless).
This might not be the most storage efficient, but is fastest and easiest to maintain.
See this [Github Discussion](https://github.com/rdkit/rdkit/discussions/7235) for more info.
"""

codec_id = "rdkit_mol"

def encode(self, buf: np.ndarray):
"""
Encode a chunk of RDKit Mols to byte strings
"""
to_encode = np.empty(shape=len(buf), dtype=object)
for idx, mol in enumerate(buf):
if mol is None or (isinstance(mol, bytes) and len(mol) == 0):
continue
if not isinstance(mol, Chem.Mol):
raise ValueError(f"Expected an RDKitMol, but got {type(buf)} instead.")
props = Chem.PropertyPickleOptions.AllProps
to_encode[idx] = mol.ToBinary(props)

to_encode = np.array(to_encode, dtype=object)
return super().encode(to_encode)

def decode(self, buf, out=None):
"""Decode the variable length bytes encoded data into a RDKit Mol."""
dec = super().decode(buf, out)
for idx, mol in enumerate(dec):
if len(mol) == 0:
continue
dec[idx] = Chem.Mol(mol)

if out is not None:
np.copyto(out, dec)
return out
else:
return dec


class AtomArrayCodec(MsgPack):
"""
Codec for FastPDB (i.e. Biotite) Atom Arrays.
Info: Only the most essential structural information of a protein is retained
This conversion saves the 3D coordinates, chain ID, residue ID, insertion code, residue name, heteroatom indicator, atom name, element, atom ID, B-factor, occupancy, and charge.
Records such as CONECT (connectivity information), ANISOU (anisotropic Temperature Factors), HETATM (heteroatoms and ligands) are handled by `fastpdb`.
We believe this makes for a good _ML-ready_ format, but let us know if you require any other information to be saved.
Info: PDBs as ND-arrays using `biotite`
To save PDBs in a Polaris-compatible format, we convert them to ND-arrays using `fastpdb` and `biotite`.
We then save these ND-arrays to Zarr archives.
For more info, see [fastpdb](https://github.com/biotite-dev/fastpdb)
and [biotite](https://github.com/biotite-dev/biotite/blob/main/src/biotite/structure/atoms.py)
This codec is a subclass of the `MsgPack` codec from the `numcodecs`
"""

codec_id = "atom_array"

def encode(self, buf: np.ndarray):
"""
Encode a chunk of AtomArrays to a plain Python structure that MsgPack can encode
"""

to_pack = np.empty_like(buf)

for idx, atom_array in enumerate(buf):
# A chunk can have missing values
if atom_array is None:
continue

if not isinstance(atom_array, struc.AtomArray):
raise ValueError(f"Expected an AtomArray, but got {type(atom_array)} instead")

data = {
"coord": atom_array.coord,
"chain_id": atom_array.chain_id,
"res_id": atom_array.res_id,
"ins_code": atom_array.ins_code,
"res_name": atom_array.res_name,
"hetero": atom_array.hetero,
"atom_name": atom_array.atom_name,
"element": atom_array.element,
"atom_id": atom_array.atom_id,
"b_factor": atom_array.b_factor,
"occupancy": atom_array.occupancy,
"charge": atom_array.charge,
}
data = {k: v.tolist() for k, v in data.items()}
to_pack[idx] = data

return super().encode(to_pack)

def decode(self, buf, out=None):
"""Decode the MsgPack decoded data into a `fastpdb` AtomArray."""

dec = super().decode(buf, out)

structs = np.empty(shape=len(dec), dtype=object)

for idx, data in enumerate(dec):
if data is None:
continue

atom_array = []
array_length = len(data["coord"])

for ind in range(array_length):
atom = struc.Atom(
coord=data["coord"][ind],
chain_id=data["chain_id"][ind],
res_id=data["res_id"][ind],
ins_code=data["ins_code"][ind],
res_name=data["res_name"][ind],
hetero=data["hetero"][ind],
atom_name=data["atom_name"][ind],
element=data["element"][ind],
b_factor=data["b_factor"][ind],
occupancy=data["occupancy"][ind],
charge=data["charge"][ind],
atom_id=data["atom_id"][ind],
)
atom_array.append(atom)

# Note that this is a `fastpdb` AtomArray, not a NumPy array.
structs[idx] = struc.array(atom_array)

if out is not None:
np.copyto(out, structs)
return out
else:
return structs


register_codec(RDKitMolCodec)
register_codec(AtomArrayCodec)
11 changes: 11 additions & 0 deletions tests/test_codecs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import datamol as dm
import zarr

from polaris.dataset.zarr.codecs import RDKitMolCodec


def test_rdkit_mol_codec():
mol = dm.to_mol("C1=CC=CC=C1")
arr = zarr.array([mol, mol], chunks=(2,), dtype=object, object_codec=RDKitMolCodec())
assert dm.same_mol(arr[0], mol)
assert dm.same_mol(arr[1], mol)

0 comments on commit ffca9da

Please sign in to comment.