-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
murcko scaffold split #18
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# beignet subsets | ||
|
||
::: beignet.subsets.murcko_scaffold_split |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from ._murcko_scaffold_split import murcko_scaffold_split | ||
|
||
__all__ = [ | ||
"murcko_scaffold_split", | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
import math | ||
import random | ||
from collections import defaultdict | ||
from typing import Sequence | ||
|
||
from torch.utils.data import Dataset, Subset | ||
|
||
try: | ||
from rdkit import Chem | ||
from rdkit.Chem.Scaffolds.MurckoScaffold import GetScaffoldForMol | ||
|
||
_RDKit_AVAILABLE = True | ||
except (ImportError, ModuleNotFoundError): | ||
_RDKit_AVAILABLE = False | ||
Chem, MurckoScaffoldSmiles = None, None | ||
|
||
|
||
def murcko_scaffold_split( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add |
||
dataset: Dataset, | ||
smiles: Sequence[str], | ||
test_size: float | int, | ||
*, | ||
seed: int = 0xDEADBEEF, | ||
shuffle: bool = True, | ||
include_chirality: bool = False, | ||
) -> tuple[Subset, Subset]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think a better design is returning |
||
""" | ||
Creates datasets subsets with disjoint Murcko scaffolds based | ||
on provided SMILES strings. | ||
|
||
Note that for datasets that are small or not highly diverse, | ||
the final test set may be smaller than the specified test_size. | ||
|
||
Parameters | ||
---------- | ||
dataset : Dataset | ||
The dataset to split. | ||
smiles : Sequence[str] | ||
A list of SMILES strings. | ||
test_size : float | int | ||
The size of the test set. If float, should be between 0.0 and 1.0. | ||
If int, should be between 0 and len(smiles). | ||
seed : int, optional | ||
The random seed to use for shuffling, by default 0xDEADBEEF | ||
shuffle : bool, optional | ||
Whether to shuffle the indices, by default True | ||
include_chirality : bool, optional | ||
Whether to include chirality in the scaffold, by default False | ||
|
||
Returns | ||
------- | ||
tuple[Subset, Subset] | ||
The train and test subsets. | ||
|
||
References | ||
---------- | ||
- Bemis, G. W., & Murcko, M. A. (1996). The properties of known drugs. | ||
1. Molecular frameworks. Journal of medicinal chemistry, 39(15), 2887–2893. | ||
https://doi.org/10.1021/jm9602928 | ||
- "RDKit: Open-source cheminformatics. https://www.rdkit.org" | ||
""" | ||
train_idx, test_idx = _murcko_scaffold_split_indices( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Inline? |
||
smiles, | ||
test_size, | ||
seed=seed, | ||
shuffle=shuffle, | ||
include_chirality=include_chirality, | ||
) | ||
return Subset(dataset, train_idx), Subset(dataset, test_idx) | ||
|
||
|
||
def _murcko_scaffold_split_indices( | ||
smiles: list[str], | ||
test_size: float | int, | ||
*, | ||
seed: int = 0xDEADBEEF, | ||
shuffle: bool = True, | ||
include_chirality: bool = False, | ||
) -> tuple[list[int], list[int]]: | ||
""" | ||
Get train and test indices based on Murcko scaffolds.""" | ||
if not _RDKit_AVAILABLE: | ||
raise ImportError( | ||
"This function requires RDKit to be installed (pip install rdkit)" | ||
) | ||
|
||
if ( | ||
isinstance(test_size, int) and (test_size <= 0 or test_size >= len(smiles)) | ||
) or (isinstance(test_size, float) and (test_size <= 0 or test_size >= 1)): | ||
raise ValueError( | ||
f"Test_size should be a float in (0, 1) or and int < {len(smiles)}." | ||
) | ||
|
||
if isinstance(test_size, float): | ||
test_size = math.ceil(len(smiles) * test_size) | ||
|
||
scaffolds = defaultdict(list) | ||
|
||
for ind, s in enumerate(smiles): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Avoid abbreviations, e.g., use |
||
mol = Chem.MolFromSmiles(s) | ||
if mol is not None: | ||
scaffold = Chem.MolToSmiles( | ||
GetScaffoldForMol(mol), isomericSmiles=include_chirality | ||
) | ||
scaffolds[scaffold].append(ind) | ||
|
||
train_idx = [] | ||
test_idx = [] | ||
|
||
if shuffle: | ||
if seed is not None: | ||
random.Random(seed).shuffle(scaffolds) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use generator. |
||
else: | ||
random.shuffle(scaffolds) | ||
|
||
for index_list in scaffolds.values(): | ||
if len(test_idx) + len(index_list) <= test_size: | ||
test_idx = [*test_idx, *index_list] | ||
else: | ||
train_idx.extend(index_list) | ||
|
||
return train_idx, test_idx |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from importlib.util import find_spec | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make explicit. |
||
from unittest.mock import MagicMock, patch | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Use |
||
|
||
import pytest | ||
from beignet.subsets._murcko_scaffold_split import ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
_murcko_scaffold_split_indices, | ||
murcko_scaffold_split, | ||
) | ||
from torch.utils.data import Dataset, Subset | ||
|
||
_RDKit_AVAILABLE = find_spec("rdkit") is not None | ||
|
||
|
||
@pytest.mark.skipif(not _RDKit_AVAILABLE, reason="RDKit is not available") | ||
@patch("beignet.subsets._murcko_scaffold_split._murcko_scaffold_split_indices") | ||
def test_murcko_scaffold_split(mock__murcko_scaffold_split_indices): | ||
mock__murcko_scaffold_split_indices.return_value = ([0], [1]) | ||
|
||
mock_dataset = MagicMock(spec=Dataset) | ||
|
||
train_dataset, test_dataset = murcko_scaffold_split( | ||
dataset=mock_dataset, | ||
smiles=["C", "C"], | ||
test_size=0.5, | ||
shuffle=False, | ||
seed=0, | ||
) | ||
|
||
assert isinstance(train_dataset, Subset) | ||
assert isinstance(test_dataset, Subset) | ||
assert train_dataset.indices == [0] | ||
assert test_dataset.indices == [1] | ||
|
||
|
||
@pytest.mark.skipif(not _RDKit_AVAILABLE, reason="RDKit is not available") | ||
@pytest.mark.parametrize( | ||
"test_size, expected_train_idx, expected_test_idx", | ||
[ | ||
pytest.param(0.5, [2, 3], [0, 1], id="test_size is float"), | ||
pytest.param(2, [2, 3], [0, 1], id="test_size is int"), | ||
], | ||
) | ||
def test__murcko_scaffold_split_indices( | ||
test_size, expected_train_idx, expected_test_idx | ||
): | ||
smiles = ["C1CCCCC1", "C1CCCCC1", "CCO", "CCO"] | ||
|
||
train_idx, test_idx = _murcko_scaffold_split_indices( | ||
smiles, | ||
test_size=test_size, | ||
) | ||
assert train_idx == expected_train_idx | ||
assert test_idx == expected_test_idx | ||
|
||
|
||
@pytest.mark.skipif(not _RDKit_AVAILABLE, reason="RDKit is not available") | ||
@pytest.mark.parametrize( | ||
"smiles, test_size", | ||
[ | ||
pytest.param(["CCO"], 1.2, id="test_size is float > 1"), | ||
pytest.param(["CCO"], -1, id="test_size is negative"), | ||
pytest.param(["CCO"], 0, id="test_size is 0"), | ||
pytest.param(["CCO"], 5, id="test_size > len(smiles)"), | ||
], | ||
) | ||
def test__murcko_scaffold_split_indices_invalid_inputs(smiles, test_size): | ||
with pytest.raises(ValueError): | ||
_murcko_scaffold_split_indices( | ||
smiles, | ||
test_size=test_size, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make explicit.