Skip to content

Commit

Permalink
Add separate colabfold_relax script to relax after a prediction was done
Browse files Browse the repository at this point in the history
  • Loading branch information
milot-mirdita committed Nov 9, 2023
1 parent 94ddcfb commit 8606689
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 30 deletions.
31 changes: 1 addition & 30 deletions colabfold/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
setup_logging,
CFMMCIFIO,
)
from colabfold.relax import relax_me

from Bio.PDB import MMCIFParser, PDBParser, MMCIF2Dict
from Bio.PDB.PDBIO import Select
Expand Down Expand Up @@ -299,36 +300,6 @@ def pad_input(
) # template_mask (4, 4) second value
return input_fix

def relax_me(
pdb_filename=None,
pdb_lines=None,
pdb_obj=None,
use_gpu=False,
max_iterations=0,
tolerance=2.39,
stiffness=10.0,
max_outer_iterations=3
):
if "relax" not in dir():
from alphafold.common import residue_constants
from alphafold.relax import relax

if pdb_obj is None:
if pdb_lines is None:
pdb_lines = Path(pdb_filename).read_text()
pdb_obj = protein.from_pdb_string(pdb_lines)

amber_relaxer = relax.AmberRelaxation(
max_iterations=max_iterations,
tolerance=tolerance,
stiffness=stiffness,
exclude_residues=[],
max_outer_iterations=max_outer_iterations,
use_gpu=use_gpu)

relaxed_pdb_lines, _, _ = amber_relaxer.process(prot=pdb_obj)
return relaxed_pdb_lines

class file_manager:
def __init__(self, prefix: str, result_dir: Path):
self.prefix = prefix
Expand Down
108 changes: 108 additions & 0 deletions colabfold/relax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from pathlib import Path

def relax_me(
pdb_filename=None,
pdb_lines=None,
pdb_obj=None,
use_gpu=False,
max_iterations=0,
tolerance=2.39,
stiffness=10.0,
max_outer_iterations=3
):
from alphafold.common import protein
from alphafold.relax import relax

if pdb_obj is None:
if pdb_lines is None:
pdb_lines = Path(pdb_filename).read_text()
pdb_obj = protein.from_pdb_string(pdb_lines)

amber_relaxer = relax.AmberRelaxation(
max_iterations=max_iterations,
tolerance=tolerance,
stiffness=stiffness,
exclude_residues=[],
max_outer_iterations=max_outer_iterations,
use_gpu=use_gpu
)

relaxed_pdb_lines, _, _ = amber_relaxer.process(prot=pdb_obj)
return relaxed_pdb_lines

def main():
from argparse import ArgumentParser
import os
import glob
from tqdm import tqdm

parser = ArgumentParser()
parser.add_argument("input",
default="input",
help="Can be one of the following: "
"Directory with PDB files or a single PDB file",
)
parser.add_argument("results", help="Directory to write the results to or single output PDB file")
parser.add_argument(
"--max-iterations",
type=int,
default=2000,
help="Maximum number of iterations for the relaxation process. AlphaFold2 sets this to unlimited (0), however, we found that this can lead to very long relaxation times for some inputs."
)
parser.add_argument(
"--tolerance",
type=float,
default=2.39,
help="tolerance level for the relaxation convergence"
)
parser.add_argument(
"--stiffness",
type=float,
default=10.0,
help="stiffness parameter for the relaxation"
)
parser.add_argument(
"--max-outer-iterations",
type=int,
default=3,
help="maximum number of outer iterations for the relaxation process"
)
parser.add_argument("--use-gpu",
default=False,
action="store_true",
help="run amber on GPU instead of CPU",
)
args = parser.parse_args()

input_path = Path(args.input)
output_path = Path(args.results)
if output_path.is_dir():
output_path.mkdir(parents=True, exist_ok=True)

if input_path.is_dir():
pdb_files = glob.glob(str(input_path / "*.pdb"))
else:
pdb_files = [str(input_path)]

if len(pdb_files) > 1:
pdb_files = tqdm(pdb_files, desc="Processing PDB files")

for pdb_file in pdb_files:
relaxed_pdb = relax_me(
pdb_filename=pdb_file,
use_gpu=args.use_gpu,
max_iterations=args.max_iterations,
tolerance=args.tolerance,
stiffness=args.stiffness,
max_outer_iterations=args.max_outer_iterations
)
if output_path.is_dir():
output_file = output_path / Path(pdb_file).name
else:
output_file = output_path

with open(output_file, 'w') as file:
file.write(relaxed_pdb)

if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ addopts = "--tb=short"
colabfold_batch = 'colabfold.batch:main'
colabfold_search = 'colabfold.mmseqs.search:main'
colabfold_split_msas = 'colabfold.mmseqs.split_msas:main'
colabfold_relax = 'colabfold.relax:main'

[tool.black]
# Format only the new package, don't touch the existing stuff
Expand Down

0 comments on commit 8606689

Please sign in to comment.