Skip to content

Commit

Permalink
Merge pull request #55 from nanoporetech/training_features_pipes
Browse files Browse the repository at this point in the history
Add wrapper scripts to pipe training tensors directly to Tensor2Bin
  • Loading branch information
zhengzhenxian authored Sep 29, 2021
2 parents be99492 + 2b838d5 commit cc313f7
Show file tree
Hide file tree
Showing 10 changed files with 781 additions and 276 deletions.
2 changes: 2 additions & 0 deletions clair3.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
'RealignReads',
'CreateTensorPileup',
"CreateTensorFullAlignment",
'CreateTrainingTensor',
'SplitExtendBed',
'MergeBin',
'MergeVcf',
'SelectHetSnp',
'SelectCandidates',
Expand Down
229 changes: 120 additions & 109 deletions clair3/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import sys
import gc
import copy
import shlex
import os
import tables
import numpy as np
from random import random
from functools import partial

from clair3.task.main import *
from shared.interval_tree import bed_tree_from, is_region_in
Expand Down Expand Up @@ -206,63 +207,80 @@ def print_bin_size(path, prefix=None):
print('[INFO] total: {}'.format(total))


def bin_reader_generator_from(subprocess_list, Y, is_tree_empty, tree, miss_variant_set, is_allow_duplicate_chr_pos=False, non_variant_subsample_ratio=1.0):
def bin_reader_generator_from(tensor_fn, Y_true_var, Y, is_tree_empty, tree, miss_variant_set, is_allow_duplicate_chr_pos=False, maximum_non_variant_ratio=None):

"""
Bin reader generator for bin file generation.
subprocess_list: a list includes all tensor generator of each tensor file.
tensor_fn: tensor file.
Y_true_var: dictionary (contig name: label information) containing all true variant information (should not be changed).
Y: dictionary (contig name: label information) to store all variant and non variant information.
tree: dictionary(contig name : intervaltree) for quick region querying.
miss_variant_set: sometimes there will have true variant missing after downsampling reads.
is_allow_duplicate_chr_pos: whether allow duplicate positions when training, if there exists downsampled data, lower depth will add a random prefix character.
non_variant_subsample_ratio: define a maximum non variant ratio for training, we always expect use more non variant data, while it would greatly increase training
maximum_non_variant_ratio: define a maximum non variant ratio for training, we always expect use more non variant data, while it would greatly increase training
time, especially in ont data, here we usually use 1:1 or 1:2 for variant candidate: non variant candidate.
"""

X = {}
ref_list = []
total = 0
for f in subprocess_list:
for row_idx, row in enumerate(f.stdout):
chrom, coord, seq, string, alt_info = row.split("\t")
alt_info = alt_info.rstrip()
if not (is_tree_empty or is_region_in(tree, chrom, int(coord))):
continue
seq = seq.upper()
if seq[param.flankingBaseNum] not in 'ACGT':
continue
key = chrom + ":" + coord
is_reference = key not in Y

if key in miss_variant_set:
continue
for row_idx, row in enumerate(tensor_fn):
chrom, coord, seq, string, alt_info = row.split("\t")
alt_info = alt_info.rstrip()
if not (is_tree_empty or is_region_in(tree, chrom, int(coord))):
continue
seq = seq.upper()
if seq[param.flankingBaseNum] not in 'ACGT':
continue
key = chrom + ":" + coord
is_reference = key not in Y_true_var

if is_reference and non_variant_subsample_ratio < 1.0 and random() >= non_variant_subsample_ratio:
continue
if key not in X:
X[key] = (string, alt_info, seq)
elif is_allow_duplicate_chr_pos:
new_key = ""
for character in PREFIX_CHAR_STR:
tmp_key = character + key
if tmp_key not in X:
new_key = tmp_key
break
if len(new_key) > 0:
X[new_key] = (string, alt_info, seq)
if key in miss_variant_set:
continue

if key not in X:
X[key] = (string, alt_info, seq)
if is_reference:
Y[key] = output_labels_from_reference(BASE2BASE[seq[param.flankingBaseNum]])

