Skip to content

Commit

Permalink
Avoid toString conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
RustedBones committed Dec 19, 2024
1 parent e854bff commit d76d226
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -270,77 +270,164 @@ private[types] object ConverterProvider {
// Converter helpers
// =======================================================================
def cast(tree: Tree, tpe: Type): Tree = {
val msg = Constant(s"Cannot convert to ${tpe.typeSymbol.name}: ")
val fail = q"""throw new _root_.java.lang.IllegalArgumentException($msg + $tree)"""

def readBase64(term: TermName) =
q"_root_.com.google.common.io.BaseEncoding.base64().decode($term)"

val provider: OverrideTypeProvider =
OverrideTypeProviderFinder.getProvider
val s = q"$tree.toString"
tpe match {
case t if provider.shouldOverrideType(c)(t) =>
provider.createInstance(c)(t, q"$tree")
case t if t =:= typeOf[Boolean] => q"$s.toBoolean"
case t if t =:= typeOf[Int] => q"$s.toInt"
case t if t =:= typeOf[Long] => q"$s.toLong"
case t if t =:= typeOf[Float] => q"$s.toFloat"
case t if t =:= typeOf[Double] => q"$s.toDouble"
case t if t =:= typeOf[String] => q"$s"
case t if t =:= typeOf[Boolean] =>
q"""$tree match {
case b: ${typeOf[java.lang.Boolean]} => _root_.scala.Boolean.unbox(b)
case _ => $fail
}"""
case t if t =:= typeOf[Int] =>
q"""$tree match {
case i: ${typeOf[java.lang.Integer]} => _root_.scala.Int.unbox(i)
case s: ${typeOf[String]} => s.toInt
case _ => $fail
}"""
case t if t =:= typeOf[Long] =>
q"""$tree match {
case l: ${typeOf[java.lang.Long]} => _root_.scala.Long.unbox(l)
case i: ${typeOf[java.lang.Integer]} => _root_.scala.Long.unbox(i).toLong
case s: ${typeOf[String]} => s.toLong
case _ => $fail
}"""
case t if t =:= typeOf[Float] =>
q"""$tree match {
case f: ${typeOf[java.lang.Float]} => _root_.scala.Float.unbox(f)
case d: ${typeOf[java.lang.Double]} => _root_.scala.Double.unbox(d).toFloat
case s: ${typeOf[String]} => s.toFloat
case _ => $fail
}"""
case t if t =:= typeOf[Double] =>
q"""$tree match {
case d: ${typeOf[java.lang.Double]} => _root_.scala.Double.unbox(d)
case s: ${typeOf[String]} => s.toDouble
case _ => $fail
}"""
case t if t =:= typeOf[String] =>
q"""$tree match {
case s: ${typeOf[String]} => s
case _ => $fail
}"""
case t if t =:= typeOf[BigDecimal] =>
q"_root_.com.spotify.scio.bigquery.Numeric($s)"

q"""$tree match {
case bd: ${typeOf[BigDecimal]} =>
_root_.com.spotify.scio.bigquery.Numeric(bd)
case s: ${typeOf[String]} =>
_root_.com.spotify.scio.bigquery.Numeric(s)
case _ => $fail
}"""
case t if t =:= typeOf[ByteString] =>
val b =
q"_root_.com.google.common.io.BaseEncoding.base64().decode($s)"
q"_root_.com.google.protobuf.ByteString.copyFrom($b)"
val s = TermName("s")
q"""$tree match {
case bs: ${typeOf[ByteString]} => bs
case $s: ${typeOf[String]} =>
_root_.com.google.protobuf.ByteString.copyFrom(${readBase64(s)})
case _ => $fail
}"""
case t if t =:= typeOf[Array[Byte]] =>
q"_root_.com.google.common.io.BaseEncoding.base64().decode($s)"

val s = TermName("s")
q"""$tree match {
case bs: ${typeOf[Array[Byte]]} => bs
case $s: ${typeOf[String]} =>
${readBase64(s)}
case _ => $fail
}"""
case t if t =:= typeOf[Instant] =>
q"_root_.com.spotify.scio.bigquery.Timestamp.parse($s)"
q"""$tree match {
case i: ${typeOf[Instant]} => i
case s: ${typeOf[String]} =>
_root_.com.spotify.scio.bigquery.Timestamp.parse(s)
case _ => $fail
}"""
case t if t =:= typeOf[LocalDate] =>
q"_root_.com.spotify.scio.bigquery.Date.parse($s)"
q"""$tree match {
case ld: ${typeOf[LocalDate]} => ld
case s: ${typeOf[String]} =>
_root_.com.spotify.scio.bigquery.Date.parse(s)
case _ => $fail
}"""
case t if t =:= typeOf[LocalTime] =>
q"_root_.com.spotify.scio.bigquery.Time.parse($s)"
q"""$tree match {
case lt: ${typeOf[LocalTime]} => lt
case s: ${typeOf[String]} =>
_root_.com.spotify.scio.bigquery.Time.parse(s)
case _ => $fail
}"""
case t if t =:= typeOf[LocalDateTime] =>
q"_root_.com.spotify.scio.bigquery.DateTime.parse($s)"

q"""$tree match {
case ldt: ${typeOf[LocalDateTime]} => ldt
case s: ${typeOf[String]} =>
_root_.com.spotify.scio.bigquery.DateTime.parse(s)
case _ => $fail
}"""
// different than nested record match below, even though those are case classes
case t if t =:= typeOf[Geography] =>
q"_root_.com.spotify.scio.bigquery.types.Geography($s)"
q"""$tree match {
case g: ${typeOf[Geography]} => g
case s: ${typeOf[String]} =>
_root_.com.spotify.scio.bigquery.types.Geography(s)
case _ => $fail
}"""
case t if t =:= typeOf[Json] =>
q"_root_.com.spotify.scio.bigquery.types.Json($s)"
q"""$tree match {
case j: ${typeOf[Json]} => j
case tr: ${typeOf[TableRow]} =>
_root_.com.spotify.scio.bigquery.types.Json(tr)
case s: ${typeOf[String]} =>
_root_.com.spotify.scio.bigquery.types.Json(s)
case _ => $fail
}"""
case t if t =:= typeOf[BigNumeric] =>
q"_root_.com.spotify.scio.bigquery.types.BigNumeric($s)"

q"""$tree match {
case bn: ${typeOf[BigNumeric]} => bn
case bd: ${typeOf[BigDecimal]} =>
_root_.com.spotify.scio.bigquery.types.BigNumeric(bd)
case s: ${typeOf[String]} =>
_root_.com.spotify.scio.bigquery.types.BigNumeric(s)
case _ => $fail
}"""
case t if isCaseClass(c)(t) => // nested records
val fn = TermName("r" + t.typeSymbol.name)
q"""{
val $fn = $tree.asInstanceOf[_root_.java.util.Map[String, AnyRef]]
${constructor(t, fn)}
val r = TermName("record")
q"""$tree match {
case $r: ${typeOf[java.util.Map[String, AnyRef]]} => ${constructor(t, r)}
case _ => $fail
}
"""
case _ => c.abort(c.enclosingPosition, s"Unsupported type: $tpe")
}
}

def option(tree: Tree, tpe: Type): Tree =
q"if ($tree == null) None else Some(${cast(tree, tpe)})"
q"_root_.scala.Option($tree).map(x => ${cast(q"x", tpe)})"

def list(tree: Tree, tpe: Type): Tree = {
val jl = tq"_root_.java.util.List[AnyRef]"
q"asScala($tree.asInstanceOf[$jl].iterator).map(x => ${cast(q"x", tpe)}).toList"
}
def list(tree: Tree, tpe: Type): Tree =
q"asScala($tree.asInstanceOf[${typeOf[java.util.List[AnyRef]]}].iterator).map(x => ${cast(q"x", tpe)}).toList"

def field(symbol: Symbol, fn: TermName): Tree = {
val name = symbol.name.toString
val tpe = symbol.asMethod.returnType

val tree = q"$fn.get($name)"
def nonNullTree(fType: String) =
def nonNullTree(fType: String) = {
val msg = Constant(s"$fType field '$name' is null")
q"""{
val v = $fn.get($name)
if (v == null) {
throw new NullPointerException($fType + " field \"" + $name + "\" is null")
throw new NullPointerException($msg)
}
v
}"""
}

if (tpe.erasure =:= typeOf[Option[_]].erasure) {
option(tree, tpe.typeArgs.head)
} else if (tpe.erasure =:= typeOf[List[_]].erasure) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@

package com.spotify.scio.bigquery

import com.fasterxml.jackson.databind.{ObjectMapper, SerializationFeature}
import com.fasterxml.jackson.datatype.joda.JodaModule
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule
import com.spotify.scio.coders.Coder
import org.apache.avro.Conversions.DecimalConversion
import org.apache.avro.LogicalTypes
import org.apache.beam.sdk.extensions.gcp.util.Transport
import org.typelevel.scalaccompat.annotation.nowarn

import java.math.MathContext
Expand Down Expand Up @@ -63,11 +65,15 @@ package object types {
*/
case class Json(wkt: String)
object Json {
@transient
private lazy val jsonFactory = Transport.getJsonFactory
// Use same mapper as the TableRowJsonCoder
private lazy val mapper = new ObjectMapper()
.registerModule(new JavaTimeModule())
.registerModule(new JodaModule())
.disable(SerializationFeature.WRITE_DATES_AS_TIMESTAMPS)
.disable(SerializationFeature.FAIL_ON_EMPTY_BEANS);

def apply(row: TableRow): Json = Json(jsonFactory.toString(row))
def parse(json: Json): TableRow = jsonFactory.fromString(json.wkt, classOf[TableRow])
def apply(row: TableRow): Json = Json(mapper.writeValueAsString(row))
def parse(json: Json): TableRow = mapper.readValue(json.wkt, classOf[TableRow])
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ final class ConverterProviderSpec
}
implicit val arbJson: Arbitrary[Json] = Arbitrary(
for {
key <- Gen.alphaStr
// f is a field from TableRow.
// Jackson ObjectMapper will fail with such key
key <- Gen.alphaStr.retryUntil(_ != "f")
value <- Gen.alphaStr
} yield Json(s"""{"$key":"$value"}""")
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package com.spotify.scio.bigquery.types

import com.fasterxml.jackson.databind.node.{JsonNodeFactory, ObjectNode}
import com.google.protobuf.ByteString
import com.spotify.scio.bigquery._
import org.joda.time.{Instant, LocalDate, LocalDateTime, LocalTime}
Expand All @@ -30,7 +29,7 @@ class ConverterProviderTest extends AnyFlatSpec with Matchers {
"ConverterProvider" should "throw NPE with meaningful message for null in REQUIRED field" in {
the[NullPointerException] thrownBy {
Required.fromTableRow(TableRow())
} should have message """REQUIRED field "a" is null"""
} should have message "REQUIRED field 'a' is null"
}

it should "handle null in NULLABLE field" in {
Expand All @@ -40,7 +39,7 @@ class ConverterProviderTest extends AnyFlatSpec with Matchers {
it should "throw NPE with meaningful message for null in REPEATED field" in {
the[NullPointerException] thrownBy {
Repeated.fromTableRow(TableRow())
} should have message """REPEATED field "a" is null"""
} should have message "REPEATED field 'a' is null"
}

it should "handle required geography type" in {
Expand All @@ -51,14 +50,12 @@ class ConverterProviderTest extends AnyFlatSpec with Matchers {

it should "handle required json type" in {
val wkt = """{"name":"Alice","age":30}"""
val jsNodeFactory = new JsonNodeFactory(false)
val jackson = jsNodeFactory
.objectNode()
.set[ObjectNode]("name", jsNodeFactory.textNode("Alice"))
.set[ObjectNode]("age", jsNodeFactory.numberNode(30))

RequiredJson.fromTableRow(TableRow("a" -> jackson)) shouldBe RequiredJson(Json(wkt))
BigQueryType.toTableRow[RequiredJson](RequiredJson(Json(wkt))) shouldBe TableRow("a" -> jackson)
val parsed = new TableRow()
.set("name", "Alice")
.set("age", 30)

RequiredJson.fromTableRow(TableRow("a" -> parsed)) shouldBe RequiredJson(Json(wkt))
BigQueryType.toTableRow[RequiredJson](RequiredJson(Json(wkt))) shouldBe TableRow("a" -> parsed)
}

it should "handle required big numeric type" in {
Expand Down

0 comments on commit d76d226

Please sign in to comment.