-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
109 lines (92 loc) · 3.48 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#!/home/zhangjq/anaconda3/envs/september/bin/python
# -*-coding:utf-8 -*-
'''
@File : predict.py
@Time : 2022/02/21 17:11:24
@Author : zhangjq
@Version : 1.0
@Contact : zhangjq@tib.cas.cn
@License : (C)Copyright 2022-2023, zhangjq
@Desc : Enjoy your dinner
'''
import os
import re
import numpy as np
import pandas as pd
from joblib import load
from Bio import SeqIO
from features.api import calc_feat
import argparse
def create_logger(name, silent=False, to_disk=True, log_file=None):
"""Create a new logger"""
import logging
import time
from time import strftime, gmtime
import random
import sys
from datetime import datetime
# setup logger
log = logging.getLogger(name)
log.setLevel(logging.DEBUG)
log.propagate = False
formatter = logging.Formatter(fmt='%(message)s', datefmt='%Y/%m/%d %I:%M:%S')
if not silent:
ch = logging.StreamHandler(sys.stdout)
ch.setLevel(logging.DEBUG)
ch.setFormatter(formatter)
log.addHandler(ch)
if to_disk:
log_file = log_file if log_file is not None else strftime("log/log_%m%d_%H%M.txt", gmtime())
if type(log_file) == list:
for filename in log_file:
fh = logging.FileHandler(filename, mode='w')
fh.setLevel(logging.INFO)
fh.setFormatter(formatter)
log.addHandler(fh)
if type(log_file) == str:
fh = logging.FileHandler(log_file, mode='w')
fh.setLevel(logging.INFO)
fh.setFormatter(formatter)
log.addHandler(fh)
return log
######################################################################################
# python predict.py --fasta ./examples/example.fna --out example_out.csv #
######################################################################################
def predict(data: pd.DataFrame, return_prob: bool = False):
clf = load('./models/scp4ssd.joblib')
prediction = np.expand_dims(clf.predict(data), axis=1)
# print(clf.predict(data))
# print(len(clf.predict(data)))
# print(prediction)
# print(len(prediction))
if return_prob:
prob = clf.predict_proba(data)
return np.concatenate((prediction, prob), axis=1)
return prediction
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
'--fasta', '-f', help='fasta file of atcg sequences', required=True)
parser.add_argument(
'--out', '-o', help='output csv file name', required=True)
parser.add_argument(
'--verb', '-v', help='(opt) shows program progress', default=False, required=False)
# parser.add_argument("--help", '-h', descriptions="python predict.py --fasta ./examples/example.fna --out ./examples/example_out.csv")
return parser.parse_args()
def main():
# read args
args = parse_args()
# calculate nucleotide sequence features
df = calc_feat(args.fasta)
# df.to_csv(f'{args.out}_feature.tsv', sep='\t', index=0)
df['isEasySynthesis'] = predict(df)
df['seq'] = pd.DataFrame([str(seq.seq) for seq in SeqIO.parse(args.fasta, 'fasta')])
df['description'] = pd.DataFrame([str(seq.description) for seq in SeqIO.parse(args.fasta, 'fasta')])
# predict & save result
df.to_csv(args.out, index=0)
# savedir = args.out + '_prediction.txt'
# with open(savedir, 'w') as f:
# f.write(str(predict(df))) # load model && predict
if __name__ == "__main__":
# test_Ecoli_GCF_ASM584v2()
main()