Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add wrapper scripts to pipe training tensors directly to Tensor2Bin #55

Merged
merged 12 commits into from
Sep 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions clair3.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
'RealignReads',
'CreateTensorPileup',
"CreateTensorFullAlignment",
'CreateTrainingTensor',
'SplitExtendBed',
'MergeBin',
'MergeVcf',
'SelectHetSnp',
'SelectCandidates',
Expand Down
229 changes: 120 additions & 109 deletions clair3/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import sys
import gc
import copy
import shlex
import os
import tables
import numpy as np
from random import random
from functools import partial

from clair3.task.main import *
from shared.interval_tree import bed_tree_from, is_region_in
Expand Down Expand Up @@ -206,63 +207,80 @@ def print_bin_size(path, prefix=None):
print('[INFO] total: {}'.format(total))


def bin_reader_generator_from(subprocess_list, Y, is_tree_empty, tree, miss_variant_set, is_allow_duplicate_chr_pos=False, non_variant_subsample_ratio=1.0):
def bin_reader_generator_from(tensor_fn, Y_true_var, Y, is_tree_empty, tree, miss_variant_set, is_allow_duplicate_chr_pos=False, maximum_non_variant_ratio=None):

"""
Bin reader generator for bin file generation.
subprocess_list: a list includes all tensor generator of each tensor file.
tensor_fn: tensor file.
Y_true_var: dictionary (contig name: label information) containing all true variant information (should not be changed).
Y: dictionary (contig name: label information) to store all variant and non variant information.
tree: dictionary(contig name : intervaltree) for quick region querying.
miss_variant_set: sometimes there will have true variant missing after downsampling reads.
is_allow_duplicate_chr_pos: whether allow duplicate positions when training, if there exists downsampled data, lower depth will add a random prefix character.
non_variant_subsample_ratio: define a maximum non variant ratio for training, we always expect use more non variant data, while it would greatly increase training
maximum_non_variant_ratio: define a maximum non variant ratio for training, we always expect use more non variant data, while it would greatly increase training
time, especially in ont data, here we usually use 1:1 or 1:2 for variant candidate: non variant candidate.
"""

X = {}
ref_list = []
total = 0
for f in subprocess_list:
for row_idx, row in enumerate(f.stdout):
chrom, coord, seq, string, alt_info = row.split("\t")
alt_info = alt_info.rstrip()
if not (is_tree_empty or is_region_in(tree, chrom, int(coord))):
continue
seq = seq.upper()
if seq[param.flankingBaseNum] not in 'ACGT':
continue
key = chrom + ":" + coord
is_reference = key not in Y

if key in miss_variant_set:
continue
for row_idx, row in enumerate(tensor_fn):
chrom, coord, seq, string, alt_info = row.split("\t")
alt_info = alt_info.rstrip()
if not (is_tree_empty or is_region_in(tree, chrom, int(coord))):
continue
seq = seq.upper()
if seq[param.flankingBaseNum] not in 'ACGT':
continue
key = chrom + ":" + coord
is_reference = key not in Y_true_var

if is_reference and non_variant_subsample_ratio < 1.0 and random() >= non_variant_subsample_ratio:
continue
if key not in X:
X[key] = (string, alt_info, seq)
elif is_allow_duplicate_chr_pos:
new_key = ""
for character in PREFIX_CHAR_STR:
tmp_key = character + key
if tmp_key not in X:
new_key = tmp_key
break
if len(new_key) > 0:
X[new_key] = (string, alt_info, seq)
if key in miss_variant_set:
continue

if key not in X:
X[key] = (string, alt_info, seq)
if is_reference:
Y[key] = output_labels_from_reference(BASE2BASE[seq[param.flankingBaseNum]])

if len(X) == shuffle_bin_size:
yield X, total
X = {}
total += 1
if total % 100000 == 0:
print("[INFO] Processed %d tensors" % total, file=sys.stderr)
f.stdout.close()
f.wait()
yield X, total
yield None, total
ref_list.append(key)
elif is_allow_duplicate_chr_pos:
new_key = ""
for character in PREFIX_CHAR_STR:
tmp_key = character + key
if tmp_key not in X:
new_key = tmp_key
break
if len(new_key) > 0:
X[new_key] = (string, alt_info, seq)
if is_reference:
ref_list.append(new_key)

