Skip to content

Commit

Permalink
update uni-mol+ (#130)
Browse files Browse the repository at this point in the history
* update oc20

* update readme

* Update README.md

* update readme
  • Loading branch information
guolinke authored Jul 7, 2023
1 parent 1d09255 commit 37b0198
Show file tree
Hide file tree
Showing 23 changed files with 1,295 additions and 196 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ Check this [subfolder](./unimol/) for more detalis.

Highly Accurate Quantum Chemical Property Prediction with Uni-Mol+
-------------------------------------------------------------------
[![arXiv](https://img.shields.io/badge/arXiv-2303.16982-00ff00.svg)](https://arxiv.org/abs/2303.16982) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/highly-accurate-quantum-chemical-property/graph-regression-on-pcqm4mv2-lsc)](https://paperswithcode.com/sota/graph-regression-on-pcqm4mv2-lsc?p=highly-accurate-quantum-chemical-property)
[![arXiv](https://img.shields.io/badge/arXiv-2303.16982-00ff00.svg)](https://arxiv.org/abs/2303.16982) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/highly-accurate-quantum-chemical-property/graph-regression-on-pcqm4mv2-lsc)](https://paperswithcode.com/sota/graph-regression-on-pcqm4mv2-lsc?p=highly-accurate-quantum-chemical-property) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/highly-accurate-quantum-chemical-property/initial-structure-to-relaxed-energy-is2re)](https://paperswithcode.com/sota/initial-structure-to-relaxed-energy-is2re?p=highly-accurate-quantum-chemical-property)

<p align="center"><img src="unimol_plus/figure/overview.png" width=80%></p>
<p align="center"><b>Schematic illustration of the Uni-Mol+ framework</b></p>

Uni-Mol+ is a model for quantum chemical property prediction. Firstly, given a 2D molecular graph, Uni-Mol+ generates an initial 3D conformation from inexpensive methods such as RDKit. Then, the initial conformation is iteratively optimized to its equilibrium conformation, and the optimized conformation is further used to predict the QC properties. In the PCQM4MV2 bencmark, Uni-Mol+ outperforms previous SOTA methods by a large margin.
Uni-Mol+ is a model for quantum chemical property prediction. Firstly, given a 2D molecular graph, Uni-Mol+ generates an initial 3D conformation from inexpensive methods such as RDKit. Then, the initial conformation is iteratively optimized to its equilibrium conformation, and the optimized conformation is further used to predict the QC properties. In the PCQM4MV2 and OC20 bencmarks, Uni-Mol+ outperforms previous SOTA methods by a large margin.

Check this [subfolder](./unimol_plus/) for more detalis.

Expand All @@ -50,6 +50,8 @@ Check this [subfolder](./unimol_tools/) for more detalis.

News
----
**Jul 7 2023**: We update a new version of Uni-Mol+, including the model setting for OC20 and a better performance on PCQM4MV2.

**Jun 9 2023**: We release Uni-Mol tools for property prediction, representation and downstreams.

**Mar 16 2023**: We release Uni-Mol+, a model for quantum chemical property prediction.
Expand Down
61 changes: 51 additions & 10 deletions unimol_plus/README.md
Original file line number Diff line number Diff line change
@@ -1,29 +1,41 @@
Highly Accurate Quantum Chemical Property Prediction with Uni-Mol+
==================================================================
[![arXiv](https://img.shields.io/badge/arXiv-2303.16982-00ff00.svg)](https://arxiv.org/abs/2303.16982) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/highly-accurate-quantum-chemical-property/graph-regression-on-pcqm4mv2-lsc)](https://paperswithcode.com/sota/graph-regression-on-pcqm4mv2-lsc?p=highly-accurate-quantum-chemical-property)
[![arXiv](https://img.shields.io/badge/arXiv-2303.16982-00ff00.svg)](https://arxiv.org/abs/2303.16982) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/highly-accurate-quantum-chemical-property/graph-regression-on-pcqm4mv2-lsc)](https://paperswithcode.com/sota/graph-regression-on-pcqm4mv2-lsc?p=highly-accurate-quantum-chemical-property) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/highly-accurate-quantum-chemical-property/initial-structure-to-relaxed-energy-is2re)](https://paperswithcode.com/sota/initial-structure-to-relaxed-energy-is2re?p=highly-accurate-quantum-chemical-property)


<p align="center"><img src="figure/overview.png" width=80%></p>
<p align="center"><b>Schematic illustration of the Uni-Mol+ framework</b></p>

Uni-Mol+ is a model for quantum chemical property prediction. Firstly, given a 2D molecular graph, Uni-Mol+ generates an initial 3D conformation from inexpensive methods such as RDKit. Then, the initial conformation is iteratively optimized to its equilibrium conformation, and the optimized conformation is further used to predict the QC properties.

In the [PCQM4MV2](https://ogb.stanford.edu/docs/lsc/leaderboards/#pcqm4mv2) bencmark, Uni-Mol+ outperforms previous SOTA methods by a large margin.
In the [PCQM4MV2](https://ogb.stanford.edu/docs/lsc/leaderboards/#pcqm4mv2) benchmark, Uni-Mol+ outperforms previous SOTA methods by a large margin.

| Model Settings | # Layers | # Param. | Validation MAE | Model Checkpoint |
|------------------|------------| ----------- |------------------|------------------|
| Uni-Mol+ | 12 | 52.4 M | 0.0708 | [link](https://github.com/dptech-corp/Uni-Mol/releases/download/v0.1/unimol_plus_base.pt) |
| Uni-Mol+ Large | 18 | 77 M | 0.0701 | [link](https://github.com/dptech-corp/Uni-Mol/releases/download/v0.1/unimol_plus_large.pt) |
| Uni-Mol+ | 12 | 52.4 M | 0.0696 | [link](https://github.com/dptech-corp/Uni-Mol/releases/download/v0.2/unimol_plus_pcq_base.pt) |
| Uni-Mol+ Large | 18 | 77 M | 0.0693 | [link](https://github.com/dptech-corp/Uni-Mol/releases/download/v0.2/unimol_plus_pcq_large.pt) |
| Uni-Mol+ Small | 6 | 27.7 M | 0.0714 | [link](https://github.com/dptech-corp/Uni-Mol/releases/download/v0.2/unimol_plus_pcq_small.pt) |


In the [OC20](https://opencatalystproject.org/leaderboard.html) IS2RE benchmark, Uni-Mol+ outperforms previous SOTA methods by a large margin.

| Model Settings | # Layers | # Param. | Validation Mean MAE | Test Mean MAE | Model Checkpoint |
|------------------|------------| ----------- |----------------------|----------------|------------------|
| Uni-Mol+ | 12 | 48.6 M | 0.4088 | 0.4143 | [link](https://github.com/dptech-corp/Uni-Mol/releases/download/v0.2/unimol_plus_oc20_base.pt) |


Dependencies
------------
- [Uni-Core](https://github.com/dptech-corp/Uni-Core), check its [Installation Documentation](https://github.com/dptech-corp/Uni-Core#installation).
- [Uni-Core](https://github.com/dptech-corp/Uni-Core) with pytorch > 2.0.0, check its [Installation Documentation](https://github.com/dptech-corp/Uni-Core#installation).
- rdkit==2022.09.3, install via `pip install rdkit==2022.09.3`


Data Preparation
----------------

#### PCQM4MV2


First, download the data:

```bash
Expand All @@ -46,27 +58,56 @@ python get_3d_lmdb.py test-dev
python get_3d_lmdb.py test-challenge
```

#### OC20

First, follow [this page](https://github.com/Open-Catalyst-Project/ocp/blob/main/DATASET.md) to download the OC20 IS2RE data.

Second, clean the download lmdb files by the following commands:

```bash
input_path="your_input_path" # The path to the original OC20 dataset
output_path="your_output_path"
python scripts/oc20_preprocess.py --input-path $input_path --out-path $output_path --split train
python scripts/oc20_preprocess.py --input-path $input_path --out-path $output_path --split valid
python scripts/oc20_preprocess.py --input-path $input_path --out-path $output_path --split test
```

Inference
---------
```bash
export data_path="your_data_path"
export results_path="your_result_path"
export weight_path="your_ckp_path"
export arch="unimol_plus_large" # or "unimol_plus_base" if you use 12-layer model
bash inference.sh test-dev # or other splits
export task=pcq # or "oc20" for oc20 task
export arch="unimol_plus_pcq_large" # or "unimol_plus_oc20_base" for oc20 arch
export batch_size=16
bash inference.sh test-dev # or other splits, OC20's test files are in ["test_id", "test_ood_ads", "test_ood_both", "test_ood_cat"]
```

Training
--------
After infernece, you can use the `scripts/make_pcq_test_dev_submission.py` (or `scripts/make_oc20_test_submission.py`) to generate the submission files for the leaderboard evalution.

Training PCQM4MV2
-----------------
```bash
data_path="your_data_path"
save_dir="your_save_path"
lr=2e-4
batch_size=128 # per gpu batch size 128, we default use 8 GPUs
export arch="unimol_plus_large" # or "unimol_plus_base" if you use 12-layer model
export arch="unimol_plus_pcq_large" # or "unimol_plus_pcq_base" if you use 12-layer model
bash train_pcq.sh $data_path $save_dir $lr $batch_size
```

Training OC20
-------------
```bash
data_path="your_data_path"
save_dir="your_save_path"
lr=2e-4
batch_size=8 # per gpu batch size 8, we default use 8 GPUs
export arch="unimol_plus_oc20_base"
bash train_oc20.sh $data_path $save_dir $lr $batch_size
```

Citation
--------

Expand Down
21 changes: 0 additions & 21 deletions unimol_plus/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,6 @@
logger = logging.getLogger("unimol.inference")


def predicted_lddt(plddt_logits: torch.Tensor) -> torch.Tensor:
"""Computes per-residue pLDDT from logits.
Args:
logits: [num_res, num_bins] output from the PredictedLDDTHead.
Returns:
plddt: [num_res] per-residue pLDDT.
"""
num_bins = plddt_logits.shape[-1]
bin_probs = torch.nn.functional.softmax(plddt_logits.float(), dim=-1)
bin_width = 1.0 / num_bins
bounds = torch.arange(
start=0.5 * bin_width, end=1.0, step=bin_width, device=plddt_logits.device
)
plddt = torch.sum(
bin_probs * bounds.view(*((1,) * len(bin_probs.shape[:-1])), *bounds.shape),
dim=-1,
)
return plddt


def masked_mean(mask, value, dim, eps=1e-10, keepdim=False):
mask = mask.expand(*value.shape)
return torch.sum(mask * value, dim=dim, keepdim=keepdim) / (
Expand All @@ -53,7 +33,6 @@ def masked_mean(mask, value, dim, eps=1e-10, keepdim=False):


def main(args):

assert (
args.batch_size is not None
), "Must specify batch size either with --batch-size"
Expand Down
11 changes: 6 additions & 5 deletions unimol_plus/inference.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@
[ -z "${OMPI_COMM_WORLD_SIZE}" ] && OMPI_COMM_WORLD_SIZE=1
[ -z "${OMPI_COMM_WORLD_RANK}" ] && OMPI_COMM_WORLD_RANK=0


[ -z "${arch}" ] && arch=unimol_plus_base
[ -z "${arch}" ] && arch=unimol_plus_pcq_base
[ -z "${task}" ] && task=pcq
[ -z "${batch_size}" ] && batch_size=128

mkdir -p $results_path

python -m torch.distributed.launch --nproc_per_node=$n_gpu --master_port $MASTER_PORT --nnodes=$OMPI_COMM_WORLD_SIZE --node_rank=$OMPI_COMM_WORLD_RANK --master_addr=$MASTER_IP \
torchrun --nproc_per_node=$n_gpu --nnodes=$OMPI_COMM_WORLD_SIZE --node_rank=$OMPI_COMM_WORLD_RANK --master_addr=$MASTER_IP --master_port=$MASTER_PORT \
./inference.py --user-dir ./unimol_plus/ $data_path --valid-subset $1 \
--results-path $results_path \
--num-workers 8 --ddp-backend=c10d --batch-size 128 \
--task unimol_plus --loss unimol_plus --arch $arch \
--num-workers 8 --ddp-backend=c10d --batch-size $batch_size \
--task $task --loss unimol_plus --arch $arch \
--path $weight_path \
--fp16 --fp16-init-scale 4 --fp16-scale-window 256 \
--log-interval 50 --log-format simple --label-prob 0.0
45 changes: 45 additions & 0 deletions unimol_plus/scripts/make_oc20_test_submission.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import numpy as np
import torch
import pickle
import os, sys
import glob

input_folder = sys.argv[1]

subsets = ["test_id", "test_ood_ads", "test_ood_both", "test_ood_cat"]


def flatten(d, index):
res = []
for x in d:
res.extend(x[index])
return np.array(res)


def one_ckp(folder, subset):
s = f"{folder}/" + subset + "*.pkl"
files = sorted(glob.glob(s))
data = []
for file in files:
with open(file, "rb") as f:
try:
data.extend(pickle.load(f))
except Exception as e:
print("Error in file: ", file)
raise e

id = flatten(data, 0)
y_pred = flatten(data, 2)

return np.array(id), np.array(y_pred)


submission_file = {}

for subset in subsets:
id, y_pred = one_ckp(input_folder, subset)
prefix = "_".join(subset.split("_")[1:])
submission_file[prefix + "_ids"] = id
submission_file[prefix + "_energy"] = y_pred

np.savez_compressed(sys.argv[2], **submission_file)
104 changes: 104 additions & 0 deletions unimol_plus/scripts/oc20_preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import os
import lmdb
import pickle
from multiprocessing import Pool
from tqdm import tqdm
import numpy as np
import argparse
from multiprocessing import cpu_count

nthreads = cpu_count()


def inner_read(cursor):
key, value = cursor
data = pickle.loads(value)

if "y_relaxed" in data:
ret_data = {
"cell": data["cell"].numpy().astype(np.float32),
"pos": data["pos"].numpy().astype(np.float32),
"atomic_numbers": data["atomic_numbers"].numpy().astype(np.int8),
"tags": data["tags"].numpy().astype(np.int8),
"pos_relaxed": data["pos_relaxed"].numpy().astype(np.float32),
"y_relaxed": data["y_relaxed"],
"sid": data["sid"],
}
else:
ret_data = {
"cell": data["cell"].numpy().astype(np.float32),
"pos": data["pos"].numpy().astype(np.float32),
"atomic_numbers": data["atomic_numbers"].numpy().astype(np.int8),
"tags": data["tags"].numpy().astype(np.int8),
"sid": data["sid"],
}
return data["sid"], pickle.dumps(ret_data)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="generate lmdb file")
parser.add_argument("--input-path", type=str, help="initial oc20 data file path")
parser.add_argument("--out-path", type=str, help="output path")
parser.add_argument("--split", type=str, help="train/valid/test")
args = parser.parse_args()

train_list = ["train"]
valid_list = ["val_id", "val_ood_ads", "val_ood_both", "val_ood_cat"]
test_list = ["test_id", "test_ood_ads", "test_ood_both", "test_ood_cat"]
path = args.input_path
out_path = args.out_path

if args.split == "train":
name_list = train_list
elif args.split == "valid":
name_list = valid_list
elif args.split == "test":
name_list = test_list

file_list = [os.path.join(path, name, "data.lmdb") for name in name_list]
with Pool(nthreads) as pool:
for filename, outname in zip(file_list, name_list):
i = 0
env = lmdb.open(
filename,
subdir=False,
readonly=True,
lock=False,
readahead=True,
meminit=False,
max_readers=nthreads,
map_size=int(1000e9),
)
txn = env.begin()

out_dir = os.path.join(out_path, outname)
if not os.path.exists(out_dir):
os.mkdir(out_dir)
outputfilename = os.path.join(out_dir, "data.lmdb")
try:
os.remove(outputfilename)
except:
pass

env_new = lmdb.open(
outputfilename,
subdir=False,
readonly=False,
lock=False,
readahead=False,
meminit=False,
max_readers=1,
map_size=int(1000e9),
)
txn_write = env_new.begin(write=True)
for inner_output in tqdm(
pool.imap(inner_read, txn.cursor()), total=env.stat()["entries"]
):
txn_write.put(f"{inner_output[0]}".encode("ascii"), inner_output[1])
i += 1
if i % 1000 == 0:
txn_write.commit()
txn_write = env_new.begin(write=True)
txn_write.commit()
env_new.close()
env.close()
Loading

0 comments on commit 37b0198

Please sign in to comment.