diff --git a/README.md b/README.md index fb85c2e..df67e53 100644 --- a/README.md +++ b/README.md @@ -79,8 +79,26 @@ sc._jsc.hadoopConfiguration().set("fs.s3a.secret.key", secret_key) sc._jsc.hadoopConfiguration().set("fs.s3a.path.style.access", "true") ``` +## Segment version selection + +The connector automatically detects the Apache Druid segments based on the directory structure on Deep Storage (without +accessing the metadata store). Similarly as in Apache Druid, for every interval only the latest version is loaded. Any +segment granularity is supported, including mixed (e.g. hourly with daily). + +For example, if the storage contains the following segments: + +``` +1 2 3 4 5 <- Intervals +| A 1 | B 2 | C 3 | E 5 | +| D 4 | +``` + +C, D, and E segments would be selected, while A and B would be skipped (as they were most likely compacted into D, +that has a newer version). + +Thus even if multiple versions of the data exist on the storage, only one version (latest) will be loaded (without +duplicates). + ## Current limitations - The connector is reading only the dimensions, skipping all the metrics. -- Only segments with daily granularity are supported. -- Only latest interval version can be loaded. diff --git a/src/main/scala/bi/deep/DruidDataReader.scala b/src/main/scala/bi/deep/DruidDataReader.scala index 58a7b9c..2868179 100644 --- a/src/main/scala/bi/deep/DruidDataReader.scala +++ b/src/main/scala/bi/deep/DruidDataReader.scala @@ -1,6 +1,7 @@ package bi.deep import org.apache.druid.segment.{DruidRowConverter, QueryableIndexIndexableAdapter} +import org.apache.hadoop.fs.FileSystem import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.v2.reader.InputPartitionReader import org.apache.spark.sql.types.{StructField, StructType} @@ -11,6 +12,9 @@ import java.io.File case class DruidDataReader(filePath: String, schema: StructType, config: Config) extends InputPartitionReader[InternalRow] with DruidSegmentReader { + @transient + private implicit val fs: FileSystem = config.factory.fileSystemFor(filePath) + private var current: Option[InternalRow] = None private val targetRowSize: Int = schema.size private val timestampIdx: Option[Int] = { @@ -18,7 +22,7 @@ case class DruidDataReader(filePath: String, schema: StructType, config: Config) else None } - private lazy val rowConverter: DruidRowConverter = withSegment(filePath, config, filePath) { file => + private lazy val rowConverter: DruidRowConverter = withSegment(filePath, config) { file => val qi = indexIO.loadIndex(new File(file.getAbsolutePath)) val qiia = new QueryableIndexIndexableAdapter(qi) val segmentSchema = DruidSchemaReader.readSparkSchema(qi) diff --git a/src/main/scala/bi/deep/DruidDataSourceReader.scala b/src/main/scala/bi/deep/DruidDataSourceReader.scala index 2da845c..1b82ef1 100644 --- a/src/main/scala/bi/deep/DruidDataSourceReader.scala +++ b/src/main/scala/bi/deep/DruidDataSourceReader.scala @@ -1,7 +1,8 @@ package bi.deep +import bi.deep.segments.{SegmentStorage, Segment} import org.apache.druid.common.guava.GuavaUtils -import org.apache.hadoop.fs.LocatedFileStatus +import org.apache.hadoop.fs.{FileSystem, LocatedFileStatus, Path} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, InputPartition} @@ -13,19 +14,28 @@ import scala.collection.JavaConverters.seqAsJavaListConverter class DruidDataSourceReader(config: Config) extends DataSourceReader { - val mainPath: String = config.inputPath + config.dataSource + "/" - private lazy val filesPaths: Array[LocatedFileStatus] = LatestSegmentSelector(config, mainPath).getPathsArray - private lazy val schema: StructType = readSchema() + @transient + private lazy val mainPath: Path = new Path(config.inputPath, config.dataSource) + @transient + private implicit lazy val fs: FileSystem = config.factory.fileSystemFor(mainPath) + @transient + private lazy val storage = SegmentStorage(mainPath) - override def readSchema(): StructType = { - val schemaReader = new DruidSchemaReader(mainPath, config) - schemaReader.calculateSchema(filesPaths) - } + @transient + private lazy val schemaReader = DruidSchemaReader(config) + + @transient + private lazy val filesPaths: Seq[Segment] = storage.findValidSegments(config.startDate, config.endDate) + + @transient + private lazy val schema: StructType = readSchema() + + override def readSchema(): StructType = schemaReader.calculateSchema(filesPaths) - private def fileToReaderFactory(file: LocatedFileStatus): InputPartition[InternalRow] = { - DruidDataReaderFactory(file.getPath.toString, schema, config) + private def fileToReaderFactory(file: Segment): InputPartition[InternalRow] = { + DruidDataReaderFactory(file.path.toString, schema, config) } override def planInputPartitions(): util.List[InputPartition[InternalRow]] = { diff --git a/src/main/scala/bi/deep/DruidSchemaReader.scala b/src/main/scala/bi/deep/DruidSchemaReader.scala index 49576fa..21ff53c 100644 --- a/src/main/scala/bi/deep/DruidSchemaReader.scala +++ b/src/main/scala/bi/deep/DruidSchemaReader.scala @@ -1,19 +1,19 @@ package bi.deep +import bi.deep.segments.Segment import org.apache.druid.segment.QueryableIndex import org.apache.druid.segment.column.ColumnCapabilitiesImpl -import org.apache.hadoop.fs.LocatedFileStatus +import org.apache.hadoop.fs.FileSystem import org.apache.spark.sql.types.{LongType, StructField, StructType} import scala.collection.JavaConverters.collectionAsScalaIterableConverter import scala.collection.immutable.ListMap -class DruidSchemaReader(path: String, config: Config) extends DruidSegmentReader { - - def calculateSchema(files: Array[LocatedFileStatus]): StructType = { +case class DruidSchemaReader(config: Config) extends DruidSegmentReader { + def calculateSchema(files: Seq[Segment])(implicit fs: FileSystem): StructType = { val druidSchemas = files.map { file => - withSegment(file.getPath.toString, config, path) { segmentDir => + withSegment(file.path.toString, config) { segmentDir => val qi = indexIO.loadIndex(segmentDir) DruidSchemaReader.readDruidSchema(qi) } diff --git a/src/main/scala/bi/deep/DruidSegmentReader.scala b/src/main/scala/bi/deep/DruidSegmentReader.scala index fa37f59..8db213d 100644 --- a/src/main/scala/bi/deep/DruidSegmentReader.scala +++ b/src/main/scala/bi/deep/DruidSegmentReader.scala @@ -8,14 +8,10 @@ import org.apache.commons.io.FileUtils import org.apache.druid.common.config.NullHandling import org.apache.druid.guice.annotations.Json import org.apache.druid.guice.{GuiceInjectableValues, GuiceInjectors} -import org.apache.druid.jackson.DefaultObjectMapper -import org.apache.druid.java.util.emitter.EmittingLogger import org.apache.druid.query.DruidProcessingConfig import org.apache.druid.segment.IndexIO -import org.apache.druid.segment.column.ColumnConfig -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.log4j.{Level, Logger} -import org.slf4j.LoggerFactory import java.io.File import java.nio.file.Files @@ -44,9 +40,7 @@ trait DruidSegmentReader { ) } - def withSegment[R](file: String, config: Config, path: String)(handler: File => R): R = { - val fileSystem = config.factory.fileSystemFor(path) - + def withSegment[R](file: String, config: Config)(handler: File => R)(implicit fs: FileSystem): R = { val segmentDir = if (config.tempSegmentDir != "") { val temp = new File(config.tempSegmentDir, sha1Hex(file)) FileUtils.forceMkdir(temp) @@ -57,7 +51,7 @@ trait DruidSegmentReader { try { val segmentFile = new File(segmentDir, DRUID_SEGMENT_FILE) - fileSystem.copyToLocalFile(new Path(file), new Path(segmentFile.toURI)) + fs.copyToLocalFile(new Path(file), new Path(segmentFile.toURI)) ZipUtils.unzip(segmentFile, segmentDir) handler(segmentDir) } diff --git a/src/main/scala/bi/deep/FileSystemFactory.scala b/src/main/scala/bi/deep/FileSystemFactory.scala index 9c96784..fd9b921 100644 --- a/src/main/scala/bi/deep/FileSystemFactory.scala +++ b/src/main/scala/bi/deep/FileSystemFactory.scala @@ -1,6 +1,6 @@ package bi.deep -import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.util.SerializableHadoopConfiguration import java.net.URI @@ -11,4 +11,8 @@ class FileSystemFactory(hadoopConf: SerializableHadoopConfiguration) extends Ser def fileSystemFor(path: String): FileSystem = { FileSystem.get(new URI(path), hadoopConf.config) } + + def fileSystemFor(path: Path): FileSystem = { + FileSystem.get(path.toUri, hadoopConf.config) + } } \ No newline at end of file diff --git a/src/main/scala/bi/deep/LatestSegmentSelector.scala b/src/main/scala/bi/deep/LatestSegmentSelector.scala deleted file mode 100644 index 988a229..0000000 --- a/src/main/scala/bi/deep/LatestSegmentSelector.scala +++ /dev/null @@ -1,57 +0,0 @@ -package bi.deep - -import bi.deep.Config.DRUID_SEGMENT_FILE -import bi.deep.HadoopUtils.remoteIterator -import bi.deep.LatestSegmentSelector.{getIntervals, listPrimaryPartitions} -import org.apache.hadoop.fs.{LocatedFileStatus, Path} - -import java.time.LocalDate - - -case class LatestSegmentSelector(config: Config, mainPathName: String) { - - private val directoryPaths = listPrimaryPartitions(getIntervals(config.startDate, config.endDate), mainPathName) - private var results: Array[LocatedFileStatus] = Array.empty[LocatedFileStatus] - - private def latestDate(file: LocatedFileStatus, path: String): String = { - file.toString.replace(path, "").split('/')(1) - } - - protected def lastCreatedDir(files: List[LocatedFileStatus], path: String): String = { - latestDate(files.maxBy(latestDate(_, path)), path) - } - - def createPathsList(path: String): List[LocatedFileStatus] = { - val files = config.factory.fileSystemFor(path) - .listFiles(new Path(path), true) - .toList - files.filter(p => p.getPath.getName == DRUID_SEGMENT_FILE) - } - - def latestSegmentForPath(path: String): Unit = { - val indexFiles = createPathsList(path) - val lastDate = lastCreatedDir(indexFiles, path) - val latestFiles = createPathsList(path + "/" + lastDate) - results = results ++ latestFiles - } - - def getPathsArray: Array[LocatedFileStatus] = { - directoryPaths.foreach(latestSegmentForPath) - results - } -} - -object LatestSegmentSelector { - - def getIntervals(startDate: LocalDate, endDate: LocalDate): List[(LocalDate, LocalDate)] = { - Stream.iterate(startDate)(_.plusDays(1)).takeWhile(_.isBefore(endDate)) - .map(date => (date, date.plusDays(1))).toList - } - - def listPrimaryPartitions(dates: List[(LocalDate, LocalDate)], path: String): List[String] = { - path.split(':')(0) match { - case "hdfs" => dates.map(x => path + s"${x._1}T00_00_00.000Z_${x._2}T00_00_00.000Z") - case _ => dates.map(x => path + s"${x._1}T00:00:00.000Z_${x._2}T00:00:00.000Z") - } - } -} \ No newline at end of file diff --git a/src/main/scala/bi/deep/segments/MixedGranularitySolver.scala b/src/main/scala/bi/deep/segments/MixedGranularitySolver.scala new file mode 100644 index 0000000..aa3e9c9 --- /dev/null +++ b/src/main/scala/bi/deep/segments/MixedGranularitySolver.scala @@ -0,0 +1,41 @@ +package bi.deep.segments + +object MixedGranularitySolver { + trait Overlapping[T] { + + /** Defines if two values overlaps (have common part) */ + def overlaps(left: T, right: T): Boolean + + } + + private def upsert[K, V](map: Map[K, V])(key: K, value: V)(f: (V, V) => V): Map[K, V] = { + map.get(key) match { + case None => map.updated(key, value) + case Some(existing) => map.updated(key, f(existing, value)) + } + } + + private def upsertMany[K, V](map: Map[K, V])(kv: (K, V)*)(f: (V, V) => V): Map[K, V] = { + kv.foldLeft(map) { case (map, (k, v)) => upsert(map)(k, v)(f) } + } + + def solve[T](values: List[T])(implicit overlapping: Overlapping[T], ordering: Ordering[T]): List[T] = { + if (values.length < 2) values + else { + val collisions = values.combinations(2).foldLeft(Map.empty[T, Set[T]]) { case (collisions, left :: right :: Nil) => + if (overlapping.overlaps(left, right)) upsertMany(collisions)( + left -> Set(left, right), + right -> Set(right, left) + )(_ ++ _) + else upsertMany(collisions)( + left -> Set(left), + right -> Set(right) + )(_ ++ _) + } + + val results = collisions.foldLeft(Set.empty[T]) { case (solutions, (_, collisions)) => solutions + collisions.max } + results.toList + } + + } +} diff --git a/src/main/scala/bi/deep/segments/Segment.scala b/src/main/scala/bi/deep/segments/Segment.scala new file mode 100644 index 0000000..545c65f --- /dev/null +++ b/src/main/scala/bi/deep/segments/Segment.scala @@ -0,0 +1,27 @@ +package bi.deep.segments + +import bi.deep.utils.Parsing +import org.apache.hadoop.fs.{LocatedFileStatus, Path} +import org.apache.log4j.Logger +import org.joda.time.{Instant, Interval} + +import scala.util.Try + + +case class Segment(path: Path, interval: Interval, version: Instant, partition: Int) + + +object Segment { + implicit private val logger: Logger = Logger.getLogger("SegmentParser") + + private val pattern = """(\d+)""".r + + def apply(segment: SegmentVersion, status: LocatedFileStatus): Option[Segment] = Parsing.withLogger { + val relative = status.getPath.toString.replace(segment.path.toString, "") + for { + partition <- Try(pattern.findFirstIn(relative).get.toInt) + } yield { + new Segment(status.getPath, segment.interval, segment.version, partition) + } + } +} \ No newline at end of file diff --git a/src/main/scala/bi/deep/segments/SegmentInterval.scala b/src/main/scala/bi/deep/segments/SegmentInterval.scala new file mode 100644 index 0000000..c02a990 --- /dev/null +++ b/src/main/scala/bi/deep/segments/SegmentInterval.scala @@ -0,0 +1,22 @@ +package bi.deep.segments + +import bi.deep.utils.{Parsing, TimeUtils} +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.log4j.Logger +import org.joda.time.Interval + + +/** A path to Segment's Interval */ +case class SegmentInterval(path: Path, interval: Interval) { + def overlapedBy(range: Interval): Boolean = range.overlaps(interval) +} + +object SegmentInterval { + implicit private val logger: Logger = Logger.getLogger("SegmentIntervalParser") + + def apply(status: FileStatus): Option[SegmentInterval] = Parsing.withLogger { + for { + interval <- TimeUtils.parseIntervalFromHadoop(status.getPath.getName) + } yield new SegmentInterval(status.getPath, interval) + } +} \ No newline at end of file diff --git a/src/main/scala/bi/deep/segments/SegmentStorage.scala b/src/main/scala/bi/deep/segments/SegmentStorage.scala new file mode 100644 index 0000000..bab63f9 --- /dev/null +++ b/src/main/scala/bi/deep/segments/SegmentStorage.scala @@ -0,0 +1,68 @@ +package bi.deep.segments + +import bi.deep.segments.MixedGranularitySolver.Overlapping +import bi.deep.utils.Extensions.RemoteIteratorExt +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.log4j.Logger +import org.joda.time.Interval + +import java.time.{LocalDate, ZoneId} + +case class SegmentStorage(root: Path) { + + import SegmentStorage._ + + private val logger = Logger.getLogger(classOf[SegmentStorage]) + + def findIntervals(start: LocalDate, end: LocalDate)(implicit fs: FileSystem): List[SegmentInterval] = { + + val startTime = start.atStartOfDay(ZoneId.of("UTC")).toInstant + val endTime = end.atStartOfDay(ZoneId.of("UTC")).toInstant + val range = new Interval(startTime.toEpochMilli, endTime.toEpochMilli) + + fs.listStatusIterator(root).asScala + .flatMap(SegmentInterval(_)) + .filter(_.overlapedBy(range)) + .toList + } + + def findLatestVersion(segmentInterval: SegmentInterval)(implicit fs: FileSystem): SegmentVersion = { + fs.listStatusIterator(segmentInterval.path).asScala + .flatMap(SegmentVersion(segmentInterval, _)) + .maxBy(_.version.getMillis) + } + + def findPartitions(segmentVersion: SegmentVersion)(implicit fs: FileSystem): List[Segment] = { + fs.listFiles(segmentVersion.path, true).asScala + .flatMap(Segment(segmentVersion, _)) + .toList + } + + def findValidSegments(start: LocalDate, end: LocalDate)(implicit fs: FileSystem): List[Segment] = { + val versions = findIntervals(start, end) + .map(findLatestVersion) + + val segments = MixedGranularitySolver + .solve(versions) + .flatMap(findPartitions) + + logger.info(s"Found ${segments.length} segments in range [$start, $end]") + segments + } + +} + +object SegmentStorage { + + implicit val segmentOverlapping: Overlapping[SegmentVersion] = new Overlapping[SegmentVersion] { + /** Defines if two values overlaps (have common part) */ + override def overlaps(left: SegmentVersion, right: SegmentVersion): Boolean = { + left.interval.overlaps(right.interval) + } + } + + implicit val segmentOrdering: Ordering[SegmentVersion] = new Ordering[SegmentVersion] { + override def compare(left: SegmentVersion, right: SegmentVersion): Int = left.version.compareTo(right.version) + } + +} \ No newline at end of file diff --git a/src/main/scala/bi/deep/segments/SegmentVersion.scala b/src/main/scala/bi/deep/segments/SegmentVersion.scala new file mode 100644 index 0000000..88c23ae --- /dev/null +++ b/src/main/scala/bi/deep/segments/SegmentVersion.scala @@ -0,0 +1,18 @@ +package bi.deep.segments + +import bi.deep.utils.{Parsing, TimeUtils} +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.log4j.Logger +import org.joda.time.{Instant, Interval} + + +case class SegmentVersion(path: Path, interval: Interval, version: Instant) + +object SegmentVersion { + implicit private val logger: Logger = Logger.getLogger("SegmentVersionParser") + def apply(segment: SegmentInterval, status: FileStatus): Option[SegmentVersion] = Parsing.withLogger { + for { + version <- TimeUtils.parseInstantFromHadoop(status.getPath.getName) + } yield new SegmentVersion(status.getPath, segment.interval, version) + } +} \ No newline at end of file diff --git a/src/main/scala/bi/deep/utils/Extensions.scala b/src/main/scala/bi/deep/utils/Extensions.scala new file mode 100644 index 0000000..8749486 --- /dev/null +++ b/src/main/scala/bi/deep/utils/Extensions.scala @@ -0,0 +1,17 @@ +package bi.deep.utils + +import org.apache.hadoop.fs.RemoteIterator + +object Extensions { + + implicit class RemoteIteratorExt[T](val iterator: RemoteIterator[T]) extends AnyVal { + + def asScala: Iterator[T] = new Iterator[T] { + override def hasNext: Boolean = iterator.hasNext + + override def next(): T = iterator.next() + } + + } + +} diff --git a/src/main/scala/bi/deep/utils/Parsing.scala b/src/main/scala/bi/deep/utils/Parsing.scala new file mode 100644 index 0000000..ef7d86b --- /dev/null +++ b/src/main/scala/bi/deep/utils/Parsing.scala @@ -0,0 +1,18 @@ +package bi.deep.utils + +import org.apache.log4j.Logger + +import scala.util.{Failure, Success, Try} + +object Parsing { + + def withLogger[T](body: => Try[T])(implicit logger: Logger): Option[T] = { + body match { + case Success(value) => Some(value) + case Failure(exception) => + logger.warn(s"Failed to parse segment interval: ${exception.getMessage}", exception) + None + } + } + +} diff --git a/src/main/scala/bi/deep/utils/TimeUtils.scala b/src/main/scala/bi/deep/utils/TimeUtils.scala new file mode 100644 index 0000000..caa3594 --- /dev/null +++ b/src/main/scala/bi/deep/utils/TimeUtils.scala @@ -0,0 +1,16 @@ +package bi.deep.utils + +import org.joda.time.{Instant, Interval} + +import scala.util.Try + +object TimeUtils { + + def parseInstantFromHadoop(value: String): Try[Instant] = Try(Instant.parse(value.replace('_', ':'))) + + def parseIntervalFromHadoop(value: String): Try[Interval] = for { + start <- parseInstantFromHadoop(value.substring(0, 24)) + end <- parseInstantFromHadoop(value.substring(25)) + } yield new Interval(start, end) + +} diff --git a/src/test/scala/bi/deep/LatestSegmentSelectorTest.scala b/src/test/scala/bi/deep/LatestSegmentSelectorTest.scala deleted file mode 100644 index 8f4c140..0000000 --- a/src/test/scala/bi/deep/LatestSegmentSelectorTest.scala +++ /dev/null @@ -1,66 +0,0 @@ -package bi.deep - -import bi.deep.LatestSegmentSelector.listPrimaryPartitions -import org.scalatest.funsuite.AnyFunSuite - -import java.time.LocalDate - -class LatestSegmentSelectorTest extends AnyFunSuite { - - test("Correctly return path when List of dates is not empty") { - val path: String = "hdfs://test/" - val dates = List((LocalDate.parse("2022-01-01"), LocalDate.parse("2022-01-02")), (LocalDate.parse("2022-01-02"), LocalDate.parse("2022-01-03"))) - val result = listPrimaryPartitions(dates, path) - val expected = List("hdfs://test/2022-01-01T00_00_00.000Z_2022-01-02T00_00_00.000Z", "hdfs://test/2022-01-02T00_00_00.000Z_2022-01-03T00_00_00.000Z") - assertResult(expected)(result) - } - - test("Empty List of dates") { - val path: String = "hdfs://test/" - val dates = List.empty - val result = listPrimaryPartitions(dates, path) - val expected = List.empty - assertResult(expected)(result) - } - - - test("Correctly return one-element list of (LocalDate, LocalDate) when startDate - endDate equals 1 day") { - val startDate: LocalDate = LocalDate.parse("2022-01-01") - val endDate: LocalDate = LocalDate.parse("2022-01-02") - val result = LatestSegmentSelector.getIntervals(startDate, endDate) - val expected = List((startDate, endDate)) - assertResult(expected)(result) - } - - test("Correctly return an empty list when endDate > startDate") { - val startDate: LocalDate = LocalDate.parse("2022-01-10") - val endDate: LocalDate = LocalDate.parse("2022-01-05") - val result = LatestSegmentSelector.getIntervals(startDate, endDate) - val expected = List.empty - assertResult(expected)(result) - } - - test("Correctly return an empty list when endDate = startDate") { - val startDate: LocalDate = LocalDate.parse("2022-01-10") - val endDate: LocalDate = LocalDate.parse("2022-01-10") - val result = LatestSegmentSelector.getIntervals(startDate, endDate) - val expected = List.empty - assertResult(expected)(result) - } - - test("Correctly return a list of (LocalDate, LocalDate)") { - val startDate: LocalDate = LocalDate.parse("2022-01-05") - val endDate: LocalDate = LocalDate.parse("2022-01-10") - val result = LatestSegmentSelector.getIntervals(startDate, endDate) - val expected = List("2022-01-05", "2022-01-06", "2022-01-07", "2022-01-08", "2022-01-09").map(LocalDate.parse) - .map(date => date -> date.plusDays(1)) - assertResult(expected)(result) - } - - test("Correctly add day to LocalDate") { - val startDate: LocalDate = LocalDate.parse("2022-01-05") - val result = LocalDate.parse("2022-01-06") - val expected = startDate.plusDays(1) - assertResult(expected)(result) - } -} diff --git a/src/test/scala/bi/deep/segments/MixedGranularitySolverTest.scala b/src/test/scala/bi/deep/segments/MixedGranularitySolverTest.scala new file mode 100644 index 0000000..4c54477 --- /dev/null +++ b/src/test/scala/bi/deep/segments/MixedGranularitySolverTest.scala @@ -0,0 +1,97 @@ +package bi.deep.segments + +import bi.deep.segments.MixedGranularitySolver.{Overlapping, solve} +import org.joda.time.Interval +import org.scalatest.funsuite.AnyFunSuite + +class MixedGranularitySolverTest extends AnyFunSuite { + + case class Foo(interval: Interval, priority: Int) + + object Foo { + def apply(start: Long, end: Long, priority: Int): Foo = new Foo(new Interval(start, end), priority) + } + + private implicit val overlapping: Overlapping[Foo] = new Overlapping[Foo] { + /** Defines if two values overlaps (have common part) */ + override def overlaps(left: Foo, right: Foo): Boolean = left.interval.overlaps(right.interval) + } + + private implicit val ordering: Ordering[Foo] = Ordering.fromLessThan((l, r) => Ordering.Int.lt(l.priority, r.priority)) + + test("solver without overlaps") { + val values = List(Foo(0, 1, 1), Foo(1, 2, 1), Foo(2, 3, 1), Foo(3, 4, 1)) + + val results = solve(values) + + assertResult(values.toSet)(results.toSet) + } + + test("solver with single overlap - decides ordering") { + val values = List(Foo(0, 1, 1), Foo(0, 1, 10)) + val result = solve(values) + + assertResult(List(Foo(0, 1, 10)))(result) + } + + /** + * 0 1 2 + * | A 1 | B 2 | + * | C 3 | + * */ + test("longer is greater then overlaps - 2") { + val a = Foo(0, 1, 1) + val b = Foo(1, 2, 2) + val c = Foo(0, 2, 3) + + val result = solve(List(a, b, c)).toSet + assertResult(Set(c))(result) + } + + /** + * 0 1 2 3 + * | A 1 | B 2 | D 3 | + * | C 4 | + * */ + test("longer is greater then overlaps - 3") { + val a = Foo(0, 1, 1) + val b = Foo(1, 2, 2) + val c = Foo(0, 3, 4) + val d = Foo(2, 3, 3) + + val result = solve(List(a, b, c, d)).toSet + assertResult(Set(c))(result) + } + + /** + * 0 1 2 + * | A 1 | + * | B 2 | C 3 | + */ + test("shorts are greater then overlap - 2") { + val a = Foo(0, 2, 1) + val b = Foo(0, 1, 2) + val c = Foo(1, 2, 3) + + val result = solve(List(a, b, c)).toSet + assertResult(Set(b, c))(result) + } + + /** + * 0 1 2 3 4 + * | A 1 | B 2 | C 3 | E 5 | + * | D 4 | + */ + test("shorts are greater then overlap - 2, mix") { + val a = Foo(0, 1, 1) + val b = Foo(1, 2, 2) + val c = Foo(2, 3, 3) + val d = Foo(0, 2, 4) + val e = Foo(3, 4, 5) + + + val result = solve(List(a, b, c, d, e)).toSet + assertResult(Set(d, c, e))(result) + } + +}