if is_reference and key not in Y:
Y[key] = output_labels_from_reference(BASE2BASE[seq[param.flankingBaseNum]])

if len(X) == shuffle_bin_size:
if maximum_non_variant_ratio is not None:
_filter_non_variants(X, ref_list, maximum_non_variant_ratio)
yield X, total, False
X = {}
ref_list = []
total += 1
if total % 100000 == 0:
print("[INFO] Processed %d tensors" % total, file=sys.stderr)

if maximum_non_variant_ratio is not None:
_filter_non_variants(X, ref_list, maximum_non_variant_ratio)
yield X, total, True


def _filter_non_variants(X, ref_list, maximum_non_variant_ratio):
non_variant_num = len(ref_list)
variant_num = len(X) - non_variant_num
if non_variant_num > variant_num * maximum_non_variant_ratio:
non_variant_keep_fraction = maximum_non_variant_ratio * variant_num / (1. * non_variant_num)
probabilities = np.random.random_sample((non_variant_num,))
for key, p in zip(ref_list, probabilities):
if p > non_variant_keep_fraction:
X.pop(key)


def get_training_array(tensor_fn, var_fn, bed_fn, bin_fn, shuffle=True, is_allow_duplicate_chr_pos=True, chunk_id=None,
Expand All @@ -288,7 +306,8 @@ def get_training_array(tensor_fn, var_fn, bed_fn, bin_fn, shuffle=True, is_allow

tree = bed_tree_from(bed_file_path=bed_fn)
is_tree_empty = len(tree.keys()) == 0
Y, miss_variant_set = variant_map_from(var_fn, tree, is_tree_empty)
Y_true_var, miss_variant_set = variant_map_from(var_fn, tree, is_tree_empty)
Y = copy.deepcopy(Y_true_var)

global param
float_type = 'int32'
Expand All @@ -300,32 +319,12 @@ def get_training_array(tensor_fn, var_fn, bed_fn, bin_fn, shuffle=True, is_allow

tensor_shape = param.ont_input_shape if platform == 'ont' else param.input_shape

variant_num, non_variant_num, non_variant_subsample_ratio = 0, 0, 1.0
if maximum_non_variant_ratio is not None and candidate_details_fn_prefix:
candidate_details_fn_prefix = candidate_details_fn_prefix.split('/')
directry, file_prefix = '/'.join(candidate_details_fn_prefix[:-1]), candidate_details_fn_prefix[-1]
file_list = [f for f in os.listdir(directry) if f.startswith(file_prefix)]
for f in file_list:
fn = open(os.path.join(directry, f), 'r')
for row in fn:
chr_pos = row.split('\t')[0]
key = chr_pos.replace(' ', ':')
if key in Y:
variant_num += 1
else:
non_variant_num += 1
fn.close()

max_non_variant_num = variant_num * maximum_non_variant_ratio
if max_non_variant_num < non_variant_num:
non_variant_subsample_ratio = float(max_non_variant_num / non_variant_num)
print("[INFO] variants/non variants/subsample ratio: {}/{}/{}".format(variant_num, non_variant_num,
round(non_variant_subsample_ratio, 4)),
file=sys.stderr)
# select all match prefix if file path not exists
subprocess_list = []
if os.path.exists(tensor_fn):
subprocess_list.append(subprocess_popen(shlex.split("{} -fdc {}".format(param.zstd, tensor_fn))))
if tensor_fn == 'PIPE':
subprocess_list.append(sys.stdin)
elif os.path.exists(tensor_fn):
subprocess_list.append(subprocess_popen(shlex.split("{} -fdc {}".format(param.zstd, tensor_fn))).stdout)
# select all match prefix if file path not exists
else:
tensor_fn = tensor_fn.split('/')
directry, file_prefix = '/'.join(tensor_fn[:-1]), tensor_fn[-1]
Expand All @@ -346,7 +345,8 @@ def get_training_array(tensor_fn, var_fn, bed_fn, bin_fn, shuffle=True, is_allow
return 0
for file_name in all_file_name:
subprocess_list.append(
subprocess_popen(shlex.split("{} -fdc {}".format(param.zstd, os.path.join(directry, file_name)))))
subprocess_popen(shlex.split("{} -fdc {}".format(param.zstd, os.path.join(directry, file_name)))).stdout)

tables.set_blosc_max_threads(64)
int_atom = tables.Atom.from_dtype(np.dtype(float_type))
string_atom = tables.StringAtom(itemsize=param.no_of_positions + 50)
Expand All @@ -361,49 +361,60 @@ def get_training_array(tensor_fn, var_fn, bed_fn, bin_fn, shuffle=True, is_allow
table_dict = update_table_dict()

# generator to avoid high memory occupy
bin_reader_generator = bin_reader_generator_from(subprocess_list=subprocess_list,
Y=Y,
is_tree_empty=is_tree_empty,
tree=tree,
miss_variant_set=miss_variant_set,
is_allow_duplicate_chr_pos=is_allow_duplicate_chr_pos,
non_variant_subsample_ratio=non_variant_subsample_ratio)
total_compressed = 0
while True:
X, total = next(bin_reader_generator)
if X is None or not len(X):
break
all_chr_pos = sorted(X.keys())
if shuffle == True:
np.random.shuffle(all_chr_pos)
for key in all_chr_pos:

string, alt_info, seq = X[key]
del X[key]
label = None
if key in Y:
label = Y[key]
pos = key + ':' + seq
if not is_allow_duplicate_chr_pos:
del Y[key]
elif is_allow_duplicate_chr_pos:
tmp_key = key[1:]
label = Y[tmp_key]
pos = tmp_key + ':' + seq
if label is None:
print(key)
continue
total_compressed = write_table_dict(table_dict, string, label, pos, total_compressed, alt_info,
tensor_shape, pileup)

if total_compressed % 500 == 0 and total_compressed > 0:
table_dict = write_table_file(table_file, table_dict, tensor_shape, param.label_size, float_type)
bin_reader_generator = partial(bin_reader_generator_from,
Y_true_var=Y_true_var,
Y=Y,
is_tree_empty=is_tree_empty,
tree=tree,
miss_variant_set=miss_variant_set,
is_allow_duplicate_chr_pos=is_allow_duplicate_chr_pos,
maximum_non_variant_ratio=maximum_non_variant_ratio)

if total_compressed % 50000 == 0:
print("[INFO] Compressed %d tensor" % (total_compressed), file=sys.stderr)
total_compressed = 0
for fin in subprocess_list:
bin_g = bin_reader_generator(tensor_fn=fin)
completed = False
while not completed:
try:
X, total, completed = next(bin_g)
except StopIteration:
completed = True

if X is None or not len(X):
break
all_chr_pos = sorted(X.keys())
if shuffle == True:
np.random.shuffle(all_chr_pos)
for key in all_chr_pos:

string, alt_info, seq = X[key]
del X[key]
label = None
if key in Y:
label = Y[key]
pos = key + ':' + seq
if not is_allow_duplicate_chr_pos:
del Y[key]
elif is_allow_duplicate_chr_pos:
tmp_key = key[1:]
label = Y[tmp_key]
pos = tmp_key + ':' + seq
if label is None:
print(key)
continue
total_compressed = write_table_dict(table_dict, string, label, pos, total_compressed, alt_info,
tensor_shape, pileup)

if total_compressed % 500 == 0 and total_compressed > 0:
table_dict = write_table_file(table_file, table_dict, tensor_shape, param.label_size, float_type)

if total_compressed % 50000 == 0:
print("[INFO] Compressed %d tensor" % (total_compressed), file=sys.stderr)
fin.close()

if total_compressed % 500 != 0 and total_compressed > 0:
table_dict = write_table_file(table_file, table_dict, tensor_shape, param.label_size, float_type)

table_file.close()
print("[INFO] Compressed %d/%d tensor" % (total_compressed, total), file=sys.stderr)

Loading