Skip to content

Commit

Permalink
changed argument lists to argument iterators
Browse files Browse the repository at this point in the history
  • Loading branch information
bjornwallner committed Mar 25, 2024
1 parent 691b86a commit c81209d
Showing 1 changed file with 57 additions and 44 deletions.
101 changes: 57 additions & 44 deletions src/DockQ/DockQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import traceback
import itertools
import math
from functools import lru_cache, wraps
from functools import lru_cache, wraps, partial
from argparse import ArgumentParser
from tqdm import tqdm
from parallelbar import progress_map
Expand Down Expand Up @@ -761,7 +761,7 @@ def run_on_all_native_interfaces(
capri_peptide=capri_peptide,
low_memory=False,
)
if info:
if info and not low_memory:
info["chain1"], info["chain2"] = (
chain_map[chain_pair[0]],
chain_map[chain_pair[1]],
Expand Down Expand Up @@ -885,6 +885,7 @@ def product_without_dupl(*args, repeat=1):
#result = set(list(map(lambda x: tuple(sorted(x)), result))) # to remove symmetric duplicates
for prod in result:
yield tuple(prod)

def count_chain_combinations(chain_clusters):
counts={}
for chain in chain_clusters:
Expand All @@ -893,29 +894,18 @@ def count_chain_combinations(chain_clusters):
counts[chains]=0
counts[chains]+=1
number_of_combinations=np.prod([math.factorial(a) for a in counts.values()])
return number_of_combinations
#combos=itertools.product(*[itertools.permutations(chains) for chains in set([tuple(ch) for ch in chain_clusters.values()])])
return(number_of_combinations,counts)

#return(number_of_combinations,counts)
#set(chain_clusters.values())


def get_all_mappings(
model_structure, native_structure, model_chains, native_chains,initial_mapping,allowed_mismatches=0
):
model_chains_to_combo = [mc for mc in model_chains if mc not in initial_mapping.values()]
native_chains_to_combo = [nc for nc in native_chains if nc not in initial_mapping.keys()]

chain_clusters, reverse_map = group_chains(
model_structure,
native_structure,
model_chains_to_combo,
native_chains_to_combo,
allowed_mismatches,
)

def get_all_chain_maps(chain_clusters,initial_mapping,reverse_map,model_chains_to_combo,native_chains_to_combo):
all_mappings = product_without_dupl(
*[cluster for cluster in chain_clusters.values() if cluster]
)
chain_maps=[]
for mapping in all_mappings:
chain_map = {key:value for key, value in initial_mapping.items()}
if reverse_map:
Expand All @@ -926,11 +916,7 @@ def get_all_mappings(
chain_map.update({
native_chain: mapping[i] for i, native_chain in enumerate(native_chains_to_combo)
})
chain_maps.append(chain_map)
return chain_maps

def run_on_all_native_interfaces_multi(args):
return run_on_all_native_interfaces(*args)
yield(chain_map)


#@profile
Expand All @@ -955,35 +941,62 @@ def main():
best_result = None
best_mapping = None

model_chains_to_combo = [mc for mc in model_chains if mc not in initial_mapping.values()]
native_chains_to_combo = [nc for nc in native_chains if nc not in initial_mapping.keys()]




chain_maps=get_all_mappings(
chain_clusters, reverse_map = group_chains(
model_structure,
native_structure,
model_chains,
native_chains,
initial_mapping,
model_chains_to_combo,
native_chains_to_combo,
args.allowed_mismatches,
)

num_chain_combinations=count_chain_combinations(chain_clusters)
chain_maps=get_all_chain_maps(chain_clusters,initial_mapping,reverse_map,model_chains_to_combo,native_chains_to_combo)

low_memory=num_chain_combinations>100
run_chain_map=partial(run_on_all_native_interfaces,
model_structure,
native_structure,
no_align=args.no_align,
use_CA_only=args.use_CA,
capri_peptide=args.capri_peptide,
low_memory=low_memory) ##args: chain_map

if num_chain_combinations>1:
#chunk_size=max(1,num_chain_combinations // args.n_cpu)
#I suspect large chunk_size will result in large input arguments to the workers.
chuck_size=128

#for large num_chain_combinations it should be possible to divide the chain_maps in chunks
result_this_mappings=progress_map(run_chain_map,chain_maps, total=num_chain_combinations,n_cpu=args.n_cpu, chunk_size=chunk_size)
#get a fresh iterator
chain_maps=get_all_chain_maps(chain_clusters,initial_mapping,reverse_map,model_chains_to_combo,native_chains_to_combo)
for chain_map,result_this_mapping in zip(chain_maps,result_this_mappings):
total_dockq = sum(
[result["DockQ_F1" if args.optDockQF1 else "DockQ"] for result in result_this_mapping.values()]
)
if total_dockq > best_dockq:
best_dockq = total_dockq
best_result = result_this_mapping
best_mapping = chain_map



low_memory=len(chain_maps) > 100
chain_map_args=[(model_structure,native_structure,chain_map,args.no_align,args.use_CA,args.capri_peptide,low_memory) for chain_map in chain_maps]
if len(chain_maps)>1:
chunk_size=max(1,len(chain_maps) // args.n_cpu)
result_this_mappings=progress_map(run_on_all_native_interfaces_multi,chain_map_args, n_cpu=args.n_cpu, chunk_size=chunk_size)
else: #skip multi-threading for single jobs (skip the bar basically)
result_this_mappings=[run_on_all_native_interfaces(*chain_map_arg) for chain_map_arg in chain_map_args]
for chain_map,result_this_mapping in zip(chain_maps,result_this_mappings):
total_dockq = sum(
[result["DockQ_F1" if args.optDockQF1 else "DockQ"] for result in result_this_mapping.values()]
)
if total_dockq > best_dockq:
best_dockq = total_dockq
best_result = result_this_mapping
best_mapping = chain_map
# result_this_mappings=[run_chain_map(chain_map) for chain_map in chain_maps]
for chain_maps in chain_maps:
result_this_mapping=run_chain_map(chain_map)
total_dockq = sum(
[result["DockQ_F1" if args.optDockQF1 else "DockQ"] for result in result_this_mapping.values()]
)
if total_dockq > best_dockq:
best_dockq = total_dockq
best_result = result_this_mapping
best_mapping = chain_map


if low_memory: #retrieve the full output by reruning the best chain mapping
best_result=run_on_all_native_interfaces(
model_structure,
Expand All @@ -1010,7 +1023,7 @@ def print_results(info, short=False, verbose=False, capri_peptide=False):
print(
f"Total DockQ over {len(info['best_result'])} native interfaces: {info['GlobalDockQ']:.3f} with {info['best_mapping_str']} model:native mapping"
)
# print(info["best_result"])
print(info["best_result"])
for chains, results in info["best_result"].items():
print(
f"DockQ{capri_peptide_str} {results['DockQ']:.3f} DockQ_F1 {results['DockQ_F1']:.3f} Fnat {results['fnat']:.3f} iRMS {results['irms']:.3f} LRMS {results['Lrms']:.3f} Fnonnat {results['fnonnat']:.3f} clashes {results['clashes']} mapping {results['chain1']}{results['chain2']}:{chains[0]}{chains[1]} {info['model']} {results['chain1']} {results['chain2']} -> {info['native']} {chains[0]} {chains[1]}"
Expand Down

0 comments on commit c81209d

Please sign in to comment.