Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Syntax for polymorphic values. #54

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ scalacOptions in Test ++= {

scalacOptions in Test += "-Yrangepos"

scalacOptions in Test += "-P:kind-projector:forall=true"

libraryDependencies += "com.novocode" % "junit-interface" % "0.11" % "test"
testOptions += Tests.Argument(TestFrameworks.JUnit, "-a", "-v")

Expand Down
37 changes: 37 additions & 0 deletions src/main/scala/Extractors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,41 @@ trait Extractors {
case _ => None
}
}
object TermLambda {
private val LambdaName = newTermName("Λ")

def unapply(tree: Tree): Option[(List[Tree], Tree)] = tree match {
case Apply(TypeApply(Ident(name), tParams), body :: Nil) if name == LambdaName => Some((tParams, body))
case _ => None
}
}
object TermNuType {
private val NuName = newTermName("ν")

def unapply(tree: Tree): Option[Tree] = tree match {
case TypeApply(Ident(name), tpe :: Nil) if name == NuName => Some(tpe)
case _ => None
}
}
object PolyVal {
def unapply(tree: Tree): Option[(Tree, TermName, List[Tree], Tree)] = tree match {

// Λ[A, B, ...](e) : T
case Typed(TermLambda(tParams, body), tpe) => Some((tpe, nme.apply, tParams, body))

// ν[T].method[A, B, ...](e)
case Apply(TypeApply(Select(TermNuType(tpe), method), tParams), body :: Nil) => Some((tpe, method.toTermName, tParams, body))

// ν[T][A, B, ...](e)
case Apply(TypeApply(TermNuType(tpe), tParams), body :: Nil) => Some((tpe, nme.apply, tParams, body))

// ν[T].method(e)
case Apply(Select(TermNuType(tpe), method), body :: Nil) => Some((tpe, method.toTermName, Nil, body))

// ν[T](e)
case Apply(TermNuType(tpe), body :: Nil) => Some((tpe, nme.apply, Nil, body))

case _ => None
}
}
}
139 changes: 84 additions & 55 deletions src/main/scala/KindProjector.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,27 @@ class KindProjector(val global: Global) extends Plugin {
val name = "kind-projector"
val description = "Expand type lambda syntax"
val components = new KindRewriter(this, global) :: Nil

var enableForall = false

override def processOptions(options: List[String], error: String => Unit): Unit = {

// enable ∀ rewrites if "forall=true" is present
val (forallOpts, rest) = options partition { _.split("=")(0) == "forall" }
enableForall = forallOpts.lastOption match {
case Some(opt) =>
opt.split("=").tail match {
case Array("true") => true
case _ => false
}
case None => false
}

if(rest.nonEmpty) error(s"Unrecognized ${name} options: ${rest.mkString}")
}
}

