Skip to content

Commit

Permalink
tidying up
Browse files Browse the repository at this point in the history
  • Loading branch information
kristyhoran committed Sep 26, 2024
1 parent c4f247a commit 3b12b82
Show file tree
Hide file tree
Showing 14 changed files with 186 additions and 75 deletions.
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
author_email="kristyhoran15@gmail.com",
maintainer="Kristy Horan",
maintainer_email="kristyhoran15@gmail.com",
python_requires=">=3.8, <4",
python_requires=">=3.10, <4",
packages=find_packages(exclude=["contrib", "docs", "tests"]),
zip_safe=False,
install_requires=["pandas","xlsxwriter","psutil","tqdm","requests","pytest"],
install_requires=["pandas","pytest","tabulate","unidecode"],
test_suite="nose.collector",
tests_require=["nose", "pytest","psutil"],
entry_points={
Expand All @@ -50,5 +50,5 @@
"Programming Language :: Python :: 3.7",
"Topic :: Scientific/Engineering :: Bio-Informatics",
],
package_data={"tbtamr": ["db/*","dep_config.json"]}
package_data={"tbtamr": ["db/*","configs/*"]}
)
11 changes: 9 additions & 2 deletions tbtamr/Annotate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import sys,gzip,pandas,pathlib,json, subprocess, os
from CustomLog import logger
from Utils import check_annotate
from .CustomLog import logger
from .Utils import check_annotate


def check_file(pth) -> bool:
Expand Down Expand Up @@ -73,6 +73,13 @@ def create_output_dir(seq_id) -> bool:

def annotate(vcf_file, seq_id):

create_output_dir(seq_id=seq_id)
fh = logging.FileHandler(f'{seq_id}/tbtamr.log')
fh.setLevel(logging.DEBUG)
formatter = logging.Formatter('[%(levelname)s:%(asctime)s] %(message)s', datefmt='%Y-%m-%d %I:%M:%S %p')
fh.setFormatter(formatter)
logger.addHandler(fh)

if check_canannotate() and create_output_dir(seq_id = seq_id) and check_file(pth = vcf_file):
run_snpeff(vcf_file= vcf_file, seq_id = seq_id)

Expand Down
10 changes: 8 additions & 2 deletions tbtamr/Call.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import sys,gzip,pandas,pathlib,json
from CustomLog import logger

from .CustomLog import logger

from mutamr import Fastq2vcf

Expand All @@ -18,5 +18,11 @@ def generatevcf(read1,read2,threads,ram,seq_id,keep,mindepth,minfrac,force,mtb,t
force = force,
tmp = tmp)
vcf = V.run()
fh = logging.FileHandler(f'{seq_id}/tbtamr.log')
fh.setLevel(logging.DEBUG)
formatter = logging.Formatter('[%(levelname)s:%(asctime)s] %(message)s', datefmt='%Y-%m-%d %I:%M:%S %p')
fh.setFormatter(formatter)
logger.addHandler(fh)
logger.info(f"Your vcf file has been successfully created!")

return vcf
32 changes: 21 additions & 11 deletions tbtamr/Parse.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import sys,gzip,pandas,pathlib,json,warnings,subprocess
from CustomLog import logger
from Annotate import annotate
import sys,gzip,pandas,pathlib,json,warnings,subprocess,logging
from .CustomLog import logger
from .Annotate import annotate
warnings.simplefilter(action='ignore', category=FutureWarning)
pandas.set_option("mode.chained_assignment", None)

Expand All @@ -10,11 +10,20 @@ def __init__(self,
vcf,
catalog,
catalog_config,
seq_id):
seq_id,
force):
self.vcf_file = vcf
self.catalog = catalog
self.config = self.get_config(pth = catalog_config)
self.seq_id = seq_id
self.force = force

self.create_output_dir(seq_id=self.seq_id, force= self.force)
fh = logging.FileHandler(f'{seq_id}/tbtamr.log')
fh.setLevel(logging.DEBUG)
formatter = logging.Formatter('[%(levelname)s:%(asctime)s] %(message)s', datefmt='%Y-%m-%d %I:%M:%S %p')
fh.setFormatter(formatter)
logger.addHandler(fh)

