Skip to content

Commit

Permalink
tidy up code
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-schultz committed Mar 21, 2024
1 parent 9981105 commit 358adb0
Showing 1 changed file with 49 additions and 111 deletions.
160 changes: 49 additions & 111 deletions hail/src/main/scala/is/hail/rvd/AbstractRVDSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,8 @@ package is.hail.rvd
import is.hail.annotations._
import is.hail.backend.{ExecuteContext, HailStateManager}
import is.hail.compatibility
import is.hail.expr.{JSONAnnotationImpex, ir}
import is.hail.expr.ir.{
IR, Literal, PartitionNativeReader, PartitionZippedIndexedNativeReader,
PartitionZippedNativeReader, ReadPartition, Ref, ToStream, flatMapIR,
}
import is.hail.expr.{ir, JSONAnnotationImpex}
import is.hail.expr.ir.{flatMapIR, IR, Literal, PartitionNativeReader, PartitionZippedIndexedNativeReader, PartitionZippedNativeReader, ReadPartition, Ref, ToStream}
import is.hail.expr.ir.lowering.{TableStage, TableStageDependency}
import is.hail.io._
import is.hail.io.fs.FS
Expand All @@ -17,6 +14,7 @@ import is.hail.types.encoded.ETypeSerializer
import is.hail.types.physical._
import is.hail.types.virtual._
import is.hail.utils._

import org.apache.spark.TaskContext
import org.apache.spark.sql.Row
import org.json4s.{DefaultFormats, Formats, JValue, ShortTypeHints}
Expand Down Expand Up @@ -496,125 +494,65 @@ case class IndexedRVDSpec2(
uidFieldName,
)

val body = (ctxs: Ref) =>
flatMapIR(ToStream(ctxs, true)) { ctx =>
ir.ReadPartition(ctx, requestedType.rowType, reader)
}

