Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small fixes #53

Merged
merged 8 commits into from
Sep 29, 2021
64 changes: 38 additions & 26 deletions clair3/CallVariants.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,11 +686,8 @@ def output_from(
hetero_InsDel_length_tuples, hetero_InsDel_probabilities,
) = all_pro
maximum_probability = 0.0
maximum_loops = 4
reference_base, alternate_base = None, None
loop_index = 0
while (reference_base is None or alternate_base is None) and loop_index < maximum_loops:
loop_index += 1
while (reference_base is None or alternate_base is None):
maximum_probability = max(
homo_Ref_probability,
max(homo_SNP_probabilities),
Expand Down Expand Up @@ -723,32 +720,37 @@ def output_from(

if is_homo_SNP:
reference_base = reference_sequence[tensor_position_center]
idx = homo_SNP_probabilities.index(maximum_probability)
base1, base2 = homo_SNP_bases_from(homo_SNP_probabilities)
alternate_base = base1 if base1 != reference_base else base2
sorted_alt_bases, alternate_base = find_alt_base(alt_info_dict, alternate_base)
if alternate_base is None or alternate_base == reference_base:
homo_SNP_probabilities[idx] = 0
continue

elif is_hetero_SNP:
base1, base2 = hetero_SNP_bases_from(hetero_SNP_probabilities)
idx = hetero_SNP_probabilities.index(maximum_probability)
reference_base = reference_sequence[tensor_position_center]
is_multi = base1 != reference_base and base2 != reference_base
if is_multi:
sorted_alt_bases, _ = find_alt_base(alt_info_dict)
if len(sorted_alt_bases) == 0:
break
if len(sorted_alt_bases) < 2:
alternate_base = sorted_alt_bases[0]
hetero_SNP_probabilities[np.argmax(hetero_SNP_probabilities)] = 0.0
break
hetero_SNP_probabilities[idx] = 0
continue
alternate_base = ','.join(sorted_alt_bases[:2])
else:
alternate_base = base1 if base1 != reference_base else base2
sorted_alt_bases, alternate_base = find_alt_base(alt_info_dict, alternate_base)
if alternate_base is None or alternate_base == reference_base:
hetero_SNP_probabilities[idx] = 0
continue


elif is_homo_insertion:
variant_length = None
idx = homo_Ins_probabilities.index(maximum_probability)
if add_indel_length:
idx = homo_Ins_probabilities.index(maximum_probability)
variant_length = homo_Ins_lengths[idx]
insertion_bases = insertion_bases_using_alt_info_from(
alt_info_dict=alt_info_dict,
Expand All @@ -757,7 +759,8 @@ def output_from(

insertion_length = len(insertion_bases)
if insertion_length == 0:
break
homo_Ins_probabilities[idx] = 0
continue
reference_base = reference_sequence[tensor_position_center]
alternate_base = insertion_bases

Expand All @@ -776,25 +779,27 @@ def output_from(
)
insertion_length = len(insertion_bases)
if insertion_length == 0:
break
hetero_ACGT_Ins_probabilities[idx] = 0
continue
reference_base = reference_sequence[tensor_position_center]
alternate_base = insertion_bases

is_SNP_Ins_multi = hetero_Ins_base != reference_base
if is_SNP_Ins_multi:
sorted_alt_bases, _ = find_alt_base(alt_info_dict)
if len(sorted_alt_bases) == 0:
break
hetero_ACGT_Ins_probabilities[idx] = 0
continue
else:
alternate_base = "{},{}".format(sorted_alt_bases[0], alternate_base)

elif is_hetero_InsIns:
insertion_bases_list = []
idx = hetero_InsIns_probabilities.index(maximum_probability)
if add_indel_length:
idx = hetero_InsIns_probabilities.index(maximum_probability)
variant_length_1, variant_length_2 = hetero_InsIns_length_tuples[idx]
del hetero_InsIns_probabilities[idx]
del hetero_InsIns_length_tuples[idx]
# del hetero_InsIns_probabilities[idx]
# del hetero_InsIns_length_tuples[idx]

insertion_bases1 = insertion_bases_using_alt_info_from(
alt_info_dict=alt_info_dict,
Expand All @@ -819,7 +824,8 @@ def output_from(
return_multi=True
)
if len(insertion_bases_list) < 2:
break
hetero_InsIns_probabilities[idx] = 0
continue
insertion_bases, another_insertion_bases = insertion_bases_list

reference_base = reference_sequence[tensor_position_center]
Expand All @@ -830,12 +836,13 @@ def output_from(
if alternate_base_1 != alternate_base_2:
alternate_base = "{},{}".format(alternate_base_1, alternate_base_2)
else:
reference_base, alternate_base = None, None
hetero_InsIns_probabilities[idx] = 0
continue

elif is_homo_deletion:
variant_length = None
idx = homo_Del_probabilities.index(maximum_probability)
if add_indel_length:
idx = homo_Del_probabilities.index(maximum_probability)
variant_length = homo_Del_lengths[idx]

deletion_bases = deletion_bases_using_alt_info_from(
Expand All @@ -844,7 +851,8 @@ def output_from(
)
deletion_length = len(deletion_bases)
if deletion_length == 0:
break
homo_Del_probabilities[idx] = 0
continue
reference_base = reference_sequence[tensor_position_center] + deletion_bases
alternate_base = reference_base[0]

Expand All @@ -862,7 +870,8 @@ def output_from(
)
deletion_length = len(deletion_bases)
if deletion_length == 0:
break
hetero_ACGT_Del_probabilities[idx] = 0
continue
reference_base = reference_sequence[tensor_position_center] + deletion_bases
alternate_base = reference_base[0]

Expand All @@ -874,8 +883,8 @@ def output_from(

elif is_hetero_DelDel:
deletion_bases_list = []
idx = hetero_DelDel_probabilities.index(maximum_probability)
if add_indel_length:
idx = hetero_DelDel_probabilities.index(maximum_probability)
variant_length_1, variant_length_2 = sorted(hetero_DelDel_length_tuples[idx],
reverse=True) # longer deletion should be in first position
deletion_base1 = deletion_bases_using_alt_info_from(
Expand Down Expand Up @@ -903,7 +912,8 @@ def output_from(
)

if len(deletion_bases_list) < 2:
break
hetero_DelDel_probabilities[idx] = 0
continue

deletion_bases, deletion_bases1 = deletion_bases_list

Expand All @@ -918,12 +928,13 @@ def output_from(
):
alternate_base = "{},{}".format(alternate_base_1, alternate_base_2)
else:
reference_base, alternate_base = None, None
hetero_DelDel_probabilities[idx] = 0
continue

elif is_insertion_and_deletion:
variant_length_1, variant_length_2 = None, None
idx = hetero_InsDel_probabilities.index(maximum_probability)
if add_indel_length:
idx = hetero_InsDel_probabilities.index(maximum_probability)
variant_length_1, variant_length_2 = hetero_InsDel_length_tuples[idx]

insertion_bases = insertion_bases_using_alt_info_from(
Expand All @@ -939,7 +950,8 @@ def output_from(
deletion_length = len(deletion_bases)

if insertion_length == 0 or deletion_length == 0:
break
hetero_InsDel_probabilities[idx] = 0
continue
reference_base = reference_sequence[tensor_position_center] + deletion_bases
alternate_base = "{},{}".format(
reference_base[0],
Expand Down
10 changes: 8 additions & 2 deletions clair3/Train.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,9 @@ def DataGenerator(x, data_size, shuffle_chunk_list, train_flag=True):
optimizer=optimizer
)
early_stop_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, mode="min")
model_save_callbakck = tf.keras.callbacks.ModelCheckpoint(ochk_prefix + ".{epoch:02d}", period=1, save_weights_only=False)
model_save_callback = tf.keras.callbacks.ModelCheckpoint(ochk_prefix + ".{epoch:02d}", period=1, save_weights_only=False)
model_best_callback = tf.keras.callbacks.ModelCheckpoint("best_val_loss", monitor='val_loss', save_best_only=True, mode="min")
train_log_callback = tf.keras.callbacks.CSVLogger("training.log", separator='\t')

# Use first 20 element to initialize tensorflow model using graph mode
output = model(np.array(table_dataset_list[0].root.position_matrix[:20]))
Expand All @@ -228,11 +230,15 @@ def DataGenerator(x, data_size, shuffle_chunk_list, train_flag=True):
validate_dataset = validate_dataset if add_validation_dataset else None
if args.chkpnt_fn is not None:
model.load_weights(args.chkpnt_fn)
logging.info("[INFO] Starting from model {}".format(args.chkpnt_fn))

train_history = model.fit(x=train_dataset,
epochs=max_epoch,
validation_data=validate_dataset,
callbacks=[early_stop_callback, model_save_callbakck],
callbacks=[early_stop_callback,
model_save_callback,
model_best_callback,
train_log_callback],
verbose=1,
shuffle=False)

Expand Down
2 changes: 1 addition & 1 deletion preprocess/CreateTensorPileup.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def generate_tensor(pos, pileup_bases, reference_sequence, reference_start, refe

minimum_snp_af_for_candidate = minimum_snp_af_for_candidate if minimum_snp_af_for_candidate > 0 else param.min_af
minimum_snp_af_for_candidate = max(minimum_snp_af_for_candidate, param.min_af_dict[platform]) if fast_mode else minimum_snp_af_for_candidate
minimum_indel_af_for_candidate = max(minimum_indel_af_for_candidate, param.min_af_dict[platform]) if minimum_indel_af_for_candidate > 0 else param.min_af_dict[platform]
minimum_indel_af_for_candidate = minimum_indel_af_for_candidate if minimum_indel_af_for_candidate > 0 else param.min_af_dict[platform]

# check whether first non reference candidate in the first position
pass_af = len(pileup_list) and (pileup_list[0][0] != reference_base)
Expand Down
2 changes: 1 addition & 1 deletion run_clair3.sh
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ NC="\\033[0m"
ARGS=`getopt -o b:f:t:m:p:o:hv \
-l bam_fn:,ref_fn:,threads:,model_path:,platform:,output:,\
bed_fn::,vcf_fn::,ctg_name::,sample_name::,qual::,samtools::,python::,pypy::,parallel::,whatshap::,chunk_num::,chunk_size::,var_pct_full::,ref_pct_full::,\
snp_min_af::,indel_min_af::,pileup_model_prefix::,fa_model_preix::,fast_mode,gvcf,pileup_only,print_ref_calls,haploid_precise,haploid_sensitive,include_all_ctgs,no_phasing_for_fa,call_snp_only,help,version -n 'run_clair3.sh' -- "$@"`
snp_min_af::,indel_min_af::,pileup_model_prefix::,fa_model_prefix::,fast_mode,gvcf,pileup_only,print_ref_calls,haploid_precise,haploid_sensitive,include_all_ctgs,no_phasing_for_fa,call_snp_only,help,version -n 'run_clair3.sh' -- "$@"`

if [ $? != 0 ] ; then echo"No input. Terminating...">&2 ; exit 1 ; fi
eval set -- "${ARGS}"
Expand Down