From 22aedf464aaa3f51eb577a7a26bbc4bfb4d03f7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Koziarkiewicz?= Date: Sun, 14 Jan 2018 17:23:01 +0100 Subject: [PATCH] #1 Added shape deconstruction from stage. --- .../scala/net/mikolak/travesty/Registry.scala | 37 ++++++++++++ .../net/mikolak/travesty/RegistrySpec.scala | 57 +++++++++++++++++++ 2 files changed, 94 insertions(+) create mode 100644 src/main/scala/net/mikolak/travesty/Registry.scala create mode 100644 src/test/scala/net/mikolak/travesty/RegistrySpec.scala diff --git a/src/main/scala/net/mikolak/travesty/Registry.scala b/src/main/scala/net/mikolak/travesty/Registry.scala new file mode 100644 index 0000000..ebe2d87 --- /dev/null +++ b/src/main/scala/net/mikolak/travesty/Registry.scala @@ -0,0 +1,37 @@ +package net.mikolak.travesty + +import akka.stream.{Graph, Shape} + +import scala.reflect.runtime.universe._ + +object Registry { + private[travesty] def deconstructShape[T <: Graph[_ <: Shape, _]: TypeTag](g: T): ShapeTypes = { + val tpe = typeOf[T] + val graphType = tpe.baseType(typeOf[Graph[_, _]].typeSymbol) + val bottomShapeType = graphType.typeArgs.head + + val shapeType = bottomShapeType.baseClasses.map(bottomShapeType.baseType).head + + val inletSize = g.shape.inlets.size + val outletSize = g.shape.outlets.size + val portTypes = shapeType.typeArgs + + val typeLists = if (inletSize + outletSize > portTypes.size) { + if (outletSize > inletSize) { + val (ins, outsPart) = portTypes.splitAt(inletSize) + (ins, outsPart ++ List.fill(outletSize - outsPart.size)(outsPart.lastOption.getOrElse(ins.last))) + } else { + val (insPart, outs) = portTypes.splitAt(portTypes.length - outletSize) + (insPart ++ List.fill(inletSize - insPart.size)(insPart.lastOption.getOrElse(outs.last)), outs) + } + + } else { + portTypes.splitAt(inletSize) + } + + ShapeTypes.tupled(typeLists) + } + +} + +case class ShapeTypes(inlets: List[Type], outlets: List[Type]) diff --git a/src/test/scala/net/mikolak/travesty/RegistrySpec.scala b/src/test/scala/net/mikolak/travesty/RegistrySpec.scala new file mode 100644 index 0000000..9310615 --- /dev/null +++ b/src/test/scala/net/mikolak/travesty/RegistrySpec.scala @@ -0,0 +1,57 @@ +package net.mikolak.travesty + +import akka.stream.{Graph, Shape, scaladsl} +import akka.stream.scaladsl.{BidiFlow, Broadcast, Flow, Merge, MergePreferred, Sink, Source, Unzip, Zip, ZipN} +import org.scalatest.prop.TableDrivenPropertyChecks +import org.scalatest.{FlatSpec, MustMatchers} +import org.scalatest.words.MustVerb + +import scala.concurrent.Future +import scala.reflect.runtime.universe._ + +class RegistrySpec extends FlatSpec with MustMatchers with MustVerb with TableDrivenPropertyChecks { + + { + def tested[T <: Graph[_ <: Shape, _]: TypeTag](g: T) = Registry.deconstructShape(g) + + "deconstructShape" must "correctly define port types for basic shapes" in { + tested(Source.empty[String]) must be(ShapeTypes(Nil, List(typeOf[String]))) + tested(Sink.seq[Int]) must be(ShapeTypes(List(typeOf[Int]), Nil)) + tested(Flow[Boolean].map(identity)) must be(ShapeTypes(List(typeOf[Boolean]), List(typeOf[Boolean]))) + tested(Flow[A].map(_.toString)) must be(ShapeTypes(List(typeOf[A]), List(typeOf[java.lang.String]))) + } + + it must "correctly define port types for fanX shapes" in { + tested(Broadcast[String](4)) must be(ShapeTypes(List(typeOf[String]), List.fill(4)(typeOf[String]))) + tested(Merge[Boolean](7, true)) must be(ShapeTypes(List.fill(7)(typeOf[Boolean]), List(typeOf[Boolean]))) + + tested(Zip[A, String]) must be(ShapeTypes(List(typeOf[A], typeOf[String]), List(typeOf[(A, String)]))) + tested(Unzip[A, String]) must be(ShapeTypes(List(typeOf[(A, String)]), List(typeOf[A], typeOf[String]))) + tested(ZipN[A](20)) must be(ShapeTypes(List.fill(20)(typeOf[A]), List(typeOf[scala.collection.immutable.Seq[A]]))) + } + + it must "correctly define ports for uniformFanX shapes" in { + tested(MergePreferred[B](3, false)) must be(ShapeTypes(List.fill(3 + 1)(typeOf[B]), List(typeOf[B]))) + } + + it must "correctly define port types for BidiFlows" in { + tested(BidiFlow.fromFunctions[A, B, B, A](_.toB, _.toA)) must be( + ShapeTypes(List(typeOf[A], typeOf[B]), List(typeOf[B], typeOf[A]))) + } + + it must "correctly define port types for misc shapes" in { + tested(Flow[A].map(_.toB).async) must be(ShapeTypes(List(typeOf[A]), List(typeOf[B]))) + + tested(Flow[A].mapAsync(3)(a => Future.successful(a.toB)).async) must be(ShapeTypes(List(typeOf[A]), List(typeOf[B]))) + } + } +} + +trait A { + def toB: B + +} + +trait B { + def toA: A +}