From 79cf9f385cc91c35e3219bff655435fb64eaf232 Mon Sep 17 00:00:00 2001 From: Yuriy Mazepin Date: Sun, 26 May 2024 23:31:42 +0300 Subject: [PATCH] purs on scala3 --- ops/src/main/scala-3/Common.scala | 82 +++++++ ops/src/main/scala-3/ops.scala | 224 +++++++++++++++++++ purs/src/main/scala-3/decoders.scala | 308 +++++++++++++++++++++++++++ purs/src/main/scala-3/encoders.scala | 154 ++++++++++++++ purs/src/main/scala-3/io.scala | 13 ++ purs/src/main/scala-3/main.scala | 172 +++++++++++++++ purs/src/main/scala-3/purs.scala | 85 ++++++++ purs/src/main/scala-3/tpe.scala | 6 + 8 files changed, 1044 insertions(+) create mode 100644 ops/src/main/scala-3/Common.scala create mode 100644 ops/src/main/scala-3/ops.scala create mode 100644 purs/src/main/scala-3/decoders.scala create mode 100644 purs/src/main/scala-3/encoders.scala create mode 100644 purs/src/main/scala-3/io.scala create mode 100644 purs/src/main/scala-3/main.scala create mode 100644 purs/src/main/scala-3/purs.scala create mode 100644 purs/src/main/scala-3/tpe.scala diff --git a/ops/src/main/scala-3/Common.scala b/ops/src/main/scala-3/Common.scala new file mode 100644 index 0000000..b8a37d1 --- /dev/null +++ b/ops/src/main/scala-3/Common.scala @@ -0,0 +1,82 @@ +package proto + +import scala.quoted.* +import scala.collection.immutable.ArraySeq +import compiletime.asMatchable +import scala.annotation.* + +trait Common: + implicit val qctx: Quotes + import qctx.reflect.{*, given} + import qctx.reflect.defn.* + import report.* + + val ArrayByteType: TypeRepr = TypeRepr.of[Array[Byte]] + val ArraySeqByteType: TypeRepr = TypeRepr.of[ArraySeq[Byte]] + val NTpe: TypeRepr = TypeRepr.of[N] + val ItetableType: TypeRepr = TypeRepr.of[scala.collection.Iterable[?]] + val ArrayType: TypeRepr = TypeRepr.of[Array[?]] + + extension (s: Symbol) + def constructorParams: List[Symbol] = s.primaryConstructor.paramSymss.find(_.headOption.fold(false)( _.isTerm)).getOrElse(Nil) + def tpe: TypeRepr = + s.tree match + case x: ClassDef => x.constructor.returnTpt.tpe + case ValDef(_,tpt,_) => tpt.tpe + case Bind(_, pattern: Term) => pattern.tpe + + def defaultMethodName(i: Int): String = s"$$lessinit$$greater$$default$$${i+1}" + + val commonTypes: List[TypeRepr] = + TypeRepr.of[String] :: TypeRepr.of[Int] :: TypeRepr.of[Long] :: TypeRepr.of[Boolean] :: TypeRepr.of[Double] :: TypeRepr.of[Float] :: ArrayByteType :: ArraySeqByteType :: Nil + + val packedTypes: List[TypeRepr] = + TypeRepr.of[Int] :: TypeRepr.of[Long] :: TypeRepr.of[Boolean] :: TypeRepr.of[Double] :: TypeRepr.of[Float] :: Nil + + extension (t: TypeRepr) + def isNType: Boolean = t =:= NTpe + def isCaseClass: Boolean = t.typeSymbol.flags.is(Flags.Case) + def isCaseObject: Boolean = t.termSymbol.flags.is(Flags.Case) + def isCaseType: Boolean = t.isCaseClass || t.isCaseObject + def isSealedTrait: Boolean = t.typeSymbol.flags.is(Flags.Sealed) && t.typeSymbol.flags.is(Flags.Trait) + def isIterable: Boolean = t <:< ItetableType && !t.isArraySeqByte + def isArray: Boolean = t <:< ArrayType && !t.isArrayByte + def isRepeated: Boolean = t.isIterable || t.isArray + def isString: Boolean = t =:= TypeRepr.of[String] + def isInt: Boolean = t =:= TypeRepr.of[Int] + def isLong: Boolean = t =:= TypeRepr.of[Long] + def isBoolean: Boolean = t =:= TypeRepr.of[Boolean] + def isDouble: Boolean = t =:= TypeRepr.of[Double] + def isFloat: Boolean = t =:= TypeRepr.of[Float] + def isArrayByte: Boolean = t =:= ArrayByteType + def isArraySeqByte: Boolean = t =:= ArraySeqByteType + def isCommonType: Boolean = commonTypes.exists(_ =:= t) + def isPackedType: Boolean = packedTypes.exists(_ =:= t) + + def isOption: Boolean = t.asMatchable match + case AppliedType(t1, _) if t1.typeSymbol == OptionClass => true + case _ => false + + def typeArgs: List[TypeRepr] = t.dealias.asMatchable match + case AppliedType(t1, args) => args + case _ => Nil + + def optionArgument: TypeRepr = t.asMatchable match + case AppliedType(t1, args) if t1.typeSymbol == OptionClass => args.head + case _ => errorAndAbort(s"It isn't Option type: ${t.typeSymbol.name}") + + def knownFinalSubclasses: List[Symbol] = + @tailrec def loop(q: List[Symbol], acc: List[Symbol]): List[Symbol] = q match + case Nil => acc + case x :: xs if x.tpe.isSealedTrait => loop(x.tpe.typeSymbol.children ++ xs, acc) + case x :: xs => loop(xs, x :: acc) + loop(t.typeSymbol.children, Nil) + + def findN(x: Symbol): Option[Int] = { + x.annotations.collect{ + case Apply(Select(New(tpt), _), List(Literal(IntConstant(num)))) + if tpt.tpe.asMatchable.isNType => num + } match + case List(x) => Some(x) + case _ => None + } \ No newline at end of file diff --git a/ops/src/main/scala-3/ops.scala b/ops/src/main/scala-3/ops.scala new file mode 100644 index 0000000..abd95bf --- /dev/null +++ b/ops/src/main/scala-3/ops.scala @@ -0,0 +1,224 @@ +package proto + +import scala.quoted.* +import scala.annotation.* + +trait Ops extends Common: + implicit val qctx: Quotes + import qctx.reflect.{*, given} + import report.* + + private[proto] case class ChildMeta(name: String, tpe: TypeRepr, n: Int, noargs: Boolean, rec: Boolean) + + sealed trait Tpe { val tpe: TypeRepr; val name: String } + final case class TraitType(tpe: TypeRepr, name: String, children: Seq[ChildMeta], firstLevel: Boolean) extends Tpe + final case class RegularType(tpe: TypeRepr, name: String) extends Tpe + final case class RecursiveType(tpe: TypeRepr, name: String) extends Tpe + final case class NoargsType(tpe: TypeRepr, name: String) extends Tpe + final case class TupleType(tpe: TypeRepr, name: String, tpe_1: TypeRepr, tpe_2: TypeRepr) extends Tpe + + sealed trait DefVal + sealed trait HasDefFun { val value: String } + sealed trait FillDef { val value: String } + + case object NoDef extends DefVal + + case object NoneDef extends DefVal with HasDefFun { val value = "Nothing" } + case object SeqDef extends DefVal with HasDefFun { val value = "[]" } + + final case class StrDef(v: String) extends DefVal with HasDefFun with FillDef { val value = s""""$v"""" } + final case class OthDef(v: Any) extends DefVal with HasDefFun with FillDef { val value = v.toString } + + sealed trait IterablePurs + final case class ArrayPurs(tpe: TypeRepr) extends IterablePurs + final case class ArrayTuplePurs(tpe1: TypeRepr, tpe2: TypeRepr) extends IterablePurs + + def collectDefExpressions(tpe: TypeRepr): Seq[(Expr[String], Expr[Any])] = { + val types = collectTypeRepr(tpe, tail=Nil, acc=Nil) + types.flatMap{ tpe => + val tpe_sym = tpe.typeSymbol + + tpe.typeSymbol.constructorParams.zipWithIndex.map{ case (sym, i) => + if sym.flags.is(Flags.HasDefault) then + val compMod = tpe_sym.companionModule + val id = s"`${tpe_sym.name}_${sym.name}_default`" + compMod.methodMember(defaultMethodName(i)) match + case List(methodSym) => + Some(Expr(id) -> Ref(compMod).select(methodSym).asExpr) + case _ => errorAndAbort(s"${sym.fullName}`: default value method not found") + else None + } + }.flatten + } + + def collectTypeRepr(head: TypeRepr, tail: Seq[TypeRepr], acc: Seq[TypeRepr]): Seq[TypeRepr] = { + val (tail1, acc1): (Seq[TypeRepr], Seq[TypeRepr]) = + if acc.exists(_ =:= head) then + (tail, acc) + else if head.isSealedTrait then + val children = head.knownFinalSubclasses.map(_.tpe) + (children++tail, acc) + else if head.isOption then + val typeArgType = head.optionArgument + if !typeArgType.isCommonType then (typeArgType+:tail, acc) + else (tail, acc) + else if head.isRepeated then + head.typeArgs match + case x :: Nil => + val typeArgType = x + if !typeArgType.isCommonType then (typeArgType+:tail, acc) + else (tail, acc) + case x :: y :: Nil => + val zs = List(x, y).filter(!_.isCommonType) + (zs++tail, acc) + case _ => throw new Exception(s"too many type args for ${head}") + else + val ys = head.typeSymbol.constructorParams.map(_.tpe) + (ys++tail, acc:+head) + + tail1 match { + case h +: t => collectTypeRepr(h, tail=t, acc=acc1) + case _ => acc1 + } + } + + // @tailrec + def collectTpes(tpe: TypeRepr): Seq[Tpe] = { + collectTpes(tpe, tail=Nil, acc=Nil, firstLevel=true) + } + + def collectTpes(head: TypeRepr, tail: Seq[TypeRepr], acc: Seq[Tpe], firstLevel: Boolean): Seq[Tpe] = { + val (tail1, acc1): (Seq[TypeRepr], Seq[Tpe]) = + if acc.exists(_.tpe =:= head) then + (tail, acc) + else if head.isSealedTrait then + val children = findChildren(head) + (children.map(_.tpe)++tail, acc:+TraitType(head, head.typeSymbol.name, children, firstLevel)) + else if head.isOption then + val typeArgType = head.optionArgument + if !typeArgType.isCommonType then (typeArgType+:tail, acc) + else (tail, acc) + else if head.isRepeated then + head.typeArgs match + case x :: Nil => + val typeArgType = x + if !typeArgType.isCommonType then (typeArgType+:tail, acc) + else (tail, acc) + case x :: y :: Nil => + val zs = List(x, y).filter(!_.isCommonType) + (zs++tail, acc:+TupleType(TypeRepr.of[Tuple2[Unit, Unit]].appliedTo(x::y::Nil), tupleFunName(x, y), x, y)) //todo check + case _ => throw new Exception(s"too many type args for ${head}") + else + val (ys, z) = type_to_tpe(head) + (ys++tail, acc:+z) + + tail1 match { + case h +: t => collectTpes(h, tail=t, acc=acc1, firstLevel=false) + case _ => acc1 + } + } + + def fields(tpe: TypeRepr): List[(String, TypeRepr, Int, DefVal)] = { + val tpe_Sym = tpe.typeSymbol + val tpe_CompanionSym = tpe_Sym.companionModule + val compClass = tpe_Sym.companionClass + val compModRef = Ref(tpe_CompanionSym) + + tpe.typeSymbol.constructorParams.zipWithIndex.map{ case (sym, i) => + val tpe1 = sym.tpe + val defval = + if sym.flags.is(Flags.HasDefault) then + val id = s"`${tpe_Sym.name}_${sym.name}_default`" + tpe_CompanionSym.methodMember(defaultMethodName(i)) match + case List(x) if x.isDefDef => + val y = x.tree.asInstanceOf[DefDef] + if (y.returnTpt.tpe.isString) StrDef(id) + else OthDef(id) + case _ => errorAndAbort(s"${sym.fullName}`: default value method not found") + else if tpe1.isOption then + NoneDef + else if tpe1.isRepeated then + SeqDef + else + NoDef + (sym.name, tpe1, findN(sym), defval) + + }.collect{ case (a, b, Some(n), dv) => (a, b, n, dv) }.sortBy(_._3) + } + + def findChildren(tpe: TypeRepr): Seq[ChildMeta] = + tpe.knownFinalSubclasses.map(x => x -> findN(x)).collect{ case (x, Some(n)) => x -> n }.sortBy(_._2).map{ + case (x: Symbol, n) => + val tpe1: TypeRepr = x.tpe + ChildMeta(name=x.name, tpe1, n, noargs=fields(tpe1).isEmpty, rec=isRecursive(tpe1)) + } + + def isRecursive(base: TypeRepr): Boolean = { + @tailrec def loop(compareTo: List[TypeRepr]): Boolean = compareTo match + case Nil => false + case x :: _ if x =:= base => true + case x :: xs => loop(x.typeArgs ++ xs) + loop(fields(base).map(_._2).filter(tpe => !tpe.isCommonType)) + } + + def tupleFunName(tpe_1: TypeRepr, tpe_2: TypeRepr): String = { + (pursType(tpe_1)._1.filter(_.isLetter) + "_" + pursType(tpe_2)._1).filter(_.isLetter) + } + + def type_to_tpe(head: TypeRepr): (Seq[TypeRepr], Tpe) = { + val xs = fields(head).map(_._2) + val ys = xs.filter(x => !x.isCommonType) + val z = + if (xs.isEmpty) NoargsType(head, head.typeSymbol.name.stripSuffix("$")) //strip $ for case objects + else if (isRecursive(head)) RecursiveType(head, head.typeSymbol.name) + else RegularType(head, head.typeSymbol.name) + (ys, z) + } + + def pursType(tpe: TypeRepr): (String, String) = { + def trim(x: String): String = x.stripPrefix("(").stripSuffix(")") + val (a, b) = pursTypePars(tpe) + trim(a) -> trim(b) + } + + def pursTypePars(tpe: TypeRepr): (String, String) = { + if tpe.isString then + "String" -> "(Maybe String)" + else if tpe.isInt then + "Int" -> "(Maybe Int)" + else if tpe.isLong then + "BigInt" -> "(Maybe BigInt)" + else if tpe.isBoolean then + "Boolean" -> "(Maybe Boolean)" + else if tpe.isDouble then + "Number" -> "(Maybe Number)" + else if tpe.isArrayByte then + "Uint8Array" -> "(Maybe Uint8Array)" + else if tpe.isOption then + val typeArg = tpe.optionArgument + if typeArg.isLong then + "(Maybe BigInt)" -> "(Maybe BigInt)" + else if typeArg.isDouble then + "(Maybe Number)" -> "(Maybe Number)" + else + val name = typeArg.typeSymbol.name + s"(Maybe $name)" -> s"Maybe $name)" + else if tpe.isRepeated then + iterablePurs(tpe) match + case ArrayPurs(tpe) => + val name = tpe.typeSymbol.name + s"(Array $name)" -> s"(Array $name)" + case ArrayTuplePurs(tpe1, tpe2) => + val name1 = pursTypePars(tpe1)._1 + val name2 = pursTypePars(tpe2)._1 + s"(Array (Tuple $name1 $name2))" -> s"(Array (Tuple $name1 $name2))" + else + val name = tpe.typeSymbol.name + name -> s"(Maybe $name)" + } + + def iterablePurs(tpe: TypeRepr): IterablePurs = + tpe.typeArgs match + case x :: Nil => ArrayPurs(x) + case x :: y :: Nil => ArrayTuplePurs(x, y) + case _ => throw new Exception(s"too many type args for $tpe") diff --git a/purs/src/main/scala-3/decoders.scala b/purs/src/main/scala-3/decoders.scala new file mode 100644 index 0000000..5b6ae5f --- /dev/null +++ b/purs/src/main/scala-3/decoders.scala @@ -0,0 +1,308 @@ +package proto +package purs + +import scala.annotation.unused +import scala.quoted.* + +trait Decoders extends Ops with Purs: + implicit val qctx: Quotes + import qctx.reflect.* + + def decodersFrom(types: Seq[Tpe]): Seq[Coder] = + types.map{ + case TraitType(tpe, name, children, true) => + val cases = children.map{ case ChildMeta(name1, _, n, noargs, rec) => + if (noargs) s"$n -> decode (decode$name1 _xs_ pos1) \\_ -> $name1" + else if (rec) s"$n -> decode (decode$name1 _xs_ pos1) $name1''" + else s"$n -> decode (decode$name1 _xs_ pos1) $name1" + } + val tmpl = + s"""|decode$name :: Uint8Array -> Decode.Result $name + |decode$name _xs_ = do + | { pos: pos1, val: tag } <- Decode.unsignedVarint32 _xs_ 0 + | case tag `zshr` 3 of${cases.map("\n "+_).mkString("")} + | i -> Left $$ Decode.BadType i + | where + | decode :: forall a. Decode.Result a -> (a -> $name) -> Decode.Result $name + | decode res f = map (\\{ pos, val } -> { pos, val: f val }) res""".stripMargin + Coder(tmpl, Some(s"decode$name")) + case TraitType(tpe, name, children, false) => + val cases = children.flatMap{ case ChildMeta(name1, _, n, noargs, rec) => + if (noargs) s"$n -> decodeFieldLoop end (decode$name1 _xs_ pos2) \\_ -> Just $name1" :: Nil + else if (rec) s"$n -> decodeFieldLoop end (decode$name1 _xs_ pos2) (Just <<< $name1'')" :: Nil + else s"$n -> decodeFieldLoop end (decode$name1 _xs_ pos2) (Just <<< $name1)" :: Nil + } + val tmpl = + s"""|decode$name :: Uint8Array -> Int -> Decode.Result $name + |decode$name _xs_ pos0 = do + | { pos, val: msglen } <- Decode.unsignedVarint32 _xs_ pos0 + | tailRecM3 decode (pos + msglen) Nothing pos + | where + | decode :: Int -> Maybe $name -> Int -> Decode.Result' (Step { a :: Int, b :: Maybe $name, c :: Int } { pos :: Int, val :: $name }) + | decode end acc pos1 | pos1 < end = do + | { pos: pos2, val: tag } <- Decode.unsignedVarint32 _xs_ pos1 + | case tag `zshr` 3 of${cases.map("\n "+_).mkString} + | _ -> decodeFieldLoop end (Decode.skipType _xs_ pos2 $$ tag .&. 7) \\_ -> acc + | decode end (Just acc) pos1 = pure $$ Done { pos: pos1, val: acc } + | decode end acc@Nothing pos1 = Left $$ Decode.MissingFields "$name"""".stripMargin + Coder(tmpl, None) + case TupleType(tpe, tupleName, tpe_1, tpe_2) => + val fun = "decode" + tupleName + val cases = List(("first", tpe_1, 1, NoDef), ("second", tpe_2, 2, NoDef)).flatMap(decodeFieldLoop.tupled) + val tmpl = + s"""|$fun :: Uint8Array -> Int -> Decode.Result (Tuple ${pursTypePars(tpe_1)._1} ${pursTypePars(tpe_2)._1}) + |$fun _xs_ pos0 = do + | { pos, val: msglen } <- Decode.unsignedVarint32 _xs_ pos0 + | { pos: pos1, val } <- tailRecM3 decode (pos + msglen) { ${nothingValue("first", tpe_1)}, ${nothingValue("second", tpe_2)} } pos + | case val of + | { ${justValue("first", tpe_1)}, ${justValue("second", tpe_2)} } -> pure { pos: pos1, val: Tuple first second } + | _ -> Left $$ Decode.MissingFields "$fun" + | where + | decode :: Int -> { first :: ${pursType(tpe_1)._2}, second :: ${pursType(tpe_2)._2} } -> Int -> Decode.Result' (Step { a :: Int, b :: { first :: ${pursType(tpe_1)._2}, second :: ${pursType(tpe_2)._2} }, c :: Int } { pos :: Int, val :: { first :: ${pursType(tpe_1)._2}, second :: ${pursType(tpe_2)._2} } }) + | decode end acc pos1 | pos1 < end = do + | { pos: pos2, val: tag } <- Decode.unsignedVarint32 _xs_ pos1 + | case tag `zshr` 3 of${cases.map("\n "+_).mkString("")} + | _ -> decodeFieldLoop end (Decode.skipType _xs_ pos2 $$ tag .&. 7) (\\_ -> acc) + | decode end acc pos1 = pure $$ Done { pos: pos1, val: acc }""".stripMargin + Coder(tmpl, None) + case NoargsType(tpe, name) => + val tmpl = + s"""|decode$name :: Uint8Array -> Int -> Decode.Result Unit + |decode$name _xs_ pos0 = do + | { pos, val: msglen } <- Decode.unsignedVarint32 _xs_ pos0 + | pure { pos: pos + msglen, val: unit }""".stripMargin + Coder(tmpl, None) + case RecursiveType(tpe, name) => + val fs = fields(tpe) + val defObj: String = fs.map{ case (name, tpe, _, _) => nothingValue(name, tpe) }.mkString("{ ", ", ", " }") + val justObj: String = fs.map{ case (name, tpe, _, _) => justValue(name, tpe) }.mkString("{ ", ", ", " }") + val unObj: String = fs.map{ + case (name,_,_,_) => name + }.mkString("{ ", ", ", " }") + val simplify = fs.forall{ case (name, tpe, _, defval) => defval.isInstanceOf[HasDefFun] } + val hasDefval = fs.exists(_._4.isInstanceOf[FillDef]) + val tmpl = + if (simplify) { + if (hasDefval) { + val defs = fs.map{ + case (name, _, _, defval: FillDef) => s"$name: fromMaybe ${defval.value} $name" + case (name, _, _, _) => name + }.mkString(", ") + val cases = fs.flatMap(decodeFieldLoop.tupled) + s"""|decode$name :: Uint8Array -> Int -> Decode.Result $name + |decode$name _xs_ pos0 = do + | { pos, val: msglen } <- Decode.unsignedVarint32 _xs_ pos0 + | { pos: pos1, val: $unObj } <- tailRecM3 decode (pos + msglen) $defObj pos + | pure { pos: pos1, val: $name { $defs }} + | where + | decode :: Int -> $name' -> Int -> Decode.Result' (Step { a :: Int, b :: $name', c :: Int } { pos :: Int, val :: $name' }) + | decode end acc pos1 | pos1 < end = do + | { pos: pos2, val: tag } <- Decode.unsignedVarint32 _xs_ pos1 + | case tag `zshr` 3 of${cases.map("\n "+_).mkString("")} + | _ -> decodeFieldLoop end (Decode.skipType _xs_ pos2 $$ tag .&. 7) \\_ -> acc + | decode end acc pos1 = pure $$ Done { pos: pos1, val: acc }""".stripMargin + } else { + val cases = fs.flatMap(decodeFieldLoopNewtype(name).tupled) + s"""|decode$name :: Uint8Array -> Int -> Decode.Result $name + |decode$name _xs_ pos0 = do + | { pos, val: msglen } <- Decode.unsignedVarint32 _xs_ pos0 + | tailRecM3 decode (pos + msglen) ($name $defObj) pos + | where + | decode :: Int -> $name -> Int -> Decode.Result' (Step { a :: Int, b :: $name, c :: Int } { pos :: Int, val :: $name }) + | decode end acc'@($name acc) pos1 | pos1 < end = do + | { pos: pos2, val: tag } <- Decode.unsignedVarint32 _xs_ pos1 + | case tag `zshr` 3 of${cases.map("\n "+_).mkString("")} + | _ -> decodeFieldLoop end (Decode.skipType _xs_ pos2 $$ tag .&. 7) \\_ -> acc' + | decode end acc pos1 = pure $$ Done { pos: pos1, val: acc }""".stripMargin + } + } else { + if (hasDefval) { + val name = tpe.typeSymbol.name + val name1 = tpe.typeSymbol.name + "'" + val cases = fs.flatMap(decodeFieldLoop.tupled) + val defs = fs.map{ + case (name, _, _, defval: FillDef) => s"$name: fromMaybe ${defval.value} $name" + case (name, _, _, _) => name + }.mkString(", ") + val case1 = fs.map{ + case (name, tpe, _, _: FillDef) => name + case (name, tpe, _, _) => justValue(name, tpe) + }.mkString(", ") + s"""|decode$name :: Uint8Array -> Int -> Decode.Result $name + |decode$name _xs_ pos0 = do + | { pos, val: msglen } <- Decode.unsignedVarint32 _xs_ pos0 + | { pos: pos1, val } <- tailRecM3 decode (pos + msglen) $defObj pos + | case val of + | { $case1 } -> pure { pos: pos1, val: $name { $defs }} + | _ -> Left $$ Decode.MissingFields "$name" + | where + | decode :: Int -> $name1 -> Int -> Decode.Result' (Step { a :: Int, b :: $name1, c :: Int } { pos :: Int, val :: $name1 }) + | decode end acc pos1 | pos1 < end = do + | { pos: pos2, val: tag } <- Decode.unsignedVarint32 _xs_ pos1 + | case tag `zshr` 3 of${cases.map("\n "+_).mkString("")} + | _ -> decodeFieldLoop end (Decode.skipType _xs_ pos2 $$ tag .&. 7) \\_ -> acc + | decode end acc pos1 = pure $$ Done { pos: pos1, val: acc }""".stripMargin + } else { + val name = tpe.typeSymbol.name + val name1 = tpe.typeSymbol.name + "'" + val cases = fs.flatMap(decodeFieldLoop.tupled) + s"""|decode$name :: Uint8Array -> Int -> Decode.Result $name + |decode$name _xs_ pos0 = do + | { pos, val: msglen } <- Decode.unsignedVarint32 _xs_ pos0 + | { pos: pos1, val } <- tailRecM3 decode (pos + msglen) $defObj pos + | case val of + | $justObj -> pure { pos: pos1, val: $name $unObj } + | _ -> Left $$ Decode.MissingFields "$name" + | where + | decode :: Int -> $name1 -> Int -> Decode.Result' (Step { a :: Int, b :: $name1, c :: Int } { pos :: Int, val :: $name1 }) + | decode end acc pos1 | pos1 < end = do + | { pos: pos2, val: tag } <- Decode.unsignedVarint32 _xs_ pos1 + | case tag `zshr` 3 of${cases.map("\n "+_).mkString("")} + | _ -> decodeFieldLoop end (Decode.skipType _xs_ pos2 $$ tag .&. 7) \\_ -> acc + | decode end acc pos1 = pure $$ Done { pos: pos1, val: acc }""".stripMargin + } + } + Coder(tmpl, None) + case RegularType(tpe, name) => + val fs = fields(tpe) + val defObj: String = fs.map{ case (name, tpe, _, _) => nothingValue(name, tpe) }.mkString("{ ", ", ", " }") + val justObj: String = fs.map{ case (name, tpe, _, _) => justValue(name, tpe) }.mkString("{ ", ", ", " }") + val unObj: String = fs.map{ + case (name,_,_,_) => name + }.mkString("{ ", ", ", " }") + val cases = fs.flatMap(decodeFieldLoop.tupled) + val simplify = fs.forall{ case (name, tpe, _, defval) => defval.isInstanceOf[HasDefFun] } + val hasDefval = fs.exists(_._4.isInstanceOf[FillDef]) + val tmpl = + if (simplify) { + if (hasDefval) { + val defs = fs.map{ + case (name, _, _, defval: FillDef) => s"$name: fromMaybe ${defval.value} $name" + case (name, _, _, _) => name + }.mkString(", ") + s"""|decode$name :: Uint8Array -> Int -> Decode.Result $name + |decode$name _xs_ pos0 = do + | { pos, val: msglen } <- Decode.unsignedVarint32 _xs_ pos0 + | { pos: pos1, val: $unObj } <- tailRecM3 decode (pos + msglen) $defObj pos + | pure { pos: pos1, val: { ${defs} }} + | where + | decode :: Int -> $name' -> Int -> Decode.Result' (Step { a :: Int, b :: $name', c :: Int } { pos :: Int, val :: $name' }) + | decode end acc pos1 | pos1 < end = do + | { pos: pos2, val: tag } <- Decode.unsignedVarint32 _xs_ pos1 + | case tag `zshr` 3 of${cases.map("\n "+_).mkString("")} + | _ -> decodeFieldLoop end (Decode.skipType _xs_ pos2 $$ tag .&. 7) \\_ -> acc + | decode end acc pos1 = pure $$ Done { pos: pos1, val: acc }""".stripMargin + } else + s"""|decode$name :: Uint8Array -> Int -> Decode.Result $name + |decode$name _xs_ pos0 = do + | { pos, val: msglen } <- Decode.unsignedVarint32 _xs_ pos0 + | tailRecM3 decode (pos + msglen) $defObj pos + | where + | decode :: Int -> $name -> Int -> Decode.Result' (Step { a :: Int, b :: $name, c :: Int } { pos :: Int, val :: $name }) + | decode end acc pos1 | pos1 < end = do + | { pos: pos2, val: tag } <- Decode.unsignedVarint32 _xs_ pos1 + | case tag `zshr` 3 of${cases.map("\n "+_).mkString("")} + | _ -> decodeFieldLoop end (Decode.skipType _xs_ pos2 $$ tag .&. 7) \\_ -> acc + | decode end acc pos1 = pure $$ Done { pos: pos1, val: acc }""".stripMargin + } else { + if (hasDefval) { + val defs = fs.map{ + case (name, _, _, defval: FillDef) => s"$name: fromMaybe ${defval.value} $name" + case (name, _, _, _) => name + }.mkString(", ") + val case1 = fs.map{ + case (name, tpe, _, _: FillDef) => name + case (name, tpe, _, _) => justValue(name, tpe) + }.mkString(", ") + s"""|decode$name :: Uint8Array -> Int -> Decode.Result $name + |decode$name _xs_ pos0 = do + | { pos, val: msglen } <- Decode.unsignedVarint32 _xs_ pos0 + | { pos: pos1, val } <- tailRecM3 decode (pos + msglen) $defObj pos + | case val of + | { $case1 } -> pure { pos: pos1, val: { $defs }} + | _ -> Left $$ Decode.MissingFields "$name" + | where + | decode :: Int -> $name' -> Int -> Decode.Result' (Step { a :: Int, b :: $name', c :: Int } { pos :: Int, val :: $name' }) + | decode end acc pos1 | pos1 < end = do + | { pos: pos2, val: tag } <- Decode.unsignedVarint32 _xs_ pos1 + | case tag `zshr` 3 of${cases.map("\n "+_).mkString("")} + | _ -> decodeFieldLoop end (Decode.skipType _xs_ pos2 $$ tag .&. 7) \\_ -> acc + | decode end acc pos1 = pure $$ Done { pos: pos1, val: acc }""".stripMargin + } else { + s"""|decode$name :: Uint8Array -> Int -> Decode.Result $name + |decode$name _xs_ pos0 = do + | { pos, val: msglen } <- Decode.unsignedVarint32 _xs_ pos0 + | { pos: pos1, val } <- tailRecM3 decode (pos + msglen) $defObj pos + | case val of + | $justObj -> pure { pos: pos1, val: $unObj } + | _ -> Left $$ Decode.MissingFields "$name" + | where + | decode :: Int -> $name' -> Int -> Decode.Result' (Step { a :: Int, b :: $name', c :: Int } { pos :: Int, val :: $name' }) + | decode end acc pos1 | pos1 < end = do + | { pos: pos2, val: tag } <- Decode.unsignedVarint32 _xs_ pos1 + | case tag `zshr` 3 of${cases.map("\n "+_).mkString("")} + | _ -> decodeFieldLoop end (Decode.skipType _xs_ pos2 $$ tag .&. 7) \\_ -> acc + | decode end acc pos1 = pure $$ Done { pos: pos1, val: acc }""".stripMargin + } + } + Coder(tmpl, None) + }.distinct + + private def decodeFieldLoop(name: String, tpe: TypeRepr, n: Int, defval: DefVal): List[String] = { + decodeFieldLoopTmpl{ case (n: Int, fun: String, mod: String) => + s"$n -> decodeFieldLoop end ($fun _xs_ pos2) \\val -> acc { $mod }" :: Nil + }(name, tpe, n, defval) + } + + private def decodeFieldLoopNewtype(newtype: String)(name: String, tpe: TypeRepr, n: Int, defval: DefVal): List[String] = { + decodeFieldLoopTmpl{ case (n: Int, fun: String, mod: String) => + s"$n -> decodeFieldLoop end ($fun _xs_ pos2) \\val -> $newtype $$ acc { $mod }" :: Nil + }(name, tpe, n, defval) + } + + private def decodeFieldLoopTmpl(tmpl: (Int, String, String) => List[String])(name: String, tpe: TypeRepr, n: Int, @unused defval: DefVal): List[String] = { + if (tpe.isString) { + tmpl(n, "Decode.string", s"$name = Just val") + } else if (tpe.isInt) { + tmpl(n, "Decode.signedVarint32", s"$name = Just val") + } else if (tpe.isLong) { + tmpl(n, "Decode.bigInt", s"$name = Just val") + } else if (tpe.isBoolean) { + tmpl(n, "Decode.boolean", s"$name = Just val") + } else if (tpe.isDouble) { + tmpl(n, "Decode.double", s"$name = Just val") + } else if (tpe.isOption) { + val tpe1 = tpe.optionArgument + if (tpe1.isString) { + tmpl(n, "Decode.string", s"$name = Just val") + } else if (tpe1.isInt) { + tmpl(n, "Decode.signedVarint32", s"$name = Just val") + } else if (tpe1.isLong) { + tmpl(n, "Decode.bigInt", s"$name = Just val") + } else if (tpe1.isBoolean) { + tmpl(n, "Decode.boolean", s"$name = Just val") + } else if (tpe1.isDouble) { + tmpl(n, "Decode.double", s"$name = Just val") + } else { + val typeArgName = tpe1.typeSymbol.name + tmpl(n, s"decode$typeArgName", s"$name = Just val") + } + } else if (tpe.isArrayByte) { + tmpl(n, "Decode.bytes", s"$name = Just val") + } else if (tpe.isRepeated) { + iterablePurs(tpe) match { + case ArrayPurs(tpe1) => + if (tpe1.isString) { + tmpl(n, "Decode.string", s"$name = snoc acc.$name val") + } else { + val tpeName = tpe1.typeSymbol.name + tmpl(n, s"decode$tpeName", s"$name = snoc acc.$name val") + } + case ArrayTuplePurs(tpe1, tpe2) => + tmpl(n, s"decode${tupleFunName(tpe1, tpe2)}", s"$name = snoc acc.$name val") + } + } else { + val name1 = tpe.typeSymbol.name + tmpl(n, s"decode${name1}", s"$name = Just val") + } + } \ No newline at end of file diff --git a/purs/src/main/scala-3/encoders.scala b/purs/src/main/scala-3/encoders.scala new file mode 100644 index 0000000..3a1732e --- /dev/null +++ b/purs/src/main/scala-3/encoders.scala @@ -0,0 +1,154 @@ +package proto +package purs + +import scala.quoted.* +import scala.annotation.unused + +trait Encoders extends Ops: + implicit val qctx: Quotes + import qctx.reflect.* + + def encodersFrom(types: Seq[Tpe]): Seq[Coder] = + types.map{ + case TraitType(tpe, name, children, true) => + val cases = children.map{ case ChildMeta(name1, _, n, noargs, rec) => + if (noargs) + s"encode$name $name1 = concatAll [ Encode.unsignedVarint32 ${(n << 3) + 2}, encode$name1 ]" :: Nil + else { + val pursTypeName = if (rec) s"$name1''" else name1 + s"encode$name ($pursTypeName x) = concatAll [ Encode.unsignedVarint32 ${(n << 3) + 2}, encode$name1 x ]" :: Nil + } + } + val tmpl = + s"""|encode$name :: $name -> Uint8Array + |${cases.map(_.mkString("\n")).mkString("\n")}""".stripMargin + Coder(tmpl, Some(s"encode$name")) + case TraitType(tpe, name, children, false) => + val cases = children.map{ case ChildMeta(name1, _, n, noargs, rec) => + if (noargs) + List( + s"encode$name $name1 = do" + , s" let xs = concatAll [ Encode.unsignedVarint32 ${(n << 3) + 2}, encode$name1 ]" + , s" concatAll [ Encode.unsignedVarint32 $$ length xs, xs ]" + ).mkString("\n") :: Nil + else { + val pursTypeName = if (rec) s"$name1''" else name1 + List( + s"encode$name ($pursTypeName x) = do" + , s" let xs = concatAll [ Encode.unsignedVarint32 ${(n << 3) + 2}, encode$name1 x ]" + , s" concatAll [ Encode.unsignedVarint32 $$ length xs, xs ]" + ).mkString("\n") :: Nil + } + } + val tmpl = + s"""|encode$name :: $name -> Uint8Array + |${cases.map(_.mkString("\n")).mkString("\n")}""".stripMargin + Coder(tmpl, None) + case TupleType(tpe, tupleName, tpe_1, tpe_2) => + val fun = "encode" + tupleName + val tmpl = + s"""|$fun :: Tuple ${pursTypePars(tpe_1)._1} ${pursTypePars(tpe_2)._1} -> Uint8Array + |$fun (Tuple _1 _2) = do + | let msg = { _1, _2 } + | let xs = concatAll + | ${List(("_1", tpe_1, 1, NoDef), ("_2", tpe_2, 2, NoDef)).flatMap(encodeField.tupled).mkString(" [ ", "\n , ", "\n ]")} + | concatAll [ Encode.unsignedVarint32 $$ length xs, xs ]""".stripMargin + Coder(tmpl, None) + case NoargsType(tpe, name) => + val tmpl = + s"""|encode$name :: Uint8Array + |encode$name = Encode.unsignedVarint32 0""".stripMargin + Coder(tmpl, None) + case RecursiveType(tpe, name) => + val encodeFields = fields(tpe).flatMap(encodeField.tupled) + val tmpl = + if (encodeFields.nonEmpty) { + s"""|encode$name :: $name -> Uint8Array + |encode$name ($name msg) = do + | let xs = concatAll + | ${encodeFields.mkString(" [ ", "\n , ", "\n ]")} + | concatAll [ Encode.unsignedVarint32 $$ length xs, xs ]""".stripMargin + } else { + s"""|encode$name :: $name -> Uint8Array + |encode$name _ = Encode.unsignedVarint32 0""".stripMargin + } + Coder(tmpl, None) + case RegularType(tpe, name) => + val encodeFields = fields(tpe).flatMap(encodeField.tupled) + val tmpl = + if (encodeFields.nonEmpty) { + s"""|encode$name :: $name -> Uint8Array + |encode$name msg = do + | let xs = concatAll + | ${encodeFields.mkString(" [ ", "\n , ", "\n ]")} + | concatAll [ Encode.unsignedVarint32 $$ length xs, xs ]""".stripMargin + } else { + s"""|encode$name :: $name -> Uint8Array + |encode$name _ = Encode.unsignedVarint32 0""".stripMargin + } + Coder(tmpl, None) + }.distinct + + def encodeField(name: String, tpe: TypeRepr, n: Int, @unused defval: DefVal): List[String] = + if tpe.isString then + List( + s"""Encode.unsignedVarint32 ${(n<<3)+2}""" + , s"""Encode.string msg.$name""" + ) + else if tpe.isInt then + List( + s"""Encode.unsignedVarint32 ${(n<<3)+0}""" + , s"""Encode.signedVarint32 msg.$name""" + ) + else if tpe.isLong then + List( + s"""Encode.unsignedVarint32 ${(n<<3)+0}""" + , s"""Encode.bigInt msg.$name""" + ) + else if tpe.isBoolean then + List( + s"""Encode.unsignedVarint32 ${(n<<3)+0}""" + , s"""Encode.boolean msg.$name""" + ) + else if tpe.isDouble then + List( + s"""Encode.unsignedVarint32 ${(n<<3)+1}""" + , s"""Encode.double msg.$name""" + ) + else if tpe.isOption then + val tpe1 = tpe.optionArgument + if tpe1.isString then + s"""fromMaybe (fromArray []) $$ map (\\x -> concatAll [ Encode.unsignedVarint32 ${(n<<3)+2}, Encode.string x ]) msg.$name""" :: Nil + else if tpe1.isInt then + s"""fromMaybe (fromArray []) $$ map (\\x -> concatAll [ Encode.unsignedVarint32 ${(n<<3)+0}, Encode.signedVarint32 x ]) msg.$name""" :: Nil + else if tpe1.isLong then + s"""fromMaybe (fromArray []) $$ map (\\x -> concatAll [ Encode.unsignedVarint32 ${(n<<3)+0}, Encode.bigInt x ]) msg.$name""" :: Nil + else if tpe1.isBoolean then + s"""fromMaybe (fromArray []) $$ map (\\x -> concatAll [ Encode.unsignedVarint32 ${(n<<3)+0}, Encode.boolean x ]) msg.$name""" :: Nil + else if tpe1.isDouble then + s"""fromMaybe (fromArray []) $$ map (\\x -> concatAll [ Encode.unsignedVarint32 ${(n<<3)+1}, Encode.double x ]) msg.$name""" :: Nil + else + val typeArgName = tpe1.typeSymbol.name + s"""fromMaybe (fromArray []) $$ map (\\x -> concatAll [ Encode.unsignedVarint32 ${(n<<3)+2}, encode$typeArgName x ]) msg.$name""" :: Nil + else if tpe.isArrayByte then + List( + s"""Encode.unsignedVarint32 ${(n<<3)+2}""" + , s"""Encode.bytes msg.$name""" + ) + else if tpe.isRepeated then + iterablePurs(tpe) match { + case ArrayPurs(x) => + if (x.isString) { + s"""concatAll $$ concatMap (\\x -> [ Encode.unsignedVarint32 ${(n<<3)+2}, Encode.string x ]) msg.$name""" :: Nil + } else { + s"""concatAll $$ concatMap (\\x -> [ Encode.unsignedVarint32 ${(n<<3)+2}, encode${x.typeSymbol.name} x ]) msg.$name""" :: Nil + } + case ArrayTuplePurs(tpe1, tpe2) => + s"concatAll $$ concatMap (\\x -> [ Encode.unsignedVarint32 ${(n<<3)+2}, encode${tupleFunName(tpe1, tpe2)} x ]) msg.$name" :: Nil + } + else + val tpeName = tpe.typeSymbol.name + List( + s"""Encode.unsignedVarint32 ${(n<<3)+2}""" + , s"""encode${tpeName} msg.$name""" + ) diff --git a/purs/src/main/scala-3/io.scala b/purs/src/main/scala-3/io.scala new file mode 100644 index 0000000..29dbfb7 --- /dev/null +++ b/purs/src/main/scala-3/io.scala @@ -0,0 +1,13 @@ +package proto +package purs + +object io { + def writeToFile(path: String, res: String): Unit = { + if res.nonEmpty then + import java.io.{BufferedWriter, FileWriter} + val w = new BufferedWriter(new FileWriter(path)) + w.write(res) + w.close() + else () + } +} \ No newline at end of file diff --git a/purs/src/main/scala-3/main.scala b/purs/src/main/scala-3/main.scala new file mode 100644 index 0000000..b215120 --- /dev/null +++ b/purs/src/main/scala-3/main.scala @@ -0,0 +1,172 @@ +package proto +package purs + +import scala.quoted.* + +type Content = String +type ModuleName = String + +case class GenResRaw( + purs: Map[ModuleName, List[String]], + defaultValues: Map[String, Any] +) { + def toGenRes: GenRes = + val purs = this.purs.map( (moduleName, contents) => + var content: String = contents.mkString + defaultValues.foreach{ (id, value) => + content = content.replace(id, value.toString).nn + } + moduleName -> content + ) + GenRes(purs = purs) +} + +case class GenRes( + purs: Map[ModuleName, Content] +) + +inline def generate[D, E]( + inline moduleEncode: ModuleName + , inline moduleDecode: ModuleName + , inline moduleCommon: ModuleName + ): GenRes = + ${Macro.generate[D, E]('moduleEncode, 'moduleDecode, 'moduleCommon)} + +object Macro: + def generate[D: Type, E: Type]( + moduleEncodeExpr: Expr[ModuleName] + , moduleDecodeExpr: Expr[ModuleName] + , moduleCommonExpr: Expr[ModuleName] + )(using qctx: Quotes): Expr[GenRes] = + Impl().generate[D, E](moduleEncodeExpr, moduleDecodeExpr, moduleCommonExpr) +end Macro + +private class Impl(using val qctx: Quotes) extends Ops with Encoders with Decoders: + import qctx.reflect.* + + def generate[D: Type, E: Type]( + moduleEncodeExpr: Expr[ModuleName] + , moduleDecodeExpr: Expr[ModuleName] + , moduleCommonExpr: Expr[ModuleName] + ): Expr[GenRes] = + + val moduleEncode = moduleEncodeExpr.valueOrAbort + val moduleDecode = moduleDecodeExpr.valueOrAbort + val moduleCommon = moduleCommonExpr.valueOrAbort + + val tpeD = TypeRepr.of[D] + val tpeE = TypeRepr.of[E] + + val decodeTpes = collectTpes(tpeD) + val decoders = decodersFrom(decodeTpes) + + val defExpr = collectDefExpressions(tpeD) ++ collectDefExpressions(tpeE) + val namesExpr: Expr[List[String]] = Expr.ofList(defExpr.map(_._1).toList) + val valuesExpr: Expr[List[Any]] = Expr.ofList(defExpr.map(_._2).toList) + + val encodeTpes = collectTpes(tpeE) + val encoders = encodersFrom(encodeTpes) + + val commonTpes = encodeTpes intersect decodeTpes + val commonPursTypes = makePursTypes(commonTpes, genMaybe=false) + val decodePursTypes = makePursTypes(decodeTpes, genMaybe=true) diff commonPursTypes + val encodePursTypes = makePursTypes(encodeTpes diff commonTpes, genMaybe=false) + + val moduleCommonValue = { + if (commonPursTypes.nonEmpty) { + val code = commonPursTypes.flatMap(_.tmpl).mkString("\n") + val is = List( + if (code.contains(" :: Eq ")) Some("Data.Eq" -> "class Eq") else None + , if (code.contains("Nothing")) Some("Data.Maybe" -> "Maybe(Nothing)") else None + , if (code.contains("Tuple ")) Some("Data.Tuple" -> "Tuple") else None + , if (code.contains("BigInt ")) Some("Proto.BigInt" -> "BigInt") else None + ).flatten.groupMapReduce(_._1)(_._2)(_ + ", " + _).map(x => "import " + x._1 + " (" + x._2 + ")").to(List).sorted.mkString("\n") + s"""|module $moduleCommon + | ( ${commonPursTypes.flatMap(_.`export`).mkString("\n , ")} + | ) where + | + |$is + | + |$code""".stripMargin + } else "" + } + val moduleEncodeValue = { + val code = encodePursTypes.flatMap(_.tmpl).mkString("\n") + "\n\n" + encoders.map(_.tmpl).mkString("\n\n") + val is = List( + if (code.contains("concatMap ")) Some("Data.Array" -> "concatMap") else None + , if (code.contains("Uint8Array")) Some("Proto.Uint8Array" -> "Uint8Array") else None + , if (code.contains("BigInt")) Some("Proto.BigInt" -> "BigInt") else None + , if (code.contains(" :: Eq ")) Some("Data.Eq" -> "class Eq") else None + , if (raw"\WMaybe ".r.findFirstIn(code).isDefined) Some("Data.Maybe" -> "Maybe(..)") else None + // , if (code contains "Nothing") Some("Data.Maybe" -> "Maybe(Nothing)") else None + , if (code.contains("fromMaybe ")) Some("Data.Maybe" -> "fromMaybe") else None + , if (code.contains("Tuple ")) Some("Data.Tuple" -> "Tuple(Tuple)") else None + , if (code.contains("map ")) Some("Prelude" -> "map") else None + , if (code.contains(" $ ")) Some("Prelude" -> "($)") else None + , if (code.contains("Encode.")) Some("Proto.Encode as Encode" -> "") else None + , if (code.contains("length ")) Some("Proto.Uint8Array" -> "length") else None + , if (code.contains("concatAll ")) Some("Proto.Uint8Array" -> "concatAll") else None + , if (code.contains("fromArray ")) Some("Proto.Uint8Array" -> "fromArray") else None + ).flatten.groupMapReduce(_._1)(_._2)(_ + ", " + _).map(x => "import " + x._1 + Some(x._2).filter(_.nonEmpty).map(" ("+_+")").getOrElse("")).to(List).sorted.mkString("\n") + s"""|module $moduleEncode + | ( ${(encodePursTypes.flatMap(_.`export`)++encoders.flatMap(_.`export`)).mkString("\n , ")} + | ) where + | + |$is + |${if (commonPursTypes.nonEmpty) s"import $moduleCommon" else ""} + | + |$code""".stripMargin + } + + val moduleDecodeValue = { + val code = + s"""|decodeFieldLoop :: forall a b c. Int -> Decode.Result a -> (a -> b) -> Decode.Result' (Step { a :: Int, b :: b, c :: Int } { pos :: Int, val :: c }) + |decodeFieldLoop end res f = map (\\{ pos, val } -> Loop { a: end, b: f val, c: pos }) res""".stripMargin + "\n\n" + decodePursTypes.flatMap(_.tmpl).mkString("\n") + "\n\n" + decoders.map(_.tmpl).mkString("\n\n") + val is = List( + if (code.contains("Loop ") || code.contains("Done ")) Some("Control.Monad.Rec.Class" -> "Step(Loop, Done)") else None + , if (code.contains("tailRecM3 ")) Some("Control.Monad.Rec.Class" -> "tailRecM3") else None + , if (code.contains("snoc ")) Some("Data.Array" -> "snoc") else None + , if (code.contains("Uint8Array ")) Some("Proto.Uint8Array" -> "Uint8Array") else None + , if (code.contains("BigInt")) Some("Proto.BigInt" -> "BigInt") else None + , if (code.contains("Left ")) Some("Data.Either" -> "Either(Left)") else None + , if (code.contains(" :: Eq ")) Some("Data.Eq" -> "class Eq") else None + , if (code.contains("zshr")) Some("Data.Int.Bits" -> "zshr") else None + , if (code.contains(" .&. ")) Some("Data.Int.Bits" -> "(.&.)") else None + , if (code.contains("Just ") || code.contains("Nothing")) Some("Data.Maybe" -> "Maybe(Just, Nothing)") else None + , if (code.contains("fromMaybe ")) Some("Data.Maybe" -> "fromMaybe") else None + , if (code.contains("Tuple ")) Some("Data.Tuple" -> "Tuple(Tuple)") else None + , if (code.contains("Unit")) Some("Data.Unit" -> "Unit") else None + , if (code.contains("unit")) Some("Data.Unit" -> "unit") else None + , if (code.contains("map ")) Some("Prelude" -> "map") else None + , if (code.contains(" do")) Some("Prelude" -> "bind") else None + , if (code.contains("pure ")) Some("Prelude" -> "pure") else None + , if (code.contains(" $ ")) Some("Prelude" -> "($)") else None + , if (code.contains(" + ")) Some("Prelude" -> "(+)") else None + , if (code.contains(" < ")) Some("Prelude" -> "(<)") else None + , if (code.contains(" <<< ")) Some("Prelude" -> "(<<<)") else None + , if (code.contains("Decode.")) Some("Proto.Decode as Decode" -> "") else None + ).flatten.groupMapReduce(_._1)(_._2)(_ + ", " + _).map(x => "import " + x._1 + Some(x._2).filter(_.nonEmpty).map(" ("+_+")").getOrElse("")).to(List).sorted.mkString("\n") + s"""|module $moduleDecode + | ( ${(decodePursTypes.flatMap(_.`export`)++decoders.flatMap(_.`export`)).mkString("\n , ")} + | ) where + | + |$is + |${if (commonPursTypes.nonEmpty) s"import $moduleCommon" else ""} + | + |$code""".stripMargin + } + + val strToContentExpr: String => Expr[List[String]] = str => Expr.ofList(str.grouped(1024).toList.map(Expr.apply)) + + '{ + GenResRaw( + purs = Map( + ${moduleCommonExpr} -> ${strToContentExpr(moduleCommonValue)} + , ${moduleEncodeExpr} -> ${strToContentExpr(moduleEncodeValue)} + , ${moduleDecodeExpr} -> ${strToContentExpr(moduleDecodeValue)} + ), + defaultValues = $namesExpr.zip($valuesExpr).toMap + ).toGenRes + } + +end Impl \ No newline at end of file diff --git a/purs/src/main/scala-3/purs.scala b/purs/src/main/scala-3/purs.scala new file mode 100644 index 0000000..7c07779 --- /dev/null +++ b/purs/src/main/scala-3/purs.scala @@ -0,0 +1,85 @@ +package proto +package purs + +import scala.quoted.* +import compiletime.asMatchable + +trait Purs extends Ops: + implicit val qctx: Quotes + import qctx.reflect.* + + def nothingValue(name: String, tpe: TypeRepr): String = { + if (tpe.isRepeated) { + iterablePurs(tpe) match { + case _: ArrayPurs => s"$name: []" + case _: ArrayTuplePurs => s"$name: []" + } + } else s"$name: Nothing" + } + + def justValue(name: String, tpe: TypeRepr): String = { + if (tpe.isOption) name + else if (tpe.isRepeated) name + else { + s"$name: Just $name" + } + } + + def makePursTypes(types: Seq[Tpe], genMaybe: Boolean): Seq[PursType] = { + types.flatMap{ + case TraitType(tpe, name, children, _) => + List(PursType(List( + s"data $name = ${children.map{ + case x if x.noargs => x.name + case x if x.rec => s"${x.name}'' ${x.name}" + case x => s"${x.name} ${x.name}" + }.mkString(" | ")}" + , s"derive instance eq$name :: Eq $name" + ), `export`=Some(s"$name(..)"))) + case _: TupleType => Nil + case _: NoargsType => Nil + case RecursiveType(tpe, name) => + val f = fields(tpe) + val fs = f.map{ case (name1, tpe, _, _) => name1 -> pursType(tpe) } + val params = fs.map{ case (name1, tpe) => s"$name1 :: ${tpe._1}" }.mkString(", ") + val x = s"newtype $name = $name { $params }" + val eq = s"derive instance eq$name :: Eq $name" + val defaults = f.collect{ case (name1, tpe1, _, v: HasDefFun) => (name1, pursType(tpe1)._1, v) } + Seq( + Some(PursType(List(x, eq), Some(s"$name($name)"))) + , if (defaults.nonEmpty) { + val tmpl1 = s"default$name :: { ${defaults.map{ case (name1, tpe1, _) => s"$name1 :: $tpe1" }.mkString(", ")} }" + val tmpl2 = s"default$name = { ${defaults.map{ case (name1, _, v) => s"$name1: ${v.value}" }.mkString(", ")} }" + Some(PursType(Seq(tmpl1, tmpl2), Some(s"default$name"))) + } else None + , if (genMaybe) { + val params1 = fs.map{ case (name1, tpe) => s"$name1 :: ${tpe._2}" }.mkString(", ") + if (params != params1) { + val x1 = s"type $name' = { $params1 }" + Some(PursType(List(x1), None)) + } else None + } else None + ).flatten + case RegularType(tpe, name) => + val f = fields(tpe) + val fs = f.map{ case (name1, tpe1, _, _) => name1 -> pursType(tpe1) } + val params = fs.map{ case (name1, tpe1) => s"$name1 :: ${tpe1._1}" }.mkString(", ") + val x = if (params.nonEmpty) s"type $name = { $params }" else s"type $name = {}" + val defaults = f.collect{ case (name1, tpe1, _, v: HasDefFun) => (name1, pursType(tpe1)._1, v) } + Seq( + Some(PursType(Seq(x), Some(s"$name"))) + , if (defaults.nonEmpty) { + val tmpl1 = s"default$name :: { ${defaults.map{ case (name1, tpe1, _) => s"$name1 :: $tpe1" }.mkString(", ")} }" + val tmpl2 = s"default$name = { ${defaults.map{ case (name1, _, v) => s"$name1: ${v.value}" }.mkString(", ")} }" + Some(PursType(Seq(tmpl1, tmpl2), Some(s"default$name"))) + } else None + , if (genMaybe) { + val params1 = fs.map{ case (name1, tpe1) => s"$name1 :: ${tpe1._2}" }.mkString(", ") + if (params != params1) { + val x1 = s"type $name' = { $params1 }" + Some(PursType(Seq(x1), None)) + } else None + } else None + ).flatten + }.distinct + } \ No newline at end of file diff --git a/purs/src/main/scala-3/tpe.scala b/purs/src/main/scala-3/tpe.scala new file mode 100644 index 0000000..22bd3aa --- /dev/null +++ b/purs/src/main/scala-3/tpe.scala @@ -0,0 +1,6 @@ +package proto +package purs + +final case class PursType(tmpl: Seq[String], `export`: Option[String]) + +final case class Coder(tmpl: String, `export`: Option[String]) \ No newline at end of file