-
Notifications
You must be signed in to change notification settings - Fork 0
/
prepare_dataset.py
119 lines (100 loc) · 3.85 KB
/
prepare_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import pandas as pd
import argparse
import rdkit.Chem as Chem
import os
import rdkit.Chem.MolStandardize.rdMolStandardize as rdMolStandardize
from profis.utils.finger import sparse2dense, smiles2sparse_KRFP, smiles2sparse_ECFP
from profis.utils.vectorizer import SMILESVectorizer
def prepare(data_path, gen_ecfp=False, gen_krfp=False, to_dense=True):
"""
Prepare the dataset by standardizing the SMILES strings, checking for compatibility with
the model's token alphabet and (optionally) generating fingerprints.
Args:
data_path (str): Path to the dataset.
gen_ecfp (bool): Generate ECFP fingerprints.
gen_krfp (bool): Generate KRFP fingerprints.
to_dense (bool): Convert sparse fingerprints to dense fingerprints.
"""
# handle input
if gen_ecfp and gen_krfp:
raise ValueError("Please choose only one fingerprint type to generate")
# Load the dataset
if not os.path.exists(data_path):
raise FileNotFoundError(f"File {data_path} not found")
df = pd.read_csv(data_path)
# Check the column names
df.columns = df.columns.str.lower()
if "fps" in df.columns and (gen_ecfp or gen_krfp):
raise ValueError("Fingerprint column already exists in the dataset")
for name in ["smiles", "activity", "fps"]:
if name not in df.columns:
raise ValueError(f"Column {name} not found in the dataset")
print(f"Loaded data from {data_path}")
# Standardize the SMILES strings and check for validity
df["smiles"] = df["smiles"].apply(try_standardize_smiles)
df = df.dropna(subset=["smiles"])
# Check for compatibility with the model's token alphabet
vectorizer = SMILESVectorizer()
df["is_compatible"] = df["smiles"].apply(
check_if_alphabet_compatibile, vectorizer=vectorizer
)
df = df[df["is_compatible"]]
# Generate fingerprints
if gen_ecfp:
df["fps"] = df["smiles"].apply(smiles2sparse_ECFP)
elif gen_krfp:
df["fps"] = df["smiles"].apply(smiles2sparse_KRFP)
# Convert sparse fingerprints to dense fingerprints
if to_dense or gen_ecfp or gen_krfp:
df["fps"] = df["fps"].apply(sparse2dense)
# Save the processed dataset
out_path = data_path.replace(".parquet", "_processed.parquet")
df.to_parquet(out_path, index=False)
print(f"Processed data saved to {out_path}")
return
def try_standardize_smiles(smiles):
"""
Standardize the SMILES string. Returns None if the SMILES string is invalid or cannot be standardized.
Args:
smiles (str): SMILES string.
Returns:
str: Standardized SMILES string.
"""
u = rdMolStandardize.Uncharger()
try:
mol = Chem.MolFromSmiles(smiles)
mol = rdMolStandardize.Cleanup(mol)
uncharged_mol = u.uncharge(mol)
except:
return None
return Chem.MolToSmiles(uncharged_mol)
def check_if_alphabet_compatibile(smiles, vectorizer):
"""
Check if the SMILES string is compatible with the model's token alphabet.
Args:
smiles (str): SMILES string.
vectorizer (SMILESVectorizer): SMILES vectorizer.
Returns:
bool: True if the SMILES string is compatible, False otherwise.
"""
try:
vectorizer.vectorize(smiles)
except ValueError:
return False
return True
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--data_path", type=str, help="Path to the dataset")
parser.add_argument(
"--gen_ecfp", action="store_true", help="Generate ECFP fingerprints"
)
parser.add_argument(
"--gen_krfp", action="store_true", help="Generate KRFP fingerprints"
)
parser.add_argument(
"--to_dense",
action="store_true",
help="Convert sparse fingerprints to dense fingerprints",
)
args = parser.parse_args()
prepare(args.data_path, args.skip_fingerprint)