diff --git a/RoseTTAFold2.ipynb b/RoseTTAFold2.ipynb index 699f9625..72e2864a 100644 --- a/RoseTTAFold2.ipynb +++ b/RoseTTAFold2.ipynb @@ -52,7 +52,7 @@ "source": [ "%%time\n", "#@title setup **RoseTTAFold2** (~1m)\n", - "params = \"RF2_jan24\" # @param [\"RF2_apr23\",\"RF2_jan24\"]\n", + "params = \"RF2_apr23\" # @param [\"RF2_apr23\",\"RF2_jan24\"]\n", "\n", "import os, time, sys\n", "os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"max_split_size_mb:512\"\n", @@ -92,7 +92,12 @@ " while os.path.isfile(f\"{params}.tgz.aria2\"):\n", " time.sleep(5)\n", "\n", - "if not os.path.isfile(f\"{params}.pt\"):\n", + "if params == \"RF2_jan24\":\n", + " model_params = f\"{params}.pt\"\n", + "if params == \"RF2_apr23\":\n", + " model_params = f\"weights/{params}.pt\"\n", + "\n", + "if not os.path.isfile(model_params):\n", " os.system(f\"tar -zxvf {params}.tgz\")\n", "\n", "if not \"IMPORTED\" in dir():\n", @@ -120,15 +125,16 @@ "\n", " IMPORTED = True\n", "\n", - "if not \"pred\" in dir():\n", + "if not \"pred\" in dir() or model_params_sele != model_params:\n", " from predict import Predictor\n", " print(\"compile RoseTTAFold2\")\n", - " model_params = f\"{params}.pt\"\n", + "\n", " if (torch.cuda.is_available()):\n", " pred = Predictor(model_params, torch.device(\"cuda:0\"))\n", " else:\n", " print (\"WARNING: using CPU\")\n", " pred = Predictor(model_params, torch.device(\"cpu\"))\n", + " model_params_sele = model_params\n", "\n", "def get_unique_sequences(seq_list):\n", " unique_seqs = list(OrderedDict.fromkeys(seq_list))\n", @@ -253,7 +259,7 @@ "use_dropout = False #@param {type:\"boolean\"}\n", "max_msa = 256 #@param [16, 32, 64, 128, 256, 512] {type:\"raw\"}\n", "random_seed = 0 #@param {type:\"integer\"}\n", - "num_models = 1 #@param [\"1\", \"2\", \"4\", \"8\", \"16\", \"32\"] {type:\"raw\"}\n", + "num_models = 1 #@param [\"1\", \"5\", \"10\", \"15\", \"20\", \"25\"] {type:\"raw\"}\n", "\n", "# process\n", "max_extra_msa = max_msa * 8\n", @@ -341,7 +347,7 @@ ], "metadata": { "cellView": "form", - "id": "_oJTZGgdeKkO" + "id": "Eh48KV70rQ03" }, "execution_count": null, "outputs": [] @@ -391,7 +397,7 @@ ], "metadata": { "cellView": "form", - "id": "53wdd2WX70o_" + "id": "3m0H-yCIrpc4" }, "execution_count": null, "outputs": [] @@ -409,6 +415,7 @@ "settings_path = f\"{jobname}/settings.txt\"\n", "with open(settings_path, \"w\") as text_file:\n", " text_file.write(f\"method=RoseTTAFold2\\n\")\n", + " text_file.write(f\"params={params}\\n\")\n", " text_file.write(f\"sequence={sequence}\\n\")\n", " text_file.write(f\"sym={sym}\\n\")\n", " text_file.write(f\"order={order}\\n\")\n",