class KindRewriter(plugin: Plugin, val global: Global)
class KindRewriter(plugin: KindProjector, val global: Global)
extends PluginComponent with Transform with TypingTransformers with TreeDSL {

import global._
Expand Down Expand Up @@ -128,7 +146,60 @@ class KindRewriter(plugin: Plugin, val global: Global)
def makeTypeParamContra(name: Name, bounds: TypeBoundsTree = DefaultBounds): TypeDef =
TypeDef(Modifiers(PARAM | CONTRAVARIANT), makeTypeName(name), Nil, bounds)

def polyLambda(tree: Tree): Tree = tree match {
// Given a name, e.g. A or `+A` or `A <: Foo`, build a type
// parameter tree using the given name, bounds, variance, etc.
def makeTypeParamFromName(ident: Ident): TypeDef = {
val decoded = NameTransformer.decode(ident.name.toString)
val src = s"type _X_[$decoded] = Unit"
sp.parse(src) match {
case Some(TypeDef(_, _, List(tpe), _)) => tpe.duplicate
case None => reporter.error(ident.pos, s"Can't parse param: ${ident.name}"); null
}
}

// Like makeTypeParam, but can be used recursively in the case of types
// that are themselves parameterized.
def makeComplexTypeParam(t: Tree): TypeDef = t match {
case id @ Ident(_) =>
makeTypeParamFromName(id)

case AppliedTypeTree(Ident(name), ps) =>
val tparams = ps.map(makeComplexTypeParam)
TypeDef(Modifiers(PARAM), makeTypeName(name), tparams, DefaultBounds)

case ExistentialTypeTree(AppliedTypeTree(Ident(name), ps), _) =>
val tparams = ps.map(makeComplexTypeParam)
TypeDef(Modifiers(PARAM), makeTypeName(name), tparams, DefaultBounds)

case x =>
reporter.error(x.pos, "Can't parse %s (%s)" format (x, x.getClass.getName))
null.asInstanceOf[TypeDef]
}

def typeArgsToTypeParams(args: List[Tree]): List[TypeDef] = args.map {
case id @ Ident(_) =>
makeTypeParamFromName(id)

case AppliedTypeTree(Ident(Plus), Ident(name) :: Nil) =>
makeTypeParamCo(name)

case AppliedTypeTree(Ident(Minus), Ident(name) :: Nil) =>
makeTypeParamContra(name)

case AppliedTypeTree(Ident(name), ps) =>
val tparams = ps.map(makeComplexTypeParam)
TypeDef(Modifiers(PARAM), makeTypeName(name), tparams, DefaultBounds)

case ExistentialTypeTree(AppliedTypeTree(Ident(name), ps), _) =>
val tparams = ps.map(makeComplexTypeParam)
TypeDef(Modifiers(PARAM), makeTypeName(name), tparams, DefaultBounds)

case x =>
reporter.error(x.pos, "Can't parse %s (%s)" format (x, x.getClass.getName))
null.asInstanceOf[TypeDef]
}

def polyTerm(tree: Tree): Tree = tree match {
case PolyLambda(methodName, (arrowType @ UnappliedType(_ :: targs)) :: Nil, Function1Tree(name, body)) =>
val (f, g) = targs match {
case a :: b :: Nil => (a, b)
Expand All @@ -139,42 +210,21 @@ class KindRewriter(plugin: Plugin, val global: Global)
atPos(tree.pos.makeTransparent)(
q"new $arrowType { def $methodName[$TParam]($name: $f[$TParam]): $g[$TParam] = $body }"
)
case PolyVal(targetType, methodName, tArgs, body) if plugin.enableForall =>
atPos(tree.pos.makeTransparent)(tArgs match {
case Nil =>
val tParam = newTypeName(freshName("A"))
q"new $targetType { def $methodName[$tParam] = $body }"
case _ =>
val tParams = typeArgsToTypeParams(tArgs)
q"new $targetType { def $methodName[..$tParams] = $body }"
})
case _ => tree
}

// The transform method -- this is where the magic happens.
override def transform(tree: Tree): Tree = {

// Given a name, e.g. A or `+A` or `A <: Foo`, build a type
// parameter tree using the given name, bounds, variance, etc.
def makeTypeParamFromName(name: Name): TypeDef = {
val decoded = NameTransformer.decode(name.toString)
val src = s"type _X_[$decoded] = Unit"
sp.parse(src) match {
case Some(TypeDef(_, _, List(tpe), _)) => tpe
case None => reporter.error(tree.pos, s"Can't parse param: $name"); null
}
}

// Like makeTypeParam, but can be used recursively in the case of types
// that are themselves parameterized.
def makeComplexTypeParam(t: Tree): TypeDef = t match {
case Ident(name) =>
makeTypeParamFromName(name)

case AppliedTypeTree(Ident(name), ps) =>
val tparams = ps.map(makeComplexTypeParam)
TypeDef(Modifiers(PARAM), makeTypeName(name), tparams, DefaultBounds)

case ExistentialTypeTree(AppliedTypeTree(Ident(name), ps), _) =>
val tparams = ps.map(makeComplexTypeParam)
TypeDef(Modifiers(PARAM), makeTypeName(name), tparams, DefaultBounds)

case x =>
reporter.error(x.pos, "Can't parse %s (%s)" format (x, x.getClass.getName))
null.asInstanceOf[TypeDef]
}

// Given the list a::as, this method finds the last argument in the list
// (the "subtree") and returns that separately from the other arguments.
// The stack is just used to enable tail recursion, and a and as are
Expand Down Expand Up @@ -205,28 +255,7 @@ class KindRewriter(plugin: Plugin, val global: Global)
// Lambda[(A, B) => Function2[A, Int, B]] case.
def handleLambda(a: Tree, as: List[Tree]): Tree = {
val (args, subtree) = parseLambda(a, as, Nil)
val innerTypes = args.map {
case Ident(name) =>
makeTypeParamFromName(name)

case AppliedTypeTree(Ident(Plus), Ident(name) :: Nil) =>
makeTypeParamCo(name)

case AppliedTypeTree(Ident(Minus), Ident(name) :: Nil) =>
makeTypeParamContra(name)

case AppliedTypeTree(Ident(name), ps) =>
val tparams = ps.map(makeComplexTypeParam)
TypeDef(Modifiers(PARAM), makeTypeName(name), tparams, DefaultBounds)

case ExistentialTypeTree(AppliedTypeTree(Ident(name), ps), _) =>
val tparams = ps.map(makeComplexTypeParam)
TypeDef(Modifiers(PARAM), makeTypeName(name), tparams, DefaultBounds)

case x =>
reporter.error(x.pos, "Can't parse %s (%s)" format (x, x.getClass.getName))
null.asInstanceOf[TypeDef]
}
val innerTypes = typeArgsToTypeParams(args)
makeTypeProjection(innerTypes, subtree)
}

Expand Down Expand Up @@ -322,7 +351,7 @@ class KindRewriter(plugin: Plugin, val global: Global)
// given a tree, see if it could possibly be a type lambda
// (either placeholder syntax or lambda syntax). if so, handle
// it, and if not, transform it in the normal way.
val result = polyLambda(tree match {
val result = polyTerm(tree match {

// Lambda[A => Either[A, Int]] case.
case AppliedTypeTree(Ident(TypeLambda1), AppliedTypeTree(target, a :: as) :: Nil) =>
Expand Down
15 changes: 9 additions & 6 deletions src/test/scala/polylambda.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@ package d_m

import org.junit.Test

trait ~>[-F[_], +G[_]] {
def apply[A](x: F[A]): G[A]
}
trait ~>>[-F[_], +G[_]] {
def dingo[B](x: F[B]): G[B]
}
final case class Const[A, B](getConst: A)

class PolyLambdas {

trait ~>[-F[_], +G[_]] {
def apply[A](x: F[A]): G[A]
}

trait ~>>[-F[_], +G[_]] {
def dingo[B](x: F[B]): G[B]
}

type ToSelf[F[_]] = F ~> F

val kf1 = Lambda[Option ~> Vector](_.toVector)
Expand Down
96 changes: 96 additions & 0 deletions src/test/scala/polyval.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package d_m

import org.junit.Test

class PolyVals {

trait Forall[F[_]] {
def apply[A]: F[A]
}

trait Forall2[F[_, _]] {
def apply[A, B]: F[A, B]
}

trait ForallK[F[_[_]]] {
def apply[G[_]]: F[G]
}


trait Semigroup[A] {
def combine(x: A, y: A): A
}

def listSemigroup[A]: Semigroup[List[A]] = new Semigroup[List[A]] {
def combine(x: List[A], y: List[A]): List[A] = x ++ y
}

trait Functor[F[_]]
trait Monad[F[_]] {
def functor: Functor[F]
}

final class Fun[A, B](val run: A => B)

// universally quantified semigroup
type SemigroupK[F[_]] = Forall[λ[α => Semigroup[F[α]]]]

// natural transformations
type ~>[F[_] , G[_] ] = Forall [λ[α => F[α] => G[α]]]
type ≈>[F[_[_]], G[_[_]]] = ForallK[λ[α[_] => F[α] => G[α]]]
type ~~>[F[_, _], G[_, _]] = Forall2[λ[(α, β) => F[α, β] => G[α, β]]]

// Const functor and constructors
type ConstA[A] = Forall[Const[A, ?]]
type ConstMaker1 = Forall[λ[α => α => ConstA[α]]]
type ConstMaker2 = Forall2[λ[(α, β) => α => Const[α, β]]]

// existentials via universals
type Consumer[F[_], R] = Forall[λ[A => F[A] => R]]
type Exists[F[_]] = Forall[λ[R => Consumer[F, R] => R]]
def existential[F[_], A](fa: F[A]): Exists[F] = ν[Exists[F]](_[A](fa))

@Test
def testSemigroupK(): Unit = {
val listSemigroupK = ν[SemigroupK[List]](listSemigroup)
assert(listSemigroupK[Int].combine(List(1, 2), List(3, 4)) == List(1, 2, 3, 4))
}

@Test
def testNaturalTransformations(): Unit = {
val headOption1 = ν[List ~> Option](_.headOption)
val headOption2 = ν[List ~> Option].apply[A]((l: List[A]) => l.headOption)

val monadToFunctor1 = ν[Monad ≈> Functor][F[_]](_.functor)
val monadToFunctor2 = ν[Monad ≈> Functor].apply[F[_]]((m: Monad[F]) => m.functor)

val fun = Λ[A, B](new Fun(_)): Function1 ~~> Fun

val listFunctor = new Functor[List] {}
val listMonad = new Monad[List] { def functor = listFunctor }

assert(headOption1[Int](List(1, 2)) == Some(1))
assert(headOption2[Int](List(1, 2)) == Some(1))
assert(monadToFunctor1[List](listMonad) == listFunctor)
assert(monadToFunctor2[List](listMonad) == listFunctor)
assert(fun[String, Int](_.length).run("foo") == 3)
}

@Test
def testConst(): Unit = {
val const42 = ν[ConstA[Int]][B](new Const[Int, B](42))
val constMaker = ν[ConstMaker1][A ](a => ν[ConstA[A]][B](new Const[A, B](a)))
val constMaker2 = ν[ConstMaker2][A, B](a => new Const[A, B](a) )

assert(const42[String].getConst == 42)
assert(constMaker[Int](42)[String].getConst == 42)
assert(constMaker2[Int, String](42).getConst == 42)
}

@Test
def testExistential(): Unit = {
val list = existential(List("one", "two", "three"))
val len = ν[Consumer[List, Int]](_.length)
assert(list[Int](len) == 3)
}
}