Skip to content

Commit

Permalink
Merge pull request #6 from nickjcroucher/pyjar
Browse files Browse the repository at this point in the history
Changes to code for enhancement #1
  • Loading branch information
nickjcroucher authored Dec 18, 2020
2 parents 75e30e5 + d994063 commit e817f2d
Show file tree
Hide file tree
Showing 8 changed files with 845 additions and 173 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ cache:
- "build"
- "$HOME/.cache/pip"
python:
- "3.5"
- "3.8"
sudo: false
install:
- "source ./install_dependencies.sh"
Expand Down
2 changes: 1 addition & 1 deletion configure.ac
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ PKG_CHECK_MODULES([zlib], [zlib])
AC_CHECK_HEADERS([zlib.h math.h])

# Check for Python
AM_PATH_PYTHON([3.0],
AM_PATH_PYTHON([3.8],
[],
[AC_MSG_WARN([Python not found. Python is required to build presage python binding. Python can be obtained from http://www.python.org])])

Expand Down
336 changes: 253 additions & 83 deletions python/gubbins/common.py

Large diffs are not rendered by default.

62 changes: 38 additions & 24 deletions python/gubbins/pyjar.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
sys.stderr.write("This version of Gubbins requires python v3.8 or higher\n")
sys.exit(0)

from gubbins.utils import generate_shared_mem_array

####################################################
# Function to read an alignment in various formats #
####################################################
Expand Down Expand Up @@ -63,7 +65,7 @@ def read_info(infofile, type = 'raxml'):
print("Error: model information file " + infofile + " does not exist")
sys.exit()

if type not in ['raxml','iqtree','fasttree']:
if type not in ['raxml', 'raxmlng', 'iqtree','fasttree']:
sys.stderr.write('Only able to parse GTR-type models from raxml, iqtree or fasttree')
sys.exit()

Expand Down Expand Up @@ -106,6 +108,22 @@ def read_info(infofile, type = 'raxml'):
# order is ac ag at cg ct gt
words=line.split()
r=[float(words[9]), float(words[10]), float(words[11]), float(words[12]), float(words[13]), float(words[14])]
elif type == 'raxmlng':
sep_by_braces = line.replace('{','}').split('}')
if sep_by_braces[0] == "GTR":
r = [float(rate) for rate in sep_by_braces[1].split('/')]
f = [float(rate) for rate in sep_by_braces[3].split('/')]
elif sep_by_braces[0] == "K80":
sep_rates = [float(rate) for rate in sep_by_braces[1].split('/')]
r = [sep_rates[0], sep_rates[1], sep_rates[0], sep_rates[0], sep_rates[1], sep_rates[0]]
f = [0.25,0.25,0.25,0.25]
elif sep_by_braces[0] == "HKY":
sep_rates = [float(rate) for rate in sep_by_braces[1].split('/')]
r = [sep_rates[0], sep_rates[1], sep_rates[0], sep_rates[0], sep_rates[1], sep_rates[0]]
f = [float(rate) for rate in sep_by_braces[3].split('/')]
elif line.startswith("JC"):
f = [0.25,0.25,0.25,0.25]
r = [1.0,1.0,1.0,1.0,1.0,1.0]
elif type == 'iqtree':
if line.startswith('Base frequencies:'):
words=line.split()
Expand Down Expand Up @@ -182,7 +200,7 @@ def get_base_patterns(alignment, verbose):
print("Unique base patterns:", len(base_patterns))
return base_pattern_bases_array, square_base_pattern_positions_array

def reconstruct_alignment_column(column_indices, tree = None, alignment_sequence_names = None, ancestral_node_indices = None, base_patterns = None, base_pattern_positions = None, base_matrix = None, base_frequencies = None, new_aln = None, verbose = False):
def reconstruct_alignment_column(column_indices, tree = None, alignment_sequence_names = None, ancestral_node_indices = None, base_patterns = None, base_pattern_positions = None, base_matrix = None, base_frequencies = None, new_aln = None, threads = 1, verbose = False):

### TIMING
if verbose:
Expand All @@ -207,8 +225,12 @@ def reconstruct_alignment_column(column_indices, tree = None, alignment_sequence
base_pattern_positions = numpy.ndarray(base_pattern_positions.shape, dtype = base_pattern_positions.dtype, buffer = base_pattern_positions_shm.buf)

# Extract information for iterations
columns = base_patterns[column_indices].tolist()
column_positions = base_pattern_positions[column_indices,:]
if threads == 1:
columns = base_patterns
column_positions = base_pattern_positions
else:
columns = base_patterns[column_indices]
column_positions = base_pattern_positions[column_indices,:]

### TIMING
if verbose:
Expand Down Expand Up @@ -313,7 +335,7 @@ def reconstruct_alignment_column(column_indices, tree = None, alignment_sequence
max_root_base=None
max_root_base_likelihood=float("-inf")
for root_base in columnbases:
if node.L[root_base]>max_root_base_likelihood:
if node.L[root_base] > max_root_base_likelihood:
max_root_base_likelihood=node.L[root_base]
max_root_base=node.C[root_base]
node.r=max_root_base
Expand All @@ -323,7 +345,7 @@ def reconstruct_alignment_column(column_indices, tree = None, alignment_sequence

try:
#5a. Visit an unreconstructed internal node x whose father y has already been reconstructed. Denote by i the reconstructed amino acid at node y.
i=node.parent_node.r
i = node.parent_node.r
except AttributeError:
continue
#5b. Reconstruct node x by choosing Cx(i).
Expand All @@ -346,9 +368,9 @@ def reconstruct_alignment_column(column_indices, tree = None, alignment_sequence
reconstructed_alleles = {}
for node in tree.postorder_node_iter():
if node.is_leaf():
node.r=base[node.taxon.label]
node.r = base[node.taxon.label]
else:
has_child_base=False
has_child_base = False
for child in node.child_node_iter():
if child.r in bases:
has_child_base=True
Expand Down Expand Up @@ -430,10 +452,10 @@ def jar(alignment = None, base_patterns = None, base_pattern_positions = None, t
tree=read_tree(tree_filename)

# Read the info file and get frequencies and rates
if info_filename!="":
if info_filename != "":
if verbose:
print("Reading info file:", info_filename)
f, r=read_info(info_filename, type = info_filetype)
f,r = read_info(info_filename, type = info_filetype)
else:
if verbose:
print("Using default JC rates and frequencies")
Expand Down Expand Up @@ -476,22 +498,13 @@ def jar(alignment = None, base_patterns = None, base_pattern_positions = None, t
with SharedMemoryManager() as smm:

# Convert alignment to shared memory numpy array
new_aln_array_raw = smm.SharedMemory(size = new_aln_array.nbytes)
new_aln_shared_array = numpy.ndarray(new_aln_array.shape, dtype = new_aln_array.dtype, buffer = new_aln_array_raw.buf)
new_aln_shared_array[:] = new_aln_array[:]
new_aln_shared_array = NumpyShared(name = new_aln_array_raw.name, shape = new_aln_array.shape, dtype = new_aln_array.dtype)
new_aln_shared_array = generate_shared_mem_array(new_aln_array, smm)

# Convert base patterns to shared memory numpy array
base_patterns_raw = smm.SharedMemory(size = base_patterns.nbytes)
base_patterns_shared_array = numpy.ndarray(base_patterns.shape, dtype = base_patterns.dtype, buffer = base_patterns_raw.buf)
base_patterns_shared_array[:] = base_patterns[:]
base_patterns_shared_array = NumpyShared(name = base_patterns_raw.name, shape = base_patterns.shape, dtype = base_patterns.dtype)
base_patterns_shared_array = generate_shared_mem_array(base_patterns, smm)

# Convert base pattern positions to shared memory numpy array
base_pattern_positions_raw = smm.SharedMemory(size = base_pattern_positions.nbytes)
base_pattern_positions_shared_array = numpy.ndarray(base_pattern_positions.shape, dtype = base_pattern_positions.dtype, buffer = base_pattern_positions_raw.buf)
base_pattern_positions_shared_array[:] = base_pattern_positions[:]
base_pattern_positions_shared_array = NumpyShared(name = base_pattern_positions_raw.name, shape = base_pattern_positions.shape, dtype = base_pattern_positions.dtype)
base_pattern_positions_shared_array = generate_shared_mem_array(base_pattern_positions, smm)

# split list of sites into chunks per core
bp_list = list(range(len(base_patterns)))
Expand All @@ -509,6 +522,7 @@ def jar(alignment = None, base_patterns = None, base_pattern_positions = None, t
base_matrix = mb,
base_frequencies = f,
new_aln = new_aln_shared_array,
threads = threads,
verbose = verbose),
base_pattern_indices
)
Expand All @@ -520,10 +534,10 @@ def jar(alignment = None, base_patterns = None, base_pattern_positions = None, t
print("Printing alignment with internal node sequences: ", output_prefix+".joint.aln")
with open(output_prefix+".joint.aln", "w") as asr_output:
for taxon in alignment:
print(">"+taxon.id, file=asr_output)
print(">" + taxon.id, file = asr_output)
print(taxon.seq, file=asr_output)
for taxon in ancestral_node_indices:
print(">"+taxon, file=asr_output)
print(">" + taxon, file = asr_output)
print(''.join(out_aln[:,ancestral_node_indices[taxon]]), file=asr_output)

# Combine results for each base across the alignment
Expand Down
Loading

0 comments on commit e817f2d

Please sign in to comment.