Skip to content
This repository has been archived by the owner on Jun 18, 2024. It is now read-only.

Make bulk inference easier & output internal embeddings #1

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions alphafold/data/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,9 +726,17 @@ def _process_single_hit(
cif_path = os.path.join(mmcif_dir, hit_pdb_code + '.cif')
logging.info('Reading PDB entry from %s. Query: %s, template: %s',
cif_path, query_sequence, template_sequence)
# Fail if we can't find the mmCIF file.
with open(cif_path, 'r') as cif_file:
cif_string = cif_file.read()
try:
with open(cif_path, 'r') as cif_file:
cif_string = cif_file.read()
except FileNotFoundError:
# Don't fail hard if we can't find the mmCIF file.
# Changed this after seeing a failure because no PDB
# template file was found for 7byz, apparently because it was
# deprecated: https://www.rcsb.org/structure/removed/7byz.
# Not sure why this doesn't get handled above in _assess_hhsearch_hit.
return SingleHitResult(features=None, error=None, warning=f"Can't find CIF path {cif_path}!")


parsing_result = mmcif_parsing.parse(
file_id=hit_pdb_code, mmcif_string=cif_string)
Expand Down
10 changes: 6 additions & 4 deletions alphafold/model/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,10 @@ def make_data_config(

def tf_example_to_features(tf_example: tf.train.Example,
config: ml_collections.ConfigDict,
random_seed: int = 0) -> FeatureDict:
random_seed: int = 0,
num_res: int = None) -> FeatureDict:
"""Converts tf_example to numpy feature dictionary."""
num_res = int(tf_example.features.feature['seq_length'].int64_list.value[0])
num_res = num_res or int(tf_example.features.feature['seq_length'].int64_list.value[0])
cfg, feature_names = make_data_config(config, num_res=num_res)

if 'deletion_matrix_int' in set(tf_example.features.feature):
Expand Down Expand Up @@ -75,10 +76,11 @@ def tf_example_to_features(tf_example: tf.train.Example,

def np_example_to_features(np_example: FeatureDict,
config: ml_collections.ConfigDict,
random_seed: int = 0) -> FeatureDict:
random_seed: int = 0,
num_res: int = None) -> FeatureDict:
"""Preprocesses NumPy feature dict using TF pipeline."""
np_example = dict(np_example)
num_res = int(np_example['seq_length'][0])
num_res = num_res or int(np_example['seq_length'][0])
cfg, feature_names = make_data_config(config, num_res=num_res)

if 'deletion_matrix_int' in np_example:
Expand Down
16 changes: 12 additions & 4 deletions alphafold/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,10 @@ class RunModel:
"""Container for JAX model."""

def __init__(self,
name: str,
config: ml_collections.ConfigDict,
params: Optional[Mapping[str, Mapping[str, np.ndarray]]] = None):
self.name = name
self.config = config
self.params = params

Expand All @@ -60,7 +62,8 @@ def _forward_fn(batch):
batch,
is_training=False,
compute_loss=False,
ensemble_representations=True)
ensemble_representations=True,
return_representations=True)

