Skip to content

Commit

Permalink
ZIO Schema Deriver refactoring (#69)
Browse files Browse the repository at this point in the history
* Refactor SchemaEncoder

* SchemaEncoder summoned with test

* Implement Derivers for encoders

* Implement derivers for decoders

* Clean up commented code

* WIP - refactoring tests

- issue with deriving of codecs for case classes

* WIP - fixed compilation error during encoders derivitaion

- add default primitive decoders for ValueVector

* Exhaustive list of primitive codecs with tests

* Refactor TabularSpec

* Refactor VectorSpec

* Refactor IpcSpec

* Refactor DataframeSpec

* Fix lint

* Add ammonite related stuff into .gitignore

* Use writer API for both cases to add primitive values

* Fix primitives for VectorSchemaRoot

* Refactor ValueEncoder.encodePrimitive methods (deduplicate)

* Working on list codecs for ValueVector (WIP)

* Working on optional codecs for ValueVector (WIP)

* Bring map/contramap methods back

* Check option list codec

* Make options work for list of options and option of lists (involving dirty hacks, must be refactored)

* Refactor decodeValue method

* Refactor list decoder

* Implement optional primitives for vector schema root

* Rename vector schema root's decoder/encoder constructors

* SchemaEncoderDeriver supports nested schemas

* Fix optional for vector schema root decoder

* Remove old encodePrimitive method

* Tests for "list"

* Tests for "struct"

* Introduce ValueVectorEncoder.overridePrimitive

* Simplify codec creation in tests

* Check if summoning for ValueVectorCodec works - it doesn't :(

* Return back VectorSchemaRoot codec's map/contrampa/transform methods

* Kinda fix compilation errors for different scala versions

* scalafix

* Polish ValueVectorEncoder.encoder constructor

* Make primitive codec overriding work for VectorSchemaRootCodec!

* fix formatting
  • Loading branch information
grouzen authored Aug 21, 2024
1 parent e305c15 commit 0ecdb3b
Show file tree
Hide file tree
Showing 28 changed files with 2,828 additions and 1,809 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,5 @@ private/
.metals/
.vscode/
metals.sbt
.ammonite/
*.sc
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,16 @@ import scala.jdk.CollectionConverters._

object Tabular {

def empty[A](implicit schema: ZSchema[A]): RIO[Scope with BufferAllocator, VectorSchemaRoot] =
def empty[A: ZSchema](implicit
schemaEncoder: SchemaEncoder[A]
): RIO[Scope with BufferAllocator, VectorSchemaRoot] =
ZIO.fromAutoCloseable(
ZIO.serviceWithZIO[BufferAllocator] { implicit alloc =>
for {
schema0 <- ZIO.fromEither(SchemaEncoder.schemaRoot[A])
vectors <- ZIO.foreach(schema0.getFields.asScala.toList) { f =>
schema0 <- ZIO.fromEither(schemaEncoder.encode)
vectors <- ZIO.foreach(schema0.getFields.asScala.toList) { field =>
for {
vec <- ZIO.attempt(f.createVector(alloc))
vec <- ZIO.attempt(field.createVector(alloc))
_ <- ZIO.attempt(vec.allocateNew())
} yield vec
}
Expand All @@ -27,17 +29,15 @@ object Tabular {
}
)

def fromChunk[A](chunk: Chunk[A])(implicit
schema: ZSchema[A],
def fromChunk[A: ZSchema: SchemaEncoder](chunk: Chunk[A])(implicit
encoder: VectorSchemaRootEncoder[A]
): RIO[Scope with BufferAllocator, VectorSchemaRoot] =
for {
root <- empty
_ <- encoder.encodeZIO(chunk, root)
} yield root

def fromStream[R, A](stream: ZStream[R, Throwable, A])(implicit
schema: ZSchema[A],
def fromStream[R, A: ZSchema: SchemaEncoder](stream: ZStream[R, Throwable, A])(implicit
encoder: VectorSchemaRootEncoder[A]
): RIO[R with Scope with BufferAllocator, VectorSchemaRoot] =
for {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ object Vector {
new FromChunkPartiallyApplied[V]

final class FromChunkPartiallyApplied[V <: ValueVector](private val dummy: Boolean = true) extends AnyVal {
def apply[A](chunk: Chunk[A])(implicit encoder: ValueVectorEncoder[A, V]): RIO[Scope with BufferAllocator, V] =
def apply[A](chunk: Chunk[A])(implicit encoder: ValueVectorEncoder[V, A]): RIO[Scope with BufferAllocator, V] =
encoder.encodeZIO(chunk)
}

Expand All @@ -22,7 +22,7 @@ object Vector {
final class FromStreamPartiallyApplied[V <: ValueVector](private val dummy: Boolean = true) extends AnyVal {
def apply[R, A](
stream: ZStream[R, Throwable, A]
)(implicit encoder: ValueVectorEncoder[A, V]): RIO[R with Scope with BufferAllocator, V] =
)(implicit encoder: ValueVectorEncoder[V, A]): RIO[R with Scope with BufferAllocator, V] =
for {
chunk <- stream.runCollect
vec <- fromChunk(chunk)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,126 +1,48 @@
package me.mnedokushev.zio.apache.arrow.core.codec

import org.apache.arrow.vector.types.FloatingPointPrecision
import org.apache.arrow.vector.types.pojo.{ ArrowType, Field, FieldType, Schema }
import zio.schema.{ Schema => ZSchema, StandardType }
import org.apache.arrow.vector.types.pojo.{ ArrowType, Field, FieldType, Schema => JSchema }
import zio.schema.{ Deriver, Factory, Schema, StandardType }

import scala.annotation.tailrec
import scala.annotation.nowarn
import scala.jdk.CollectionConverters._
import scala.util.control.NonFatal

trait SchemaEncoder[A] { self =>

def encode(implicit schema: Schema[A]): Either[Throwable, JSchema] =
Left(EncoderError(s"Given ZIO schema $schema mut be of type Schema.Record[A]"))

def encodeField(name: String, nullable: Boolean): Field

}

object SchemaEncoder {

def schemaRoot[A](implicit schema: ZSchema[A]): Either[Throwable, Schema] = {

@tailrec
def encodeSchema[A1](name: String, schemaField: ZSchema[A1], nullable: Boolean): Field =
schemaField match {
case ZSchema.Primitive(standardType, _) =>
encodePrimitive(name, standardType, nullable)
case ZSchema.Optional(schemaOpt, _) =>
encodeSchema(name, schemaOpt, true)
case _: ZSchema.Record[_] =>
field(name, new ArrowType.Struct, nullable)
case ZSchema.Sequence(_, _, _, _, _) =>
field(name, new ArrowType.List, nullable)
case lzy: ZSchema.Lazy[_] =>
encodeSchema(name, lzy.schema, nullable)
case other =>
throw EncoderError(s"Unsupported ZIO Schema type $other")
}

try {
val fields = schema match {
case record: ZSchema.Record[A] =>
record.fields.map { case ZSchema.Field(name, schemaField, _, _, _, _) =>
encodeSchema(name, schemaField, nullable = false)
}
case _ =>
throw EncoderError(s"Given ZIO schema mut be of type Schema.Record[Val]")
}

Right(new Schema(fields.toList.asJava))
} catch {
case encodeError: EncoderError => Left(encodeError)
case NonFatal(ex) => Left(EncoderError("Error encoding schema", Some(ex)))
}
}

private def encodePrimitive[A](name: String, standardType: StandardType[A], nullable: Boolean): Field = {
def namedField(arrowType: ArrowType) =
field(name, arrowType, nullable)

standardType match {
case StandardType.StringType =>
namedField(new ArrowType.Utf8)
case StandardType.BoolType =>
namedField(new ArrowType.Bool)
case StandardType.ByteType =>
namedField(new ArrowType.Int(8, false))
case StandardType.ShortType =>
namedField(new ArrowType.Int(16, true))
case StandardType.IntType =>
namedField(new ArrowType.Int(32, true))
case StandardType.LongType =>
namedField(new ArrowType.Int(64, true))
case StandardType.FloatType =>
namedField(new ArrowType.FloatingPoint(FloatingPointPrecision.HALF))
case StandardType.DoubleType =>
namedField(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE))
case StandardType.BinaryType =>
namedField(new ArrowType.Binary)
case StandardType.CharType =>
namedField(new ArrowType.Int(16, false))
case StandardType.UUIDType =>
namedField(new ArrowType.FixedSizeBinary(8))
case StandardType.BigDecimalType =>
namedField(new ArrowType.Decimal(11, 2, 128))
case StandardType.BigIntegerType =>
namedField(new ArrowType.FixedSizeBinary(8))
case StandardType.DayOfWeekType =>
namedField(new ArrowType.Int(3, false))
case StandardType.MonthType =>
namedField(new ArrowType.Int(4, false))
case StandardType.MonthDayType =>
namedField(new ArrowType.Int(64, false))
case StandardType.PeriodType =>
namedField(new ArrowType.FixedSizeBinary(8))
case StandardType.YearType =>
namedField(new ArrowType.Int(16, false))
case StandardType.YearMonthType =>
namedField(new ArrowType.Int(64, false))
case StandardType.ZoneIdType =>
namedField(new ArrowType.Utf8)
case StandardType.ZoneOffsetType =>
namedField(new ArrowType.Utf8)
case StandardType.DurationType =>
namedField(new ArrowType.Int(64, false))
case StandardType.InstantType =>
namedField(new ArrowType.Int(64, false))
case StandardType.LocalDateType =>
namedField(new ArrowType.Utf8)
case StandardType.LocalTimeType =>
namedField(new ArrowType.Utf8)
case StandardType.LocalDateTimeType =>
namedField(new ArrowType.Utf8)
case StandardType.OffsetTimeType =>
namedField(new ArrowType.Utf8)
case StandardType.OffsetDateTimeType =>
namedField(new ArrowType.Utf8)
case StandardType.ZonedDateTimeType =>
namedField(new ArrowType.Utf8)
case other =>
throw EncoderError(s"Unsupported ZIO Schema StandardType $other")
def primitive[A](encode0: (String, Boolean) => Field)(implicit @nowarn ev: StandardType[A]): SchemaEncoder[A] =
new SchemaEncoder[A] {

override def encodeField(name: String, nullable: Boolean): Field =
encode0(name, nullable)
}
}

private[codec] def field(name: String, arrowType: ArrowType, nullable: Boolean): Field =
new Field(name, new FieldType(nullable, arrowType, null), null)
implicit def encoder[A: Factory: Schema](deriver: Deriver[SchemaEncoder]): SchemaEncoder[A] =
implicitly[Factory[A]].derive[SchemaEncoder](deriver)

def fromDeriver[A: Factory: Schema](deriver: Deriver[SchemaEncoder]): SchemaEncoder[A] =
implicitly[Factory[A]].derive[SchemaEncoder](deriver)

def fromDefaultDeriver[A: Factory: Schema]: SchemaEncoder[A] =
fromDeriver(SchemaEncoderDeriver.default)

def fromSummonedDeriver[A: Factory: Schema]: SchemaEncoder[A] =
fromDeriver(SchemaEncoderDeriver.summoned)

def primitiveField(name: String, tpe: ArrowType.PrimitiveType, nullable: Boolean): Field =
new Field(name, new FieldType(nullable, tpe, null), null)

private[codec] def fieldNullable(name: String, arrowType: ArrowType): Field =
field(name, arrowType, nullable = true)
def listField(name: String, child: Field, nullable: Boolean): Field =
new Field(name, new FieldType(nullable, new ArrowType.List, null), List(child).asJava)

private[codec] def fieldNotNullable(name: String, arrowType: ArrowType): Field =
field(name, arrowType, nullable = false)
def structField(name: String, fields: List[Field], nullable: Boolean): Field =
new Field(name, new FieldType(nullable, new ArrowType.Struct, null), fields.asJava)

}
Loading

0 comments on commit 0ecdb3b

Please sign in to comment.