if len(X) == shuffle_bin_size:
yield X, total
X = {}
total += 1
if total % 100000 == 0:
print("[INFO] Processed %d tensors" % total, file=sys.stderr)
f.stdout.close()
f.wait()
yield X, total
yield None, total
ref_list.append(key)
elif is_allow_duplicate_chr_pos:
new_key = ""
for character in PREFIX_CHAR_STR:
tmp_key = character + key
if tmp_key not in X:
new_key = tmp_key
break
if len(new_key) > 0:
X[new_key] = (string, alt_info, seq)
if is_reference:
ref_list.append(new_key)

if is_reference and key not in Y:
Y[key] = output_labels_from_reference(BASE2BASE[seq[param.flankingBaseNum]])

if len(X) == shuffle_bin_size:
if maximum_non_variant_ratio is not None:
_filter_non_variants(X, ref_list, maximum_non_variant_ratio)
yield X, total, False
X = {}
ref_list = []
total += 1
if total % 100000 == 0:
print("[INFO] Processed %d tensors" % total, file=sys.stderr)

if maximum_non_variant_ratio is not None:
_filter_non_variants(X, ref_list, maximum_non_variant_ratio)
yield X, total, True


def _filter_non_variants(X, ref_list, maximum_non_variant_ratio):
non_variant_num = len(ref_list)
variant_num = len(X) - non_variant_num
if non_variant_num > variant_num * maximum_non_variant_ratio:
non_variant_keep_fraction = maximum_non_variant_ratio * variant_num / (1. * non_variant_num)
probabilities = np.random.random_sample((non_variant_num,))
for key, p in zip(ref_list, probabilities):
if p > non_variant_keep_fraction:
X.pop(key)


