Skip to content

Commit

Permalink
scala 2.12 support #28
Browse files Browse the repository at this point in the history
  • Loading branch information
YuriyMazepin committed Jan 1, 2022
1 parent 5dd98f8 commit fd2a0dd
Show file tree
Hide file tree
Showing 8 changed files with 1,214 additions and 1 deletion.
7 changes: 6 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ val `proto-parent` = project.in(file(".")).settings(
).aggregate(proto, protosyntax, protopurs, prototex, protoops, bench)

ThisBuild / scalaVersion := "3.1.1-RC2"
ThisBuild / crossScalaVersions := "3.1.1-RC2" :: "2.13.7" :: Nil
ThisBuild / crossScalaVersions := "3.1.1-RC2" :: "2.13.7" :: "2.12.15" :: Nil

lazy val proto = project.in(file("proto")).settings(
name := "proto",
Expand All @@ -28,6 +28,7 @@ lazy val protoops = project.in(file("ops")).settings(
, libraryDependencies ++= {
CrossVersion.partialVersion(scalaVersion.value) match {
case Some((2, 13)) => Seq("org.scala-lang" % "scala-reflect" % scalaVersion.value)
case Some((2, 12)) => Seq("org.scala-lang" % "scala-reflect" % scalaVersion.value)
case _ => Nil
}
}
Expand Down Expand Up @@ -69,6 +70,10 @@ ThisBuild / isSnapshot := true

ThisBuild / scalacOptions ++= {
CrossVersion.partialVersion(scalaVersion.value) match {
case Some((2, 12)) => Seq(
"-deprecation"
, "-Wconf:cat=deprecation&msg=Auto-application:silent"
)
case Some((2, 13)) => Seq(
"-deprecation"
, "-Wconf:cat=deprecation&msg=Auto-application:silent"
Expand Down
303 changes: 303 additions & 0 deletions proto/src/main/scala-2.12/BuildCodec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,303 @@
package proto

import com.google.protobuf.CodedOutputStream
import scala.reflect.macros.blackbox.Context

trait BuildCodec extends Common {
val c: Context
import c.universe._
import definitions._

def sizeBasic(field: FieldInfo, sizeAcc: TermName): Option[List[c.Tree]] =
common.sizeFun(field).map(fun => List(
q"${sizeAcc} = ${sizeAcc} + ${CodedOutputStream.computeTagSize(field.num)} + ${fun}"
))

def sizeOption(field: FieldInfo, sizeAcc: TermName): Option[List[c.Tree]] =
if (field.tpe.isOption) {
val tpe1 = field.tpe.typeArgs(0)
val field1 = field.copy(getter=q"${field.getter}.get", tpe=tpe1, num=field.num)
common.sizeFun(field1).map(v => List(
q"if (${field.getter}.isDefined) ${sizeAcc} = ${sizeAcc} + ${CodedOutputStream.computeTagSize(field.num)} + ${v}"
)).orElse(Some(List(
q"""val ${field.prepareName}: ${OptionClass}[${PrepareType}] = if (${field.getter}.isDefined) {
val p: ${PrepareType} = implicitly[${messageCodecFor(tpe1)}].prepare(${field.getter}.get)
${sizeAcc} = ${sizeAcc} + ${CodedOutputStream.computeTagSize(field.num)} + ${CodedOutputStreamType.typeSymbol.companion}.computeUInt32SizeNoTag(p.size) + p.size
Some(p)
} else {
None
}"""
)))
} else {
None
}

def sizeCollection(field: FieldInfo, sizeAcc: TermName): Option[List[c.Tree]] =
if (field.tpe.isIterable) {
val value = TermName("value")
val tpe1 = field.tpe.iterableArgument
val field1 = field.copy(getter=q"${value}", tpe=tpe1, num=field.num)
common.sizeFun(field1).map{ v =>
if ( field1.tpe =:= StringClass.selfType
|| field1.tpe =:= ArrayByteType
|| field1.tpe =:= BytesType
) {
val tagSizeName = TermName(s"${field.sizeName}TagSize")
List(
q"val ${tagSizeName} = ${CodedOutputStream.computeTagSize(field.num)}"
, q"var ${field.sizeName}: Int = 0"
, q"${field.getter}.foreach((${value}: ${tpe1}) => ${field.sizeName} = ${field.sizeName} + ${v} + ${tagSizeName})"
, q"${sizeAcc} = ${sizeAcc} + ${field.sizeName}"
)
} else {
List(
q"var ${field.sizeName}: Int = 0"
, q"${field.getter}.foreach((${value}: ${tpe1}) => ${field.sizeName} = ${field.sizeName} + ${v})"
, q"${sizeAcc} = ${sizeAcc} + ${CodedOutputStream.computeTagSize(field.num)} + ${CodedOutputStreamType.typeSymbol.companion}.computeUInt32SizeNoTag(${field.sizeName}) + ${field.sizeName}"
)
}
}.orElse{
val counter = TermName(c.freshName("n"))
Some(List(
q"val ${field.prepareName}: ${ArrayClass}[${PrepareType}] = new ${ArrayClass}[${PrepareType}](${field.getter}.size)"
, q"var ${counter} = 0"
, q"""${field.getter}.foreach((${value}: ${tpe1}) => {
val p: ${PrepareType} = implicitly[${messageCodecFor(tpe1)}].prepare(${value})
${field.prepareName}(${counter}) = p
${sizeAcc} = ${sizeAcc} + ${CodedOutputStream.computeTagSize(field.num)} + ${CodedOutputStreamType.typeSymbol.companion}.computeUInt32SizeNoTag(p.size) + p.size
${counter} = ${counter} + 1
})"""
))
}
} else {
None
}

def sizeMessage(field: FieldInfo, sizeAcc: TermName): List[c.Tree] =
List(
q"val ${field.prepareName} = implicitly[${messageCodecFor(field.tpe)}].prepare(${field.getter})"
, q"${sizeAcc} = ${sizeAcc} + ${CodedOutputStream.computeTagSize(field.num)} + ${CodedOutputStreamType.typeSymbol.companion}.computeUInt32SizeNoTag(${field.prepareName}.size) + ${field.prepareName}.size"
)

def size(params: List[FieldInfo], sizeAcc: TermName): List[c.Tree] =
params.map{field =>
sizeBasic(field, sizeAcc)
.orElse(sizeOption(field, sizeAcc))
.orElse(sizeCollection(field, sizeAcc))
.getOrElse(sizeMessage(field, sizeAcc))
}.flatten

def writeBasic(field: FieldInfo, os: TermName): Option[List[c.Tree]] =
common.writeFun(field, os).map(fun => List(
q"${os}.writeUInt32NoTag(${common.tag(field)})"
, q"${fun}"
))

def writeOption(field: FieldInfo, os: TermName): Option[List[c.Tree]] =
if (field.tpe.isOption) {
val tpe1 = field.tpe.typeArgs(0)
val field1 = field.copy(getter=q"${field.getter}.get", tpe=tpe1, num=field.num)
common.writeFun(field1, os).map(v => List(
q"""if (${field.getter}.isDefined) {
${os}.writeUInt32NoTag(${common.tag(field)})
${v}
}"""
)).orElse(Some(List(
q"""if (${field.prepareName}.isDefined) {
val p = ${field.prepareName}.get
${os}.writeUInt32NoTag(${common.tag(field)})
${os}.writeUInt32NoTag(p.size)
p.write(${os})
}"""
)))
} else {
None
}

def writeCollection(field: FieldInfo, os: TermName): Option[List[c.Tree]] =
if (field.tpe.isIterable) {
val value = TermName("value")
val tpe1 = field.tpe.iterableArgument
val field1 = field.copy(getter=q"${value}", tpe=tpe1, num=field.num)
common.writeFun(field1, os).map(v =>
if ( field1.tpe =:= StringClass.selfType
|| field1.tpe =:= ArrayByteType
|| field1.tpe =:= BytesType
) {
List(
q"""${field.getter}.foreach((${value}: ${tpe1}) => {
${os}.writeUInt32NoTag(${common.tag(field)})
${v}
})"""
)
} else {
List(
q"${os}.writeUInt32NoTag(${common.tag(field)})"
, q"${os}.writeUInt32NoTag(${field.sizeName})"
, q"${field.getter}.foreach((${value}: ${tpe1}) => ${v})"
)
}
).orElse{
val counter = TermName(c.freshName("n"))
Some(List(
q"var ${counter} = 0"
, q"""while(${counter} < ${field.prepareName}.length) {
val p = ${field.prepareName}(${counter})
${os}.writeUInt32NoTag(${common.tag(field)})
${os}.writeUInt32NoTag(p.size)
p.write(${os})
${counter} = ${counter} + 1
}"""
))
}
} else {
None
}

def writeMessage(field: FieldInfo, os: TermName): List[c.Tree] =
List(
q"${os}.writeUInt32NoTag(${common.tag(field)})"
, q"${os}.writeUInt32NoTag(${field.prepareName}.size)"
, q"${field.prepareName}.write(${os})"
)

def write(params: List[FieldInfo], os: TermName): List[c.Tree] =
params.map(field =>
writeBasic(field, os)
.orElse(writeOption(field, os))
.orElse(writeCollection(field, os))
.getOrElse(writeMessage(field, os))
).flatten

def prepare(params: List[FieldInfo], aType: c.Type, aName: TermName, sizeAcc: TermName, osName: TermName): List[c.Tree] = {
val PrepareName = TypeName(s"Prepare${aType.typeSymbol.name}")
List(
q"""class ${PrepareName}(${aName}: ${aType}) extends ${PrepareType} {
var ${sizeAcc}: ${IntClass} = 0
..${size(params, sizeAcc)}
val size: ${IntClass} = ${sizeAcc}
def write(${osName}: ${CodedOutputStreamType}): ${UnitClass} = {
..${write(params, osName)}
}
}"""
, q"def prepare(${aName}: ${aType}): ${PrepareType} = new ${PrepareName}(${aName})"
)
}

def initArg(field: FieldInfo): c.Tree =
if (field.tpe.isOption) {
q"${field.readName}"
} else if (field.tpe.isIterable) {
q"${field.readName}.result"
} else {
val err: String = s"missing required field `${field.name}: ${field.tpe}`"
val orElse = field.defaultValue.getOrElse(q"throw new RuntimeException(${err})")
q"${field.readName}.getOrElse($orElse)"
}

def read(params: List[FieldInfo], t: c.Type, buildResult: Option[c.Tree]=None): c.Tree = {
val init: List[c.Tree] =
if (t.isTrait) {
List(q"var readRes: ${OptionClass}[${t}] = ${NoneModule}")
} else {
params.map(field =>
if (field.tpe.isOption) {
q"var ${field.readName}: ${field.tpe} = ${NoneModule}"
} else if (field.tpe.isIterable) {
val tpe1 = field.tpe.iterableArgument
q"var ${field.readName}: ${builder(tpe1, field.tpe)} = ${field.tpe.typeConstructor.typeSymbol.companion}.newBuilder"
} else {
q"var ${field.readName}: ${OptionClass}[${field.tpe}] = ${NoneModule}"
}
)
}

val result = {
buildResult match {
case Some(fun1) => q"${fun1}"
case None =>
if (t.isTrait) {
val err: String = s"missing one of required field ${params.map(field => field.name.toString + ": " + field.tpe).mkString("`", "` or `", "`")}"
q"readRes.getOrElse(throw new RuntimeException(${err}))"
} else if (t.typeSymbol.companion == NoSymbol) {
q"${t.typeSymbol.owner.asClass.selfType.member(TermName(t.typeSymbol.name.encodedName.toString))}"
} else {
val args = params.map(field => q"${field.name} = ${initArg(field)}")
q"${t.typeSymbol.companion}.apply(..${args})"
}
}
}

def putLimit(is: TermName, read: c.Tree): List[c.Tree] =
List(
q"val readSize: Int = ${is}.readRawVarint32"
, q"val limit = ${is}.pushLimit(readSize)"
, q"${read}"
, q"${is}.popLimit(limit)"
)

def tagMatch(is: TermName): List[c.Tree] =
params.flatMap(p =>
if (p.tpe.isIterable && p.tpe.iterableArgument.isCommonType)
p :: p.copy(nonPacked = true) :: Nil
else p :: Nil
).map{ field =>
val readContent: List[c.Tree] = common.readFun(field, is).map(readFun => List(
q"${field.readName} = Some(${readFun})"
)).getOrElse{
val value = TermName("value")
if (field.tpe.isOption) {
val tpe1 = field.tpe.typeArgs(0)
val field1 = field.copy(getter=q"${value}", tpe=tpe1, num=field.num)
common.readFun(field1, is).map(readFun => List(
q"${field.readName} = Some(${readFun})"
)).getOrElse(
putLimit(is, q"${field.readName} = Some(implicitly[${messageCodecFor(tpe1)}].read(${is}))")
)
} else if (field.tpe.isIterable) {
val tpe1 = field.tpe.iterableArgument
val field1: FieldInfo = field.copy(getter=q"${value}", tpe=tpe1, num=field.num)
common.readFun(field1, is).map(readFun =>
if ( field1.tpe =:= StringClass.selfType
|| field1.tpe =:= ArrayByteType
|| field1.tpe =:= BytesType
|| field1.nonPacked
) {
q"${field.readName} += ${readFun}" :: Nil
} else {
val readMessage: c.Tree = q"""while (${is}.getBytesUntilLimit > 0) {
${field.readName} += ${readFun}
}"""
putLimit(is, readMessage)
}
).getOrElse{
val readMessage: c.Tree = q"${field.readName} += implicitly[${messageCodecFor(tpe1)}].read(${is})"
putLimit(is, readMessage)
}
} else {
putLimit(is, q"${field.readName} = Some(implicitly[${messageCodecFor(field.tpe)}].read(${is}))")
}
}
cq"${common.tag(field)} => ..${readContent}"
}

val isName = TermName("is")

if (params.size > 0) {
q"""def read(${isName}: ${CodedInputStreamType}): ${t} = {
var done = false
..${init}
while(done == false) {
${isName}.readTag match {
case ..${tagMatch(isName)}
case 0 => done = true
case skip => ${isName}.skipField(skip)
}
}
${result}
}"""
} else {
q"def read(${isName}: ${CodedInputStreamType}): ${t} = ${result}"
}
}
}
Loading

0 comments on commit fd2a0dd

Please sign in to comment.