Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement PathSeq taxon hit scoring in Spark #3406

Merged
merged 2 commits into from
Aug 9, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,20 +1,37 @@
package org.broadinstitute.hellbender.tools.spark.pathseq;

import org.broadinstitute.hellbender.exceptions.GATKException;

/**
* Pathogen abundance scores assigned to a taxonomic node and reported by the ClassifyReads tool.
* See the ClassifyReads tool for scoring documentation.
* Pathogen abundance scores assigned to a taxonomic node and reported by the PathSeqScoreSpark tool.
* See the PathSeqScoreSpark tool for scoring documentation.
*/
public final class PSPathogenTaxonScore {
public double score = 0; //PathSeq abundance score calculated in the ClassifyReads tool
public double scoreNormalized = 0; //Score normalized to percent of total
public int reads = 0; //Number of total reads mapped
public int unambiguous = 0; //Number of reads mapped unamibuously to this node
public long refLength = 0; //Length of reference in bp

public final static String outputHeader = "score\tscore_normalized\treads\tunambiguous\treference_length";

public double selfScore = 0; //Total abundance score assigned directly to this taxon
public double descendentScore = 0; //Sum of descendents' scores
public double scoreNormalized = 0; //selfScore + descendentScore, normalized to percent of total selfScores
public int totalReads = 0; //Number of total reads mapped
public int unambiguousReads = 0; //Number of reads mapped unamibuously to this node
public long referenceLength = 0; //Length of reference in bp

@Override
public String toString() {
return score + "\t" + scoreNormalized + "\t" + reads + "\t" + unambiguous + "\t" + refLength;
return (selfScore + descendentScore) + "\t" + scoreNormalized + "\t" + totalReads + "\t" + unambiguousReads + "\t" + referenceLength;
}

public PSPathogenTaxonScore add(final PSPathogenTaxonScore other) {
final PSPathogenTaxonScore result = new PSPathogenTaxonScore();
result.selfScore = this.selfScore + other.selfScore;
result.descendentScore = this.descendentScore + other.descendentScore;
result.totalReads = this.totalReads + other.totalReads;
result.unambiguousReads = this.unambiguousReads + other.unambiguousReads;
if (this.referenceLength != other.referenceLength) {
throw new GATKException("Cannot add PSPathogenTaxonScores with different reference lengths.");
}
result.referenceLength = this.referenceLength;
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
import htsjdk.samtools.*;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.broadinstitute.hellbender.exceptions.UserException;
import org.broadinstitute.hellbender.tools.spark.sv.utils.SVUtils;
import org.broadinstitute.hellbender.utils.Utils;
import org.broadinstitute.hellbender.utils.gcs.BucketUtils;
import org.broadinstitute.hellbender.utils.read.GATKRead;
Expand Down Expand Up @@ -44,7 +46,7 @@ public JavaRDD<GATKRead> scoreReads(final JavaSparkContext ctx,

//Load taxonomy database, created by running PathSeqBuildReferenceTaxonomy with this reference
final PSTaxonomyDatabase taxDB = readTaxonomyDatabase(scoreArgs.taxonomyDatabasePath);
final Broadcast<Map<String, String>> accessionToTaxIdBroadcast = ctx.broadcast(taxDB.accessionToTaxId);
final Broadcast<PSTaxonomyDatabase> taxonomyDatabaseBroadcast = ctx.broadcast(taxDB);

//Check header against database
if (scoreArgs.headerWarningFile != null) {
Expand All @@ -54,19 +56,21 @@ public JavaRDD<GATKRead> scoreReads(final JavaSparkContext ctx,
//Determine which alignments are valid hits and return their tax IDs in PSPathogenAlignmentHit
//Also adds pathseq tags containing the hit IDs to the reads
final JavaRDD<Tuple2<Iterable<GATKRead>, PSPathogenAlignmentHit>> readHits = mapGroupedReadsToTax(groupedReads,
scoreArgs.minCoverage, scoreArgs.minIdentity, accessionToTaxIdBroadcast);

//Collect PSPathogenAlignmentHit objects
final Iterable<PSPathogenAlignmentHit> hitInfo = collectValues(readHits);
scoreArgs.minCoverage, scoreArgs.minIdentity, taxonomyDatabaseBroadcast);

//Get the original reads, now with their pathseq hit tags set
final JavaRDD<GATKRead> readsFinal = flattenIterableKeys(readHits);

//Compute taxonomic scores
final Map<String, PSPathogenTaxonScore> taxScores = computeTaxScores(hitInfo, taxDB.tree);
//Compute taxonomic scores from the alignment hits
final JavaRDD<PSPathogenAlignmentHit> alignmentHits = readHits.map(Tuple2::_2);
final JavaPairRDD<String, PSPathogenTaxonScore> taxScoresRdd = alignmentHits.mapPartitionsToPair(iter -> computeTaxScores(iter, taxonomyDatabaseBroadcast.value()));

//Reduce scores by taxon and compute normalized scores
Map<String, PSPathogenTaxonScore> taxScoresMap = new HashMap<>(taxScoresRdd.reduceByKey(PSPathogenTaxonScore::add).collectAsMap());
taxScoresMap = computeNormalizedScores(taxScoresMap, taxDB.tree);

//Write scores to file
writeScoresFile(taxScores, taxDB.tree, scoreArgs.scoresPath);
writeScoresFile(taxScoresMap, taxDB.tree, scoreArgs.scoresPath);

return readsFinal;
}
Expand Down Expand Up @@ -161,15 +165,15 @@ public static void writeMissingReferenceAccessions(final String path, final SAMF
static JavaRDD<Tuple2<Iterable<GATKRead>, PSPathogenAlignmentHit>> mapGroupedReadsToTax(final JavaRDD<Iterable<GATKRead>> pairs,
final double minCoverage,
final double minIdentity,
final Broadcast<Map<String, String>> refNameToTaxBroadcast) {
final Broadcast<PSTaxonomyDatabase> taxonomyDatabaseBroadcast) {
return pairs.map(readIter -> {

//Number of reads in the pair (1 for unpaired reads)
final int numReads = (int) Utils.stream(readIter).count();

//Get tax IDs of all alignments in all reads that meet the coverage/identity criteria.
final Stream<String> taxIds = Utils.stream(readIter)
.flatMap(read -> getValidHits(read, refNameToTaxBroadcast, minCoverage, minIdentity).stream());
.flatMap(read -> getValidHits(read, taxonomyDatabaseBroadcast.value(), minCoverage, minIdentity).stream());

//Get list of tax IDs that are hits in all reads
final List<String> hitTaxIds;
Expand Down Expand Up @@ -204,7 +208,7 @@ static JavaRDD<Tuple2<Iterable<GATKRead>, PSPathogenAlignmentHit>> mapGroupedRea
* Gets set of sufficiently well-mapped hits
*/
private static Set<String> getValidHits(final GATKRead read,
final Broadcast<Map<String, String>> refNameToTaxBroadcast,
PSTaxonomyDatabase taxonomyDatabase,
final double minCoverage,
final double minIdentity) {

Expand All @@ -226,7 +230,7 @@ private static Set<String> getValidHits(final GATKRead read,

//Throw out duplicates and accessions not in the taxonomic database so it returns a list of unique tax ID's
// for each read in the pair
return hits.stream().map(contig -> refNameToTaxBroadcast.value().containsKey(contig) ? refNameToTaxBroadcast.value().get(contig) : null)
return hits.stream().map(contig -> taxonomyDatabase.accessionToTaxId.containsKey(contig) ? taxonomyDatabase.accessionToTaxId.get(contig) : null)
.filter(Objects::nonNull).collect(Collectors.toSet());
}

Expand Down Expand Up @@ -277,25 +281,22 @@ public static boolean isValidAlignment(final Cigar cigar, final int numMismatche
}

/**
* Computes abundance scores and returns map from taxonomic id to scores
* Computes abundance scores and returns key-values of taxonomic id and scores
*/
public static Map<String, PSPathogenTaxonScore> computeTaxScores(final Iterable<PSPathogenAlignmentHit> readTaxHits,
final PSTree tree) {
public static Iterator<Tuple2<String, PSPathogenTaxonScore>> computeTaxScores(final Iterator<PSPathogenAlignmentHit> taxonHits,
final PSTaxonomyDatabase taxonomyDatabase) {
final PSTree tree = taxonomyDatabase.tree;
final Map<String, PSPathogenTaxonScore> taxIdsToScores = new HashMap<>();
final Set<String> invalidIds = new HashSet<>();
Double sum = 0.0;
for (final PSPathogenAlignmentHit hit : readTaxHits) {
while (taxonHits.hasNext()) {
final PSPathogenAlignmentHit hit = taxonHits.next();
final Collection<String> hitTaxIds = new ArrayList<>(hit.taxIDs);

//Find and omit hits to tax ID's not in the database
final Set<String> invalidHitIds = new HashSet<>(hitTaxIds.size());
for (final String taxid : hitTaxIds) {
if (!tree.hasNode(taxid) || tree.getLengthOf(taxid) == 0) {
invalidHitIds.add(taxid);
}
final Set<String> hitInvalidTaxIds = new HashSet<>(SVUtils.hashMapCapacity(hitTaxIds.size()));
for (final String taxId : hitTaxIds) {
if (!tree.hasNode(taxId) || tree.getLengthOf(taxId) == 0) hitInvalidTaxIds.add(taxId);
}
hitTaxIds.removeAll(invalidHitIds);
invalidIds.addAll(invalidHitIds);
hitTaxIds.removeAll(hitInvalidTaxIds);
invalidIds.addAll(hitInvalidTaxIds);

//Number of genomes hit by this read and number of mates in the tuple (1 for single, 2 for pair)
final int numHits = hitTaxIds.size();
Expand All @@ -305,43 +306,69 @@ public static Map<String, PSPathogenTaxonScore> computeTaxScores(final Iterable<
final String lowestCommonAncestor = tree.getLCA(hitTaxIds);
final List<String> lcaPath = tree.getPathOf(lowestCommonAncestor);
for (final String taxId : lcaPath) {
getOrAddScoreInfo(taxId, taxIdsToScores, tree).unambiguous += hit.numMates;
getOrAddScoreInfo(taxId, taxIdsToScores, tree).unambiguousReads += hit.numMates;
}

//Scores normalized by genome length and degree of ambiguity (number of hits)
final Set<String> hitPathNodes = new HashSet<>(); //Set of all unique hits and ancestors
for (final String taxId : hitTaxIds) {
if (!tree.hasNode(taxId) || tree.getLengthOf(taxId) == 0) {
invalidIds.add(taxId);
continue;
}
final Double score = SCORE_GENOME_LENGTH_UNITS * hit.numMates / (numHits * tree.getLengthOf(taxId));
sum += score;
//Get list containing this node and its ancestors
final List<String> path = tree.getPathOf(taxId);
hitPathNodes.addAll(path);
for (final String pathTaxID : path) {
PSPathogenTaxonScore info = getOrAddScoreInfo(pathTaxID, taxIdsToScores, tree);
info.score += score;
taxIdsToScores.put(pathTaxID, info);
for (final String pathTaxId : path) {
final PSPathogenTaxonScore info = getOrAddScoreInfo(pathTaxId, taxIdsToScores, tree);
if (pathTaxId.equals(taxId)) {
info.selfScore += score;
} else {
info.descendentScore += score;
}
taxIdsToScores.put(pathTaxId, info);
}
}

//"reads" score is the number of reads that COULD belong to each node i.e. an upper-bound
for (final String taxId : hitPathNodes) {
getOrAddScoreInfo(taxId, taxIdsToScores, tree).reads += hit.numMates;
getOrAddScoreInfo(taxId, taxIdsToScores, tree).totalReads += hit.numMates;
}
}
PSUtils.logItemizedWarning(logger, invalidIds, "The following taxonomic ID hits were ignored because " +
"they either could not be found in the tree or had a reference length of 0 (this may happen when " +
"the catalog file, taxdump file, and/or pathogen reference are inconsistent)");
return taxIdsToScores.entrySet().stream().map(entry -> new Tuple2<>(entry.getKey(), entry.getValue())).iterator();
}

/**
* Assigns scores normalized to 100%. For each taxon, its normalized score is own score divided by the sum
* over all scores, plus the sum of its childrens' normalized scores.
*/
final static Map<String, PSPathogenTaxonScore> computeNormalizedScores(final Map<String, PSPathogenTaxonScore> taxIdsToScores,
final PSTree tree) {
//Get sum of all scores that were assigned directly to each taxa (as opposed to being propagated up from descendents)
double sum = 0.;
for (final PSPathogenTaxonScore score : taxIdsToScores.values()) {
sum += score.selfScore;
}

//Scores normalized to 100%
//Gets normalized selfScores and adds it to all ancestors
for (final Map.Entry<String, PSPathogenTaxonScore> entry : taxIdsToScores.entrySet()) {
final String readName = entry.getKey();
final PSPathogenTaxonScore score = entry.getValue();
final String taxId = entry.getKey();
final double selfScore = entry.getValue().selfScore;
final double normalizedScore;
if (sum == 0) {
score.scoreNormalized = 0;
normalizedScore = 0;
} else {
score.scoreNormalized = 100.0 * score.score / sum;
normalizedScore = 100.0 * selfScore / sum;
}
final List<String> path = tree.getPathOf(taxId);
for (final String pathTaxId : path) {
taxIdsToScores.get(pathTaxId).scoreNormalized += normalizedScore;
}
taxIdsToScores.replace(readName, score);
}
PSUtils.logItemizedWarning(logger, invalidIds, "The following taxonomic ID hits were ignored because " +
"they either could not be found in the tree or had a reference length of 0 (this may happen when " +
"the catalog file, taxdump file, and/or pathogen reference are inconsistent)");
return taxIdsToScores;
}

Expand All @@ -356,7 +383,7 @@ private static PSPathogenTaxonScore getOrAddScoreInfo(final String taxIds,
score = taxScores.get(taxIds);
} else {
score = new PSPathogenTaxonScore();
score.refLength = tree.getLengthOf(taxIds);
score.referenceLength = tree.getLengthOf(taxIds);
taxScores.put(taxIds, score);
}
return score;
Expand Down
Loading