Skip to content

Commit

Permalink
Cleaned up Map Version
Browse files Browse the repository at this point in the history
  • Loading branch information
luka.jurukovski committed Nov 19, 2018
1 parent 34432c0 commit c6bef31
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ package com.sksamuel.avro4s.github

import java.io.ByteArrayOutputStream

import com.sksamuel.avro4s.{AvroFixed, AvroInputStream, AvroOutputStream, AvroSchema}
import com.sksamuel.avro4s._
import org.scalatest.{FunSuite, Matchers}

case class Data(uuid: Option[UUID])
@AvroName("Data")
case class Data(@AvroName("name")uuid: Option[UUID])
case class UUID(@AvroFixed(8) bytes: Array[Byte])

class GithubIssue193 extends FunSuite with Matchers {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,26 +1,49 @@
package com.sksamuel.avro4s

case class Anno(className: String, args: Seq[Any])
case class Anno(className: String, args: Map[String,AnyRef])

class AnnotationExtractors(annos: Seq[Anno]) {

// returns the value of the first arg from the first annotation that matches the given class
private def findFirst(clazz: Class[_]): Option[String] = annos.find(c => clazz.isAssignableFrom(Class.forName(c.className))).map(_.args.head.toString)
private def findFirst(argument: String, clazz: Class[_]): Option[String] = {

val result = annos.find(c => clazz.isAssignableFrom(Class.forName(c.className))).flatMap(_.args.get(argument).map(_.toString))
result match {
case Some(x) => println(s"Found for $argument `$x` in ${clazz.getCanonicalName}")
case None if annos.nonEmpty =>
println(s"Found None for $argument in ${clazz.getCanonicalName}")
println(annos)
case _ => Nil
}
result
}


private def findAll(argument: String, clazz: Class[_]): Seq[String] = {
val result = annos.filter(c => clazz.isAssignableFrom(Class.forName(c.className))).flatMap(_.args.get(argument).map(_.toString))
result match {
case x if result.nonEmpty => println(s"Found for $argument `$x` in ${clazz.getCanonicalName}")
case _ if annos.nonEmpty =>
println(s"Found None for $argument in ${clazz.getCanonicalName}")
println(annos)
case _ => Nil
}
result
}

private def findAll(clazz: Class[_]): Seq[String] = annos.filter(c => clazz.isAssignableFrom(Class.forName(c.className))).map(_.args.head.toString)

private def exists(clazz: Class[_]): Boolean = annos.exists(c => clazz.isAssignableFrom(Class.forName(c.className)))

def namespace: Option[String] = findFirst(classOf[AvroNamespaceable])
def doc: Option[String] = findFirst(classOf[AvroDocumentable])
def aliases: Seq[String] = findAll(classOf[AvroAliasable])
def namespace: Option[String] = findFirst("namespace", classOf[AvroNamespaceable])
def doc: Option[String] = findFirst("doc", classOf[AvroDocumentable])
def aliases: Seq[String] = findAll("alias",classOf[AvroAliasable])

def fixed: Option[Int] = findFirst(classOf[AvroFixable]).map(_.toInt)
def fixed: Option[Int] = findFirst("size",classOf[AvroFixable]).map(_.toInt)

def name: Option[String] = findFirst(classOf[AvroNameable])
def name: Option[String] = findFirst("name",classOf[AvroNameable])

def props: Map[String, String] = annos.filter(c => classOf[AvroProperty].isAssignableFrom(Class.forName(c.className))).map { anno =>
anno.args.head.toString -> anno.args(1).toString
anno.args("key").toString -> anno.args("value").toString
}.toMap

def erased: Boolean = exists(classOf[AvroErasedName])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package com.sksamuel.avro4s
import scala.reflect.internal.{Definitions, StdNames, SymbolTable}
import scala.reflect.macros.whitebox
import scala.reflect.runtime.universe
import scala.reflect.runtime.currentMirror
import scala.tools.reflect.ToolBox

class ReflectHelper[C <: whitebox.Context](val c: C) {
import c.universe
Expand Down Expand Up @@ -134,7 +136,12 @@ class ReflectHelper[C <: whitebox.Context](val c: C) {

def annotations(sym: Symbol): List[Anno] = sym.annotations.map { a =>
val name = a.tree.tpe.typeSymbol.fullName
val args = a.tree.children.tail.map(_.toString.stripPrefix("\"").stripSuffix("\""))
val tb = currentMirror.mkToolBox()

val args = tb.compile(tb.parse(a.toString)).apply() match {
case c: AvroFieldReflection => c.getAllFields
case _ => Map.empty[String, AnyRef]
}
Anno(name, args)
}

Expand All @@ -158,7 +165,11 @@ object ReflectHelper {

def annotations(sym: Symbol): List[Anno] = sym.annotations.map { a =>
val name = a.tree.tpe.typeSymbol.fullName
val args = a.tree.children.tail.map(_.toString.stripPrefix("\"").stripSuffix("\""))
val tb = currentMirror.mkToolBox()
val args = tb.compile(tb.parse(a.toString)).apply() match {
case c: AvroFieldReflection => c.getAllFields
case _ => Map.empty[String, AnyRef]
}
Anno(name, args)
}

Expand Down
11 changes: 9 additions & 2 deletions avro4s-macros/src/main/scala/com/sksamuel/avro4s/SchemaFor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ import scala.reflect.ClassTag
import scala.reflect.internal.{Definitions, StdNames, SymbolTable}
import scala.reflect.macros.whitebox
import scala.reflect.runtime.universe._
import scala.reflect.runtime.currentMirror
import scala.tools.reflect.ToolBox


/**
* A [[SchemaFor]] generates an Avro Schema for a Scala or Java type.
Expand Down Expand Up @@ -162,7 +165,7 @@ object SchemaFor extends TupleSchemaFor with CoproductSchemaFor {
override def schema: Schema = {

val annos = tag.runtimeClass.getAnnotations.toList.map { a =>
Anno(a.annotationType.getClass.getName, Nil)
Anno(a.annotationType.getClass.getName, Map.empty)
}

val extractor = new AnnotationExtractors(annos)
Expand Down Expand Up @@ -193,7 +196,11 @@ object SchemaFor extends TupleSchemaFor with CoproductSchemaFor {

val annos = typeRef.pre.typeSymbol.annotations.map { a =>
val name = a.tree.tpe.typeSymbol.fullName
val args = a.tree.children.tail.map(_.toString.stripPrefix("\"").stripSuffix("\""))
val tb = currentMirror.mkToolBox()
val args = tb.compile(tb.parse(a.toString)).apply() match {
case c: AvroFieldReflection => c.getAllFields
case _ => Map.empty[String, AnyRef]
}
Anno(name, args)
}

Expand Down
58 changes: 39 additions & 19 deletions avro4s-macros/src/main/scala/com/sksamuel/avro4s/annotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@ package com.sksamuel.avro4s

import scala.annotation.StaticAnnotation

case class AvroAlias(alias: String) extends AvroAliasable
case class AvroAlias(override val alias: String) extends AvroAliasable

trait AvroAliasable extends StaticAnnotation {
trait AvroAliasable extends AvroFieldReflection {
val alias: String
}

case class AvroDoc(doc: String) extends AvroDocumentable
case class AvroDoc(override val doc: String) extends AvroDocumentable

trait AvroDocumentable extends StaticAnnotation {
trait AvroDocumentable extends AvroFieldReflection {
val doc: String
}

Expand All @@ -21,17 +21,17 @@ trait AvroDocumentable extends StaticAnnotation {
*
* This annotation can be used in the following ways:
*
* - On a field, eg case class `Foo(@AvroField(10) name: String)`
* - On a field, eg class `Foo(@AvroField(10) name: String)`
* which results in the field `name` having schema type FIXED with
* a size of 10.
*
* - On a value type, eg `@AvroField(7) case class Foo(name: String) extends AnyVal`
* - On a value type, eg `@AvroField(7) class Foo(name: String) extends AnyVal`
* which results in all usages of the value type having schema
* FIXED with a size of 7 rather than the default.
*/
case class AvroFixed(size: Int) extends AvroFixable
case class AvroFixed(override val size: Int) extends AvroFixable

trait AvroFixable extends StaticAnnotation {
trait AvroFixable extends AvroFieldReflection {
val size: Int
}

Expand All @@ -40,12 +40,12 @@ trait AvroFixable extends StaticAnnotation {
* [[AvroName]] allows the name used by Avro to be different
* from what is defined in code.
*
* For example, if a case class defines a field z, such as
* `case class Foo(z: String)` then normally this will be
* For example, if a class defines a field z, such as
* ` class Foo(z: String)` then normally this will be
* serialized as an entry 'z' in the Avro Record.
*
* However, if the field is annotated such as
* `case class Foo(@AvroName("x") z: String)` then the entry
* ` class Foo(@AvroName("x") z: String)` then the entry
* in the Avro Record will be for 'x'.
*
* Similarly for deserialization, if a field is annotated then
Expand All @@ -64,23 +64,27 @@ trait AvroFixable extends StaticAnnotation {
* it will compare the name in the record to the annotated value.
*
*/
case class AvroName(name: String) extends AvroNameable
case class AvroName(override val name: String) extends AvroNameable

trait AvroNameable extends StaticAnnotation {
sealed trait AvroNameable extends AvroFieldReflection {
val name: String
}


case class AvroNamespace(namespace: String) extends AvroNamespaceable
case class AvroNamespace(override val namespace: String) extends AvroNamespaceable

trait AvroNamespaceable extends StaticAnnotation {
sealed trait AvroNamespaceable extends AvroFieldReflection {
val namespace: String
}

case class AvroProp(name: String, value:String) extends AvroProperty
case class AvroProp(override val key: String, override val value:String) extends AvroProperty

trait AvroProperty extends StaticAnnotation {
val name: String
case class AvroProp2(override val doc: String) extends AvroNameable with AvroDocumentable {
override val name: String = "static"
}

trait AvroProperty extends AvroFieldReflection{
val key: String
val value: String
}

Expand All @@ -97,4 +101,20 @@ trait AvroProperty extends StaticAnnotation {
* When this annotation is present on a type, the name used in the
* schema will simply be the raw type, eg `Foo`.
*/
class AvroErasedName extends StaticAnnotation
case class AvroErasedName() extends AvroFieldReflection


trait AvroFieldReflection extends StaticAnnotation {
private def getClassFields(clazz: Class[_]): Map[String, AnyRef] = {
val fields = clazz.getDeclaredFields.map(field => {
field setAccessible true
field.getName -> field.get(this)
}).toMap
if(clazz.getSuperclass != null){
fields ++ getClassFields(clazz.getSuperclass)
}
fields
}

def getAllFields = getClassFields(this.getClass)
}
1 change: 1 addition & 0 deletions project/GlobalPlugin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ object GlobalPlugin extends AutoPlugin {
javacOptions := Seq("-source", "1.8", "-target", "1.8"),
libraryDependencies ++= Seq(
"org.scala-lang" % "scala-reflect" % scalaVersion.value,
"org.scala-lang" % "scala-compiler" % scalaVersion.value,
"org.apache.avro" % "avro" % AvroVersion,
"org.slf4j" % "slf4j-api" % Slf4jVersion % "test",
"log4j" % "log4j" % Log4jVersion % "test",
Expand Down

0 comments on commit c6bef31

Please sign in to comment.