Skip to content

Commit

Permalink
Update EvoProtGrad class and download_esm_models function
Browse files Browse the repository at this point in the history
  • Loading branch information
ragnorc committed Apr 12, 2024
1 parent 5c332a5 commit 14d66ef
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
12 changes: 6 additions & 6 deletions helix/evoprotgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .main import stub


def download_esm_models(slugs: list[str] = ["facebook/esm1b_t33_650M_UR50S", "facebook/esm2_t33_650M_UR50D", "facebook/esm2_t36_3B_UR50D"]):
def download_esm_models(slugs: list[str] = ["facebook/esm1b_t33_650M_UR50S", "facebook/esm2_t33_650M_UR50D", "facebook/esm2_t36_3B_UR50D", "facebook/esm2_t48_15B_UR50D"]):
from transformers import EsmForMaskedLM, AutoTokenizer
for slug in slugs:
EsmForMaskedLM.from_pretrained(slug)
Expand All @@ -19,7 +19,7 @@ def download_esm_models(slugs: list[str] = ["facebook/esm1b_t33_650M_UR50S", "fa
"pandas").run_function(download_esm_models)


@stub.cls(gpu='A10G', timeout=2000, image=image, allow_cross_region_volumes=True, concurrency_limit=9)
@stub.cls(gpu='A100', timeout=2000, image=image, allow_cross_region_volumes=True, concurrency_limit=9)
class EvoProtGrad:
def __init__(self, experts: list[str] = ["esm"], device: str = "cuda"):
from evo_prot_grad import get_expert
Expand Down Expand Up @@ -50,7 +50,7 @@ def evolve(self, sequence: str, n_steps: int = 100, parallel_chains: int = 10, m


@stub.local_entrypoint()
def get_evoprotgrad_variants(sequence: str, output_csv_file: str = None, output_fasta_file: str = None, experts: str = "esm", n_steps: int = 100, num_chains: int = 20, max_mutations: int = -1, random_seed: int = None, concurrency_limit: int = 30):
def get_evoprotgrad_variants(sequence: str, output_csv_file: str = None, output_fasta_file: str = None, experts: str = "esm", n_steps: int = 100, num_chains: int = 20, max_mutations: int = -1, random_seed: int = None, batch_size: int = 9):
from .evoprotgrad import EvoProtGrad
from helix.utils import dataframe_to_fasta, count_mutations

Expand All @@ -61,13 +61,13 @@ def get_evoprotgrad_variants(sequence: str, output_csv_file: str = None, output_
raise Exception(
"Must specify either output_csv_file or output_fasta_file")

num_calls = num_chains // concurrency_limit
remaining_chains = num_chains % concurrency_limit
num_calls = num_chains // batch_size
remaining_chains = num_chains % batch_size
print(
f"Running {num_chains} parallel chains in {num_calls+1} containers")

results = []
args = [(sequence, n_steps, concurrency_limit, max_mutations, random_seed)
args = [(sequence, n_steps, batch_size, max_mutations, random_seed)
for _ in range(num_calls)]
if remaining_chains > 0:
args.append((sequence, n_steps, remaining_chains,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "helixbio"
version = "0.1.9"
version = "0.2.0"
description = ""
authors = ["Ragnor Comerford <hello@ragnor.co>"]
readme = "README.md"
Expand Down

0 comments on commit 14d66ef

Please sign in to comment.