Skip to content

Commit

Permalink
Support CRAM reference regions. (#1605)
Browse files Browse the repository at this point in the history
* Support CRAM reference regions.  Changed so that CRAM reading does not require an entire contig to be loaded at once, this will support faster loading of sections of cram files in IGV and other tools.  
* Breaking changes to the CRAMReferenceSource interface that add a new method getReferenceBasesByRegion.  Implementers will need to be updated..
  • Loading branch information
cmnbroad authored Jun 3, 2022
1 parent f461401 commit 489c419
Show file tree
Hide file tree
Showing 23 changed files with 725 additions and 175 deletions.
8 changes: 4 additions & 4 deletions src/main/java/htsjdk/samtools/CRAMIterator.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public class CRAMIterator implements SAMRecordIterator, Closeable {
private final CramContainerIterator containerIterator;
private final CramHeader cramHeader;
private final SAMFileHeader samFileHeader;
private final CRAMReferenceRegion cramReferenceState;
private final CRAMReferenceRegion cramReferenceRegion;
private final QueryInterval[] queryIntervals;

private ValidationStringency validationStringency;
Expand Down Expand Up @@ -64,7 +64,7 @@ public CRAMIterator(final InputStream inputStream,

this.validationStringency = validationStringency;
samFileHeader = containerIterator.getSamFileHeader();
cramReferenceState = new CRAMReferenceRegion(referenceSource, samFileHeader);
cramReferenceRegion = new CRAMReferenceRegion(referenceSource, samFileHeader.getSequenceDictionary());
cramHeader = containerIterator.getCramHeader();
firstContainerOffset = this.countingInputStream.getCount();
samRecords = new ArrayList<>(new CRAMEncodingStrategy().getReadsPerSlice());
Expand All @@ -81,7 +81,7 @@ public CRAMIterator(final SeekableStream seekableStream,

this.validationStringency = validationStringency;
samFileHeader = containerIterator.getSamFileHeader();
cramReferenceState = new CRAMReferenceRegion(referenceSource, samFileHeader);
cramReferenceRegion = new CRAMReferenceRegion(referenceSource, samFileHeader.getSequenceDictionary());
cramHeader = containerIterator.getCramHeader();
firstContainerOffset = this.countingInputStream.getCount();
samRecords = new ArrayList<>(new CRAMEncodingStrategy().getReadsPerSlice());
Expand Down Expand Up @@ -111,7 +111,7 @@ private BAMIteratorFilter.FilteringIteratorState nextContainer() {
if (containerMatchesQuery(container)) {
samRecords = container.getSAMRecords(
validationStringency,
cramReferenceState,
cramReferenceRegion,
compressorCache,
getSAMFileHeader());
samRecordIterator = samRecords.iterator();
Expand Down
215 changes: 179 additions & 36 deletions src/main/java/htsjdk/samtools/cram/build/CRAMReferenceRegion.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,73 +24,216 @@
*/
package htsjdk.samtools.cram.build;

import htsjdk.samtools.SAMFileHeader;
import htsjdk.samtools.SAMRecord;
import htsjdk.samtools.SAMSequenceDictionary;
import htsjdk.samtools.SAMSequenceRecord;
import htsjdk.samtools.cram.ref.CRAMReferenceSource;
import htsjdk.samtools.cram.ref.ReferenceContext;
import htsjdk.samtools.cram.ref.ReferenceContextType;
import htsjdk.samtools.cram.structure.AlignmentContext;
import htsjdk.samtools.util.Log;
import htsjdk.utils.ValidationUtils;

/**
* A (cached) region of a reference. Maintains a CRAMReferenceSource for retrieving additional regions.
* Holds a region/fragment of a reference contig. Maintains a CRAMReferenceSource for retrieving additional regions.
* This is a mutable object that is used to traverse along a reference contig via serial calls to either
* {@link #fetchReferenceBases(int)} or {@link #fetchReferenceBasesByRegion(int, int, int)}. It caches the bases
* from the previous request, along with metadata about the (0-based) start offset, and length of the
* cached bases.
*/
public class CRAMReferenceRegion {
private static final Log log = Log.getInstance(CRAMReferenceRegion.class);
public static final int UNINITIALIZED_START = -1;
public static final int UNINITIALIZED_LENGTH = -1;

private final CRAMReferenceSource referenceSource;
private final SAMFileHeader samFileHeader;
private final SAMSequenceDictionary sequenceDictionary;

private byte[] referenceBases = null; // cache the reference bases
private int referenceBasesContextID = ReferenceContext.UNINITIALIZED_REFERENCE_ID;
private int referenceIndex = ReferenceContext.UNINITIALIZED_REFERENCE_ID;
private byte[] referenceBases = null;
private SAMSequenceRecord sequenceRecord = null;
private int regionStart = UNINITIALIZED_START; // 0-based start offset of the region
private int regionLength = UNINITIALIZED_LENGTH; // length of the bases cached by this reference region

/**
* @param cramReferenceSource {@link CRAMReferenceSource} to use to obtain reference bases
* @param samFileHeader {@link SAMFileHeader} to use to resolve reference contig names to reference index
* @param sequenceDictionary {@link SAMSequenceDictionary} to use to resolve reference contig names to reference index
*/
public CRAMReferenceRegion(final CRAMReferenceSource cramReferenceSource, final SAMFileHeader samFileHeader) {
if (cramReferenceSource == null) {
throw new IllegalArgumentException("A valid reference must be supplied to retrieve records from the CRAM stream.");
}
public CRAMReferenceRegion(final CRAMReferenceSource cramReferenceSource, final SAMSequenceDictionary sequenceDictionary) {
ValidationUtils.nonNull(cramReferenceSource, "cramReferenceSource");
ValidationUtils.nonNull(sequenceDictionary, "sequenceDictionary");

this.referenceSource = cramReferenceSource;
this.samFileHeader = samFileHeader;
this.sequenceDictionary = sequenceDictionary;
}

/**
* @return the currently cached reference bases (may ne null)
* @return the currently cached reference bases (may return null)
*/
public byte[] getCurrentReferenceBases() {
return referenceBases;
}

/**
* Return the reference bases for the given reference index.
* @param referenceIndex
* @return the current reference index or {@link ReferenceContext#UNINITIALIZED_REFERENCE_ID} if no index has
* been established
*/
public int getReferenceIndex() {
return referenceIndex;
}

/**
* @return the 0-based start position of the range of the current reference sequence region.
* {@link #UNINITIALIZED_START} if no region has been established
*/
public int getRegionStart() {
return regionStart;
}

/**
* @return the length of the current reference sequence region. {@link #UNINITIALIZED_LENGTH}
* if no region has been established
*/
public int getRegionLength() {
return regionLength;
}

/**
* Return the reference bases for an entire contig given a reference contig index.
*
* Note: Serial calls to this method on behalf of non-coordinate sorted inputs can result in
* thrashing and poor performance due to repeated calls to the underlying CRAMReferenceSource,
* especially when the CRAMReferenceSource is fetching bases from a remote reference.
*
* @return bases for the entire reference contig specifed by {@code referenceIndex}
* @param referenceIndex the reference index for which bases should be retrieved.
* @throws IllegalArgumentException if the requested index is not present in the sequence dictionary or if
* the sequence's bases cannot be retrieved from the CRAMReferenceSource
*/
public byte[] getReferenceBases(final int referenceIndex) {
// for non-coord sorted this could cause a lot of thrashing
if (referenceIndex != SAMRecord.NO_ALIGNMENT_REFERENCE_INDEX) {
if (referenceBases == null || referenceIndex != referenceBasesContextID) {
final SAMSequenceRecord sequence = samFileHeader.getSequence(referenceIndex);
referenceBases = referenceSource.getReferenceBases(sequence, true);
if (referenceBases == null) {
throw new IllegalArgumentException(
String.format(
"A reference must be supplied (reference sequence %s not found).",
sequence));
}

referenceBasesContextID = referenceIndex;
public void fetchReferenceBases(final int referenceIndex) {
ValidationUtils.validateArg(referenceIndex >= 0, "reference index must be >= 0");

// Re-resolve the reference bases if we don't have a current region or if the region we have
// doesn't span the *entire* contig requested.
if ((referenceIndex != this.referenceIndex) ||
regionStart != 0 ||
(regionLength < referenceBases.length)) {
setCurrentSequence(referenceIndex);
referenceBases = referenceSource.getReferenceBases(sequenceRecord, true);
if (referenceBases == null) {
throw new IllegalArgumentException(
String.format("A reference must be supplied (reference sequence %s not found).", sequenceRecord));
}
return referenceBases;
regionStart = 0;
regionLength = sequenceRecord.getSequenceLength();
}
}

/**
* Get the reference bases for a region of a reference contig. If the current region does not match the
* requested region, the {@link #referenceSource} will be called to retrieve the bases.
*
* The caller cannot assume that the entire region requested is always fetched (if the requested range matches
* the alignment span for CRAM record or slice contains a CRAM records that is mapped beyond the end of the
* reference contig, fewer bases than were requested may be fetched.
*
* @param referenceIndex reference index for which to retrieve bases
* @param zeroBasedStart zero based start of the first base to be retrieved
* @param requestedFragmentLength length of the fragment to be retrieved
*
* @throws IllegalArgumentException if the requested sequence cannot be located in the sequence dictionary, or
* if the requested sequence cannot be provided by the underlying referenceSource
*/
public void fetchReferenceBasesByRegion(
final int referenceIndex,
final int zeroBasedStart,
final int requestedFragmentLength) {
ValidationUtils.validateArg(referenceIndex >= 0, "reference index must be non-negative");
ValidationUtils.validateArg(zeroBasedStart >= 0, "start must be >= 0");

if (referenceIndex == this.referenceIndex &&
zeroBasedStart == regionStart &&
requestedFragmentLength == regionLength) {
// exact match for what we already have
return;
}

if (referenceIndex != this.referenceIndex) {
setCurrentSequence(referenceIndex);
}

if (zeroBasedStart >= sequenceRecord.getSequenceLength()) {
throw new IllegalArgumentException(String.format("Requested start %d is beyond the sequence length %s",
zeroBasedStart,
sequenceRecord.getSequenceName()));
}

// retain whatever cached reference bases we may have to minimize subsequent re-fetching
return null;
referenceBases = referenceSource.getReferenceBasesByRegion(sequenceRecord, zeroBasedStart, requestedFragmentLength);
if (referenceBases == null) {
throw new IllegalArgumentException(
String.format("Failure getting reference bases for sequence %s", sequenceRecord.getSequenceName()));
} else if (referenceBases.length < requestedFragmentLength) {
log.warn("The bases of length " + referenceBases.length +
" returned by the reference source do not satisfy the requested fragment length " +
(zeroBasedStart + requestedFragmentLength));
}
regionStart = zeroBasedStart;
regionLength = referenceBases.length;
}

/**
* Fetch the bases to span an {@link AlignmentContext}.
* @param alignmentContext the alignment context for which to fetch bases. must be an AlignmentContext
* for a single reference {@link ReferenceContextType#SINGLE_REFERENCE_TYPE} slice (see
* {@link ReferenceContext#isMappedSingleRef()})
*/
public void fetchReferenceBasesByRegion(final AlignmentContext alignmentContext) {
ValidationUtils.validateArg(
alignmentContext.getReferenceContext().isMappedSingleRef(),
"a mapped single reference alignment context is required");
fetchReferenceBasesByRegion(
alignmentContext.getReferenceContext().getReferenceSequenceID(),
alignmentContext.getAlignmentStart() - 1, // 1-based alignment context to 0-based reference offset
alignmentContext.getAlignmentSpan());
}

public void setEmbeddedReference(final byte[] embeddedReferenceBytes, final int embeddedReferenceIndex) {
referenceBasesContextID = embeddedReferenceIndex;
referenceBases = embeddedReferenceBytes;
/**
* Set this {@link CRAMReferenceRegion} to use an embedded reference.
*
* @param embeddedReferenceBases the embedded reference bases to be used
* @param embeddedReferenceIndex the reference ID used in the slice containing the embedded reference
* @param zeroBasedStart the zero based reference start of the first base in the embedded reference bases
*/
public void setEmbeddedReferenceBases(
final byte[] embeddedReferenceBases,
final int embeddedReferenceIndex,
final int zeroBasedStart) {
ValidationUtils.nonNull(embeddedReferenceBases, "embeddedReferenceBases");
setCurrentSequence(embeddedReferenceIndex);
referenceBases = embeddedReferenceBases;
regionStart = zeroBasedStart;
regionLength = embeddedReferenceBases.length;
}

/**
* @return the length of the entire reference contig maintained by this region. note that this is not the
* same as the length of the current bases maintained by this region (this can happen if the region contains
* a fragment, or an embedded reference fragment), or -1 if no current region is established
*/
public int getFullContigLength() {
return sequenceRecord == null ? -1 : sequenceRecord.getSequenceLength();
}

private void setCurrentSequence(final int referenceIndex) {
this.referenceIndex= referenceIndex;
this.sequenceRecord = getSAMSequenceRecord(referenceIndex);
}

private SAMSequenceRecord getSAMSequenceRecord(final int referenceIndex) {
final SAMSequenceRecord samSequenceRecord = sequenceDictionary.getSequence(referenceIndex);
if (samSequenceRecord == null) {
throw new IllegalArgumentException(
String.format("Reference sequence index %d not found", referenceIndex));
}
return samSequenceRecord;
}

}
16 changes: 13 additions & 3 deletions src/main/java/htsjdk/samtools/cram/build/SliceFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import htsjdk.samtools.cram.common.CramVersions;
import htsjdk.samtools.cram.ref.CRAMReferenceSource;
import htsjdk.samtools.cram.ref.ReferenceContext;
import htsjdk.samtools.cram.structure.AlignmentContext;
import htsjdk.samtools.cram.structure.CRAMEncodingStrategy;
import htsjdk.samtools.cram.structure.CRAMCompressionRecord;
import htsjdk.samtools.cram.structure.CompressionHeader;
Expand Down Expand Up @@ -67,7 +68,7 @@ public SliceFactory(
final SAMFileHeader samFileHeader,
final long globalRecordCounter) {
this.encodingStrategy = cramEncodingStrategy;
this.cramReferenceRegion = new CRAMReferenceRegion(cramReferenceSource, samFileHeader);
this.cramReferenceRegion = new CRAMReferenceRegion(cramReferenceSource, samFileHeader.getSequenceDictionary());
minimumSingleReferenceSliceThreshold = encodingStrategy.getMinimumSingleReferenceSliceSize();
maxRecordsPerSlice = this.encodingStrategy.getReadsPerSlice();
this.coordinateSorted = samFileHeader.getSortOrder() == SAMFileHeader.SortOrder.coordinate;
Expand Down Expand Up @@ -136,7 +137,11 @@ public List<Slice> createSlices(
containerByteOffset,
sliceStagingEntry.getGlobalRecordCounter()
);
slice.setReferenceMD5(cramReferenceRegion.getCurrentReferenceBases());
final AlignmentContext sliceAlignmentContext = slice.getAlignmentContext();
if (sliceAlignmentContext.getReferenceContext().isMappedSingleRef()) {
cramReferenceRegion.fetchReferenceBasesByRegion(sliceAlignmentContext);
slice.setReferenceMD5(cramReferenceRegion);
}
slices.add(slice);
}
cramRecordSliceEntries.clear();
Expand All @@ -150,11 +155,16 @@ private final List<CRAMCompressionRecord> convertToCRAMRecords(final List<SAMRec
final List<CRAMCompressionRecord> cramCompressionRecords = new ArrayList<>();
for (final SAMRecord samRecord : samRecords) {
int referenceIndex = samRecord.getReferenceIndex();
byte[] referenceBases = null;
if (referenceIndex != SAMRecord.NO_ALIGNMENT_REFERENCE_INDEX) {
cramReferenceRegion.fetchReferenceBases(referenceIndex);
referenceBases = cramReferenceRegion.getCurrentReferenceBases();
}
final CRAMCompressionRecord cramCompressionRecord = new CRAMCompressionRecord(
CramVersions.DEFAULT_CRAM_VERSION,
encodingStrategy,
samRecord,
cramReferenceRegion.getReferenceBases(referenceIndex),
referenceBases,
recordIndex++,
readGroupNameToID);
cramCompressionRecords.add(cramCompressionRecord);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,19 @@
* exception.
*/
public class CRAMLazyReferenceSource implements CRAMReferenceSource {
private static final String NO_REF_MESSAGE = "A reference must be supplied that includes the reference sequence for %s.";

@Override
public byte[] getReferenceBases(final SAMSequenceRecord sequenceRecord, final boolean tryNameVariants) {
throw new IllegalArgumentException(
String.format("A reference must be supplied that includes the reference sequence for %s.",
sequenceRecord.getSequenceName()));
throw new IllegalArgumentException(String.format(NO_REF_MESSAGE, sequenceRecord.getSequenceName()));
}

@Override
public byte[] getReferenceBasesByRegion(
final SAMSequenceRecord sequenceRecord,
final int zeroBasedStart,
final int requestedRegionLength) {
throw new IllegalArgumentException(String.format(NO_REF_MESSAGE, sequenceRecord.getSequenceName()));
}

}
Loading

0 comments on commit 489c419

Please sign in to comment.