diff --git a/RoseTTAFold2.ipynb b/RoseTTAFold2.ipynb index 72e2864a..0aa65fa6 100644 --- a/RoseTTAFold2.ipynb +++ b/RoseTTAFold2.ipynb @@ -57,11 +57,11 @@ "import os, time, sys\n", "os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"max_split_size_mb:512\"\n", "\n", - "if params == \"RF2_jan24\" and not os.path.isfile(\"RF2_jan24.tgz\"):\n", + "if params == \"RF2_jan24\" and not os.path.isfile(f\"{params}.tgz\"):\n", " # send param download into background\n", " os.system(\"(apt-get install aria2; aria2c -q -x 16 https://files.ipd.uw.edu/dimaio/RF2_jan24.tgz) &\")\n", "\n", - "if params == \"RF2_apr23\" and not os.path.isfile(\"RF2_apr23.tgz\"):\n", + "if params == \"RF2_apr23\" and not os.path.isfile(f\"{params}.tgz\"):\n", " # send param download into background\n", " os.system(\"(apt-get install aria2; aria2c -q -x 16 https://files.ipd.uw.edu/dimaio/RF2_apr23.tgz) &\")\n", "\n", @@ -86,19 +86,18 @@ " os.makedirs(\"hhsuite\", exist_ok=True)\n", " os.system(f\"curl -fsSL https://github.com/soedinglab/hh-suite/releases/download/v3.3.0/hhsuite-3.3.0-SSE2-Linux.tar.gz | tar xz -C hhsuite/\")\n", "\n", + "if not os.path.isfile(f\"{params}.pt\"):\n", + " time.sleep(5)\n", "\n", "if os.path.isfile(f\"{params}.tgz.aria2\"):\n", " print(\"downloading RoseTTAFold2 params\")\n", " while os.path.isfile(f\"{params}.tgz.aria2\"):\n", " time.sleep(5)\n", "\n", - "if params == \"RF2_jan24\":\n", - " model_params = f\"{params}.pt\"\n", - "if params == \"RF2_apr23\":\n", - " model_params = f\"weights/{params}.pt\"\n", - "\n", - "if not os.path.isfile(model_params):\n", + "if not os.path.isfile(f\"{params}.pt\"):\n", " os.system(f\"tar -zxvf {params}.tgz\")\n", + " if params == \"RF2_apr23\":\n", + " os.system(f\"mv weights/{params}.pt .\")\n", "\n", "if not \"IMPORTED\" in dir():\n", " if 'RoseTTAFold2/network' not in sys.path:\n", @@ -125,16 +124,16 @@ "\n", " IMPORTED = True\n", "\n", - "if not \"pred\" in dir() or model_params_sele != model_params:\n", + "if not \"pred\" in dir() or params_sele != params:\n", " from predict import Predictor\n", " print(\"compile RoseTTAFold2\")\n", "\n", " if (torch.cuda.is_available()):\n", - " pred = Predictor(model_params, torch.device(\"cuda:0\"))\n", + " pred = Predictor(f\"{params}.pt\", torch.device(\"cuda:0\"))\n", " else:\n", " print (\"WARNING: using CPU\")\n", - " pred = Predictor(model_params, torch.device(\"cpu\"))\n", - " model_params_sele = model_params\n", + " pred = Predictor(f\"{params}.pt\", torch.device(\"cpu\"))\n", + " params_sele = params\n", "\n", "def get_unique_sequences(seq_list):\n", " unique_seqs = list(OrderedDict.fromkeys(seq_list))\n",