diff --git a/ESMFold.ipynb b/ESMFold.ipynb index 65e9100f..74c15533 100644 --- a/ESMFold.ipynb +++ b/ESMFold.ipynb @@ -59,30 +59,33 @@ "%%time\n", "#@title install\n", "#@markdown install ESMFold, OpenFold and download Params (~2min 30s)\n", - "\n", + "version = \"1\" # @param [\"0\", \"1\"]\n", + "model_name = \"esmfold_v0.model\" if version == \"0\" else \"esmfold_v0.model\"\n", "import os, time\n", - "if not os.path.isfile(\"esmfold.model\"):\n", + "if not os.path.isfile(model_name):\n", " # download esmfold params\n", " os.system(\"apt-get install aria2 -qq\")\n", - " os.system(\"aria2c -q -x 16 https://colabfold.steineggerlab.workers.dev/esm/esmfold.model &\")\n", + " os.system(f\"aria2c -q -x 16 https://colabfold.steineggerlab.workers.dev/esm/{model_name} &\")\n", "\n", - " # install libs\n", - " os.system(\"pip install -q omegaconf pytorch_lightning biopython ml_collections einops py3Dmol\")\n", - " os.system(\"pip install -q git+https://github.com/NVIDIA/dllogger.git\")\n", + " if not os.path.isfile(\"finished_install\"):\n", + " # install libs\n", + " os.system(\"pip install -q omegaconf pytorch_lightning biopython ml_collections einops py3Dmol\")\n", + " os.system(\"pip install -q git+https://github.com/NVIDIA/dllogger.git\")\n", "\n", - " # install openfold\n", - " commit = \"6908936b68ae89f67755240e2f588c09ec31d4c8\"\n", - " os.system(f\"pip install -q git+https://github.com/aqlaboratory/openfold.git@{commit}\")\n", + " # install openfold\n", + " commit = \"6908936b68ae89f67755240e2f588c09ec31d4c8\"\n", + " os.system(f\"pip install -q git+https://github.com/aqlaboratory/openfold.git@{commit}\")\n", "\n", - " # install esmfold\n", - " os.system(f\"pip install -q git+https://github.com/sokrypton/esm.git\")\n", + " # install esmfold\n", + " os.system(f\"pip install -q git+https://github.com/sokrypton/esm.git\")\n", + " os.system(\"touch finished_install\")\n", "\n", " # wait for Params to finish downloading...\n", - " if not os.path.isfile(\"esmfold.model\"):\n", - " # backup source!\n", - " os.system(\"aria2c -q -x 16 https://files.ipd.uw.edu/pub/esmfold/esmfold.model\")\n", + " if not os.path.isfile(model_name):\n", + " print(\"ERROR: downloading esmfold params\")\n", " else:\n", - " while os.path.isfile(\"esmfold.model.aria2\"):\n", + " print(\"waiting for param download...\")\n", + " while os.path.isfile(f\"{model_name}.aria2\"):\n", " time.sleep(5)" ] }, @@ -94,14 +97,16 @@ "from string import ascii_uppercase, ascii_lowercase\n", "import hashlib, re, os\n", "import numpy as np\n", + "import torch\n", "from jax.tree_util import tree_map\n", "import matplotlib.pyplot as plt\n", "from scipy.special import softmax\n", + "import gc\n", "\n", "def parse_output(output):\n", " pae = (output[\"aligned_confidence_probs\"][0] * np.arange(64)).mean(-1) * 31\n", " plddt = output[\"plddt\"][0,:,1]\n", - " \n", + "\n", " bins = np.append(0,np.linspace(2.3125,21.6875,63))\n", " sm_contacts = softmax(output[\"distogram_logits\"],-1)[0]\n", " sm_contacts = sm_contacts[...,bins<8].sum(-1)\n", @@ -128,7 +133,7 @@ "if copies == \"\" or copies <= 0: copies = 1\n", "sequence = \":\".join([sequence] * copies)\n", "num_recycles = 3 #@param [\"0\", \"1\", \"2\", \"3\", \"6\", \"12\", \"24\"] {type:\"raw\"}\n", - "chain_linker = 25 \n", + "chain_linker = 25\n", "\n", "ID = jobname+\"_\"+get_hash(sequence)[:5]\n", "seqs = sequence.split(\":\")\n", @@ -141,10 +146,17 @@ "elif len(u_seqs) == 1: mode = \"homo\"\n", "else: mode = \"hetero\"\n", "\n", - "if \"model\" not in dir():\n", - " import torch\n", - " model = torch.load(\"esmfold.model\")\n", + "if \"model\" not in dir() or model_name != model_name_:\n", + " if \"model\" in dir():\n", + " # delete old model from memory\n", + " del model\n", + " gc.collect()\n", + " if torch.cuda.is_available():\n", + " torch.cuda.empty_cache()\n", + "\n", + " model = torch.load(model_name)\n", " model.eval().cuda().requires_grad_(False)\n", + " model_name_ = model_name\n", "\n", "# optimized for Tesla T4\n", "if length > 700:\n", @@ -193,7 +205,7 @@ " size=(800,480), hbondCutoff=4.0,\n", " Ls=None,\n", " animate=False):\n", - " \n", + "\n", " if chains is None:\n", " chains = 1 if Ls is None else len(Ls)\n", " view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js', width=size[0], height=size[1])\n", @@ -215,7 +227,7 @@ " view.addStyle({'and':[{'resn':\"GLY\"},{'atom':'CA'}]},\n", " {'sphere':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n", " view.addStyle({'and':[{'resn':\"PRO\"},{'atom':['C','O'],'invert':True}]},\n", - " {'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}}) \n", + " {'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n", " if show_mainchains:\n", " BB = ['C','O','N','CA']\n", " view.addStyle({'atom':BB},{'stick':{'colorscheme':f\"WhiteCarbon\",'radius':0.3}})\n",