Skip to content
This repository has been archived by the owner on Aug 1, 2024. It is now read-only.

Commit

Permalink
Merge pull request #371 from YaoYinYing/main
Browse files Browse the repository at this point in the history
re-PR for pretrained ESM weights issue.
  • Loading branch information
tomsercu authored Nov 26, 2022
2 parents 4f126ca + ab1119c commit 74d25cb
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 0 deletions.
60 changes: 60 additions & 0 deletions scripts/download_weights.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#! /bin/sh
# Usage: bash <repo>/esm/scripts/download_weights.sh /path/to/weights/
# re run the command to continue downloading incomplete files.
echo $0
run_dir=$(readlink -f $(dirname $0)) ;

if [[ "$(which aria2c)" == "" ]];then
echo aria2c is required for downloading!
exit 1
fi

for _sh_exec in /bin/sh bash zsh;do
# only bash and zsh are tested.
_sh_exec=$(which ${_sh_exec})
if [[ "$(${_sh_exec} --version)" =~ "bash" || "$(${_sh_exec} --version)" =~ "zsh" ]];then
echo $(which ${_sh_exec});
break
fi
done

model_pth=$1
if [[ "$model_pth" == "" ]];then
model_pth=$PWD
fi

mkdir -p $model_pth/checkpoints

pushd $model_pth/checkpoints
cat $run_dir/../README.md |grep -e '^|' |grep -e 'fair-esm/models' |tr -d '|' | \
awk 'BEGIN{print "set -e"};
{
# read urls starts with https
split($0,arr,"https");
url="https"arr[2];
# remove blank spaces after url string
split(url,url_arr," ")
url=url_arr[1]
# guessing regression pt url
url_regression=url;
sub("models","regression",url_regression);
sub(".pt","-contact-regression.pt",url_regression);
# downloading weight
url_basename_idx=split(url,url_arr,"/")
url_basename=url_arr[url_basename_idx]
print "if [[ ! -f "url_basename" || -f "url_basename".aria2 ]];then echo Download not complete: "url_basename";aria2c -x 10 "url";else echo Download complete: "url_basename";fi"
# downloading regression, if not existing and through an error, we just ignore it.
url_regression_basename_idx=split(url_regression,url_arr,"/")
url_regression_basename=url_arr[url_regression_basename_idx]
print "if [[ ! -f "url_regression_basename" || -f "url_regression_basename".aria2 ]];then echo Download not complete: "url_regression_basename";aria2c -x 10 "url_regression" 2>/dev/null || echo Never mind. "url_regression_basename" may not exist. ;else echo Download complete: "url_regression_basename";fi"
}' |$_sh_exec


echo "Your model directory is located at \`$(readlink -f ${model_pth})\`."

popd
14 changes: 14 additions & 0 deletions scripts/esmfold_inference.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@

from pathlib import Path
import sys,os
import argparse
import logging
import sys
Expand Down Expand Up @@ -83,6 +86,9 @@ def create_batched_sequence_datasest(
parser.add_argument(
"-o", "--pdb", help="Path to output PDB directory", type=Path, required=True
)
parser.add_argument(
"-m", "--model-dir", help="Parent path to Pretrained ESM data directory. ", type=Path, default=None
)
parser.add_argument(
"--num-recycles",
type=int,
Expand Down Expand Up @@ -121,7 +127,15 @@ def create_batched_sequence_datasest(
logger.info(f"Loaded {len(all_sequences)} sequences from {args.fasta}")

logger.info("Loading model")

# Use pre-downloaded ESM weights from model_pth.
if args.model_dir is not None:
# if pretrained model path is available
torch.hub.set_dir(args.model_dir)

model = esm.pretrained.esmfold_v1()


model = model.eval()
model.set_chunk_size(args.chunk_size)

Expand Down

0 comments on commit 74d25cb

Please sign in to comment.