def get_training_array(tensor_fn, var_fn, bed_fn, bin_fn, shuffle=True, is_allow_duplicate_chr_pos=True, chunk_id=None,
Expand All @@ -288,7 +306,8 @@ def get_training_array(tensor_fn, var_fn, bed_fn, bin_fn, shuffle=True, is_allow

tree = bed_tree_from(bed_file_path=bed_fn)
is_tree_empty = len(tree.keys()) == 0
Y, miss_variant_set = variant_map_from(var_fn, tree, is_tree_empty)
Y_true_var, miss_variant_set = variant_map_from(var_fn, tree, is_tree_empty)
Y = copy.deepcopy(Y_true_var)

global param
float_type = 'int32'
Expand All @@ -300,32 +319,12 @@ def get_training_array(tensor_fn, var_fn, bed_fn, bin_fn, shuffle=True, is_allow

tensor_shape = param.ont_input_shape if platform == 'ont' else param.input_shape

variant_num, non_variant_num, non_variant_subsample_ratio = 0, 0, 1.0
if maximum_non_variant_ratio is not None and candidate_details_fn_prefix:
candidate_details_fn_prefix = candidate_details_fn_prefix.split('/')
directry, file_prefix = '/'.join(candidate_details_fn_prefix[:-1]), candidate_details_fn_prefix[-1]
file_list = [f for f in os.listdir(directry) if f.startswith(file_prefix)]
for f in file_list:
fn = open(os.path.join(directry, f), 'r')
for row in fn:
chr_pos = row.split('\t')[0]
key = chr_pos.replace(' ', ':')
if key in Y:
variant_num += 1
else:
non_variant_num += 1
fn.close()

max_non_variant_num = variant_num * maximum_non_variant_ratio
if max_non_variant_num < non_variant_num:
non_variant_subsample_ratio = float(max_non_variant_num / non_variant_num)
print("[INFO] variants/non variants/subsample ratio: {}/{}/{}".format(variant_num, non_variant_num,
round(non_variant_subsample_ratio, 4)),
file=sys.stderr)
# select all match prefix if file path not exists
subprocess_list = []
if os.path.exists(tensor_fn):
subprocess_list.append(subprocess_popen(shlex.split("{} -fdc {}".format(param.zstd, tensor_fn))))
if tensor_fn == 'PIPE':
subprocess_list.append(sys.stdin)
elif os.path.exists(tensor_fn):
subprocess_list.append(subprocess_popen(shlex.split("{} -fdc {}".format(param.zstd, tensor_fn))).stdout)
# select all match prefix if file path not exists
else:
tensor_fn = tensor_fn.split('/')
directry, file_prefix = '/'.join(tensor_fn[:-1]), tensor_fn[-1]
Expand All @@ -346,7 +345,8 @@ def get_training_array(tensor_fn, var_fn, bed_fn, bin_fn, shuffle=True, is_allow
return 0
for file_name in all_file_name:
subprocess_list.append(
subprocess_popen(shlex.split("{} -fdc {}".format(param.zstd, os.path.join(directry, file_name)))))
subprocess_popen(shlex.split("{} -fdc {}".format(param.zstd, os.path.join(directry, file_name)))).stdout)

tables.set_blosc_max_threads(64)
int_atom = tables.Atom.from_dtype(np.dtype(float_type))
string_atom = tables.StringAtom(itemsize=param.no_of_positions + 50)
Expand All @@ -361,49 +361,60 @@ def get_training_array(tensor_fn, var_fn, bed_fn, bin_fn, shuffle=True, is_allow
table_dict = update_table_dict()

# generator to avoid high memory occupy
bin_reader_generator = bin_reader_generator_from(subprocess_list=subprocess_list,
Y=Y,
is_tree_empty=is_tree_empty,
tree=tree,
miss_variant_set=miss_variant_set,
is_allow_duplicate_chr_pos=is_allow_duplicate_chr_pos,
non_variant_subsample_ratio=non_variant_subsample_ratio)
total_compressed = 0
while True:
X, total = next(bin_reader_generator)
if X is None or not len(X):
break
all_chr_pos = sorted(X.keys())
if shuffle == True:
np.random.shuffle(all_chr_pos)
for key in all_chr_pos:

string, alt_info, seq = X[key]
del X[key]
label = None
if key in Y:
label = Y[key]
pos = key + ':' + seq
if not is_allow_duplicate_chr_pos:
del Y[key]
elif is_allow_duplicate_chr_pos:
tmp_key = key[1:]
label = Y[tmp_key]
pos = tmp_key + ':' + seq
if label is None:
print(key)
continue
total_compressed = write_table_dict(table_dict, string, label, pos, total_compressed, alt_info,
tensor_shape, pileup)

if total_compressed % 500 == 0 and total_compressed > 0:
table_dict = write_table_file(table_file, table_dict, tensor_shape, param.label_size, float_type)
bin_reader_generator = partial(bin_reader_generator_from,
Y_true_var=Y_true_var,
Y=Y,
is_tree_empty=is_tree_empty,
tree=tree,
miss_variant_set=miss_variant_set,
is_allow_duplicate_chr_pos=is_allow_duplicate_chr_pos,
maximum_non_variant_ratio=maximum_non_variant_ratio)

if total_compressed % 50000 == 0:
print("[INFO] Compressed %d tensor" % (total_compressed), file=sys.stderr)
total_compressed = 0
for fin in subprocess_list:
bin_g = bin_reader_generator(tensor_fn=fin)
completed = False
while not completed:
try:
X, total, completed = next(bin_g)
except StopIteration:
completed = True

if X is None or not len(X):
break
all_chr_pos = sorted(X.keys())
if shuffle == True:
np.random.shuffle(all_chr_pos)
for key in all_chr_pos:

string, alt_info, seq = X[key]
del X[key]
label = None
if key in Y:
label = Y[key]
pos = key + ':' + seq
if not is_allow_duplicate_chr_pos:
del Y[key]
elif is_allow_duplicate_chr_pos:
tmp_key = key[1:]
label = Y[tmp_key]
pos = tmp_key + ':' + seq
if label is None:
print(key)
continue
total_compressed = write_table_dict(table_dict, string, label, pos, total_compressed, alt_info,
tensor_shape, pileup)

if total_compressed % 500 == 0 and total_compressed > 0:
table_dict = write_table_file(table_file, table_dict, tensor_shape, param.label_size, float_type)

if total_compressed % 50000 == 0:
print("[INFO] Compressed %d tensor" % (total_compressed), file=sys.stderr)
fin.close()

if total_compressed % 500 != 0 and total_compressed > 0:
table_dict = write_table_file(table_file, table_dict, tensor_shape, param.label_size, float_type)

table_file.close()
print("[INFO] Compressed %d/%d tensor" % (total_compressed, total), file=sys.stderr)

Loading

0 comments on commit cc313f7

Please sign in to comment.