diff --git a/hail/src/main/scala/is/hail/rvd/AbstractRVDSpec.scala b/hail/src/main/scala/is/hail/rvd/AbstractRVDSpec.scala index de35f108c7e..e6d83d4333d 100644 --- a/hail/src/main/scala/is/hail/rvd/AbstractRVDSpec.scala +++ b/hail/src/main/scala/is/hail/rvd/AbstractRVDSpec.scala @@ -512,19 +512,24 @@ case class IndexedRVDSpec2( } val (nestedContexts, newPartitioner) = if (filterIntervals) { + /* We want to filter to intervals in newPartitioner, while preserving the old partitioning, + * but dropping any partitions we know would be empty. So we construct a map from old + * partitions to the range of overlapping new partitions, dropping any with an empty range. */ val contextsAndBounds = for { - oldPartIdx <- part.rangeBounds.indices - oldInterval = part.rangeBounds(oldPartIdx) + (oldInterval, oldPartIdx) <- part.rangeBounds.toFastSeq.zipWithIndex overlapRange = extendedNP.queryInterval(oldInterval) if overlapRange.nonEmpty } yield { val ctxs = overlapRange.map(newPartIdx => makeCtx(oldPartIdx, newPartIdx)) + // the interval spanning all overlapping filter intervals val newInterval = Interval( extendedNP.rangeBounds(overlapRange.head).left, extendedNP.rangeBounds(overlapRange.last).right, ) ( ctxs, + // Shrink oldInterval to the rows filtered to. + // By construction we know oldInterval and newInterval overlap oldInterval.intersect(extendedNP.kord, newInterval).get, ) } @@ -532,11 +537,14 @@ case class IndexedRVDSpec2( (nestedContexts, new RVDPartitioner(part.sm, part.kType, newRangeBounds)) } else { - val nestedContexts = extendedNP.rangeBounds.indices.map { newPartIdx => - val newInterval = extendedNP.rangeBounds(newPartIdx) - val overlapRange = part.queryInterval(newInterval) - overlapRange.map(oldPartIdx => makeCtx(oldPartIdx, newPartIdx)) - } + /* We want to use newPartitioner as the partitioner, dropping any rows not contained in any + * new partition. So we construct a map from new partitioner to the range of overlapping old + * partitions. */ + val nestedContexts = + extendedNP.rangeBounds.toFastSeq.zipWithIndex.map { case (newInterval, newPartIdx) => + val overlapRange = part.queryInterval(newInterval) + overlapRange.map(oldPartIdx => makeCtx(oldPartIdx, newPartIdx)) + } (nestedContexts, extendedNP) }