def run_cmd(self, cmd) -> bool:

Expand Down Expand Up @@ -152,8 +161,6 @@ def get_data(self, vcf_file):
_type = 'unzipped'

dpths,annot = self.check_data(vcf_file = vcf_file, _type = _type)
print(dpths)
print(annot)
if dpths and annot:
return data
elif dpths and not annot:
Expand Down Expand Up @@ -181,20 +188,22 @@ def variant_generator(self,vcf_file, genes) -> list:

return results

def create_output_dir(self,seq_id, force = False) -> bool:
def create_output_dir(self,seq_id, force) -> bool:

logger.info(f"Will now create directory for {seq_id}")

# ex = not force
# pri
msg = f"Something has gone wrong creating the folder for {seq_id}." if force else f"Folder for {seq_id} exists. Please use --force if you would like to override any previous results."
try:
pathlib.Path(f"{seq_id}").mkdir(exist_ok=True)
pathlib.Path(f"{seq_id}").mkdir(exist_ok=force)
return True
except:
logger.critical(f"Something has gone wrong creating the folder for {seq_id}.")
logger.critical(msg)
raise SystemExit


def save_variants(self, df) -> bool:
self.create_output_dir(seq_id=self.seq_id)
self.create_output_dir(seq_id=self.seq_id, force = self.force)
pandas.DataFrame(df).to_csv(f'{self.seq_id}/{self.seq_id}_variants.csv', index = False)
return True

Expand All @@ -205,6 +214,7 @@ def get_catalog(self, catalog) -> pandas.DataFrame:
return pandas.read_csv(catalog, dtype= str)

def get_variant_data(self) -> list:

if self.check_file(pth = self.vcf_file) and self.check_file(pth = self.catalog):
try:
catalog = self.get_catalog(catalog=self.catalog)
Expand Down
107 changes: 81 additions & 26 deletions tbtamr/Predict.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import sys,gzip,pandas,pathlib,json
from CustomLog import logger
from version import db_version
import sys,gzip,pandas,pathlib,json,logging, re
from .CustomLog import logger
from .version import db_version
from datetime import date
import warnings;
warnings.simplefilter(action='ignore', category=FutureWarning)
Expand Down Expand Up @@ -41,6 +41,13 @@ def __init__(self,
'predicted drug resistance'
]


fh = logging.FileHandler(f'{self.seq_id}/tbtamr.log')
fh.setLevel(logging.DEBUG)
formatter = logging.Formatter('[%(levelname)s:%(asctime)s] %(message)s', datefmt='%Y-%m-%d %I:%M:%S %p')
fh.setFormatter(formatter)
logger.addHandler(fh)

def check_file(self, pth) -> bool:

if pathlib.Path(pth).exists():
Expand All @@ -50,7 +57,49 @@ def check_file(self, pth) -> bool:
raise SystemExit

return True

def check_var(self, catalog) -> bool:
logger.info(f"Checking that the variant format is as expected.")
vars = list(catalog[self.config['variant_col']])
pats = [ re.compile(i) for i in self.config['catalogue_variant'] ]

not_found = []
for v in vars:
fls = False
# print(pat.match(v))
for pat in pats:
# print(pat.search(v))
if pat.search(v) != None:
fls = True
if not fls:
not_found.append(v)
if not_found == []:
logger.info("Variant format is as expected.")
return True

logger.critical(f"It seems that your {self.config['variant_col']} is not formatted as expected. The following variants: {' '.join(not_found) if len(not_found) <= 10 else 'more than 10 variants' } do not match any of your expected patterns: {' '.join(self.config['catalogue_variant'])}. Please check your inputs and try again.")
raise SystemExit


def check_rules_and_cols(self, rules, catalog) -> bool:

logger.info("Checking that your criteria files are correctly formatted.")
cols = [ c for c in rules.columns if 'column' in c ]
icols = []
for c in cols:
icols.extend([ i for i in list(rules[c].unique()) if i != ""])
ccols = [ i for i in catalog.columns ]

for i in icols:

if i not in ccols:
logger.critical(f"There is something wrong with your interpretation criteria and/or your catalogue. Column {i} is in your interpretation criteria but is not present in your catalogue. Please check your inputs and try again.")
raise SystemExit
logger.info(f"Criteria files appear to be formatted correctly.")
return True



def collect_af(self, variants, var):

for v in variants:
Expand All @@ -61,7 +110,6 @@ def collect_af(self, variants, var):
def collect_resistance_mechs(self,catalog,variants) -> pandas.DataFrame:

vars = [var['variant'] for var in variants]

mechs = catalog[catalog[self.config['variant_col']].isin(vars)]
mechs['af'] = mechs[self.config['variant_col']].apply(lambda x:self.collect_af(variants=variants, var = x))

Expand Down Expand Up @@ -97,21 +145,21 @@ def check_shape(self, rule) -> bool:
def extract_mutations(self,dr, result) -> list:

mt = []
if f"{dr} - mechanisms" in result:
mt = [m.split()[0] for m in result[f"{dr} - mechanisms"].split(';') if m != 'No reportable mechanims' and self.check_conf_reporting(val = m.split()[-1].strip('()')) ]
if f"{dr.lower()} - mechanisms" in result:
mt = [m.split()[0] for m in result[f"{dr.lower()} - mechanisms"].split(';') if m != 'No reportable mechanims' and self.check_conf_reporting(val = m.split()[-1].strip('()')) ]

return mt

def update_result(self, result, dr, rule)-> dict:

if f"{dr} - interpretation" in result:
if f"{dr.lower()} - interpretation" in result:
new_interp = rule[1]['interpretation']
result[f"{dr} - interpretation"] = new_interp
result[f"{dr} - override"] = rule[1]['description'] # add description for tracking purposes.
result[f"{dr.lower()} - interpretation"] = new_interp
result[f"{dr.lower()} - override"] = rule[1]['description'] # add description for tracking purposes.

return result

def construct_rule(self, row):
def construct_rule(self, row) -> str:

if isinstance(row, tuple):
d = row[1].to_dict()
Expand Down Expand Up @@ -159,15 +207,15 @@ def apply_rule_override(self, dr, mechs, rules, result) -> dict:
tbl_to_check = tbl
# print(tbl_to_check)
rle = self.construct_rule(row = row)
# print(rle)

# set up date to False by default - leave existing result
update = False
# if there is a shape/length criteria to the rule
if self.check_shape(row[1]['shape']):
# logger.info(f"Will check shape of df")
shape_rule = f"len(mt) {row[1]['shape']}"
if eval(shape_rule): # if the shape/size criteria is met then apply the rest of the rule
# print(tbl_to_check)

tbl_to_check = tbl_to_check.query(rle)
update = True if not tbl_to_check.empty else False
else: # if no size/shape criteria - just apply the rule
Expand Down Expand Up @@ -239,13 +287,15 @@ def apply_rule_default(self,dr, mechs, rules, result) -> dict:
def get_resistance_profile(self, result) -> dict:
drs = {"first-line":[],"other":[]}
alldrs = []

for dt in drs:
for dr in self.config["drugs_to_report"][dt]:
if result[f"{dr} - interpretation"] in self.config["resistance_levels"]:
if result[f"{dr.lower()} - interpretation"] in self.config["resistance_levels"]:
drs[dt].append(dr)
alldrs.append(dr)

return drs,alldrs

def get_dlm(self, cond) -> str:

dl = ("","")
Expand Down Expand Up @@ -297,10 +347,12 @@ def classification(self, rules, result) -> dict:

def compare_mechs_rules(self,interpretation_rules, classification_rules, mechs, result) -> dict:

logger.info(f"Applying citeria for interpretation.")
# print(rules[rules['rule_type'] != 'default'])
for dr in self.config["drugs_to_infer"]:
result = self.apply_rule_default(dr = dr.lower(), mechs=mechs, rules=interpretation_rules[interpretation_rules['rule_type'] == 'default'], result = result)
result = self.apply_rule_override(dr= dr.lower(), mechs=mechs,rules=interpretation_rules[interpretation_rules['rule_type'] != 'default'], result=result)
logger.info(f"Applying citeria for classification of resistance profile.")
result = self.classification(rules = classification_rules, result=result)
result_df = pandas.DataFrame.from_dict(result, orient= 'index').T
result_df.to_csv(f"{self.seq_id}/tbtamr_results.csv", index = False)
Expand All @@ -322,7 +374,7 @@ def check_for_cascade(self, result, cols) -> bool:

def generate_drug_cols(self, dr) -> list:

return [f"{dr} - {i}" for i in ['mechanisms','interpretation','confidence']]
return [f"{dr.lower()} - {i}" for i in ['mechanisms','interpretation','confidence']]

def cascade_report(self, result, starter_cols):

Expand Down Expand Up @@ -360,17 +412,18 @@ def generate_reporting_df(self, result, output, cols) -> pandas.DataFrame:
df.to_csv(f"{self.seq_id}/{output}.csv", index = False)

def make_cascade(self, result) -> bool:


logger.info("Generating cascade report.")

cols = self.cascade_report(result=result, starter_cols = self.cols)
self.generate_reporting_df(result=result, cols = cols, output="tbtamr_linelist_cascade_report")


def make_line_list(self, result, cols) -> bool:

logger.info("Generating linelist for reporting.")
# wrangle reportable/not reportable
for dr in self.config['drugs_to_infer']:
mch = result[f"{dr} - mechanisms"].split(';')
mch = result[f"{dr.lower()} - mechanisms"].split(';')
mchs = []
# check if mech should be reported - based on conf in string
for m in mch:
Expand All @@ -380,17 +433,17 @@ def make_line_list(self, result, cols) -> bool:
mchs.append(m.split()[0])
else:
mchs.append(m)
result[f"{dr} - mechanisms"] = ';'.join(mchs)
result[f"{dr.lower()} - mechanisms"] = ';'.join(mchs)
# check if conf should be reported
cf = result[f"{dr} - confidence"].split(';')
cf = result[f"{dr.lower()} - confidence"].split(';')
conf = ""
for c in cf:
if self.check_conf_reporting(val=c):
conf = c
result[f"{dr} - confidence"] = conf
result[f"{dr.lower()} - confidence"] = conf
# check interpretation
if result[f"{dr} - interpretation"] not in self.config['resistance_levels']:
result[f"{dr} - interpretation"] = "Susceptible"
if result[f"{dr.lower()} - interpretation"] not in self.config['resistance_levels']:
result[f"{dr.lower()} - interpretation"] = "Susceptible"

# get cols
dr2report = self.config['drugs_to_report']
Expand Down Expand Up @@ -492,7 +545,9 @@ def run_prediction(self) -> None:
mechs = self.collect_resistance_mechs(catalog=ctlg, variants=self.variants)
interpretation_rules = self.get_rules(rules = self.interpretation_rules)
classification_rules = self.get_rules(rules = self.classification_rules)
result = self.compare_mechs_rules(interpretation_rules = interpretation_rules, classification_rules=classification_rules,mechs=mechs, result = result)
self.make_line_list(result = result, cols = self.cols)
if self.cascade:
self.make_cascade(result=result)
if self.check_var(catalog = ctlg) and self.check_rules_and_cols(rules = interpretation_rule, catalog = ctlg):
result = self.compare_mechs_rules(interpretation_rules = interpretation_rules, classification_rules=classification_rules,mechs=mechs, result = result)
self.make_line_list(result = result, cols = self.cols)
if self.cascade:
self.make_cascade(result=result)

2 changes: 1 addition & 1 deletion tbtamr/Search.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from tabulate import tabulate
from unidecode import unidecode

from CustomLog import logger
from .CustomLog import logger

def check_file(pth) -> bool:

Expand Down
File renamed without changes.
2 changes: 0 additions & 2 deletions tbtamr/configs/db_config.csv

This file was deleted.

6 changes: 5 additions & 1 deletion tbtamr/configs/db_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -93,5 +93,9 @@
"ethionamide"
]
}
}
},
"catalogue_variant": [
"[A-Za-z0-9]+[_][a-z]+",
"[A-Za-z0-9]+[_][a-z][.]\\S+"
]
}
File renamed without changes.
Loading

0 comments on commit 3b12b82

Please sign in to comment.