self.apply = jax.jit(hk.transform(_forward_fn).apply)
self.init = jax.jit(hk.transform(_forward_fn).init)
Expand All @@ -87,13 +90,16 @@ def init_params(self, feat: features.FeatureDict, random_seed: int = 0):
def process_features(
self,
raw_features: Union[tf.train.Example, features.FeatureDict],
random_seed: int) -> features.FeatureDict:
random_seed: int,
num_res: int = None) -> features.FeatureDict:
"""Processes features to prepare for feeding them into the model.

Args:
raw_features: The output of the data pipeline either as a dict of NumPy
arrays or as a tf.train.Example.
random_seed: The random seed to use when processing the features.
num_res: Number of residues to crop/pad to. If absent, defaults to
the exact number of residues.

Returns:
A dict of NumPy feature arrays suitable for feeding into the model.
Expand All @@ -102,12 +108,14 @@ def process_features(
return features.np_example_to_features(
np_example=raw_features,
config=self.config,
random_seed=random_seed)
random_seed=random_seed,
num_res=num_res)
else:
return features.tf_example_to_features(
tf_example=raw_features,
config=self.config,
random_seed=random_seed)
random_seed=random_seed,
num_res=num_res)

def eval_shape(self, feat: features.FeatureDict) -> jax.ShapeDtypeStruct:
self.init_params(feat)
Expand Down
5 changes: 4 additions & 1 deletion docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,12 @@ RUN wget -q -P /app/alphafold/alphafold/common/ \
# Install pip packages.
RUN pip3 install --upgrade pip \
&& pip3 install -r /app/alphafold/requirements.txt \
&& pip3 install --upgrade jax jaxlib==0.1.69+cuda${CUDA/./} -f \
&& pip3 install --upgrade jax==0.2.18 jaxlib==0.1.69+cuda${CUDA/./} -f \
https://storage.googleapis.com/jax-releases/jax_releases.html




# Apply OpenMM patch.
WORKDIR /opt/conda/lib/python3.7/site-packages
RUN patch -p0 < /app/alphafold/docker/openmm.patch
Expand Down
54 changes: 47 additions & 7 deletions docker/run_docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import os
import signal
from typing import Tuple
import glob
import pathlib

from absl import app
from absl import flags
Expand Down Expand Up @@ -92,6 +94,8 @@
'All FASTA paths must have a unique basename as the '
'basename is used to name the output directories for '
'each prediction.')
flags.DEFINE_string('fastas_dir', None, 'Directory containing fasta inputs. '
'Alternative to fasta_paths.')
flags.DEFINE_string('max_template_date', None, 'Maximum template release date '
'to consider (ISO-8601 format - i.e. YYYY-MM-DD). '
'Important if folding historical test sets.')
Expand All @@ -106,6 +110,18 @@
'to obtain a timing that excludes the compilation time, '
'which should be more indicative of the time required for '
'inferencing many proteins.')
flags.DEFINE_integer('random_seed', None, 'The random seed for the data '
'pipeline. By default, this is randomly generated. Note '
'that even if this is set, Alphafold may still not be '
'deterministic, because processes like GPU inference are '
'nondeterministic.')
flags.DEFINE_boolean('features_only', False, 'Only search MSAs/templates and dump features. '
'Useful when running feature generation in parallel before doing GPU '
'inference in serial')
flags.DEFINE_integer('num_workers', 1, 'Number of parallel runs of feature generation. '
'Only applies if features_only')
flags.DEFINE_boolean('do_relax', True, 'Whether to run Amber relaxation. '
'If sequence contains Xs, Amber will fail.')

FLAGS = flags.FLAGS

Expand All @@ -118,7 +134,7 @@ def _create_mount(mount_name: str, path: str) -> Tuple[types.Mount, str]:
target_path = os.path.join(_ROOT_MOUNT_DIRECTORY, mount_name)
logging.info('Mounting %s -> %s', source_path, target_path)
mount = types.Mount(target_path, source_path, type='bind', read_only=True)
return mount, os.path.join(target_path, os.path.basename(path))
return mount, target_path, os.path.join(target_path, os.path.basename(path))


def main(argv):
Expand All @@ -128,11 +144,29 @@ def main(argv):
mounts = []
command_args = []

# Mount each fasta path as a unique target directory.
if bool(FLAGS.fasta_paths) == bool(FLAGS.fastas_dir):
raise app.UsageError('Either fasta_paths or fastas_dir must be specified.')
if FLAGS.fasta_paths:
fasta_paths = FLAGS.fasta_paths
logging.info('Using %d specified fasta paths', len(fasta_paths))
else:
fasta_paths = sorted(glob.glob( str(pathlib.Path(FLAGS.fastas_dir)/'*.fasta') ))
logging.info('Found %d .fasta files in directory %s', len(fasta_paths), FLAGS.fastas_dir)
if not fasta_paths:
raise ValueError("Empty fastas list.")

# Mount each unique directory containing a fasta path as a unique target directory.
target_fasta_paths = []
for i, fasta_path in enumerate(FLAGS.fasta_paths):
mount, target_path = _create_mount(f'fasta_path_{i}', fasta_path)
mounts.append(mount)
sourcedir_to_mountdir = {}
for i, fasta_path in enumerate(fasta_paths):
sourcedir = os.path.dirname(os.path.abspath(fasta_path))
mountdir = sourcedir_to_mountdir.get(sourcedir)
if mountdir:
target_path = os.path.join(mountdir, os.path.basename(fasta_path))
else:
mount, mountdir, target_path = _create_mount(f'fasta_path_{i}', fasta_path)
mounts.append(mount)
sourcedir_to_mountdir[sourcedir] = mountdir
target_fasta_paths.append(target_path)
command_args.append(f'--fasta_paths={",".join(target_fasta_paths)}')

Expand All @@ -153,19 +187,26 @@ def main(argv):
])
for name, path in database_paths:
if path:
mount, target_path = _create_mount(name, path)
mount, _, target_path = _create_mount(name, path)
mounts.append(mount)
command_args.append(f'--{name}={target_path}')

output_target_path = os.path.join(_ROOT_MOUNT_DIRECTORY, 'output')
mounts.append(types.Mount(output_target_path, output_dir, type='bind'))

# mount this repo directory to avoid re-building Docker container on code changes
mounts.append(types.Mount('/app/alphafold', 'ALPHAFOLD_REPO_DIR', type='bind'))

command_args.extend([
f'--output_dir={output_target_path}',
f'--model_names={",".join(model_names)}',
f'--max_template_date={FLAGS.max_template_date}',
f'--preset={FLAGS.preset}',
f'--benchmark={FLAGS.benchmark}',
f'--random_seed={FLAGS.random_seed}',
f'--features_only={FLAGS.features_only}',
f'--num_workers={FLAGS.num_workers}',
f'--do_relax={FLAGS.do_relax}',
'--logtostderr',
])

Expand Down Expand Up @@ -195,7 +236,6 @@ def main(argv):

if __name__ == '__main__':
flags.mark_flags_as_required([
'fasta_paths',
'max_template_date',
])
app.run(main)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ ml-collections==0.1.0
numpy==1.19.5
scipy==1.7.0
tensorflow-cpu==2.5.0
multiprocess==0.70.12.2
Loading