diff --git a/nanocompore/DataStore.py b/nanocompore/DataStore.py index 39179bd..73547c9 100644 --- a/nanocompore/DataStore.py +++ b/nanocompore/DataStore.py @@ -1,110 +1,171 @@ # -*- coding: utf-8 -*- -from collections import * +from enum import Enum import datetime import os -import sqlite3 as lite +import sqlite3 +import contextlib +from itertools import zip_longest, product # Third party from loguru import logger -import nanocompore as pkg +from nanocompore.common import NanocomporeError + + +class DBCreateMode(Enum): + """Options for handling (non-) existence of the SQLite database file""" + MUST_EXIST = "r" # open for reading, error if file doesn't exist + CREATE_MAYBE = "a" # use an existing database, otherwise create one + OVERWRITE = "w" # always create a new database, overwrite if it exists + class DataStore(object): - """ Init analysis and check args""" - - create_reads_query = ("CREATE TABLE IF NOT EXISTS reads (" - "id INTEGER NOT NULL PRIMARY KEY," - "name VARCHAR NOT NULL UNIQUE," - "sampleid INTEGER NOT NULL," - "transcriptid VARCHAR NOT NULL," - "refstart INT NOT NULL," - "refend INT NOT NULL," - "numevents INT NOT NULL," - "numsignals INT NOT NULL," - "dwelltime REAL NOT NULL," - "FOREIGN KEY(sampleid) REFERENCES samples(id)" - "FOREIGN KEY(transcriptid) REFERENCES transcripts(id)," - "UNIQUE(id, name)" - ")" - ) - - create_kmers_query = ("CREATE TABLE IF NOT EXISTS kmers (" - "id INTEGER NOT NULL PRIMARY KEY," - "readid INTEGER NOT NULL," - "position INTEGER NOT NULL," - "sequence INTEGER NOT NULL," - "num_events INTEGER NOT NULL," - "num_signals INTEGER NOT NULL," - "status VARCHAR NOT NULL," - "dwell_time REAL NOT NULL," - "NNNNN_dwell_time REAL NOT NULL," - "mismatch_dwell_time REAL NOT NULL," - "median REAL NOT NULL," - "mad REAL NOT NULL," - "FOREIGN KEY(readid) REFERENCES reads(id)" - ")" - ) - - create_samples_query = ("CREATE TABLE IF NOT EXISTS samples (" - "id INTEGER NOT NULL PRIMARY KEY," - "name VARCHAR NOT NULL UNIQUE" - ")" - ) - - - create_transcripts_query = ("CREATE TABLE IF NOT EXISTS transcripts (" - "id INTEGER NOT NULL PRIMARY KEY," - "name VARCHAR NOT NULL UNIQUE" - ")" - ) - - - - def __init__(self, db_path:str): - self.__db_path=db_path - db_is_new = not os.path.exists(self.__db_path) - logger.debug(f"DB file doesn't exist: {db_is_new}") - if db_is_new: self.__init_db() + """Store Nanocompore data in an SQLite database - base class""" - def __enter__(self): - self.__open_db_connection() - return self + table_defs = {} # table name -> column definitions (to be filled by derived classes) + + def __init__(self, + db_path:str, + create_mode=DBCreateMode.MUST_EXIST): + self._db_path = db_path + self._create_mode = create_mode + self._connection = None + self._cursor = None - def __exit__(self,exc_type, exc_value, traceback): - self.__connection.commit() - self.__close_db_connection() + def _init_db(self): + if self.table_defs: + logger.debug("Setting up database tables") + try: + for table, column_defs in self.table_defs.items(): + if type(column_defs) is not str: # list/tuple expected + column_defs = ", ".join(column_defs) + query = f"CREATE TABLE IF NOT EXISTS {table} ({column_defs})" + self._cursor.execute(query) + self._connection.commit() + except: + logger.error(f"Error creating database table '{table}'") + raise - def __open_db_connection(self): + def __enter__(self): + init_db = False + if self._create_mode == DBCreateMode.MUST_EXIST and not os.path.exists(self._db_path): + raise NanocomporeError(f"Database file '{self._db_path}' does not exist") + if self._create_mode == DBCreateMode.OVERWRITE: + with contextlib.suppress(FileNotFoundError): # file may not exist + os.remove(self._db_path) + logger.debug(f"Removed existing database file '{self._db_path}'") + init_db = True + if self._create_mode == DBCreateMode.CREATE_MAYBE and not os.path.exists(self._db_path): + init_db = True try: - logger.debug("Connecting to DB") - self.__connection = lite.connect(self.__db_path); - self.__cursor = self.__connection.cursor() + logger.debug("Connecting to database") + self._connection = sqlite3.connect(self._db_path) + self._connection.row_factory = sqlite3.Row + self._cursor = self._connection.cursor() except: logger.error("Error connecting to database") raise + if init_db: + self._init_db() + return self - def __close_db_connection(self): - if self.__connection: - logger.debug("Closing connection to DB") - self.__connection.commit() - self.__connection.close() - self.__connection = None - self.__cursor = None - - def __init_db(self): - logger.debug("Setting up DB tables") - self.__open_db_connection() - try: - self.__cursor.execute(self.create_reads_query) - self.__cursor.execute(self.create_kmers_query) - self.__cursor.execute(self.create_samples_query) - self.__cursor.execute(self.create_transcripts_query) - except: - self.__close_db_connection() - logger.error("Error creating tables") - raise - self.__connection.commit() - self.__close_db_connection() + def __exit__(self, exc_type, exc_value, traceback): + if self._connection: + logger.debug("Closing database connection") + self._connection.commit() + self._connection.close() + self._connection = None + self._cursor = None + + @property + def cursor(self): + return self._cursor + + +class DataStore_EventAlign(DataStore): + """Store Nanocompore data in an SQLite database - subclass for Eventalign_collapse results""" + + # "reads" table: + table_def_reads = ["id INTEGER NOT NULL PRIMARY KEY", + "name VARCHAR NOT NULL UNIQUE", + "sampleid INTEGER NOT NULL", + "transcriptid VARCHAR NOT NULL", + "refstart INT NOT NULL", + "refend INT NOT NULL", + "numevents INT NOT NULL", + "numsignals INT NOT NULL", + "dwelltime REAL NOT NULL", + "kmers INT NOT NULL", + "missing_kmers INT NOT NULL", + "NNNNN_kmers INT NOT NULL", + "mismatch_kmers INT NOT NULL", + "valid_kmers INT NOT NULL", + "FOREIGN KEY(sampleid) REFERENCES samples(id)", + "FOREIGN KEY(transcriptid) REFERENCES transcripts(id)"] + + # "kmer_sequences" table: + table_def_kmer_seqs = ["id INTEGER NOT NULL PRIMARY KEY", + "sequence VARCHAR NOT NULL UNIQUE"] + + # "kmer_status" table: + table_def_kmer_status = ["id INTEGER NOT NULL PRIMARY KEY", + "status VARCHAR NOT NULL UNIQUE"] + + # "kmers" table: + # TODO: is combination of "readid" and "position" unique per kmer? + # if so, use those as combined primary key (for access efficiency)? + table_def_kmers = ["id INTEGER NOT NULL PRIMARY KEY", + "readid INTEGER NOT NULL", + "position INTEGER NOT NULL", + "sequenceid INTEGER", + # "sequence VARCHAR NOT NULL", + # "num_events INTEGER NOT NULL", + # "num_signals INTEGER NOT NULL", + "statusid INTEGER NOT NULL", + "dwell_time REAL NOT NULL", + # "NNNNN_dwell_time REAL NOT NULL", + # "mismatch_dwell_time REAL NOT NULL", + "median REAL NOT NULL", + "mad REAL NOT NULL", + "FOREIGN KEY(readid) REFERENCES reads(id)", + "FOREIGN KEY(sequenceid) REFERENCES kmer_sequences(id)", + "FOREIGN KEY(statusid) REFERENCES kmer_status(id)"] + + # "samples" table: + table_def_samples = ["id INTEGER NOT NULL PRIMARY KEY", + "name VARCHAR NOT NULL UNIQUE", + "condition VARCHAR"] + + # "transcripts" table: + table_def_transcripts = ["id INTEGER NOT NULL PRIMARY KEY", + "name VARCHAR NOT NULL UNIQUE"] + + table_defs = {"reads": table_def_reads, + "kmer_sequences": table_def_kmer_seqs, + "kmer_status": table_def_kmer_status, + "kmers": table_def_kmers, + "samples": table_def_samples, + "transcripts": table_def_transcripts} + + status_mapping = {"valid": 0, "NNNNN": 1, "mismatch": 2} + sequence_mapping = {} # filled by "__init__" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + ## set up mapping table for sequences: + self.sequence_mapping = {} + seq_prod = product(["A", "C", "G", "T"], repeat=5) + for i, seq in enumerate(seq_prod): + self.sequence_mapping["".join(seq)] = i + + def _init_db(self): + super()._init_db() + ## fill "kmer_status" and "kmer_sequences" tables: + self._cursor.executemany("INSERT INTO kmer_status VALUES (?, ?)", + [(i, x) for x, i in self.status_mapping.items()]) + self._cursor.executemany("INSERT INTO kmer_sequences VALUES (?, ?)", + [(i, x) for x, i in self.sequence_mapping.items()]) + self._connection.commit() def store_read(self, read): """ @@ -116,18 +177,19 @@ def store_read(self, read): """ tx_id = self.get_transcript_id_by_name(read.ref_id, create_if_not_exists=True) sample_id = self.get_sample_id_by_name(read.sample_name, create_if_not_exists=True) + values = (read.read_id, sample_id, tx_id, read.ref_start, read.ref_end, + read.n_events, read.n_signals, read.dwell_time) + tuple(read.kmers_status.values()) try: - self.__cursor.execute("INSERT INTO reads VALUES(NULL, ?, ?, ?, ?, ?, ?, ?, ?)", - (read.read_id, sample_id, tx_id, read.ref_start, read.ref_end, - read.n_events, read.n_signals, read.dwell_time)) - read_id = self.__cursor.lastrowid + self._cursor.execute("INSERT INTO reads VALUES(NULL" + ", ?" * len(values) + ")", + values) + read_id = self._cursor.lastrowid except Exception: - logger.error("Error inserting read into DB") + logger.error("Error inserting read into database") raise Exception for kmer in read.kmer_l: self.__store_kmer(kmer=kmer, read_id=read_id) - self.__connection.commit() + self._connection.commit() # TODO check for success and return true/false def __store_kmer(self, kmer, read_id): @@ -139,12 +201,14 @@ def __store_kmer(self, kmer, read_id): """ res = kmer.get_results() # needed for 'median' and 'mad' values try: - self.__cursor.execute("INSERT INTO kmers VALUES(NULL, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", - (read_id, res["ref_pos"], res["ref_kmer"], res["num_events"], - res["num_signals"], res["status"], res["dwell_time"], - res["NNNNN_dwell_time"], res["mismatch_dwell_time"], res["median"], res["mad"])) + status_id = self.status_mapping[res["status"]] + # in case of unexpected kmer seq., this should give None (NULL in the DB): + seq_id = self.sequence_mapping.get(res["ref_kmer"]) + self._cursor.execute("INSERT INTO kmers VALUES(NULL, ?, ?, ?, ?, ?, ?, ?)", + (read_id, res["ref_pos"], seq_id, status_id, + res["dwell_time"], res["median"], res["mad"])) except Exception: - logger.error("Error inserting kmer into DB") + logger.error("Error inserting kmer into database") raise Exception def get_transcript_id_by_name(self, tx_name, create_if_not_exists=False): @@ -159,18 +223,18 @@ def get_transcript_id_by_name(self, tx_name, create_if_not_exists=False): ");" ) try: - self.__cursor.execute(query) + self._cursor.execute(query) except Exception: - logger.error("There was an error while inserting a new transcript in the DB") + logger.error("Error while inserting transcript into the database") raise Exception query = f"SELECT id from transcripts WHERE name = '{tx_name}'" try: - self.__cursor.execute(query) - record = self.__cursor.fetchone() - self.__connection.commit() + self._cursor.execute(query) + record = self._cursor.fetchone() + self._connection.commit() except Exception: - logger.error("There was an error while selecting the transcript_id from the DB") + logger.error("Error while selecting transcript ID from the database") raise Exception if record is not None: return record[0] @@ -189,20 +253,210 @@ def get_sample_id_by_name(self, sample_name, create_if_not_exists=False): ");" ) try: - self.__cursor.execute(query) + self._cursor.execute(query) except Exception: - logger.error("There was an error while inserting a new sample in the DB") + logger.error("Error while inserting sample into the database") raise Exception query = f"SELECT id from samples WHERE name = '{sample_name}'" try: - self.__cursor.execute(query) - record = self.__cursor.fetchone() - self.__connection.commit() + self._cursor.execute(query) + record = self._cursor.fetchone() + self._connection.commit() except Exception: - logger.error("There was an error while selecting the sample_id from the DB") + logger.error("Error while selecting sample ID from the database") raise Exception if record is not None: return record[0] else: return None + + def get_samples(self, sample_dict=None): + if not self._connection: + raise NanocomporeError("Database connection not yet opened") + expected_samples = [] + if sample_dict: # query only relevant samples + for samples in sample_dict.values(): + expected_samples += samples + if not expected_samples: + raise NanocomporeError("No sample names in 'sample_dict'") + where = " WHERE name IN ('%s')" % "', '".join(expected_samples) + else: + where = "" + db_samples = {} + try: + self._cursor.execute("SELECT * FROM samples" + where) + for row in self._cursor: + db_samples[row["id"]] = row["name"] + except Exception: + logger.error("Error reading sample names from database") + raise Exception + for sample in expected_samples: # check that requested samples are in DB + if sample not in db_samples.values(): + raise NanocomporeError(f"Sample '{sample}' not present in database") + return db_samples + + # TODO: is this function never used? + def store_sample_info(self, sample_dict): + if not self._connection: + raise NanocomporeError("Database connection not yet opened") + # query: insert sample; if it exists, update condition if that's missing + query = "INSERT INTO samples(id, name, condition) VALUES (NULL, ?, ?) " \ + "ON CONFLICT(name) DO UPDATE SET condition = excluded.condition " \ + "WHERE condition IS NULL" + for condition, samples in sample_dict.items(): + try: + self._cursor.executemany(query, [(condition, sample) for sample in samples]) + except: + logger.error(f"Error storing sample information for condition '{condition}'") + raise + self._connection.commit() + + +class DataStore_SampComp(DataStore): + """Store Nanocompore data in an SQLite database - subclass for SampComp results""" + + # "parameters" table: + table_def_parameters = ["univariate_test VARCHAR CHECK (univariate_test in ('ST', 'MW', 'KS'))", + "gmm_covariance_type VARCHAR", + "gmm_test VARCHAR CHECK (gmm_test in ('anova', 'logit'))"] + # TODO: add more parameters + + # "transcripts" table: + table_def_transcripts = ["id INTEGER NOT NULL PRIMARY KEY", + "name VARCHAR NOT NULL UNIQUE"] + + # "whitelist" table: + table_def_whitelist = ["transcriptid INTEGER NOT NULL", + "readid INTEGER NOT NULL UNIQUE", # foreign key for "reads" table in EventAlign DB + "FOREIGN KEY (transcriptid) REFERENCES transcripts(id)"] + + # "kmer_stats" table: + table_def_kmer_stats = ["id INTEGER NOT NULL PRIMARY KEY", + "transcriptid INTEGER NOT NULL", + "kmer INTEGER NOT NULL", + "c1_mean_intensity REAL", + "c2_mean_intensity REAL", + "c1_median_intensity REAL", + "c2_median_intensity REAL", + "c1_sd_intensity REAL", + "c2_sd_intensity REAL", + "c1_mean_dwell REAL", + "c2_mean_dwell REAL", + "c1_median_dwell REAL", + "c2_median_dwell REAL", + "c1_sd_dwell REAL", + "c2_sd_dwell REAL", + "intensity_pvalue REAL", + "dwell_pvalue REAL", + "adj_intensity_pvalue REAL", + "adj_dwell_pvalue REAL", + "UNIQUE (transcriptid, kmer)", + "FOREIGN KEY (transcriptid) REFERENCES transcripts(id)"] + # TODO: are "c1" and "c2" (conditions) properly defined? + + # "gmm_stats" table: + table_def_gmm_stats = ["kmer_statsid INTEGER NOT NULL UNIQUE", + "n_components INTEGER NOT NULL", + "cluster_counts VARCHAR", + "test_stat REAL", + "test_pvalue REAL", + "adj_test_pvalue REAL", + "FOREIGN KEY (kmer_statsid) REFERENCES kmer_stats(id)"] + + table_defs = {"parameters": table_def_parameters, + "transcripts": table_def_transcripts, + "whitelist": table_def_whitelist, + "kmer_stats": table_def_kmer_stats} + # table "gmm_stats" is only added when needed (see "__init__") + + def __init__(self, + db_path:str, + create_mode=DBCreateMode.MUST_EXIST, + with_gmm=True, + with_sequence_context=False): + super().__init__(db_path, create_mode) + self.__with_gmm = with_gmm + self.__with_sequence_context = with_sequence_context + if with_gmm: + self.table_defs["gmm_stats"] = self.table_def_gmm_stats + if with_sequence_context: # add additional columns for context p-values + # column definitions must go before table constraints! + constraints = self.table_defs["kmer_stats"][-2:] + self.table_defs["kmer_stats"] = (self.table_defs["kmer_stats"][:-2] + + ["intensity_pvalue_context REAL", + "dwell_pvalue_context REAL", + "adj_intensity_pvalue_context REAL", + "adj_dwell_pvalue_context REAL"] + + constraints) + if with_gmm: + constraints = self.table_defs["gmm_stats"][-1:] + self.table_defs["gmm_stats"] = (self.table_defs["gmm_stats"][:-1] + + ["test_pvalue_context REAL", + "adj_test_pvalue_context REAL"] + + constraints) + + + def __insert_transcript_get_id(self, tx_name): + try: + self._cursor.execute("SELECT id FROM transcripts WHERE name = ?", [tx_name]) + if (row := self._cursor.fetchone()) is not None: + return row["id"] + self._cursor.execute("INSERT INTO transcripts VALUES (NULL, ?)", [tx_name]) + self._connection.commit() + # TODO: if there could be multiple writing threads, "INSERT OR IGNORE" + # query should go before "SELECT" + return self._cursor.lastrowid + except: + logger.error(f"Failed to insert/look up transcript '{tx_name}'") + raise + + + def store_test_results(self, tx_name, test_results): + if not self._connection: + raise NanocomporeError("Database connection not yet opened") + tx_id = self.__insert_transcript_get_id(tx_name) + for kmer, res in test_results.items(): + values = [tx_id, kmer] + values += res["shift_stats"].values() + # insert 'None' (NULL) into adj. p-value columns: + values += [res.get("intensity_pvalue"), res.get("dwell_pvalue"), None, None] + if self.__with_sequence_context: + values += [res.get("intensity_pvalue_context"), res.get("dwell_pvalue_context"), None, None] + try: + self._cursor.execute("INSERT INTO kmer_stats VALUES (NULL" + ", ?" * len(values) + ")", values) + except: + logger.error(f"Error storing statistics for transcript '{tx_name}', kmer {kmer}") + raise + kmer_statsid = self._cursor.lastrowid + if self.__with_gmm: + # insert 'None' (NULL) into adj. p-value columns: + values = [kmer_statsid, res["gmm_model"].n_components, res.get("gmm_cluster_counts"), + res.get("gmm_test_stat"), res.get("gmm_pvalue"), None] + if self.__with_sequence_context: + values += [res.get("gmm_pvalue_context"), None] + qmarks = ", ".join(["?"] * len(values)) + try: + self._cursor.execute(f"INSERT INTO gmm_stats VALUES ({qmarks})", values) + except: + logger.error(f"Error storing GMM stats for transcript '{tx_name}', kmer {kmer}") + raise + self._connection.commit() + + + def store_whitelist(self, whitelist): + if not self._connection: + raise NanocomporeError("Database connection not yet opened") + for tx_name, read_dict in whitelist: + try: + tx_id = self.__insert_transcript_get_id(tx_name) + for cond_reads in read_dict.values(): + for sample_reads in cond_reads.values(): + values = zip_longest([], sample_reads, fillvalue=tx_id) + self._cursor.executemany("INSERT INTO whitelist VALUES (?, ?)", values) + # TODO: store sample/condition information (again)? + # it can be retrieved from "reads"/"samples" tables given "readid" + self._connection.commit() + except: + logger.error(f"Error storing whitelisted reads for transcript '{tx_name}'") + raise diff --git a/nanocompore/Eventalign_collapse.py b/nanocompore/Eventalign_collapse.py index abb3b92..6ad8439 100644 --- a/nanocompore/Eventalign_collapse.py +++ b/nanocompore/Eventalign_collapse.py @@ -19,7 +19,7 @@ # Local imports from nanocompore.common import * from nanocompore.SuperParser import SuperParser -from nanocompore.DataStore import DataStore +from nanocompore.DataStore import DataStore_EventAlign, DBCreateMode # Disable multithreading for MKL and openBlas os.environ["MKL_NUM_THREADS"] = "1" @@ -28,34 +28,31 @@ os.environ["OMP_NUM_THREADS"] = "1" os.environ['OPENBLAS_NUM_THREADS'] = '1' -log_level_dict = {"debug":"DEBUG", "info":"INFO", "warning":"WARNING"} +log_level_dict = {"debug": "DEBUG", "info": "INFO", "warning": "WARNING"} #logger.remove() #~~~~~~~~~~~~~~MAIN CLASS~~~~~~~~~~~~~~# class Eventalign_collapse (): - def __init__ (self, - eventalign_fn:str, - sample_name:str, - outpath:str="./", - outprefix:str="out", - overwrite:bool = False, - n_lines:int=None, - nthreads:int = 3, - progress:bool = False): + def __init__(self, + eventalign_fn:str, + sample_name:str, + output_db_path:str, + overwrite:bool = False, + n_lines:int=None, + nthreads:int = 3, + progress:bool = False): + # TODO: is 'overwrite' a useful option, as data from multiple samples needs to be accumulated in the same DB? """ Collapse the nanopolish eventalign events at kmer level * eventalign_fn Path to a nanopolish eventalign tsv output file, or a list of file, or a regex (can be gzipped) * sample_name The name of the sample being processed - * outpath - Path to the output folder (will be created if it does exist yet) - * outprefix - text outprefix for all the files generated + * output_db_path + Path to the output (database) file * overwrite - If the output directory already exists, the standard behaviour is to raise an error to prevent overwriting existing data - This option ignore the error and overwrite data if they have the same outpath and outprefix. + Overwrite an existing output file? * n_lines Maximum number of read to parse. * nthreads @@ -74,18 +71,19 @@ def __init__ (self, # Save args to self values self.__sample_name = sample_name - self.__outpath = outpath - self.__outprefix = outprefix self.__eventalign_fn = eventalign_fn + self.__output_db_path = output_db_path + self.__overwrite = overwrite self.__n_lines = n_lines self.__nthreads = nthreads - 2 # subtract 1 for reading and 1 for writing self.__progress = progress # Input file field selection typing and renaming self.__select_colnames = ["contig", "read_name", "position", "reference_kmer", "model_kmer", "event_length", "samples"] - self.__change_colnames = {"contig":"ref_id" ,"position":"ref_pos", "read_name":"read_id", "samples":"sample_list", "event_length":"dwell_time"} + self.__change_colnames = {"contig": "ref_id", "position": "ref_pos", "read_name": "read_id", "samples": "sample_list", "event_length": "dwell_time"} self.__cast_colnames = {"ref_pos":int, "dwell_time":np.float32, "sample_list":lambda x: [float(i) for i in x.split(",")]} + def __call__(self): """ Run the analysis @@ -101,10 +99,8 @@ def __call__(self): ps_list.append (mp.Process (target=self.__split_reads, args=(in_q, error_q))) for i in range (self.__nthreads): ps_list.append (mp.Process (target=self.__process_read, args=(in_q, out_q, error_q))) - ps_list.append (mp.Process (target=self.__write_output_to_db, args=(out_q, error_q))) - - # TODO: Check that sample_name does not exist already in DB + ps_list.append(mp.Process(target=self.__write_output, args=(out_q, error_q))) # Start processes and monitor error queue try: @@ -135,8 +131,9 @@ def __call__(self): logger.error("An error occured while trying to kill processes\n") raise E + #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~PRIVATE METHODS~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# - def __split_reads (self, in_q, error_q): + def __split_reads(self, in_q, error_q): """ Mono-threaded reader """ @@ -146,38 +143,37 @@ def __split_reads (self, in_q, error_q): try: # Open input file with superParser + # TODO: benchmark performance compared to csv.DictReader (std. lib.) with SuperParser( fn = self.__eventalign_fn, select_colnames = self.__select_colnames, cast_colnames = self.__cast_colnames, change_colnames = self.__change_colnames, - n_lines=self.__n_lines) as sp: - - for l in sp: - n_events+=1 + n_lines = self.__n_lines) as sp: - # First event exception - if n_events==1: - cur_ref_id = l["ref_id"] - cur_read_id = l["read_id"] - event_l = [l] + # First line/event - initialise + l = next(iter(sp)) + # TODO: read ID should be unique, so no need to check transcript - correct? + cur_read_id = l["read_id"] + event_l = [l] + n_events = 1 - # Same read/ref group = just append to current event group - elif l["ref_id"] == cur_ref_id and l["read_id"] == cur_read_id: + # All following lines + for l in sp: + n_events += 1 + # Same read = just append to current event group + if l["read_id"] == cur_read_id: event_l.append(l) - # If new read/ref group detected = enqueue previous event group and start new one else: - n_reads+=1 + n_reads += 1 in_q.put(event_l) - - cur_ref_id = l["ref_id"] cur_read_id = l["read_id"] event_l = [l] - # Last event line exception + # Last event/line in_q.put(event_l) - n_reads+=1 + n_reads += 1 # Manage exceptions and add error trackback to error queue except Exception: @@ -190,7 +186,8 @@ def __split_reads (self, in_q, error_q): in_q.put(None) logger.debug("Parsed Reads:{} Events:{}".format(n_reads, n_events)) - def __process_read (self, in_q, out_q, error_q): + + def __process_read(self, in_q, out_q, error_q): """ Multi-threaded workers collapsing events at kmer level """ @@ -227,107 +224,41 @@ def __process_read (self, in_q, out_q, error_q): logger.debug("Processed Reads:{} Kmers:{} Events:{} Signals:{}".format(n_reads, n_kmers, n_events, n_signals)) out_q.put(None) - def __write_output_to_db (self, out_q, error_q): + + def __write_output(self, out_q, error_q): """ - Mono-threaded Writer + Single-threaded writer """ logger.debug("Start writing output to DB") - pr = profile.Profile() - pr.enable() + # pr = profile.Profile() + # pr.enable() n_reads = 0 + db_create_mode = DBCreateMode.OVERWRITE if self.__overwrite else DBCreateMode.CREATE_MAYBE try: - with DataStore(db_path=os.path.join(self.__outpath, self.__outprefix+"nanocompore.db")) as datastore, tqdm (unit=" reads") as pbar: + with DataStore_EventAlign(self.__output_db_path, db_create_mode) as datastore, \ + tqdm (unit=" reads") as pbar: # Iterate over out queue until nthread poison pills are found for _ in range (self.__nthreads): for read in iter (out_q.get, None): - n_reads+=1 + n_reads += 1 datastore.store_read(read) pbar.update(1) except Exception: logger.error("Error adding read to DB") - error_q.put (NanocomporeError(traceback.format_exc())) + error_q.put(NanocomporeError(traceback.format_exc())) finally: logger.info ("Output reads written:{}".format(n_reads)) # Kill error queue with poison pill error_q.put(None) - pr.disable() - pr.dump_stats("prof") - - def __write_output (self, out_q, error_q): - """ - Mono-threaded Writer - """ - logger.debug("Start writing output files") - - byte_offset = n_reads = n_kmers = 0 + # pr.disable() + # pr.dump_stats("prof") - # Init variables for index files - idx_fn = os.path.join(self.__outpath, self.__outprefix+"_eventalign_collapse.tsv.idx") - data_fn = os.path.join(self.__outpath, self.__outprefix+"_eventalign_collapse.tsv") - - try: - # Open output files and tqdm progress bar - with open (data_fn, "w") as data_fp, open (idx_fn, "w") as idx_fp, tqdm (unit=" reads", disable=not self.__progress) as pbar: - - # Iterate over out queue until nthread poison pills are found - for _ in range (self.__nthreads): - for read in iter (out_q.get, None): - read_res_d = read.get_read_results() - kmer_res_l = read.get_kmer_results() - n_reads+=1 - - # Define file header from first read and first kmer - if byte_offset == 0: - idx_header_list = list(read_res_d.keys())+["byte_offset","byte_len"] - idx_header_str = "\t".join(idx_header_list) - data_header_list = list(kmer_res_l[0].keys()) - data_header_str = "\t".join(data_header_list) - - # Write index file header - idx_fp.write ("{}\n".format(idx_header_str)) - - # Write data file header - byte_len = 0 - header_str = "#{}\t{}\n{}\n".format(read_res_d["read_id"], read_res_d["ref_id"], data_header_str) - data_fp.write(header_str) - byte_len+=len(header_str) - - # Write kmer data matching data field order - for kmer in kmer_res_l: - n_kmers+=1 - data_str = "\t".join([str(kmer[f]) for f in data_header_list])+"\n" - data_fp.write(data_str) - byte_len+=len(data_str) - - # Add byte - read_res_d["byte_offset"] = byte_offset - read_res_d["byte_len"] = byte_len-1 - idx_str = "\t".join([str(read_res_d[f]) for f in idx_header_list]) - idx_fp.write("{}\n".format(idx_str)) - - # Update pbar - byte_offset+=byte_len - pbar.update(1) - - # Flag last line - data_fp.write ("#\n") - - # Manage exceptions and add error trackback to error queue - except Exception: - logger.error("Error in Writer") - error_q.put (NanocomporeError(traceback.format_exc())) - - finally: - logger.debug("Written Reads:{} Kmers:{}".format(n_reads, n_kmers)) - logger.info ("Output reads written:{}".format(n_reads)) - # Kill error queue with poison pill - error_q.put(None) #~~~~~~~~~~~~~~~~~~~~~~~~~~HELPER CLASSES~~~~~~~~~~~~~~~~~~~~~~~~~~# -class Read (): +class Read: """Helper class representing a single read""" def __init__ (self, read_id, ref_id, sample_name): @@ -375,13 +306,13 @@ def add_event (self, event_d): @property def kmers_status (self): d = OrderedDict() - d["kmers"] = self.ref_end-self.ref_start+1 + d["kmers"] = self.ref_end - self.ref_start + 1 d["missing_kmers"] = d["kmers"] - len(self.kmer_l) - d["NNNNN_kmers"]=0 - d["mismatch_kmers"]=0 - d["valid_kmers"]=0 + d["NNNNN_kmers"] = 0 + d["mismatch_kmers"] = 0 + d["valid_kmers"] = 0 for k in self.kmer_l: - d[k.status+"_kmers"]+=1 + d[k.status + "_kmers"] += 1 return d def get_read_results (self): @@ -401,7 +332,7 @@ def get_kmer_results (self): l = [kmer.get_results() for kmer in self.kmer_l] return l -class Kmer (): +class Kmer: """Helper class representing a single kmer""" def __init__ (self): @@ -464,7 +395,5 @@ def get_results(self): d["mismatch_dwell_time"] = self.mismatch_dwell_time d["status"] = self.status d["median"] = statistics.median(self.sample_list) - d["mad"] = statistics.median([ abs( i-d["median"] ) for i in self.sample_list]) - + d["mad"] = statistics.median([abs(i - d["median"]) for i in self.sample_list]) return d - diff --git a/nanocompore/PostProcess.py b/nanocompore/PostProcess.py new file mode 100644 index 0000000..d9f1ae5 --- /dev/null +++ b/nanocompore/PostProcess.py @@ -0,0 +1,199 @@ +# -*- coding: utf-8 -*- + +#~~~~~~~~~~~~~~IMPORTS~~~~~~~~~~~~~~# +# Std lib +from loguru import logger + +# Third party +# ... + +# Local package +from nanocompore.common import * +from nanocompore.DataStore import DataStore_EventAlign, DataStore_SampComp + +#~~~~~~~~~~~~~~MAIN CLASS~~~~~~~~~~~~~~# +class PostProcess(object): + """Helper class for post-processing `SampComp` results""" + + def __init__(self, sampcomp_db_path:str, eventalign_db_path:str, bed_path:str=None): + self._sampcomp_db_path = sampcomp_db_path + self._eventalign_db_path = eventalign_db_path + self._bed_path = bed_path + + + def save_all(self, outpath_prefix=None, pvalue_thr=0.01): + """ + Save all text reports including genomic coordinate if a bed file was provided + * outpath_prefix + outpath + prefix to use as a basename for output files. + If not given, it will use the same prefix as the database. + * pvalue_thr + pvalue threshold to report significant sites in bed files + """ + if not outpath_prefix: + outpath_prefix = self._db_path.replace("SampComp.db", "") + logger.debug("Save reports to {}".format(outpath_prefix)) + + # Save reports + logger.debug("\tSaving extended tabular report") + self.save_report(output_fn = outpath_prefix + "nanocompore_results.tsv") + logger.debug("\tSaving shift results") + self.save_shift_stats(output_fn = outpath_prefix + "nanocompore_shift_stats.tsv") + + # Save bed and bedgraph files for each method used + if self._bed_path: + logger.debug("\tSaving significant genomic coordinates in Bed and Bedgraph format") + for m in self._metadata["pvalue_tests"]: + self.save_to_bed( + output_fn = outpath_prefix+"sig_sites_{}_thr_{}.bed".format(m, pvalue_thr), + bedgraph=False, pvalue_field=m, pvalue_thr=pvalue_thr, span=5, title="Nanocompore Significant Sites") + self.save_to_bed( + output_fn = outpath_prefix+"sig_sites_{}_thr_{}.bedgraph".format(m, pvalue_thr), + bedgraph=True, pvalue_field=m, pvalue_thr=pvalue_thr, title="Nanocompore Significant Sites") + + + def save_to_bed(self, output_fn=None, bedgraph=False, pvalue_field=None, pvalue_thr=0.01, span=5, convert=None, assembly=None, title=None): + """ + Save the position of significant positions in the genome space in BED6 or BEDGRAPH format. + The resulting file can be used in a genome browser to visualise significant genomic locations. + The option is only available if `SampCompDB` if initialised with a BED file containing genome annotations. + * output_fn + Path to file where to write the data + * bedgraph + save file in bedgraph format instead of bed + * pvalue_field + specifies what column to use as BED score (field 5, as -log10) + * pvalue_thr + only report positions with pvalue<=thr + * span + The size of each BED feature. + If size=5 (default) features correspond to kmers. + If size=1 features correspond to the first base of each kmer. + * convert + one of 'ensembl_to_ucsc' or 'ucsc_to_ensembl". Convert chromosome named between Ensembl and Ucsc conventions + * assembly + required if convert is used. One of "hg38" or "mm10" + """ + if self._bed_path is None: + raise NanocomporeError("In order to generate a BED file PostProcess needs to be initialised with a transcriptome BED") + if span < 1: + raise NanocomporeError("span has to be >=1") + if span != 5 and bedgraph: + raise NanocomporeError("Span is ignored when generating bedGraph files") + if pvalue_field not in self.results: + raise NanocomporeError(("The field '%s' is not in the results" % pvalue_field)) + if "results" not in self.__dict__: + raise NanocomporeError("It looks like there's not results slot in SampCompDB") + if convert not in [None, "ensembl_to_ucsc", "ucsc_to_ensembl"]: + raise NanocomporeError("Convert value not valid") + if convert is not None and assembly is None: + raise NanocomporeError("The assembly argument is required in order to do the conversion. Choose one of 'hg38' or 'mm10' ") + + with open(output_fn, "w") as bed_file: + if title is not None: + if not bedgraph: + bed_file.write('track type=bed name="%s" description="%s"\n' % (title, pvalue_field)) + else: + bed_file.write('track type=bedGraph name="%s" description="%s"\n' % (title, pvalue_field)) + + Record = namedtuple('Record', ['chr', 'genomicPos', 'ref_id', 'strand', 'ref_kmer', pvalue_field]) + threshold = -log(pvalue_thr, 10) + for record in self.results[list(Record._fields)].itertuples(index=False, name="Record"): + pvalue = getattr(record, pvalue_field) + if np.isnan(pvalue): + pvalue = 0 + elif pvalue < sys.float_info.min: + pvalue = -log(sys.float_info.min, 10) + else: + pvalue = -log(pvalue, 10) + if not bedgraph and pvalue < threshold: + continue + if bedgraph: + if record.strand == "+": + start_pos = record.genomicPos + 2 + else: + start_pos = record.genomicPos - 2 + end_pos = start_pos + 1 + else: + if record.strand == "+": + start_pos = record.genomicPos + else: + start_pos = record.genomicPos - span + 1 + end_pos = start_pos + span + line = bedline([record.chr, start_pos, end_pos, "%s_%s" % (record.ref_id, record.ref_kmer), + pvalue, record.strand]) + if convert == "ensembl_to_ucsc": + line = line.translateChr(assembly=assembly, target="ucsc", patches=True) + elif convert == "ucsc_to_ensembl": + line = line.translateChr(assembly=assembly, target="ens", patches=True) + if bedgraph: + bed_file.write("%s\t%s\t%s\t%s\n" % (line.chr, line.start, line.end, line.score)) + else: + bed_file.write("%s\t%s\t%s\t%s\t%s\t%s\n" % (line.chr, line.start, line.end, + line.name, line.score, line.strand)) + + + def save_report(self, output_fn:str=None, include_shift_stats:bool=True): + """ + Saves a tabulated text dump of the database containing all the statistical results for all the positions + * output_fn + Path to file where to write the data. If None, data is returned to the standard output. + """ + ## TODO: can this be done in a "with ..." clause? + if output_fn is None: + fp = sys.stdout + elif isinstance(output_fn, str): + try: + fp = open(output_fn, "w") + except: + raise NanocomporeError("Error opening output file %s" % output_fn) + else: + raise NanocomporeError("output_fn needs to be a string or None") + + with DataStore_SampComp(self._sampcomp_db_path) as sc_db, \ + DataStore_EventAlign(self._eventalign_db_path) as ea_db: + # do we have GMM results? + query = "SELECT 1 FROM sqlite_master WHERE type = 'table' AND name = 'gmm_stats'" + sc_db.cursor.execute(query) + with_gmm = sc_db.cursor.fetchone() is not None + query = "SELECT * FROM kmer_stats LEFT JOIN transcripts ON transcriptid = transcripts.id" + if with_gmm: + query += " LEFT JOIN gmm_stats ON kmer_stats.id = gmm_stats.kmer_statsid" + query += " ORDER BY transcriptid, kmer" + first_row = True + shift_stat_columns = [] + univariate_pvalue_columns = [] + gmm_pvalue_columns = [] + for row in sc_db.cursor.execute(query): + # retrieve k-mer sequence: + ea_query = "SELECT sequence FROM kmers LEFT JOIN reads ON readid = reads.id WHERE transcriptid = ? AND position = ? LIMIT 1" + ea_db.cursor.execute(ea_query, (row["transcriptid"], row["kmer"])) + seq = ea_db.cursor.fetchone()[0] + out_dict = {"transcript": row["name"], + "position": row["kmer"], + "sequence": seq} + # TODO: add chromosome, genomic pos., strand information (from where?) + if first_row: # check which columns we have (do this only once) + univariate_pvalue_columns = [col for col in row.keys() + if ("intensity_pvalue" in col) or ("dwell_pvalue" in col)] + if include_shift_stats: + shift_stat_columns = [col for col in row.keys() if col.startswith(("c1_", "c2_"))] + if with_gmm: + gmm_pvalue_columns = [col for col in row.keys() if "test_pvalue" in col] + + for col in shift_stat_columns: + out_dict[col] = row[col] + for col in univariate_pvalue_columns: + out_dict[col] = row[col] + if with_gmm: + out_dict["GMM_n_components"] = row["n_components"] + out_dict["GMM_cluster_counts"] = row["cluster_counts"] + out_dict["GMM_test_stat"] = row["test_stat"] + for col in gmm_pvalue_columns: + out_dict[col.replace("test", "GMM", 1)] = row[col] + + if first_row: # write header line + fp.write("\t".join(out_dict.keys()) + "\n") + # write output data: + fp.write("\t".join(str(x) for x in out_dict.values()) + "\n") + first_row = False diff --git a/nanocompore/SampComp.py b/nanocompore/SampComp.py index 108f779..7303b4d 100644 --- a/nanocompore/SampComp.py +++ b/nanocompore/SampComp.py @@ -11,15 +11,16 @@ # Third party from loguru import logger -import yaml from tqdm import tqdm import numpy as np from pyfaidx import Fasta +from statsmodels.stats.multitest import multipletests # Local package from nanocompore.common import * +from nanocompore.DataStore import * from nanocompore.Whitelist import Whitelist -from nanocompore.TxComp import txCompare +from nanocompore.TxComp import TxComp from nanocompore.SampCompDB import SampCompDB import nanocompore as pkg @@ -28,69 +29,69 @@ os.environ["MKL_THREADING_LAYER"] = "sequential" os.environ["NUMEXPR_NUM_THREADS"] = "1" os.environ["OMP_NUM_THREADS"] = "1" -os.environ['OPENBLAS_NUM_THREADS'] = '1' +os.environ["OPENBLAS_NUM_THREADS"] = "1" #~~~~~~~~~~~~~~MAIN CLASS~~~~~~~~~~~~~~# class SampComp(object): - """ Init analysis and check args""" + """Init analysis and check args""" #~~~~~~~~~~~~~~FUNDAMENTAL METHODS~~~~~~~~~~~~~~# + # TODO: use enums for univariate and gmm test parameters? def __init__(self, - eventalign_fn_dict:dict, - fasta_fn:str, - bed_fn:str = None, - outpath:str = "results", - outprefix:str = "out_", - overwrite:bool = False, - whitelist:Whitelist = None, - comparison_methods:list = ["GMM", "KS"], - logit:bool = True, - anova:bool = False, - allow_warnings:bool = False, - sequence_context:int = 0, - sequence_context_weights:str = "uniform", - min_coverage:int = 30, - min_ref_length:int = 100, - downsample_high_coverage:int = 5000, - max_invalid_kmers_freq:float = 0.1, - select_ref_id:list = [], - exclude_ref_id:list = [], - nthreads:int = 3, - progress:bool = False): + input_db_path:str, + output_db_path:str, + sample_dict:dict, + fasta_fn:str = "", + overwrite:bool = False, + whitelist:Whitelist = None, + univariate_test:str = "KS", # or: "MW", "ST" + fit_gmm:bool = True, + gmm_test:str = "logit", # or: "anova" + allow_anova_warnings:bool = False, + sequence_context:int = 0, + sequence_context_weighting:str = "uniform", + min_coverage:int = 30, + min_ref_length:int = 100, + downsample_high_coverage:int = 5000, + max_invalid_kmers_freq:float = 0.1, + select_ref_id:list = [], + exclude_ref_id:list = [], + nthreads:int = 3, + progress:bool = False): """ - Initialise a `SampComp` object and generates a white list of references with sufficient coverage for subsequent analysis. + Initialise a `SampComp` object and generate a whitelist of references with sufficient coverage for subsequent analysis. The retuned object can then be called to start the analysis. - * eventalign_fn_dict - Multilevel dictionnary indicating the condition_label, sample_label and file name of the eventalign_collapse output. - 2 conditions are expected and at least 2 sample replicates per condition are highly recommended. - One can also pass YAML file describing the samples instead. - Example `d = {"S1": {"R1":"path1.tsv", "R2":"path2.tsv"}, "S2": {"R1":"path3.tsv", "R2":"path4.tsv"}}` - * outpath - Path to the output folder. - * outprefix - text outprefix for all the files generated by the function. - * overwrite - If the output directory already exists, the standard behaviour is to raise an error to prevent overwriting existing data - This option ignore the error and overwrite data if they have the same outpath and outprefix. + Args: + * input_db_path + Path to the SQLite database file with event-aligned read/kmer data + * output_db_path + Path to the SQLite database file for storing results + * sample_dict + Dictionary containing lists of (unique) sample names, grouped by condition + Example: d = {"control": ["C1", "C2"], "treatment": ["T1", "T2"]} * fasta_fn Path to a fasta file corresponding to the reference used for read alignment. - * bed_fn - Path to a BED file containing the annotation of the transcriptome used as reference when mapping. + Not needed if 'whitelist' argument is provided. + * overwrite + If the output database already exists, overwrite it with a new database? + By default, new data will be added to previous data. * whitelist - Whitelist object previously generated with nanocompore Whitelist. If not given, will be automatically generated. - * comparison_methods - Statistical method to compare the 2 samples (mann_whitney or MW, kolmogorov_smirnov or KS, t_test or TT, gaussian_mixture_model or GMM). - This can be a list or a comma separated string. {MW,KS,TT,GMM} - * logit - Force logistic regression even if we have less than 2 replicates in any condition. - * allow_warnings + Whitelist object previously generated with nanocompore Whitelist. + If not given, will be automatically generated. + * univariate_test + Statistical test to compare the two conditions ('MW' for Mann-Whitney, 'KS' for Kolmogorov-Smirnov or 'ST' for Student's t), or empty for no test. + * fit_gmm + Fit a Gaussian mixture model (GMM) to the intensity/dwell-time distribution? + * gmm_test + Method to compare samples based on the GMM ('logit' or 'anova'), or empty for no comparison. + * allow_anova_warnings If True runtime warnings during the ANOVA tests don't raise an error. * sequence_context - Extend statistical analysis to contigous adjacent base if available. - * sequence_context_weights - type of weights to used for combining p-values. {uniform,harmonic} + Extend statistical analysis to contiguous adjacent bases if available. + * sequence_context_weighting + type of weighting to used for combining p-values. {uniform,harmonic} * min_coverage minimal read coverage required in all sample. * min_ref_length @@ -113,84 +114,77 @@ def __init__(self, # Save init options in dict for later log_init_state(loc=locals()) - # If eventalign_fn_dict is not a dict try to load a YAML file instead - if type(eventalign_fn_dict) == str: - logger.debug("Parsing YAML file") - if not access_file(eventalign_fn_dict): - raise NanocomporeError("{} is not a valid file".format(eventalign_fn_dict)) - with open(eventalign_fn_dict, "r") as fp: - eventalign_fn_dict = yaml.load(fp, Loader=yaml.SafeLoader) - # Check eventalign_dict file paths and labels - eventalign_fn_dict = self.__check_eventalign_fn_dict(eventalign_fn_dict) - logger.debug(eventalign_fn_dict) - - # Check if fasta and bed files exist - if not access_file(fasta_fn): - raise NanocomporeError("{} is not a valid FASTA file".format(fasta_fn)) - if bed_fn and not access_file(bed_fn): - raise NanocomporeError("{} is not a valid BED file".format(bed_fn)) + check_sample_dict(sample_dict) + logger.debug(sample_dict) # Check threads number if nthreads < 3: raise NanocomporeError("The minimum number of threads is 3") # Parse comparison methods - if comparison_methods: - if type(comparison_methods) == str: - comparison_methods = comparison_methods.split(",") - for i, method in enumerate(comparison_methods): - method = method.upper() - if method in ["MANN_WHITNEY", "MW"]: - comparison_methods[i]="MW" - elif method in ["KOLMOGOROV_SMIRNOV", "KS"]: - comparison_methods[i]="KS" - elif method in ["T_TEST", "TT"]: - comparison_methods[i]="TT" - elif method in ["GAUSSIAN_MIXTURE_MODEL", "GMM"]: - comparison_methods[i]="GMM" - else: - raise NanocomporeError("Invalid comparison method {}".format(method)) + if univariate_test and (univariate_test not in ["MW", "KS", "ST"]): + raise NanocomporeError(f"Invalid univariate test {univariate_test}") + if fit_gmm and gmm_test and (gmm_test not in ["logit", "anova"]): + raise NanocomporeError(f"Invalid GMM-based test {gmm_test}") if not whitelist: - whitelist = Whitelist( - eventalign_fn_dict = eventalign_fn_dict, - fasta_fn = fasta_fn, - min_coverage = min_coverage, - min_ref_length = min_ref_length, - downsample_high_coverage = downsample_high_coverage, - max_invalid_kmers_freq = max_invalid_kmers_freq, - select_ref_id = select_ref_id, - exclude_ref_id = exclude_ref_id) + whitelist = Whitelist(input_db_path, + sample_dict, + fasta_fn, + min_coverage = min_coverage, + min_ref_length = min_ref_length, + downsample_high_coverage = downsample_high_coverage, + max_invalid_kmers_freq = max_invalid_kmers_freq, + select_ref_id = select_ref_id, + exclude_ref_id = exclude_ref_id) elif not isinstance(whitelist, Whitelist): raise NanocomporeError("Whitelist is not valid") + self.__output_db_path = output_db_path + self.__db_args = {"with_gmm": fit_gmm, "with_sequence_context": (sequence_context > 0)} + db_create_mode = DBCreateMode.OVERWRITE if overwrite else DBCreateMode.CREATE_MAYBE + db = DataStore_SampComp(output_db_path, db_create_mode, **self.__db_args) + with db: + db.store_whitelist(whitelist) + # TODO: move this to '__call__'? + # Set private args from whitelist args self.__min_coverage = whitelist._Whitelist__min_coverage self.__downsample_high_coverage = whitelist._Whitelist__downsample_high_coverage self.__max_invalid_kmers_freq = whitelist._Whitelist__max_invalid_kmers_freq # Save private args - self.__eventalign_fn_dict = eventalign_fn_dict - self.__db_fn = os.path.join(outpath, outprefix+"SampComp.db") + self.__input_db_path = input_db_path + self.__sample_dict = sample_dict self.__fasta_fn = fasta_fn - self.__bed_fn = bed_fn self.__whitelist = whitelist - self.__comparison_methods = comparison_methods - self.__logit = logit - self.__anova = anova - self.__allow_warnings = allow_warnings - self.__sequence_context = sequence_context - self.__sequence_context_weights = sequence_context_weights self.__nthreads = nthreads - 2 self.__progress = progress # Get number of samples - n = 0 - for sample_dict in self.__eventalign_fn_dict.values(): - for sample_lab in sample_dict.keys(): - n+=1 - self.__n_samples = n + self.__n_samples = 0 + for samples in sample_dict.values(): + self.__n_samples += len(samples) + + # If statistical tests are requested, initialise the "TxComp" object: + if univariate_test or fit_gmm: + random_state = np.random.RandomState(seed=42) + self.__tx_compare = TxComp(random_state, + univariate_test=univariate_test, + fit_gmm=fit_gmm, + gmm_test=gmm_test, + sequence_context=sequence_context, + sequence_context_weighting=sequence_context_weighting, + min_coverage=self.__min_coverage, + allow_anova_warnings=allow_anova_warnings) + else: + self.__tx_compare = None + ## used to adjust p-values: + self.__univariate_test = univariate_test + self.__gmm_test = gmm_test if fit_gmm else "" + self.__sequence_context = (sequence_context > 0) + def __call__(self): """ @@ -207,7 +201,7 @@ def __call__(self): ps_list.append(mp.Process(target=self.__list_refid, args=(in_q, error_q))) for i in range(self.__nthreads): ps_list.append(mp.Process(target=self.__process_references, args=(in_q, out_q, error_q))) - ps_list.append(mp.Process(target=self.__write_output, args=(out_q, error_q))) + ps_list.append(mp.Process(target=self.__write_output_to_db, args=(out_q, error_q))) # Start processes and monitor error queue try: @@ -226,12 +220,6 @@ def __call__(self): for q in (in_q, out_q, error_q): q.close() - # Return database wrapper object - return SampCompDB( - db_fn=self.__db_fn, - fasta_fn=self.__fasta_fn, - bed_fn=self.__bed_fn) - # Catch error, kill all processed and reraise error except Exception as E: logger.error("An error occured. Killing all processes and closing queues\n") @@ -244,6 +232,67 @@ def __call__(self): logger.error("An error occured while trying to kill processes\n") raise E + # Adjust p-values for multiple testing: + if self.__univariate_test or self.__gmm_test: + logger.info("Running multiple testing correction") + self.__adjust_pvalues() + # context-based p-values are not independent tests, so adjust them separately: + if self.__sequence_context: + self.__adjust_pvalues(sequence_context=True) + + + def process_transcript(self, tx_id, whitelist_reads): + """Process a transcript given filtered reads from Whitelist""" + logger.debug(f"Processing transcript: {tx_id}") + + # Kmer data from whitelisted reads from all samples for this transcript + # Structure: kmer position -> condition -> sample -> data + kmer_data = defaultdict(lambda: {condition: + defaultdict(lambda: {"intensity": [], + "dwell": [], + "coverage": 0, + "kmers_stats": {"valid": 0, + # "missing": 0, # TODO: needed? + "NNNNN": 0, + "mismatching": 0}}) + for condition in self.__sample_dict}) + n_reads = n_kmers = 0 + + # Read kmer data from database + with DataStore_EventAlign(self.__input_db_path) as db: + for cond_lab, sample_dict in whitelist_reads.items(): + for sample_id, read_ids in sample_dict.items(): + if not read_ids: continue # TODO: error? + n_reads += len(read_ids) + values = ", ".join([str(read_id) for read_id in read_ids]) + query = f"SELECT * FROM kmers WHERE readid IN ({values})" + for row in db.cursor.execute(query): + n_kmers += 1 + pos = row["position"] + # TODO: check that kmer seq. agrees with FASTA? + data = kmer_data[pos][cond_lab][sample_id] + data["intensity"].append(row["median"]) + data["dwell"].append(row["dwell_time"]) + data["coverage"] += 1 + status = row["status"] + data["kmers_stats"][status] += 1 + + logger.debug(f"Data loaded for transcript: {tx_id}") + test_results = {} + if self.__tx_compare: + test_results = self.__tx_compare(tx_id, kmer_data) + # TODO: check "gmm_anova_failed" state of TxComp object + + # Remove 'default_factory' functions from 'kmer_data' to enable pickle/multiprocessing + kmer_data.default_factory = None + for pos_dict in kmer_data.values(): + for cond_dict in pos_dict.values(): + cond_dict.default_factory = None + + return {"kmer_data": kmer_data, "test_results": test_results, + "n_reads": n_reads, "n_kmers": n_kmers} + + #~~~~~~~~~~~~~~PRIVATE MULTIPROCESSING METHOD~~~~~~~~~~~~~~# def __list_refid(self, in_q, error_q): """Add valid refid from whitelist to input queue to dispatch the data among the workers""" @@ -252,7 +301,7 @@ def __list_refid(self, in_q, error_q): for ref_id, ref_dict in self.__whitelist: logger.debug("Adding {} to in_q".format(ref_id)) in_q.put((ref_id, ref_dict)) - n_tx+=1 + n_tx += 1 # Manage exceptions and add error trackback to error queue except Exception: @@ -265,101 +314,26 @@ def __list_refid(self, in_q, error_q): in_q.put(None) logger.debug("Parsed transcripts:{}".format(n_tx)) + def __process_references(self, in_q, out_q, error_q): """ Consume ref_id, agregate intensity and dwell time at position level and perform statistical analyses to find significantly different regions """ - n_tx = n_reads = n_lines = 0 + n_tx = n_reads = n_kmers = 0 try: logger.debug("Worker thread started") - # Open all files for reading. File pointer are stored in a dict matching the ref_dict entries - fp_dict = self.__eventalign_fn_open() - - # Process refid in input queue + # Process references in input queue for ref_id, ref_dict in iter(in_q.get, None): - logger.debug("Worker thread processing new item from in_q: {}".format(ref_id)) - # Create an empty dict for all positions first - ref_pos_list = self.__make_ref_pos_list(ref_id) - - for cond_lab, sample_dict in ref_dict.items(): - for sample_lab, read_list in sample_dict.items(): - fp = fp_dict[cond_lab][sample_lab] - - for read in read_list: - - # Move to read, save read data chunk and reset file pointer - fp.seek(read["byte_offset"]) - line_list = fp.read(read["byte_len"]).split("\n") - fp.seek(0) - - # Check read_id ref_id concordance between index and data file - header = numeric_cast_list(line_list[0][1:].split("\t")) - if not header[0] == read["read_id"] or not header[1] == read["ref_id"]: - raise NanocomporeError("Index and data files are not matching:\n{}\n{}".format(header, read)) - - # Extract col names from second line - col_names = line_list[1].split("\t") - # Check that all required fields are present - if not all_values_in (["ref_pos", "ref_kmer", "median", "dwell_time"], col_names): - raise NanocomporeError("Required fields not found in the data file: {}".format(col_names)) - # Verify if kmers events stats values are present or not - kmers_stats = all_values_in (["NNNNN_dwell_time", "mismatch_dwell_time"], col_names) - - # Parse data files kmers per kmers - prev_pos = None - for line in line_list[2:]: - # Transform line to dict and cast str numbers to actual numbers - kmer = numeric_cast_dict(keys=col_names, values=line.split("\t")) - pos = kmer["ref_pos"] - - # Check consistance between eventalign data and reference sequence - if kmer["ref_kmer"] != ref_pos_list[pos]["ref_kmer"]: - ref_pos_list[pos]["ref_kmer"] = ref_pos_list[pos]["ref_kmer"]+"!!!!" - #raise NanocomporeError ("Data reference kmer({}) doesn't correspond to the reference sequence ({})".format(ref_pos_list[pos]["ref_kmer"], kmer["ref_kmer"])) - - # Fill dict with the current pos values - ref_pos_list[pos]["data"][cond_lab][sample_lab]["intensity"].append(kmer["median"]) - ref_pos_list[pos]["data"][cond_lab][sample_lab]["dwell"].append(kmer["dwell_time"]) - ref_pos_list[pos]["data"][cond_lab][sample_lab]["coverage"] += 1 - - if kmers_stats: - # Fill in the missing positions - if prev_pos and pos-prev_pos > 1: - for missing_pos in range(prev_pos+1, pos): - ref_pos_list[missing_pos]["data"][cond_lab][sample_lab]["kmers_stats"]["missing"] += 1 - # Also fill in with normalised position event stats - n_valid = (kmer["dwell_time"]-(kmer["NNNNN_dwell_time"]+kmer["mismatch_dwell_time"])) / kmer["dwell_time"] - n_NNNNN = kmer["NNNNN_dwell_time"] / kmer["dwell_time"] - n_mismatching = kmer["mismatch_dwell_time"] / kmer["dwell_time"] - ref_pos_list[pos]["data"][cond_lab][sample_lab]["kmers_stats"]["valid"] += n_valid - ref_pos_list[pos]["data"][cond_lab][sample_lab]["kmers_stats"]["NNNNN"] += n_NNNNN - ref_pos_list[pos]["data"][cond_lab][sample_lab]["kmers_stats"]["mismatching"] += n_mismatching - # Save previous position - prev_pos = pos - - n_lines+=1 - n_reads+=1 - - logger.debug("Data for {} loaded.".format(ref_id)) - if self.__comparison_methods: - random_state=np.random.RandomState(seed=42) - ref_pos_list = txCompare( - ref_id=ref_id, - ref_pos_list=ref_pos_list, - methods=self.__comparison_methods, - sequence_context=self.__sequence_context, - sequence_context_weights=self.__sequence_context_weights, - min_coverage= self.__min_coverage, - allow_warnings=self.__allow_warnings, - logit=self.__logit, - anova=self.__anova, - random_state=random_state) + logger.debug(f"Worker thread processing new item from in_q: {ref_id}") + results = self.process_transcript(ref_id, ref_dict) + n_tx += 1 + n_reads += results["n_reads"] + n_kmers += results["n_kmers"] # Add the current read details to queue - logger.debug("Adding %s to out_q"%(ref_id)) - out_q.put((ref_id, ref_pos_list)) - n_tx+=1 + logger.debug(f"Adding '{ref_id}' to out_q") + out_q.put((ref_id, results["test_results"])) # Manage exceptions and add error trackback to error queue except Exception as e: @@ -368,133 +342,80 @@ def __process_references(self, in_q, out_q, error_q): # Deal poison pill and close file pointer finally: - logger.debug("Processed Transcrits:{} Reads:{} Lines:{}".format(n_tx, n_reads, n_lines)) + logger.debug(f"Processed {n_tx} transcripts, {n_reads} reads, {n_kmers} kmers") logger.debug("Adding poison pill to out_q") - self.__eventalign_fn_close(fp_dict) out_q.put(None) - def __write_output(self, out_q, error_q): - # Get results out of the out queue and write in shelve - pvalue_tests = set() - ref_id_list = [] - n_tx = n_pos = 0 + + def __write_output_to_db(self, out_q, error_q): + n_tx = 0 try: - with shelve.open(self.__db_fn, flag='n') as db, tqdm(total=len(self.__whitelist), unit=" Processed References", disable= not self.__progress) as pbar: + # Database was already created earlier to store the whitelist! + db = DataStore_SampComp(self.__output_db_path, DBCreateMode.MUST_EXIST, **self.__db_args) + with db: # Iterate over the counter queue and process items until all poison pills are found for _ in range(self.__nthreads): - for ref_id, ref_pos_list in iter(out_q.get, None): - ref_id_list.append(ref_id) - logger.debug("Writer thread writing %s"%ref_id) - # Get pvalue fields available in analysed data before - for pos_dict in ref_pos_list: - if 'txComp' in pos_dict: - for res in pos_dict['txComp'].keys(): - if "pvalue" in res: - n_pos+=1 - pvalue_tests.add(res) - # Write results in a shelve db - db [ref_id] = ref_pos_list - pbar.update(1) - n_tx+=1 - - # Write list of refid - db["__ref_id_list"] = ref_id_list - - # Write metadata - db["__metadata"] = { - "package_name": pkg.__version__, - "package_version": pkg.__name__, - "timestamp": str(datetime.datetime.now()), - "comparison_methods": self.__comparison_methods, - "pvalue_tests": sorted(list(pvalue_tests)), - "sequence_context": self.__sequence_context, - "min_coverage": self.__min_coverage, - "n_samples": self.__n_samples} - - # Manage exceptions and add error trackback to error queue + for ref_id, test_results in iter(out_q.get, None): + logger.debug("Writer thread storing transcript %s" % ref_id) + db.store_test_results(ref_id, test_results) + n_tx += 1 except Exception: - logger.error("Error in Writer") + logger.error("Error in writer thread") error_q.put(traceback.format_exc()) - finally: - logger.debug("Written Transcripts:{} Valid positions:{}".format(n_tx, n_pos)) - logger.info ("All Done. Transcripts processed: {}".format(n_tx)) + logger.info(f"All done. Transcripts processed: {n_tx}") # Kill error queue with poison pill error_q.put(None) - #~~~~~~~~~~~~~~PRIVATE HELPER METHODS~~~~~~~~~~~~~~# - def __check_eventalign_fn_dict(self, d): - """""" - # Check that the number of condition is 2 and raise a warning if there are less than 2 replicates per conditions - if len(d) != 2: - raise NanocomporeError("2 conditions are expected. Found {}".format(len(d))) - for cond_lab, sample_dict in d.items(): - if len(sample_dict) == 1: - logger.info("Only 1 replicate found for condition {}".format(cond_lab)) - logger.info("This is not recommended. The statistics will be calculated with the logit method") - - # Test if files are accessible and verify that there are no duplicated replicate labels - duplicated_lab = False - rep_lab_list = [] - rep_fn_list = [] - for cond_lab, sd in d.items(): - for rep_lab, fn in sd.items(): - if not access_file(fn): - raise NanocomporeError("Cannot access eventalign file: {}".format(fn)) - if fn in rep_fn_list: - raise NanocomporeError("Duplicated eventalign file detected: {}".format(fn)) - if rep_lab in rep_lab_list: - duplicated_lab = True - rep_lab_list.append(rep_lab) - rep_fn_list.append(fn) - if not duplicated_lab: - return d - - # If duplicated replicate labels found, prefix labels with condition name - else: - logger.debug("Found duplicated labels in the replicate names. Prefixing with condition name") - d_clean = OrderedDict() - for cond_lab, sd in d.items(): - d_clean[cond_lab] = OrderedDict() - for rep_lab, fn in sd.items(): - d_clean[cond_lab]["{}_{}".format(cond_lab, rep_lab)] = fn - return d_clean - - def __eventalign_fn_open(self): - """""" - fp_dict = OrderedDict() - for cond_lab, sample_dict in self.__eventalign_fn_dict.items(): - fp_dict[cond_lab] = OrderedDict() - for sample_lab, fn in sample_dict.items(): - fp_dict[cond_lab][sample_lab] = open(fn, "r") - return fp_dict - - def __eventalign_fn_close(self, fp_dict): - """""" - for sample_dict in fp_dict.values(): - for fp in sample_dict.values(): - fp.close() - - def __make_ref_pos_list(self, ref_id): - """""" - ref_pos_list = [] - with Fasta(self.__fasta_fn) as fasta: - ref_fasta = fasta [ref_id] - ref_len = len(ref_fasta) - ref_seq = str(ref_fasta) - - for pos in range(ref_len-4): - pos_dict = OrderedDict() - pos_dict["ref_kmer"] = ref_seq[pos:pos+5] - pos_dict["data"] = OrderedDict() - for cond_lab, s_dict in self.__eventalign_fn_dict.items(): - pos_dict["data"][cond_lab] = OrderedDict() - for sample_lab in s_dict.keys(): - - pos_dict["data"][cond_lab][sample_lab] = { - "intensity":[], - "dwell":[], - "coverage":0, - "kmers_stats":{"missing":0,"valid":0,"NNNNN":0,"mismatching":0}} - ref_pos_list.append(pos_dict) - return ref_pos_list + + # TODO: move this to 'DataStore_SampComp'? + def __adjust_pvalues(self, method="fdr_bh", sequence_context=False): + """Perform multiple testing correction of p-values and update database""" + db = DataStore_SampComp(self.__output_db_path, DBCreateMode.MUST_EXIST, **self.__db_args) + with db: + pvalues = [] + index = [] + # for "context-averaged" p-values, add a suffix to the column names: + col_suffix = "_context" if sequence_context else "" + if self.__univariate_test: + query = f"SELECT id, intensity_pvalue{col_suffix}, dwell_pvalue{col_suffix} FROM kmer_stats" + try: + for row in db.cursor.execute(query): + for pv_col in ["intensity_pvalue", "dwell_pvalue"]: + pv_col += col_suffix + pv = row[pv_col] + # "multipletests" doesn't handle NaN values well, so skip those: + if (pv is not None) and not np.isnan(pv): + pvalues.append(pv) + index.append({"table": "kmer_stats", "id_col": "id", + "id": row["id"], "pv_col": pv_col}) + except: + logger.error("Error reading p-values from table 'kmer_stats'") + raise + if self.__gmm_test: + pv_col = "test_pvalue" + col_suffix + query = f"SELECT kmer_statsid, {pv_col} FROM gmm_stats WHERE {pv_col} IS NOT NULL" + try: + for row in db.cursor.execute(query): + pv = row[pv_col] + # "multipletests" doesn't handle NaN values well, so skip those: + if not np.isnan(pv): # 'None' (NULL) values have been excluded in the query + pvalues.append(pv) + index.append({"table": "gmm_stats", "id_col": "kmer_statsid", + "id": row["kmer_statsid"], "pv_col": pv_col}) + except: + logger.error("Error reading p-values from table 'gmm_stats'") + raise + logger.debug(f"Number of p-values for multiple testing correction: {len(pvalues)}") + if not pvalues: + return + adjusted = multipletests(pvalues, method=method)[1] + assert len(pvalues) == len(adjusted) + # sqlite module can't handle numpy float64 values, so convert to floats using "tolist": + for ind, adj_pv in zip(index, adjusted.tolist()): + query = "UPDATE {table} SET adj_{pv_col} = ? WHERE {id_col} = {id}".format_map(ind) + try: + db.cursor.execute(query, (adj_pv, )) + except: + logger.error("Error updating adjusted p-value for ID {id} in table '{table}'".format_map(ind)) + raise diff --git a/nanocompore/SuperParser.py b/nanocompore/SuperParser.py index ea8b528..e4cf46c 100644 --- a/nanocompore/SuperParser.py +++ b/nanocompore/SuperParser.py @@ -20,18 +20,20 @@ def __init__ (self, change_colnames={}, cast_colnames={}): """ - Open a parser for field delimited file and return an iterator yield lines as namedtuples - Transparently parse gziped and multiple file with the same header + Open a parser for field-delimited files and return an iterator yielding lines as namedtuples + Transparently parse gzipped files and multiple files with the same header * fn Path to a field delimited file or a list of files or a regex or a mix of everything (Can be gzipped) * select_colnames - List of column names to use parse and return + List of column names to parse and return * sep field separator * comment skip any line starting with this string + * change_colnames + Dict with mapping (from: to) for changing column names * cast_colnames - Dict corresponding to fields (based on colnames) to cast in a given python type + Dict corresponding to fields (based on colnames) to cast into a given python type """ # Init logger and counter @@ -43,7 +45,7 @@ def __init__ (self, self._n_lines = n_lines # Input file opening - self.f_list = self._open_files (fn) + self.f_list = self._open_files(fn) # Define colnames based on file header. It needs to be the same for all the files to parse fn, fp = self.f_list[0] @@ -88,16 +90,16 @@ def __init__ (self, #~~~~~~~~~~~~~~MAGIC AND PROPERTY METHODS~~~~~~~~~~~~~~# - def __len__ (self): + def __len__(self): size = 0 for fn, fp in self.f_list: size+= int(os.path.getsize(fn)) return size-self._header_len - def __enter__ (self): + def __enter__(self): return self - def close (self): + def close(self): for i, j in self.counter.most_common(): logger.debug("{}: {}".format(i, j)) for fn, fp in self.f_list: @@ -110,26 +112,26 @@ def close (self): def __exit__(self, exception_type, exception_val, trace): self.close() - def __iter__ (self): + def __iter__(self): # Iterate over files for fn, fp in self.f_list: logger.debug("Starting to parse file {}".format(fn)) # Iterate over line in file for line in fp: - self.counter["Lines Parsed"]+=1 + self.counter["Lines Parsed"] += 1 if self._comment and line.startswith(self._comment): - self.counter["Comment lines skipped"]+=1 + self.counter["Comment lines skipped"] += 1 continue try: line = self._parse_line(line) - self.counter["Lines successfully parsed"]+=1 + self.counter["Lines successfully parsed"] += 1 yield line # early stopping condition if self._n_lines and self.counter["Lines successfully parsed"] == self._n_lines: return except (SuperParserError, TypeError) as E: - self.counter["Malformed or Invalid Lines"]+=1 + self.counter["Malformed or Invalid Lines"] += 1 logger.debug(E) logger.debug("File {}: Invalid line {}".format(fn, line)) logger.debug("End of file: {}".format(fn)) @@ -137,12 +139,12 @@ def __iter__ (self): #~~~~~~~~~~~~~~PRIVATE METHODS~~~~~~~~~~~~~~# - def _get_first_line_header (self, fp): + def _get_first_line_header(self, fp): header_line = next(fp) header_list = header_line.rstrip().split(self._sep) return header_list - def _parse_line (self, line): + def _parse_line(self, line): # Split line line = line.rstrip().split(self._sep) @@ -167,12 +169,12 @@ def _parse_line (self, line): # Return parsed line as a dict return line_d - def _open_files (self, fn_list): + def _open_files(self, fn_list): """Transparently open files, lists, regex, gzipped or not""" f_list = [] # Standard input - if fn_list is 0: + if fn_list == 0: fn = "stdin" fp = open(0) return [(fn,fp)] @@ -184,7 +186,7 @@ def _open_files (self, fn_list): if isinstance(fn_list, (list, tuple, set)): for fn_regex in fn_list: for fn in iglob(fn_regex): - self.counter["Input files"]+=1 + self.counter["Input files"] += 1 if fn.endswith(".gz"): logger.debug("Opening file {} in gzip mode".format(fn)) fp = gzip.open(fn, "rt") @@ -192,9 +194,7 @@ def _open_files (self, fn_list): logger.debug("Opening file {} in normal mode".format(fn)) fp = open(fn, "r") f_list.append((fn,fp)) - return f_list - else: raise SuperParserError ("Invalid file type") diff --git a/nanocompore/TxComp.py b/nanocompore/TxComp.py index 5898d5c..4565b4d 100644 --- a/nanocompore/TxComp.py +++ b/nanocompore/TxComp.py @@ -5,7 +5,6 @@ from collections import OrderedDict, Counter, defaultdict import warnings - # Third party from loguru import logger from scipy.stats import mannwhitneyu, ttest_ind, chi2, f_oneway @@ -21,400 +20,418 @@ from nanocompore.common import * -def txCompare( - ref_id, - ref_pos_list, - random_state, - methods=None, - sequence_context=0, - min_coverage=20, - ref=None, - sequence_context_weights="uniform", - anova=True, - logit=False, - allow_warnings=False): - logger.debug("TxCompare") - - if sequence_context_weights != "uniform" and sequence_context_weights != "harmonic": - raise NanocomporeError("Invalid sequence_context_weights (uniform or harmonic)") - - n_lowcov = 0 - tests = set() - # If we have less than 2 replicates in any condition skip anova and force logit method - if not all([ len(i)>1 for i in ref_pos_list[0]['data'].values() ]): - anova=False - logit=True - for pos, pos_dict in enumerate(ref_pos_list): - logger.trace(f"Processing position {pos}") - # Filter out low coverage positions - lowcov = False - for cond_dict in pos_dict["data"].values(): - for sample_val in cond_dict.values(): - if sample_val["coverage"] < min_coverage: - lowcov=True - ref_pos_list[pos]["lowCov"]=lowcov - - # Perform stat tests if not low cov - if lowcov: - logger.trace(f"Position {pos} is low coverage, skipping") - n_lowcov+=1 +class TxComp(object): + """Compare transcript data from two samples using statistical methods""" + + def __init__(self, + random_state, + univariate_test:str, + fit_gmm:bool, + gmm_test:str, + sequence_context:int=0, + sequence_context_weighting:str="uniform", # or: "harmonic" + min_coverage:int=20, + allow_anova_warnings:bool=False): + self.__random_state = random_state + self.__univariate_test = univariate_test + self.__fit_gmm = fit_gmm + self.__gmm_test = gmm_test + self.__min_coverage = min_coverage + self.__sequence_context = sequence_context + if sequence_context > 0: + if sequence_context_weighting == "harmonic": + # Generate weights as a symmetrical harmonic series + self.__sequence_context_weights = self.__harmonic_series() + elif sequence_context_weighting == "uniform": + self.__sequence_context_weights = [1] * (2 * self.__sequence_context + 1) + else: + raise NanocomporeError("Invalid sequence context weighting ('uniform' or 'harmonic')") + self.__allow_anova_warnings = allow_anova_warnings + self.gmm_anova_failed = False + + + def __call__(self, ref_id, kmer_data): + """Perform comparisons for one transcript ('ref_id') given k-mer data""" + logger.debug("TxComp()") + + n_lowcov = 0 + # If we have less than 2 replicates in any condition skip anova and force logit method + # TODO: looking at the first kmer only may not be reliable - find a better way + if self.__fit_gmm and (self.__gmm_test == "anova") and \ + not all([len(samples) > 1 for samples in next(iter(kmer_data.values())).values()]): + logger.warning("Not enough replicates for 'anova' GMM test. Switching to 'logit' test.") + self.__gmm_test = "logit" + self.gmm_anova_failed = True else: - res = dict() - data = pos_dict['data'] - condition_labels = tuple(data.keys()) + self.gmm_anova_failed = False + + results = {} + for pos, pos_dict in kmer_data.items(): + logger.trace(f"Processing position {pos}") + # Filter out low coverage positions + if self.__has_low_coverage(pos_dict): + logger.trace(f"Position {pos} has low coverage, skipping") + n_lowcov += 1 + continue + + # Perform stat tests + res = {} + condition_labels = tuple(pos_dict.keys()) if len(condition_labels) != 2: - raise NanocomporeError("The %s method only supports two conditions" % method) - condition1_intensity = np.concatenate([ rep['intensity'] for rep in data[condition_labels[0]].values() ]) - condition2_intensity = np.concatenate([ rep['intensity'] for rep in data[condition_labels[1]].values() ]) - condition1_dwell = np.concatenate([ rep['dwell'] for rep in data[condition_labels[0]].values() ]) - condition2_dwell = np.concatenate([ rep['dwell'] for rep in data[condition_labels[1]].values() ]) - - for met in methods: - logger.trace(f"Running {met} test on position {pos}") - if met in ["MW", "KS", "TT"] : - try: - pvalues = nonparametric_test(condition1_intensity, condition2_intensity, condition1_dwell, condition2_dwell, method=met) - except: - raise NanocomporeError("Error doing {} test on reference {}".format(met, ref_id)) - res["{}_intensity_pvalue".format(met)]=pvalues[0] - res["{}_dwell_pvalue".format(met)]=pvalues[1] - tests.add("{}_intensity_pvalue".format(met)) - tests.add("{}_dwell_pvalue".format(met)) - elif met == "GMM": - try: - gmm_results = gmm_test(data, anova=anova, logit=logit, allow_warnings=allow_warnings, random_state=random_state) - except: - raise NanocomporeError("Error doing GMM test on reference {}".format(ref_id)) - res["GMM_model"] = gmm_results['gmm'] - if anova: - res["GMM_anova_pvalue"] = gmm_results['anova']['pvalue'] - res["GMM_anova_model"] = gmm_results['anova'] - tests.add("GMM_anova_pvalue") - if logit: - res["GMM_logit_pvalue"] = gmm_results['logit']['pvalue'] - res["GMM_logit_model"] = gmm_results['logit'] - tests.add("GMM_logit_pvalue") + raise NanocomporeError("Need exactly two conditions for comparison") + condition1_intensity = np.concatenate([rep['intensity'] for rep in pos_dict[condition_labels[0]].values()]) + condition2_intensity = np.concatenate([rep['intensity'] for rep in pos_dict[condition_labels[1]].values()]) + condition1_dwell = np.concatenate([rep['dwell'] for rep in pos_dict[condition_labels[0]].values()]) + condition2_dwell = np.concatenate([rep['dwell'] for rep in pos_dict[condition_labels[1]].values()]) + + if self.__univariate_test: + logger.trace(f"Running {self.__univariate_test} test on position {pos}") + try: + pvalues = self.__nonparametric_test(condition1_intensity, condition2_intensity, + condition1_dwell, condition2_dwell) + except: + raise NanocomporeError(f"Error running {self.__univariate_test} test on transcript {ref_id}") + res["intensity_pvalue"] = pvalues[0] + res["dwell_pvalue"] = pvalues[1] + + if self.__fit_gmm: + logger.trace(f"Fitting GMM on position {pos}") + try: + gmm_results = self.__gmm_fit(pos_dict) + except: + raise NanocomporeError(f"Error running GMM test on transcript {ref_id}") + for key, value in gmm_results.items(): + res["gmm_" + key] = value # Calculate shift statistics logger.trace(f"Calculatign shift stats for {pos}") - res['shift_stats'] = shift_stats(condition1_intensity, condition2_intensity, condition1_dwell, condition2_dwell) + res["shift_stats"] = self.__shift_stats(condition1_intensity, condition2_intensity, + condition1_dwell, condition2_dwell) # Save results in main logger.trace(f"Saving test results for {pos}") - ref_pos_list[pos]['txComp'] = res - logger.debug("Skipped {} positions because not present in all samples with sufficient coverage".format(n_lowcov)) - - # Combine pvalue within a given sequence context - if sequence_context > 0: - logger.debug ("Calculate weighs and cross correlation matrices by tests") - if sequence_context_weights == "harmonic": - # Generate weights as a symmetrical harmonic series - weights = harmomic_series(sequence_context) + results[pos] = res + + logger.debug(f"Skipped {n_lowcov} positions because not present in all samples with sufficient coverage") + + if self.__sequence_context > 0: + if self.__univariate_test: + self.__combine_adjacent_pvalues(results, "intensity_pvalue") + self.__combine_adjacent_pvalues(results, "dwell_pvalue") + if self.__fit_gmm and self.__gmm_test: + self.__combine_adjacent_pvalues(results, "gmm_pvalue") + + return results + + + def __combine_adjacent_pvalues(self, results, pvalue_key): + logger.debug(f"Calculating cross correlation matrix for '{pvalue_key}'") + # Collect pvalue list for test + pval_list = [] + for res_dict in results.values(): + # TODO: avoid 'None'/'np.nan' checks below by checking and replacing here? + pval_list.append(res_dict.get(pvalue_key)) + # Compute cross correlation matrix + corr_matrix = self.__cross_corr_matrix(pval_list) + + logger.debug("Combine adjacent position pvalues with Hou's method position by position") + combined_label = f"{pvalue_key}_context" + # Iterate over each position in previously generated result dictionary + for mid_pos, res_dict in results.items(): + # If the mid p-value is NaN, also set the context p-value to NaN + if (res_dict[pvalue_key] is None) or np.isnan(res_dict[pvalue_key]): + results[mid_pos][combined_label] = np.nan + continue + ## Otherwise collect adjacent p-values and combine them: + pval_list = [] + for pos in range(mid_pos - self.__sequence_context, mid_pos + self.__sequence_context + 1): + # If any of the positions is missing or any of the p-values in the context is NaN, consider it 1 + if (pos not in results) or (results[pos][pvalue_key] is None) or np.isnan(results[pos][pvalue_key]): + pval_list.append(1) + else: # just extract the corresponding pvalue + pval_list.append(results[pos][pvalue_key]) + # Combine collected pvalues and add to dict + results[mid_pos][combined_label] = self.__combine_pvalues_hou(pval_list, corr_matrix) + + + def __nonparametric_test(self, condition1_intensity, condition2_intensity, + condition1_dwell, condition2_dwell): + if self.__univariate_test == "MW": + stat_test = lambda x, y: mannwhitneyu(x, y, alternative='two-sided') + elif self.__univariate_test == "KS": + stat_test = ks_twosamp + elif self.__univariate_test == "ST": + stat_test = lambda x, y: ttest_ind(x, y, equal_var=False) else: - weights = [1]*(2*sequence_context+1) - - # Collect pvalue lists per tests - pval_list_dict = defaultdict(list) - for pos_dict in ref_pos_list: - if 'txComp' in pos_dict: - for test in tests: - pval_list_dict[test].append(pos_dict['txComp'][test]) - elif pos_dict["lowCov"]: - for test in tests: - pval_list_dict[test].append(np.nan) - # Compute cross correlation matrix per test - corr_matrix_dict = OrderedDict() - for test in tests: - corr_matrix_dict[test] = cross_corr_matrix(pval_list_dict[test], sequence_context) - - logger.debug("Combine adjacent position pvalues with Hou's method position per position") - # Iterate over each positions in previously generated result dictionary - for mid_pos in range(len(ref_pos_list)): - # Perform test only if middle pos is valid - if not ref_pos_list[mid_pos]["lowCov"]: - pval_list_dict = defaultdict(list) - for pos in range(mid_pos-sequence_context, mid_pos+sequence_context+1): - for test in tests: - # If any of the positions is missing or any of the pvalues in the context is lowCov or NaN, consider it 1 - if pos < 0 or pos >= len(ref_pos_list) or ref_pos_list[pos]["lowCov"] or np.isnan(ref_pos_list[pos]["txComp"][test]): - pval_list_dict[test].append(1) - # else just extract the corresponding pvalue - else: - pval_list_dict[test].append(ref_pos_list[pos]["txComp"][test]) - # Combine collected pvalues and add to dict - for test in tests: - test_label = "{}_context_{}".format(test, sequence_context) - # If the mid p-value is.nan, force to nan also the context p-value - if np.isnan(ref_pos_list[mid_pos]["txComp"][test]): - ref_pos_list[mid_pos]['txComp'][test_label] = np.nan - else: - ref_pos_list[mid_pos]['txComp'][test_label] = combine_pvalues_hou(pval_list_dict[test], weights, corr_matrix_dict[test]) - - return ref_pos_list - -def nonparametric_test(condition1_intensity, condition2_intensity, condition1_dwell, condition2_dwell, method=None): - - if method in ["mann_whitney", "MW"]: - stat_test = lambda x,y: mannwhitneyu(x, y, alternative='two-sided') - elif method in ["kolmogorov_smirnov", "KS"]: - stat_test = ks_twosamp - elif method in ["t_test", "TT"]: - stat_test = lambda x,y: ttest_ind(x, y, equal_var=False) - else: - raise NanocomporeError("Invalid statistical method name (MW, KS, ttest)") - - pval_intensity = stat_test(condition1_intensity, condition2_intensity)[1] - if pval_intensity == 0: - pval_intensity = np.finfo(np.float).tiny - - pval_dwell = stat_test(condition1_dwell, condition2_dwell)[1] - if pval_dwell == 0: - pval_dwell = np.finfo(np.float).tiny - return(pval_intensity, pval_dwell) - - -def gmm_test(data, random_state, anova=True, logit=False, verbose=True, allow_warnings=False): - # Condition labels - condition_labels = tuple(data.keys()) - # List of sample labels - sample_labels = list(data[condition_labels[0]].keys()) + list(data[condition_labels[1]].keys()) - - if len(sample_labels) != len(set(sample_labels)): - raise NanocomporeError("Sample labels have to be unique and it looks like some are not.") - - # Dictionary Sample_label:Condition_label - sample_condition_labels = { sk:k for k,v in data.items() for sk in v.keys() } - if len(condition_labels) != 2: - raise NanocomporeError("gmm_test only supports two conditions") - - # Merge the intensities and dwell times of all samples in a single array - global_intensity = np.concatenate(([v['intensity'] for v in data[condition_labels[0]].values()]+[v['intensity'] for v in data[condition_labels[1]].values()]), axis=None) - global_dwell = np.concatenate(([v['dwell'] for v in data[condition_labels[0]].values()]+[v['dwell'] for v in data[condition_labels[1]].values()]), axis=None) - global_dwell = np.log10(global_dwell) - - # Scale the intensity and dwell time arrays - X = StandardScaler().fit_transform([(i, d) for i,d in zip(global_intensity, global_dwell)]) - - # Generate an array of sample labels - Y = [ k for k,v in data[condition_labels[0]].items() for _ in v['intensity'] ] + [ k for k,v in data[condition_labels[1]].items() for _ in v['intensity'] ] - - gmm_fit = fit_best_gmm(X, max_components=2, cv_types=['full'], random_state=random_state) - gmm_mod, gmm_type, gmm_ncomponents = gmm_fit - - # If the best GMM has 2 clusters do an anova test on the log odd ratios - if gmm_ncomponents == 2: - # Assign data points to the clusters - y_pred = gmm_mod.predict(X) - counters = dict() - # Count how many reads in each cluster for each sample - for lab in sample_labels: - counters[lab] = Counter(y_pred[[i==lab for i in Y]]) - cluster_counts = count_reads_in_cluster(counters) - if anova: - aov_results = gmm_anova_test(counters, sample_condition_labels, condition_labels, gmm_ncomponents, allow_warnings) - else: - aov_results=None - - if logit: - logit_results = gmm_logit_test(Y, y_pred, sample_condition_labels, condition_labels) + raise NanocomporeError("Invalid univariate test name (MW, KS, ST)") + + pval_intensity = stat_test(condition1_intensity, condition2_intensity)[1] + if pval_intensity == 0: + pval_intensity = np.finfo(np.float).tiny + + pval_dwell = stat_test(condition1_dwell, condition2_dwell)[1] + if pval_dwell == 0: + pval_dwell = np.finfo(np.float).tiny + return (pval_intensity, pval_dwell) + + + def __gmm_fit(self, data): + # Condition labels + condition_labels = tuple(data.keys()) + # List of sample labels + sample_labels = list(data[condition_labels[0]].keys()) + list(data[condition_labels[1]].keys()) + + if len(sample_labels) != len(set(sample_labels)): + raise NanocomporeError("Sample labels have to be unique and it looks like some are not.") + + # Dictionary Sample_label:Condition_label + sample_condition_labels = {sk:k for k,v in data.items() for sk in v.keys()} + if len(condition_labels) != 2: + raise NanocomporeError("GMM fitting only supports two conditions") + + # Merge the intensities and dwell times of all samples in a single array + global_intensity = np.concatenate(([v['intensity'] for v in data[condition_labels[0]].values()] + + [v['intensity'] for v in data[condition_labels[1]].values()]), axis=None) + global_dwell = np.concatenate(([v['dwell'] for v in data[condition_labels[0]].values()] + + [v['dwell'] for v in data[condition_labels[1]].values()]), axis=None) + global_dwell = np.log10(global_dwell) + + # Scale the intensity and dwell time arrays + X = StandardScaler().fit_transform([(i, d) for i, d in zip(global_intensity, global_dwell)]) + + # Generate an array of sample labels + Y = [k for k, v in data[condition_labels[0]].items() for _ in v['intensity']] + \ + [k for k, v in data[condition_labels[1]].items() for _ in v['intensity']] + + gmm_mod, gmm_type, gmm_ncomponents = self.__fit_best_gmm(X, max_components=2, cv_types=['full']) + + if gmm_ncomponents == 2: + # Assign data points to the clusters + y_pred = gmm_mod.predict(X) + counters = dict() + # Count how many reads in each cluster for each sample + for lab in sample_labels: + counters[lab] = Counter(y_pred[[i == lab for i in Y]]) + cluster_counts = self.__count_reads_in_cluster(counters) + if self.__gmm_test == "anova": + pvalue, stat, details = self.__gmm_anova_test(counters, sample_condition_labels, + condition_labels, gmm_ncomponents) + elif self.__gmm_test == "logit": + pvalue, stat, details = self.__gmm_logit_test(Y, y_pred, sample_condition_labels, condition_labels) else: - logit_results=None - - elif gmm_ncomponents == 1: - aov_results = {'pvalue': np.nan, 'delta_logit': np.nan, 'table': "NC", 'cluster_counts': "NC"} - logit_results = {'pvalue': np.nan, 'coef': "NC", 'model': "NC"} - cluster_counts = "NC" - else: - raise NanocomporeError("GMM models with n_component>2 are not supported") - - return({'anova':aov_results, 'logit': logit_results, 'gmm':{'model': gmm_mod, 'cluster_counts': cluster_counts}}) - -def fit_best_gmm(X, random_state, max_components=2, cv_types=['spherical', 'tied', 'diag', 'full']): - # Loop over multiple cv_types and n_components and for each fit a GMM - # calculate the BIC and retain the lowest - lowest_bic = np.infty - bic = [] - n_components_range = range(1, max_components+1) - for cv_type in cv_types: - for n_components in n_components_range: - # Fit a Gaussian mixture with EM - gmm = GaussianMixture(n_components=n_components, covariance_type=cv_type, random_state=random_state) - gmm.fit(X) - bic.append(gmm.bic(X)) - if bic[-1] < lowest_bic: - lowest_bic = bic[-1] - best_gmm = gmm - best_gmm_type = cv_type - best_gmm_ncomponents = n_components - return((best_gmm, best_gmm_type, best_gmm_ncomponents)) - -def gmm_anova_test(counters, sample_condition_labels, condition_labels, gmm_ncomponents, allow_warnings=False): - labels= [] - logr = [] - for sample,counter in counters.items(): - # Save the condition label the corresponds to the current sample - labels.append(sample_condition_labels[sample]) - # The Counter dictionaries in counters are not ordered - # The following line enforces the order and adds 1 to avoid empty clusters - ordered_counter = [ counter[i]+1 for i in range(gmm_ncomponents)] - total = sum(ordered_counter) - normalised_ordered_counter = [ i/total for i in ordered_counter ] - # Loop through ordered_counter and divide each value by the first - logr.append(np.log(normalised_ordered_counter[0]/(1-normalised_ordered_counter[0]))) - logr = np.around(np.array(logr), decimals=9) - logr_s1 = [logr[i] for i,l in enumerate(labels) if l==condition_labels[0]] - logr_s2 = [logr[i] for i,l in enumerate(labels) if l==condition_labels[1]] - # If the SS for either array is 0, skip the anova test - if sum_of_squares(logr_s1-np.mean(logr_s1)) == 0 and sum_of_squares(logr_s2-np.mean(logr_s2)) == 0: - if not allow_warnings: - raise NanocomporeError("While doing the Anova test we found a sample with within variance = 0. Use --allow_warnings to ignore.") + pvalue = stat = details = cluster_counts = None + + return {"model": gmm_mod, "cluster_counts": cluster_counts, "pvalue": pvalue, "test_stat": stat, + "test_details": details} + + + def __fit_best_gmm(self, X, max_components=2, cv_types=['spherical', 'tied', 'diag', 'full']): + # Loop over multiple cv_types and n_components and for each fit a GMM + # calculate the BIC and retain the lowest + lowest_bic = np.infty + bic = [] + n_components_range = range(1, max_components + 1) + for cv_type in cv_types: + for n_components in n_components_range: + # Fit a Gaussian mixture with EM + gmm = GaussianMixture(n_components=n_components, covariance_type=cv_type, + random_state=self.__random_state) + gmm.fit(X) + bic.append(gmm.bic(X)) + if bic[-1] < lowest_bic: + lowest_bic = bic[-1] + best_gmm = gmm + best_gmm_type = cv_type + best_gmm_ncomponents = n_components + return (best_gmm, best_gmm_type, best_gmm_ncomponents) + + + def __gmm_anova_test(self, counters, sample_condition_labels, condition_labels, gmm_ncomponents): + labels = [] + logr = [] + for sample, counter in counters.items(): + # Save the condition label the corresponds to the current sample + labels.append(sample_condition_labels[sample]) + # The Counter dictionaries in counters are not ordered + # The following line enforces the order and adds 1 to avoid empty clusters + ordered_counter = [counter[i] + 1 for i in range(gmm_ncomponents)] + total = sum(ordered_counter) + normalised_ordered_counter = [i / total for i in ordered_counter] + # Loop through ordered_counter and divide each value by the first + logr.append(np.log(normalised_ordered_counter[0] / (1 - normalised_ordered_counter[0]))) + logr = np.around(np.array(logr), decimals=9) + logr_s1 = [logr[i] for i, l in enumerate(labels) if l == condition_labels[0]] + logr_s2 = [logr[i] for i, l in enumerate(labels) if l == condition_labels[1]] + # If the SS for either array is 0, skip the anova test + if sum_of_squares(logr_s1 - np.mean(logr_s1)) == 0 and sum_of_squares(logr_s2 - np.mean(logr_s2)) == 0: + if not self.__allow_anova_warnings: + raise NanocomporeError("While doing the Anova test we found a sample with within variance = 0. Use --allow_anova_warnings to ignore.") + else: + aov_table = "Within variance is 0" + aov_pvalue = np.finfo(np.float).tiny else: - aov_table = "Within variance is 0" - aov_pvalue = np.finfo(np.float).tiny - else: + with warnings.catch_warnings(): + # Convert warnings to errors in order to catch them + warnings.filterwarnings('error') + try: + aov_table = f_oneway(logr_s1, logr_s2) + aov_pvalue = aov_table.pvalue + except RuntimeWarning: + if not self.__allow_anova_warnings: + raise NanocomporeError("While doing the Anova test a runtime warning was raised. Use --allow_anova_warnings to ignore.") + else: + warnings.filterwarnings('default') + aov_table = f_oneway(logr_s1, logr_s2) + aov_pvalue = np.finfo(np.float).tiny + if aov_pvalue == 0: + raise NanocomporeError("The Anova test returned a p-value of 0. This is most likely an error somewhere") + # Calculate the delta log odds ratio, i.e. the difference of the means of the log odds ratios between the two conditions + aov_delta_logit = float(np.mean(logr_s1) - np.mean(logr_s2)) + aov_details = {'table': aov_table, 'log_ratios': logr} + return (aov_pvalue, aov_delta_logit, aov_details) + + + @staticmethod + def __gmm_logit_test(Y, y_pred, sample_condition_labels, condition_labels): + Y = [sample_condition_labels[i] for i in Y] + y_pred = np.append(y_pred, [0, 0, 1, 1]) + Y.extend([condition_labels[0], condition_labels[1], condition_labels[0], condition_labels[1]]) + Y = pd.get_dummies(Y) + Y['intercept'] = 1 + logit = dm.Logit(y_pred, Y[['intercept', condition_labels[1]]]) with warnings.catch_warnings(): - # Convert warnings to errors in order to catch them warnings.filterwarnings('error') try: - aov_table = f_oneway(logr_s1, logr_s2) - aov_pvalue = aov_table.pvalue - except RuntimeWarning: - if not allow_warnings: - raise NanocomporeError("While doing the Anova test a runtime warning was raised. Use --allow_warnings to ignore.") - else: - warnings.filterwarnings('default') - aov_table = f_oneway(logr_s1, logr_s2) - aov_pvalue = np.finfo(np.float).tiny - if aov_pvalue == 0: - raise NanocomporeError("The Anova test returned a p-value of 0. This is most likely an error somewhere") - # Calculate the delta log odds ratio, i.e. the difference of the means of the log odds ratios between the two conditions - aov_delta_logit=float(np.mean(logr_s1)-np.mean(logr_s2)) - aov_results = {'pvalue': aov_pvalue, 'delta_logit': aov_delta_logit, 'table': aov_table, 'log_ratios':logr} - return(aov_results) - -def gmm_logit_test(Y, y_pred, sample_condition_labels, condition_labels): - Y = [ sample_condition_labels[i] for i in Y] - y_pred=np.append(y_pred, [0,0,1,1]) - Y.extend([condition_labels[0], condition_labels[1], condition_labels[0], condition_labels[1]]) - Y = pd.get_dummies(Y) - Y['intercept']=1 - logit = dm.Logit(y_pred,Y[['intercept',condition_labels[1]]] ) - with warnings.catch_warnings(): - warnings.filterwarnings('error') - try: - logit_mod=logit.fit(disp=0) - logit_pvalue, logit_coef = logit_mod.pvalues[1], logit_mod.params[1] - except ConvergenceWarning: - logit_mod, logit_pvalue, logit_coef = "NC", 1, "NC" - if logit_pvalue == 0: - logit_pvalue = np.finfo(np.float).tiny - logit_results = {'pvalue': logit_pvalue, 'coef': logit_coef, 'model': logit_mod} - return(logit_results) - -def count_reads_in_cluster(counters): - cluster_counts = list() - for k,v in counters.items(): - cluster_counts.append("%s:%s/%s" % (k, v[0], v[1])) - cluster_counts="__".join(cluster_counts) - return(cluster_counts) - -def shift_stats(condition1_intensity, condition2_intensity, condition1_dwell, condition2_dwell): - """Calculate shift statistics""" - shift_stats = OrderedDict([ - ('c1_mean_intensity', np.mean(condition1_intensity)), - ('c2_mean_intensity', np.mean(condition2_intensity)), - ('c1_median_intensity', np.median(condition1_intensity)), - ('c2_median_intensity', np.median(condition2_intensity)), - ('c1_sd_intensity', np.std(condition1_intensity)), - ('c2_sd_intensity', np.std(condition2_intensity)), - ('c1_mean_dwell', np.mean(condition1_dwell)), - ('c2_mean_dwell', np.mean(condition2_dwell)), - ('c1_median_dwell', np.median(condition1_dwell)), - ('c2_median_dwell', np.median(condition2_dwell)), - ('c1_sd_dwell', np.std(condition1_dwell)), - ('c2_sd_dwell', np.std(condition2_dwell)) - ]) - return(shift_stats) - - -def cross_corr_matrix(pvalues_vector, context=2): - """ Calculate the cross correlation matrix of the - pvalues for a given context. - """ - if len(pvalues_vector)<(context*3)+3: - raise NanocomporeError("Not enough p-values for a context of order %s"%context) - - pvalues_vector = np.array([ i if not np.isnan(i) else 1 for i in pvalues_vector ]) - if any(pvalues_vector==0) or any(np.isinf(pvalues_vector)) or any(pvalues_vector>1): - raise NanocomporeError("At least one p-value is invalid") - - matrix=[] - s=pvalues_vector.size - if all(p==1 for p in pvalues_vector): - return(np.ones((context*2+1, context*2+1))) - - for i in range(-context,context+1): - row=[] - for j in range(-context,context+1): - row.append(np.corrcoef((np.roll(pvalues_vector,i)[context:s-context]), (np.roll(pvalues_vector,j)[context:s-context]))[0][1]) - matrix.append(row) - return(np.array(matrix)) - -def combine_pvalues_hou(pvalues, weights, cor_mat): - """ Hou's method for the approximation for the distribution of the weighted - combination of non-independent or independent probabilities. - If any pvalue is nan, returns nan. - https://doi.org/10.1016/j.spl.2004.11.028 - pvalues: list of pvalues to be combined - weights: the weights of the pvalues - cor_mat: a matrix containing the correlation coefficients between pvalues - Test: when weights are equal and cor=0, hou is the same as Fisher - print(combine_pvalues([0.1,0.02,0.1,0.02,0.3], method='fisher')[1]) - print(hou([0.1,0.02,0.1,0.02,0.3], [1,1,1,1,1], np.zeros((5,5)))) - """ - if(len(pvalues) != len(weights)): - raise NanocomporeError("Can't combine pvalues is pvalues and weights are not the same length.") - if( cor_mat.shape[0] != cor_mat.shape[1] or cor_mat.shape[0] != len(pvalues)): - raise NanocomporeError("The correlation matrix needs to be squared, with each dimension equal to the length of the pvalued vector.") - if all(p==1 for p in pvalues): - return 1 - if any((p==0 or np.isinf(p) or p>1) for p in pvalues): - raise NanocomporeError("At least one p-value is invalid") - - # Covariance estimation as in Kost and McDermott (eq:8) - # https://doi.org/10.1016/S0167-7152(02)00310-3 - cov = lambda r: (3.263*r)+(0.710*r**2)+(0.027*r**3) - k=len(pvalues) - cov_sum=np.float64(0) - sw_sum=np.float64(0) - w_sum=np.float64(0) - tau=np.float64(0) - for i in range(k): - for j in range(i+1,k): - cov_sum += weights[i]*weights[j]*cov(cor_mat[i][j]) - sw_sum += weights[i]**2 - w_sum += weights[i] - # Calculate the weighted Fisher's combination statistic - tau += weights[i] * (-2*np.log(pvalues[i])) - # Correction factor - c = (2*sw_sum+cov_sum) / (2*w_sum) - # Degrees of freedom - f = (4*w_sum**2) / (2*sw_sum+cov_sum) - # chi2.sf is the same as 1-chi2.cdf but is more accurate - combined_p_value = chi2.sf(tau/c,f) - # Return a very small number if pvalue = 0 - if combined_p_value == 0: - combined_p_value = np.finfo(np.float).tiny - return combined_p_value - -def harmomic_series(sequence_context): - weights = [] - for i in range(-sequence_context, sequence_context+1): - weights.append(1/(abs(i)+1)) - return weights - -def sum_of_squares(x): - """ - Square each element of the input array and return the sum - """ - x = np.atleast_1d(x) - return np.sum(x*x) + logit_mod = logit.fit(disp=0) + logit_pvalue, logit_coef = logit_mod.pvalues[1], logit_mod.params[1] + except ConvergenceWarning: + logit_mod, logit_pvalue, logit_coef = None, 1, None + if logit_pvalue == 0: + logit_pvalue = np.finfo(np.float).tiny + logit_details = {'model': logit_mod} + return (logit_pvalue, logit_coef, logit_details) + + + @staticmethod + def __count_reads_in_cluster(counters): + cluster_counts = list() + for k, v in counters.items(): + cluster_counts.append("%s:%s/%s" % (k, v[0], v[1])) + cluster_counts = "_".join(cluster_counts) + return cluster_counts + + + @staticmethod + def __shift_stats(condition1_intensity, condition2_intensity, condition1_dwell, condition2_dwell): + """Calculate shift statistics""" + shift_stats = OrderedDict([ + ('c1_mean_intensity', np.mean(condition1_intensity)), + ('c2_mean_intensity', np.mean(condition2_intensity)), + ('c1_median_intensity', np.median(condition1_intensity)), + ('c2_median_intensity', np.median(condition2_intensity)), + ('c1_sd_intensity', np.std(condition1_intensity)), + ('c2_sd_intensity', np.std(condition2_intensity)), + ('c1_mean_dwell', np.mean(condition1_dwell)), + ('c2_mean_dwell', np.mean(condition2_dwell)), + ('c1_median_dwell', np.median(condition1_dwell)), + ('c2_median_dwell', np.median(condition2_dwell)), + ('c1_sd_dwell', np.std(condition1_dwell)), + ('c2_sd_dwell', np.std(condition2_dwell)) + ]) + return shift_stats + + + def __cross_corr_matrix(self, pvalues_vector): + """Calculate the cross correlation matrix of the pvalues for a given context.""" + context = self.__sequence_context + if len(pvalues_vector) < (context * 3) + 3: + raise NanocomporeError(f"Not enough p-values for a context of {context}") + + pvalues_vector = np.array([i if (i is not None) and not np.isnan(i) else 1 for i in pvalues_vector]) + if any(pvalues_vector == 0) or any(np.isinf(pvalues_vector)) or any(pvalues_vector > 1): + raise NanocomporeError("At least one p-value is invalid") + + matrix = [] + s = pvalues_vector.size + if all(p == 1 for p in pvalues_vector): + return np.ones((context * 2 + 1, context * 2 + 1)) + + for i in range(-context, context + 1): + row = [] + for j in range(-context, context + 1): + row.append(np.corrcoef((np.roll(pvalues_vector, i)[context:s - context]), + (np.roll(pvalues_vector, j)[context:s - context]))[0][1]) + matrix.append(row) + return np.array(matrix) + + + def __combine_pvalues_hou(self, pvalues, cor_mat): + """ Hou's method for the approximation for the distribution of the weighted + combination of non-independent or independent probabilities. + If any pvalue is nan, returns nan. + https://doi.org/10.1016/j.spl.2004.11.028 + pvalues: list of pvalues to be combined + cor_mat: a matrix containing the correlation coefficients between pvalues + Test: when weights are equal and cor=0, hou is the same as Fisher + print(combine_pvalues([0.1,0.02,0.1,0.02,0.3], method='fisher')[1]) + print(hou([0.1,0.02,0.1,0.02,0.3], [1,1,1,1,1], np.zeros((5,5)))) + """ + weights = self.__sequence_context_weights + # TODO: are the following sanity checks necessary/useful? + if len(pvalues) != len(weights): + raise NanocomporeError("Can't combine pvalues if pvalues and weights are not the same length.") + if cor_mat.shape[0] != cor_mat.shape[1] or cor_mat.shape[0] != len(pvalues): + raise NanocomporeError("The correlation matrix needs to be square, with each dimension equal to the length of the pvalued vector.") + if all(p == 1 for p in pvalues): + return 1 + if any((p == 0 or np.isinf(p) or p > 1) for p in pvalues): + raise NanocomporeError("At least one p-value is invalid") + + # Covariance estimation as in Kost and McDermott (eq:8) + # https://doi.org/10.1016/S0167-7152(02)00310-3 + cov = lambda r: (3.263*r)+(0.710*r**2)+(0.027*r**3) + k = len(pvalues) + cov_sum = np.float64(0) + sw_sum = np.float64(0) + w_sum = np.float64(0) + tau = np.float64(0) + for i in range(k): + for j in range(i + 1, k): + cov_sum += weights[i] * weights[j] * cov(cor_mat[i][j]) + sw_sum += weights[i]**2 + w_sum += weights[i] + # Calculate the weighted Fisher's combination statistic + tau += weights[i] * (-2 * np.log(pvalues[i])) + # Correction factor + c = (2 * sw_sum + cov_sum) / (2 * w_sum) + # Degrees of freedom + f = (4 * w_sum**2) / (2 * sw_sum + cov_sum) + # chi2.sf is the same as 1 - chi2.cdf but is more accurate + combined_p_value = chi2.sf(tau / c, f) + # Return a very small number if pvalue = 0 + if combined_p_value == 0: + combined_p_value = np.finfo(np.float).tiny + return combined_p_value + + + def __harmonic_series(self): + weights = [] + for i in range(-self.__sequence_context, self.__sequence_context + 1): + weights.append(1 / (abs(i) + 1)) + return weights + + + @staticmethod + def __sum_of_squares(x): + """ + Square each element of the input array and return the sum + """ + x = np.atleast_1d(x) + return np.sum(x * x) + + + def __has_low_coverage(self, pos_dict): + for cond_dict in pos_dict.values(): + for sample_val in cond_dict.values(): + if sample_val["coverage"] < self.__min_coverage: + return True + return False diff --git a/nanocompore/Whitelist.py b/nanocompore/Whitelist.py index d666e30..00e82d0 100755 --- a/nanocompore/Whitelist.py +++ b/nanocompore/Whitelist.py @@ -6,6 +6,7 @@ import logging from loguru import logger import random +import sqlite3 # Third party import numpy as np @@ -14,6 +15,7 @@ # Local package from nanocompore.common import * +from nanocompore.DataStore import DataStore_EventAlign # Set global random seed downsample_random_seed = 42 @@ -24,22 +26,26 @@ class Whitelist(object): #~~~~~~~~~~~~~~MAGIC METHODS~~~~~~~~~~~~~~# def __init__(self, - eventalign_fn_dict, - fasta_fn, - min_coverage = 10, - min_ref_length = 100, - downsample_high_coverage = False, - max_invalid_kmers_freq = 0.1, - max_NNNNN_freq = 0.1, - max_mismatching_freq = 0.1, - max_missing_freq = 0.1, - select_ref_id = [], - exclude_ref_id = []): + db_path, + sample_dict, + fasta_fn, + min_coverage = 10, + min_ref_length = 100, + downsample_high_coverage = False, + max_invalid_kmers_freq = 0.1, + max_NNNNN_freq = 0.1, + max_mismatching_freq = 0.1, + max_missing_freq = 0.1, + select_ref_id = [], + exclude_ref_id = []): """ - ######################################################### - * eventalign_fn_dict - Multilevel dictionnary indicating the condition_label, sample_label and file name of the eventalign_collapse output - example d = {"S1": {"R1":"path1.tsv", "R2":"path2.tsv"}, "S2": {"R1":"path3.tsv", "R2":"path4.tsv"}} + Generate a whitelist of reads that fulfill filtering criteria + Args: + * db_path + Path to the SQLite database file with event-aligned read/kmer data + * sample_dict + Dictionary containing lists of (unique) sample names, grouped by condition + example d = {"control": ["C1", "C2"], "treatment": ["T1", "T2"]} * fasta_fn Path to a fasta file corresponding to the reference used for read alignemnt * min_coverage @@ -63,53 +69,86 @@ def __init__(self, if given, refid in the list will be excluded from the analysis """ - # Check index files - self.__filter_invalid_kmers = True - for sample_dict in eventalign_fn_dict.values(): - for fn in sample_dict.values(): - idx_fn = fn+".idx" - if not access_file(idx_fn): - raise NanocomporeError("Cannot access eventalign_collapse index file {}".format(idx_fn)) - # Check header line and set a flag to skip filter if the index file does not contain kmer status information - with open(idx_fn, "r") as fp: - header = fp.readline().rstrip().split("\t") - if not all_values_in (("ref_id", "read_id", "byte_offset", "byte_len"), header): - raise NanocomporeError("The index file {} does not contain the require header fields".format(idx_fn)) - if not all_values_in (("kmers", "NNNNN_kmers", "mismatch_kmers", "missing_kmers"), header): - self.__filter_invalid_kmers = False - logger.debug("Invalid kmer information not available in index file") - - self.__eventalign_fn_dict = eventalign_fn_dict - - # Get number of samples - n = 0 - for sample_dict in self.__eventalign_fn_dict.values(): - for sample_lab in sample_dict.keys(): - n+=1 - self.__n_samples = n - - # Test is Fasta can be opened + check_sample_dict(sample_dict) + + # Create look-up dict. of conditions + cond_dict = {} + for cond, samples in sample_dict.items(): + for sample in samples: + cond_dict[sample] = cond + + # Test if Fasta can be opened try: with Fasta(fasta_fn): self._fasta_fn = fasta_fn except IOError: raise NanocomporeError("The fasta file cannot be opened") - # Create reference index for both files - logger.info("Reading eventalign index files") - ref_reads = self.__read_eventalign_index( - eventalign_fn_dict = eventalign_fn_dict, - max_invalid_kmers_freq = max_invalid_kmers_freq, - max_NNNNN_freq = max_NNNNN_freq, - max_mismatching_freq = max_mismatching_freq, - max_missing_freq = max_missing_freq, - select_ref_id = select_ref_id, - exclude_ref_id = exclude_ref_id) + # Database interaction + with DataStore_EventAlign(db_path) as db: + db_samples = db.get_samples(sample_dict) + + # How many samples are in the DB? If we want all, we don't need a constraint below. + try: + db_sample_count = db.cursor.execute("SELECT COUNT(*) FROM samples").fetchone()[0] + except Exception: + logger.error("Error counting samples in database") + raise Exception + + # Set up filters by adding conditions for DB query + select = ["reads.id AS readid", "sampleid", "transcriptid", "transcripts.name AS transcriptname"] + where = [] + # Get reads only from a subset of samples? + if len(db_samples) < db_sample_count: + where = ["sampleid IN (%s)" % ", ".join(map(str, db_samples))] + + if select_ref_id: + select.append("reads.name AS readname") + where.append("readname IN ('%s')" % "', '".join(select_ref_id)) + elif exclude_ref_id: + select.append("reads.name AS readname") + where.append("readname NOT IN ('%s')" % "', '".join(exclude_ref_id)) + + if max_invalid_kmers_freq is not None: + if max_invalid_kmers_freq < 1.0: + select.append("1.0 - CAST(valid_kmers AS REAL) / kmers AS invalid_freq") + where.append(f"invalid_freq <= {max_invalid_kmers_freq}") + else: + if max_NNNNN_freq < 1.0: + select.append("CAST(NNNNN_kmers AS REAL) / kmers AS NNNNN_freq") + where.append(f"NNNNN_freq <= {max_NNNNN_freq}") + if max_mismatching_freq < 1.0: + select.append("CAST(mismatch_kmers AS REAL) / kmers AS mismatch_freq") + where.append(f"mismatch_freq <= {max_mismatching_freq}") + if max_missing_freq < 1.0: + select.append("CAST(missing_kmers AS REAL) / kmers AS missing_freq") + where.append(f"missing_freq <= {max_missing_freq}") + + query = "SELECT %s FROM reads LEFT JOIN transcripts ON transcriptid = transcripts.id" % \ + ", ".join(select) + if where: + query += " WHERE %s" % " AND ".join(where) + + # Dict. structure: transcript -> condition -> sample -> list of reads + ref_reads = {} + logger.info("Querying reads from database") + try: + db.cursor.execute(query) + for row in db.cursor: + read_id = row["readid"] + sample_id = row["sampleid"] + condition = cond_dict[db_samples[sample_id]] + ref_id = row["transcriptname"] + ref_reads.setdefault(ref_id, {}).setdefault(condition, {}).\ + setdefault(sample_id, []).append(read_id) + except Exception: + logger.error("Error querying reads from database") + raise Exception # Filtering at transcript level logger.info("Filtering out references with low coverage") self.ref_reads = self.__select_ref( - ref_reads = ref_reads, + ref_reads=ref_reads, min_coverage=min_coverage, min_ref_length=min_ref_length, downsample_high_coverage=downsample_high_coverage) @@ -119,6 +158,7 @@ def __init__(self, self.__downsample_high_coverage = downsample_high_coverage self.__max_invalid_kmers_freq = max_invalid_kmers_freq + def __repr__(self): return "Whitelist: Number of references: {}".format(len(self)) @@ -141,7 +181,7 @@ def __len__(self): def __iter__(self): for i, j in self.ref_reads.items(): - yield(i,j) + yield (i, j) def __getitem__(self, items): return self.ref_reads.get(items, None) @@ -152,81 +192,6 @@ def ref_id_list(self): return list(self.ref_reads.keys()) #~~~~~~~~~~~~~~PRIVATE METHODS~~~~~~~~~~~~~~# - def __read_eventalign_index(self, - eventalign_fn_dict, - max_invalid_kmers_freq, - max_NNNNN_freq, - max_mismatching_freq, - max_missing_freq, - select_ref_id, - exclude_ref_id): - """Read the 2 index files and sort by sample and ref_id in a multi level dict""" - - ref_reads = OrderedDict() - - for cond_lab, sample_dict in eventalign_fn_dict.items(): - for sample_lab, fn in sample_dict.items(): - idx_fn = fn+".idx" - with open(idx_fn) as fp: - - # Get column names from header - col_names = fp.readline().rstrip().split() - c = Counter() - for line in fp: - try: - # Transform line to dict and cast str numbers to actual numbers - read = numeric_cast_dict(keys=col_names, values=line.rstrip().split("\t")) - - # Filter out ref_id if a select_ref_id list or exclude_ref_id list was provided - if select_ref_id and not read["ref_id"] in select_ref_id: - raise NanocomporeError("Ref_id not in select list") - elif exclude_ref_id and read["ref_id"] in exclude_ref_id: - raise NanocomporeError("Ref_id in exclude list") - - # Filter out reads with high number of invalid kmers if information available - if self.__filter_invalid_kmers: - if max_invalid_kmers_freq: - invalid_kmers_freq = (read["NNNNN_kmers"]+read["mismatch_kmers"]+read["missing_kmers"])/read["kmers"] - if invalid_kmers_freq > max_invalid_kmers_freq: - raise NanocomporeError("High fraction of invalid kmers ({}%) for read {}".format(round(invalid_kmers_freq*100,2), read["read_id"])) - else: - NNNNN_kmers_freq = read["NNNNN_kmers"]/read["kmers"] - max_mismatching_freq = read["mismatch_kmers"]/read["kmers"] - max_missing_freq = read["missing_kmers"]/read["kmers"] - if NNNNN_kmers_freq > max_NNNNN_freq: - raise NanocomporeError("High fraction of NNNNN kmers ({}%) for read {}".format(round(NNNNN_kmers_freq*100,2), read["read_id"])) - elif max_mismatching_freq > max_mismatching_freq: - raise NanocomporeError("High fraction of mismatching kmers ({}%) for read {}".format(round(max_mismatching_freq*100,2), read["read_id"])) - elif max_missing_freq > max_missing_freq: - raise NanocomporeError("High fraction of missing kmers ({}%) for read {}".format(round(max_missing_freq*100,2), read["read_id"])) - - # Create dict arborescence and save valid reads - if not read["ref_id"] in ref_reads: - ref_reads[read["ref_id"]] = OrderedDict() - if not cond_lab in ref_reads[read["ref_id"]]: - ref_reads[read["ref_id"]][cond_lab] = OrderedDict() - if not sample_lab in ref_reads[read["ref_id"]][cond_lab]: - ref_reads[read["ref_id"]][cond_lab][sample_lab] = [] - - # Fill in list of reads - ref_reads[read["ref_id"]][cond_lab][sample_lab].append(read) - c ["valid reads"] += 1 - - except NanocomporeError as E: - c [str(E)] += 1 - - logger.debug("\tCondition:{} Sample:{} {}".format(cond_lab, sample_lab, counter_to_str(c))) - # Fill in missing condition/sample slots in case - # a ref_id is missing from one of the eventalign files - for ref_id in ref_reads.keys(): - for cond_lab, sample_dict in eventalign_fn_dict.items(): - for sample_lab in sample_dict.keys(): - if not cond_lab in ref_reads[ref_id]: - ref_reads[ref_id][cond_lab] = OrderedDict() - if not sample_lab in ref_reads[ref_id][cond_lab]: - ref_reads[ref_id][cond_lab][sample_lab] = [] - logger.info("\tReferences found in index: {}".format(len(ref_reads))) - return ref_reads def __select_ref(self, ref_reads, @@ -234,6 +199,7 @@ def __select_ref(self, min_ref_length, downsample_high_coverage): """Select ref_id with a minimal coverage in both sample + downsample if needed""" + # TODO: replace 'OrderedDict' with 'dict' for improved performance? valid_ref_reads = OrderedDict() c = Counter() with Fasta(self._fasta_fn) as fasta: @@ -257,7 +223,7 @@ def __select_ref(self, # If all valid add to new dict logger.trace(f"ref_id {ref_id} has enough coverage in all samples: keeping it") - valid_ref_reads [ref_id] = valid_dict + valid_ref_reads[ref_id] = valid_dict # Save extra info for debug c["valid_ref_id"] += 1 diff --git a/nanocompore/__main__.py b/nanocompore/__main__.py index 314906d..4dcfcda 100644 --- a/nanocompore/__main__.py +++ b/nanocompore/__main__.py @@ -19,77 +19,115 @@ from nanocompore.SampComp import SampComp from nanocompore.SimReads import SimReads from nanocompore.Eventalign_collapse import Eventalign_collapse +from nanocompore.PostProcess import PostProcess from nanocompore.common import * #~~~~~~~~~~~~~~MAIN PARSER ENTRY POINT~~~~~~~~~~~~~~# -def main(args=None): +def main(): # General parser parser = argparse.ArgumentParser(description=package_description, formatter_class=argparse.RawTextHelpFormatter) - parser.add_argument('--version', '-v', action='version', version='v'+package_version) + parser.add_argument('--version', '-v', action='version', version='v' + package_version) subparsers = parser.add_subparsers(dest='subcommand', description=textwrap.dedent(""" nanocompore implements the following subcommands\n - \t* eventalign_collapse : Collapse the nanopolish eventalign output at kmers level and compute kmer level statistics\n - \t* sampcomp : Compare 2 samples and find significant signal differences\n - \t* simreads : Simulate reads as a NanopolishComp like file from a fasta file and an inbuild model""")) + \t* eventalign_collapse : Collapse the nanopolish eventalign output at kmer level and compute kmer-level statistics\n + \t* sampcomp : Compare samples from two conditions and find significant signal differences\n + \t* simreads : Simulate reads as a NanopolishComp-like file from a FASTA file and a built-in model""")) subparsers.required = True - # Sampcomp subparser + # Eventalign_collapse subparser + parser_ec = subparsers.add_parser("eventalign_collapse", formatter_class=argparse.RawDescriptionHelpFormatter, + description=textwrap.dedent(""" + Collapse the nanopolish eventalign output at kmer level and compute kmer-level statistics + * Minimal example: + nanocompore eventalign_collapse -i nanopolish_eventalign.tsv -s T1""")) + parser_ec.set_defaults(func=eventalign_collapse_main) + + parser_ec_in = parser_ec.add_argument_group("Input options") + parser_ec_in.add_argument("--input", "-i", default=0, + help="Path to a nanopolish eventalign tsv file, or a list of files, or a regex (can be gzipped). It can be ommited if piped to standard input (default: piped to stdin)") + parser_ec_in.add_argument("--sample", "-s", default=None, required=True, help="Unique identifier of the sample") + + parser_ec_out = parser_ec.add_argument_group("Output options") + parser_ec_out.add_argument("--output", "-o", default="eventalign_collapse.db", + help="Path or filename of database output file (default: %(default)s)") + + parser_ec_run = parser_ec.add_argument_group("Run options") + parser_ec_run.add_argument("--n_lines", "-l", default=None , type=int , + help = "Number of lines to parse (default: no limit") + + parser_ec_misc = parser_ec.add_argument_group("Other options") + parser_ec_misc.add_argument("--nthreads", "-t", default=3, type=int, + help="Total number of threads. 2 threads are reserved for the reader and the writer (default: %(default)s)") + + # SampComp subparser parser_sc = subparsers.add_parser('sampcomp', formatter_class=argparse.RawDescriptionHelpFormatter, description=textwrap.dedent(""" Compare 2 samples and find significant signal differences\n - * Minimal example with file_list arguments\n - nanocompore sampcomp -1 f1.tsv,f2.tsv -2 f3.tsv,f4.tsv -f ref.fa -o results - * Minimal example with sample YAML file\n - nanocompore sampcomp -y samples.yaml -f ref -o results""")) + * Minimal example:\n + nanocompore sampcomp -i eventalign_collapse.db -1 C1,C2 -2 T1,T2 -f ref.fa""")) parser_sc.set_defaults(func=sampcomp_main) - parser_sc_sample_yaml = parser_sc.add_argument_group('YAML sample files', description="Option allowing to describe sample files in a YAML file") - parser_sc_sample_yaml.add_argument("--sample_yaml", "-y", default=None, type=str, metavar="sample_yaml", - help="YAML file containing the sample file labels. See formatting in documentation. (required if --file_list1 and --file_list2 not given)") - parser_sc_sample_args = parser_sc.add_argument_group('Arguments sample files', description="Option allowing to describe sample files directly as command line arguments") - parser_sc_sample_args.add_argument("--file_list1", "-1", default=None, type=str, metavar="/path/to/Condition1_rep1,/path/to/Condition1_rep2", - help="Comma separated list of NanopolishComp files for label 1. (required if --sample_yaml not given)") - parser_sc_sample_args.add_argument("--file_list2", "-2", default=None, type=str, metavar="/path/to/Condition2_rep1,/path/to/Condition2_rep2", - help="Comma separated list of NanopolishComp files for label 2. (required if --sample_yaml not given)") - parser_sc_sample_args.add_argument("--label1", type=str, metavar="Condition1", default="Condition1", - help="Label for files in --file_list1 (default: %(default)s)") - parser_sc_sample_args.add_argument("--label2", type=str, metavar="Condition2", default="Condition2", - help="Label for files in --file_list2 (default: %(default)s)") - parser_sc_io = parser_sc.add_argument_group('Input options') - parser_sc_io.add_argument("--fasta", "-f", type=str, required=True, - help="Fasta file used for mapping (required)") - parser_sc_io.add_argument("--bed", type=str, default=None, - help="BED file with annotation of transcriptome used for mapping (optional)") - parser_sc_filtering = parser_sc.add_argument_group('Transcript filtering options') - parser_sc_filtering.add_argument("--max_invalid_kmers_freq", type=float, default=0.1, - help="Max fequency of invalid kmers (default: %(default)s)") - parser_sc_filtering.add_argument("--min_coverage", type=int, default=30, - help="Minimum coverage required in each condition to do the comparison (default: %(default)s)") - parser_sc_filtering.add_argument("--downsample_high_coverage", type=int, default=5000, - help="Transcripts with high coverage will be downsampled (default: %(default)s)") - parser_sc_filtering.add_argument("--min_ref_length", type=int, default=100, - help="Minimum length of a reference transcript to include it in the analysis (default: %(default)s)") - parser_sc_testing = parser_sc.add_argument_group('Statistical testing options') - parser_sc_testing.add_argument("--comparison_methods", type=str, default="GMM,KS", - help="Comma separated list of comparison methods. Valid methods are: GMM,KS,TT,MW. (default: %(default)s)") - parser_sc_testing.add_argument("--sequence_context", type=int, default=0, choices=range(0,5), - help="Sequence context for combining p-values (default: %(default)s)") - parser_sc_testing.add_argument("--sequence_context_weights", type=str, default="uniform", choices=["uniform", "harmonic"], - help="Type of weights to use for combining p-values") - parser_sc_testing.add_argument("--pvalue_thr", type=float, default=0.05, + + # TODO: YAML input option still needed? + # parser_sc_sample_yaml = parser_sc.add_argument_group('YAML sample files', description="Option allowing to describe sample files in a YAML file") + # parser_sc_sample_yaml.add_argument("--sample_yaml", "-y", default=None, type=str, metavar="sample_yaml", + # help="YAML file containing the sample file labels. See formatting in documentation. (Required if --file_list1 and --file_list2 not given)") + + parser_sc_in = parser_sc.add_argument_group('Input options') + parser_sc_in.add_argument("--input", "-i", required=True, + help="Path to the input database, i.e. 'eventalign_collapse' output (required)") + parser_sc_in.add_argument("--fasta", "-f", required=True, + help="Fasta file used for mapping (required)") + parser_sc_in.add_argument("--bed", default=None, + help="BED file with annotation of transcriptome used for mapping (optional)") + parser_sc_in.add_argument("--samples1", "-1", required=True, metavar="C1,C2", + help="Comma-separated list of sample identifiers for condition 1 (e.g. control).") + parser_sc_in.add_argument("--samples2", "-2", required=True, metavar="T1,T2", + help="Comma-separated list of sample identifiers for condition 2 (e.g. treatment).") + # TODO: where are these labels used? + parser_sc_in.add_argument("--label1", metavar="Condition1", default="Control", + help="Label for condition 1 (default: %(default)s)") + parser_sc_in.add_argument("--label2", metavar="Condition2", default="Treatment", + help="Label for condition 2 (default: %(default)s)") + + parser_sc_out = parser_sc.add_argument_group("Output options") + parser_sc_out.add_argument("--output", "-o", default="sampcomp.db", + help="Path or filename of database output file (default: %(default)s)") + parser_sc_out.add_argument("--report", "-r", default="sampcomp.tsv", + help="Path or filename of report output file (default: %(default)s)") + + parser_sc_filter = parser_sc.add_argument_group("Transcript filtering options") + parser_sc_filter.add_argument("--max_invalid_kmers_freq", type=float, default=0.1, + help="Maximum fequency of invalid kmers (default: %(default)s)") + parser_sc_filter.add_argument("--min_coverage", type=int, default=30, + help="Minimum coverage required in each condition to perform the comparison (default: %(default)s)") + parser_sc_filter.add_argument("--downsample_high_coverage", type=int, default=5000, + help="Downsample transcripts with high coverage to this number of reads (default: %(default)s)") + parser_sc_filter.add_argument("--min_ref_length", type=int, default=100, + help="Minimum length of a reference transcript for inclusion in the analysis (default: %(default)s)") + + parser_sc_test = parser_sc.add_argument_group('Statistical testing options') + parser_sc_test.add_argument("--pvalue_threshold", "-p", type=float, default=0.05, help="Adjusted p-value threshold for reporting significant sites (default: %(default)s)") - parser_sc_testing.add_argument("--logit", action='store_true', - help="Use logistic regression testing downstream of GMM method. This is a legacy option and is now the deault.") - parser_sc_testing.add_argument("--anova", action='store_true', - help="Use Anova test downstream of GMM method (default: %(default)s)") - parser_sc_testing.add_argument("--allow_warnings", action='store_true', default=False, - help="If True runtime warnings during the ANOVA tests don't raise an error (default: %(default)s)") + parser_sc_test.add_argument("--univariate_test", choices=["KS", "MW", "ST", "none"], default="KS", + help="Univariate test for comparing kmer data between conditions. KS: Kolmogorov-Smirnov test, MW: Mann-Whitney test, ST: Student's t-test, none: no univariate test. (default: %(default)s)") + parser_sc_test.add_argument("--no_gmm", action="store_true", + help="Do not perform the GMM fit and subsequent test (see --gmm_test) (default: %(default)s)") + parser_sc_test.add_argument("--gmm_test", choices=["logit", "anova", "none"], default="logit", + help="Statistical test performed after GMM fitting (unless --no_gmm is used). (default: %(default)s)") + parser_sc_test.add_argument("--allow_warnings", action="store_true", + help="If True runtime warnings during the ANOVA tests (see --gmm_test) don't raise an error (default: %(default)s)") + parser_sc_test.add_argument("--sequence_context", type=int, default=0, choices=range(0,5), + help="Sequence context for combining p-values (default: %(default)s)") + parser_sc_test.add_argument("--sequence_context_weights", default="uniform", choices=["uniform", "harmonic"], + help="Type of position weighting to use for combining p-values (default: %(default)s)") + parser_sc_misc = parser_sc.add_argument_group('Other options') parser_sc_misc.add_argument("--nthreads", "-t", type=int, default=3, - help="Number of threads (default: %(default)s)") + help="Number of threads (default: %(default)s)") - # simreads subparser + # SimReads subparser parser_sr = subparsers.add_parser('simreads', formatter_class=argparse.RawDescriptionHelpFormatter, description=textwrap.dedent(""" Simulate reads as a NanopolishComp like file from a fasta file and an inbuild model\n @@ -98,9 +136,15 @@ def main(args=None): * Minimal example with alteration of model intensity loc parameter for 50% of the reads nanocompore simreads -f ref.fa -o results -n 50 --intensity_mod 2 --mod_reads_freq 0.5 --mod_bases_freq 0.2""")) parser_sr.set_defaults(func=simreads_main) - parser_sr_io = parser_sr.add_argument_group('Input options') - parser_sr_io.add_argument("--fasta", "-f", type=str, required=True, - help="Fasta file containing references to use to generate artificial reads") + + parser_sr_in = parser_sr.add_argument_group('Input options') + parser_sr_in.add_argument("--fasta", "-f", required=True, + help="FASTA file containing transcript sequences to use for artificial reads") + + parser_sr_out = parser_sr.add_argument_group("Output options") + parser_sr_out.add_argument("--output", "-o", default="out", + help="Prefix for output files (default: %(default)s)") + parser_sr_modify = parser_sr.add_argument_group('Signal modification options') parser_sr_modify.add_argument("--intensity_mod", type=float, default=0, help="Fraction of intensity distribution SD by which to modify the intensity distribution loc value (default: %(default)s)") @@ -110,66 +154,45 @@ def main(args=None): help="Frequency of reads to modify (default: %(default)s)") parser_sr_modify.add_argument("--mod_bases_freq", type=float, default=0.25, help="Frequency of bases to modify in each read (if possible) (default: %(default)s)") - parser_sr_modify.add_argument("--mod_bases_type", type=str, default="A", choices=["A","T","C","G"], + parser_sr_modify.add_argument("--mod_bases_type", default="A", choices=["A","T","C","G"], help="Base for which to modify the signal (default: %(default)s)") parser_sr_modify.add_argument("--mod_extend_context", type=int, default=2, help="number of adjacent base affected by the signal modification following an harmonic series (default: %(default)s)") parser_sr_modify.add_argument("--min_mod_dist", type=int, default=6, help="Minimal distance between 2 bases to modify (default: %(default)s)") parser_sr_misc = parser_sr.add_argument_group('Other options') - parser_sr_misc.add_argument("--run_type", type=str, default="RNA", choices=["RNA", "DNA"], + parser_sr_misc.add_argument("--run_type", default="RNA", choices=["RNA", "DNA"], help="Define the run type model to import (default: %(default)s)") parser_sr_misc.add_argument("--nreads_per_ref", "-n", type=int, default=100, help="Number of reads to generate per references (default: %(default)s)") parser_sr_misc.add_argument("--pos_rand_seed", type=int, default=42 , - help="Define a seed for randon position picking to get a deterministic behaviour (default: %(default)s)") + help="Define a seed for random position picking to get a deterministic behaviour (default: %(default)s)") parser_sr_misc.add_argument("--not_bound", action='store_true', default=False, help="Do not bind the values generated by the distributions to the observed min and max observed values from the model file (default: %(default)s)") - # Eventalign_collapse subparser - parser_ec = subparsers.add_parser("eventalign_collapse", formatter_class=argparse.RawDescriptionHelpFormatter, - description=textwrap.dedent(""" - Collapse the nanopolish eventalign output at kmers level and compute kmer level statistics - * Minimal example - nanocompore eventalign_collapse -i nanopolish_eventalign.tsv -outprefix out\n""")) - parser_ec.set_defaults(func=eventalign_collapse_main) - parser_ec_io = parser_ec.add_argument_group("Input options") - parser_ec_io.add_argument("--eventalign", "-i", default=0, - help="Path to a nanopolish eventalign tsv output file, or a list of file, or a regex (can be gzipped). It can be ommited if piped to standard input (default: piped to stdin)") - parser_ec_io.add_argument("--sample_name", "-s", default=None, required=True, help="Unique identifier of the sample") - parser_ec_rp = parser_ec.add_argument_group("Run parameters options") - parser_ec_rp.add_argument("--n_lines", "-l", default=None , type=int , - help = "Number of lines to parse.(default: no limits") - parser_ec_misc = parser_ec.add_argument_group("Other options") - parser_ec_misc.add_argument("--nthreads", "-t", default=3, type=int, - help="Total number of threads. 2 threads are reserved for the reader and the writer (default: %(default)s)") - # Add common options for all parsers + for out_group in [parser_ec_out, parser_sc_out]: + out_group.add_argument("--outdir", "-d", default="", + help="Directory for output files. Will be preprended to --output if given. (default: %(default)s)") + out_group.add_argument("--overwrite", "-w", action="store_true", + help="Overwrite existing output files? (default: %(default)s)") for sp in [parser_sc, parser_sr, parser_ec]: - sp_output = sp.add_argument_group("Output options") - sp_output.add_argument("--outpath", "-o", type=str, default="./", - help="Path to the output folder (default: %(default)s)") - sp_output.add_argument("--outprefix", "-p", type=str, default="out", - help="text outprefix for all the files generated (default: %(default)s)") - sp_output.add_argument("--overwrite", "-w", action='store_true', default=False, - help="Use --outpath even if it exists already (default: %(default)s)") sp_verbosity = sp.add_argument_group("Verbosity options") - sp_verbosity.add_argument("--log_level", type=str, default="info", choices=["warning", "info", "debug"], - help="Set the log level (default: %(default)s)") - sp_verbosity.add_argument("--progress", default=False, action='store_true', - help="Display a progress bar during execution (default: %(default)s)") + sp_verbosity.add_argument("--log_level", default="info", choices=["warning", "info", "debug"], + help="Set the log level (default: %(default)s)") + sp_verbosity.add_argument("--progress", action="store_true", + help="Display a progress bar during execution (default: %(default)s)") - # Parse agrs and args = parser.parse_args() # Check if output folder already exists try: - mkdir(fn=args.outpath, exist_ok=args.overwrite) + mkdir(fn=args.outdir, exist_ok=True) except (NanocomporeError, FileExistsError) as E: - raise NanocomporeError("Could not create the output folder. Try using `--overwrite` option or use another directory") + raise NanocomporeError(f"Could not create the output folder: {args.outdir}") # Set logger - log_fn = os.path.join(args.outpath, args.outprefix+"_{}.log".format(vars(args)["subcommand"])) + log_fn = os.path.join(args.outdir, vars(args)["subcommand"] + ".log") set_logger(args.log_level, log_fn=log_fn) # Call relevant subfunction @@ -177,86 +200,95 @@ def main(args=None): #~~~~~~~~~~~~~~SUBCOMMAND FUNCTIONS~~~~~~~~~~~~~~# +def eventalign_collapse_main(args): + """""" + logger.warning("Running Eventalign_collapse") + + outpath = args.output + if args.outdir: + outpath = os.path.normpath(os.path.join(args.outdir, outpath)) + + # Init Eventalign_collapse + e = Eventalign_collapse(eventalign_fn = args.input, + sample_name = args.sample, + output_db_path = outpath, + overwrite = args.overwrite, + n_lines = args.n_lines, + nthreads = args.nthreads, + progress = args.progress) + + # Run eventalign_collapse + e() + def sampcomp_main(args): """""" logger.warning("Running SampComp") - # Load eventalign_fn_dict from a YAML file or assemble eventalign_fn_dict for the command line option - if args.sample_yaml: - eventalign_fn_dict = args.sample_yaml - elif args.file_list1 and args.file_list2: - eventalign_fn_dict = build_eventalign_fn_dict(args.file_list1, args.file_list2, args.label1, args.label2) - else: - raise NanocomporeError("Samples eventalign files have to be provided with either `--sample_yaml` or `--file_list1` and `--file_list2`") + outpath = args.output + if args.outdir: + outpath = os.path.normpath(os.path.join(args.outdir, outpath)) + + sample_dict = build_sample_dict(args.samples1, args.samples2, args.label1, args.label2) + + univar_test = args.univariate_test if args.univariate_test != "none" else None + gmm_test = args.gmm_test if args.gmm_test != "none" else None # Init SampComp - s = SampComp( - eventalign_fn_dict = eventalign_fn_dict, - max_invalid_kmers_freq = args.max_invalid_kmers_freq, - outpath = args.outpath, - outprefix = args.outprefix, - overwrite = args.overwrite, - fasta_fn = args.fasta, - bed_fn = args.bed, - nthreads = args.nthreads, - min_coverage = args.min_coverage, - min_ref_length = args.min_ref_length, - downsample_high_coverage = args.downsample_high_coverage, - comparison_methods = args.comparison_methods, - logit = True, - anova = args.anova, - allow_warnings = args.allow_warnings, - sequence_context = args.sequence_context, - sequence_context_weights = args.sequence_context_weights, - progress = args.progress) + s = SampComp(input_db_path = args.input, + output_db_path = outpath, + sample_dict = sample_dict, + fasta_fn = args.fasta, + overwrite = args.overwrite, + whitelist = None, + univariate_test = univar_test, + fit_gmm = not args.no_gmm, + gmm_test = gmm_test, + allow_anova_warnings = args.allow_warnings, + sequence_context = args.sequence_context, + sequence_context_weighting = args.sequence_context_weights, + min_coverage = args.min_coverage, + min_ref_length = args.min_ref_length, + downsample_high_coverage = args.downsample_high_coverage, + max_invalid_kmers_freq = args.max_invalid_kmers_freq, + nthreads = args.nthreads, + progress = args.progress) # Run SampComp - db = s() + s() # Save all reports - if(db): - db.save_all(pvalue_thr=args.pvalue_thr) + if not args.report: + return + + report_path = args.report + if args.outdir: + report_path = os.path.normpath(os.path.join(args.outdir, report_path)) + + p = PostProcess(outpath, args.input, args.bed) + p.save_report(report_path) # TODO: update "save_all()" and call that instead + def simreads_main(args): """""" logger.warning("Running SimReads") # Run SimReads - SimReads( - fasta_fn = args.fasta, - outpath = args.outpath, - outprefix = args.outprefix, - overwrite = args.overwrite, - run_type = args.run_type, - nreads_per_ref = args.nreads_per_ref, - intensity_mod = args.intensity_mod, - dwell_mod = args.dwell_mod, - mod_reads_freq = args.mod_reads_freq, - mod_bases_freq = args.mod_bases_freq, - mod_bases_type = args.mod_bases_type, - mod_extend_context = args.mod_extend_context, - min_mod_dist = args.min_mod_dist, - pos_rand_seed = args.pos_rand_seed, - not_bound = args.not_bound, - progress = args.progress) - -def eventalign_collapse_main (args): - """""" - logger.warning("Running Eventalign_collapse") - - # Init Eventalign_collapse - e = Eventalign_collapse ( - eventalign_fn = args.eventalign, - sample_name = args.sample_name, - outpath = args.outpath, - outprefix = args.outprefix, - overwrite = args.overwrite, - n_lines = args.n_lines, - nthreads = args.nthreads, - progress = args.progress) - - # Run eventalign_collapse - e() + SimReads(fasta_fn = args.fasta, + outpath = args.outdir, + outprefix = args.output, + overwrite = args.overwrite, + run_type = args.run_type, + nreads_per_ref = args.nreads_per_ref, + intensity_mod = args.intensity_mod, + dwell_mod = args.dwell_mod, + mod_reads_freq = args.mod_reads_freq, + mod_bases_freq = args.mod_bases_freq, + mod_bases_type = args.mod_bases_type, + mod_extend_context = args.mod_extend_context, + min_mod_dist = args.min_mod_dist, + pos_rand_seed = args.pos_rand_seed, + not_bound = args.not_bound, + progress = args.progress) #~~~~~~~~~~~~~~CLI ENTRYPOINT~~~~~~~~~~~~~~# diff --git a/nanocompore/common.py b/nanocompore/common.py index d9637ff..ca9b7e9 100644 --- a/nanocompore/common.py +++ b/nanocompore/common.py @@ -28,14 +28,40 @@ class NanocomporeWarning (Warning): #~~~~~~~~~~~~~~FUNCTIONS~~~~~~~~~~~~~~# -def build_eventalign_fn_dict(file_list1, file_list2, label1, label2): +def build_sample_dict(sample_list1, sample_list2, label1, label2): """ - Build the eventalign_fn_dict from file lists and labels + Build dictionary with sample information from sample lists and condition labels """ - d = OrderedDict() - d[label1] = {"{}_{}".format(label1, i): v for i, v in enumerate(file_list1.split(","),1)} - d[label2] = {"{}_{}".format(label2, i): v for i, v in enumerate(file_list2.split(","),1)} - return d + return {label1: sample_list1.split(","), label2: sample_list2.split(",")} + +def check_sample_dict(sample_dict): + # Check general structure + if type(sample_dict) not in (dict, OrderedDict): + raise NanocomporeError(f"Expected a dictionary. Got a '{type(sample_dict)}'.") + if len(sample_dict) != 2: + raise NanocomporeError(f"Expected two conditions. Found {len(sample_dict)}.") + for condition, samples in sample_dict.items(): + if type(samples) is not list: + raise NanocomporeError(f"Expected a list of sample names for condition '{condition}'. " + "Got a '{type(samples)}'.") + if not samples: + raise NanocomporeError(f"Empty sample list for condition '{condition}'.") + if len(samples) == 1: + logger.warning(f"Only one replicate found for condition '{condition}'. " + "This is not recommended. " + "Statistics will be calculated using the logit method.") + # Check for duplicate sample names + for condition, samples in sample_dict.items(): + if len(set(samples)) < len(samples): + raise NanocomporeError(f"Duplicate sample names for condition '{condition}'.") + all_samples = list(sample_dict.values()) # there must be two lists - already checked + if any([sample in all_samples[1] for sample in all_samples[0]]): + logger.warning("Found sample name shared between conditions. " + "Prefixing all sample names with their condition.") + for condition, samples in sample_dict.items(): + # can't modify 'samples' directly here! + sample_dict[condition] = [f"{condition}_{sample}" for sample in samples] + def set_logger (log_level, log_fn=None): log_level = log_level.upper()