if (filterIntervals) {
val (nestedContexts, newRangeBounds) = part.rangeBounds.indices.flatMap { oldPartIdx =>
val oldInterval = part.rangeBounds(oldPartIdx)
val overlapRange = extendedNP.queryInterval(oldInterval)
if (overlapRange.isEmpty) None
else {
val ctxs = overlapRange.map { newPartIdx =>
val partFile = partFiles(oldPartIdx)
val intersectionInterval =
extendedNP.rangeBounds(newPartIdx)
.intersect(extendedNP.kord, oldInterval).get
Row(
oldPartIdx.toLong,
s"$path/parts/$partFile",
s"$path/${indexSpec.relPath}/$partFile.idx",
RVDPartitioner.intervalToIRRepresentation(intersectionInterval, part.kType.size),
)
}
val newInterval = Interval(
extendedNP.rangeBounds(overlapRange.head).left,
extendedNP.rangeBounds(overlapRange.last).right,
)
Some((
ctxs,
oldInterval.intersect(extendedNP.kord, newInterval).get,
))
}
}.unzip

assert(TArray(TArray(reader.contextType)).typeCheck(nestedContexts))

val contexts = ir.ToStream(ir.Literal(TArray(TArray(reader.contextType)), nestedContexts))
def makeCtx(oldPartIdx: Int, newPartIdx: Int): Row = {
val oldInterval = part.rangeBounds(oldPartIdx)
val partFile = partFiles(oldPartIdx)
val intersectionInterval =
extendedNP.rangeBounds(newPartIdx)
.intersect(extendedNP.kord, oldInterval).get
Row(
oldPartIdx.toLong,
s"$path/parts/$partFile",
s"$path/${indexSpec.relPath}/$partFile.idx",
RVDPartitioner.intervalToIRRepresentation(intersectionInterval, part.kType.size),
)
}

{ (globals: IR) =>
TableStage(
globals,
new RVDPartitioner(part.sm, part.kType, newRangeBounds),
TableStageDependency.none,
contexts,
body,
val (nestedContexts, newPartitioner) = if (filterIntervals) {
val contextsAndBounds = for {
oldPartIdx <- part.rangeBounds.indices
oldInterval = part.rangeBounds(oldPartIdx)
overlapRange = extendedNP.queryInterval(oldInterval)
if overlapRange.nonEmpty
} yield {
val ctxs = overlapRange.map(newPartIdx => makeCtx(oldPartIdx, newPartIdx))
val newInterval = Interval(
extendedNP.rangeBounds(overlapRange.head).left,
extendedNP.rangeBounds(overlapRange.last).right,
)
(
ctxs,
oldInterval.intersect(extendedNP.kord, newInterval).get,
)
}
val (nestedContexts, newRangeBounds) = contextsAndBounds.unzip

(nestedContexts, new RVDPartitioner(part.sm, part.kType, newRangeBounds))
} else {
val nestedContexts = extendedNP.rangeBounds.indices.map { newPartIdx =>
val newInterval = extendedNP.rangeBounds(newPartIdx)
val overlapRange = part.queryInterval(newInterval)
overlapRange.map { oldPartIdx =>
val oldInterval = part.rangeBounds(oldPartIdx)
val partFile = partFiles(oldPartIdx)
val intersectionInterval =
extendedNP.rangeBounds(newPartIdx)
.intersect(extendedNP.kord, oldInterval).get
Row(
oldPartIdx.toLong,
s"$path/parts/$partFile",
s"$path/${indexSpec.relPath}/$partFile.idx",
RVDPartitioner.intervalToIRRepresentation(intersectionInterval, part.kType.size),
)
}
overlapRange.map(oldPartIdx => makeCtx(oldPartIdx, newPartIdx))
}

assert(TArray(TArray(reader.contextType)).typeCheck(nestedContexts))
(nestedContexts, extendedNP)
}

val contexts = ir.ToStream(ir.Literal(TArray(TArray(reader.contextType)), nestedContexts))
assert(TArray(TArray(reader.contextType)).typeCheck(nestedContexts))

{ (globals: IR) =>
TableStage(
globals,
extendedNP,
TableStageDependency.none,
contexts,
body,
)
}
{ (globals: IR) =>
TableStage(
globals,
newPartitioner,
TableStageDependency.none,
contexts = ir.ToStream(ir.Literal(TArray(TArray(reader.contextType)), nestedContexts)),
body = (ctxs: Ref) =>
flatMapIR(ToStream(ctxs, true)) { ctx =>
ir.ReadPartition(ctx, requestedType.rowType, reader)
},
)
}

// val (contextsValues, tmpRangeBounds) = (for {
// newInterval <- extendedNP.rangeBounds
// oldPartIdx <- part.queryInterval(newInterval)
// } yield {
// val oldInterval = part.rangeBounds(oldPartIdx)
// val intersectionInterval =
// newInterval.intersect(extendedNP.kord.intervalEndpointOrdering, oldInterval).get
//
// val partFile = partFiles(oldPartIdx)
// val ctx = Row(
// oldPartIdx.toLong,
// s"$path/parts/$partFile",
// s"$path/${indexSpec.relPath}/$partFile.idx",
// RVDPartitioner.intervalToIRRepresentation(intersectionInterval, part.kType.size),
// )
// (ctx, intersectionInterval)
// }).toFastSeq.unzip
// val tmpPartitioner = new RVDPartitioner(part.sm, part.kType, tmpRangeBounds)

// assert(TArray(TArray(reader.contextType)).typeCheck(nestedContexts))
//
// val contexts = ir.ToStream(ir.Literal(TArray(TArray(reader.contextType)), nestedContexts))
//
//// val body = (ctx: IR) => ir.ReadPartition(ctx, requestedType.rowType, reader)
//
// { (globals: IR) =>
// val ts = TableStage(
// globals,
// if (filterIntervals) part else extendedNP,
// TableStageDependency.none,
// contexts,
// body,
// )
// if (filterIntervals)
// ts.repartitionNoShuffle(ctx, part, dropEmptyPartitions = true)
// else ts.repartitionNoShuffle(ctx, extendedNP)
// }

case None =>
super.readTableStage(ctx, path, requestedType, uidFieldName, newPartitioner, filterIntervals)
}
Expand Down

0 comments on commit 358adb0

Please sign in to comment.