Skip to content

Commit

Permalink
refactor: rework encoding for more clear & compact resolving unions
Browse files Browse the repository at this point in the history
  • Loading branch information
Chuckame committed Sep 15, 2024
1 parent 3384199 commit 6e8d555
Show file tree
Hide file tree
Showing 23 changed files with 1,186 additions and 1,452 deletions.
21 changes: 2 additions & 19 deletions api/avro4k-core.api
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ public abstract class com/github/avrokotlin/avro4k/Avro : kotlinx/serialization/
public fun encodeToByteArray (Lkotlinx/serialization/SerializationStrategy;Ljava/lang/Object;)[B
public final fun encodeToByteArray (Lorg/apache/avro/Schema;Lkotlinx/serialization/SerializationStrategy;Ljava/lang/Object;)[B
public final fun getConfiguration ()Lcom/github/avrokotlin/avro4k/AvroConfiguration;
public fun getSerializersModule ()Lkotlinx/serialization/modules/SerializersModule;
public final fun getSerializersModule ()Lkotlinx/serialization/modules/SerializersModule;
public final fun schema (Lkotlinx/serialization/descriptors/SerialDescriptor;)Lorg/apache/avro/Schema;
}

Expand Down Expand Up @@ -113,10 +113,9 @@ public synthetic class com/github/avrokotlin/avro4k/AvroDoc$Impl : com/github/av
}

public abstract interface class com/github/avrokotlin/avro4k/AvroEncoder : kotlinx/serialization/encoding/Encoder {
public abstract fun encodeBytes (Ljava/nio/ByteBuffer;)V
public abstract fun encodeBytes ([B)V
public abstract fun encodeFixed (Lorg/apache/avro/generic/GenericFixed;)V
public abstract fun encodeFixed ([B)V
public abstract fun encodeUnionIndex (I)V
public abstract fun getCurrentWriterSchema ()Lorg/apache/avro/Schema;
}

Expand All @@ -127,11 +126,6 @@ public final class com/github/avrokotlin/avro4k/AvroEncoder$DefaultImpls {
public static fun encodeSerializableValue (Lcom/github/avrokotlin/avro4k/AvroEncoder;Lkotlinx/serialization/SerializationStrategy;Ljava/lang/Object;)V
}

public final class com/github/avrokotlin/avro4k/AvroEncoderKt {
public static final fun encodeResolving (Lcom/github/avrokotlin/avro4k/AvroEncoder;Lkotlin/jvm/functions/Function0;Lkotlin/jvm/functions/Function1;)Ljava/lang/Object;
public static final fun resolveUnion (Lcom/github/avrokotlin/avro4k/AvroEncoder;Lorg/apache/avro/Schema;Lkotlin/jvm/functions/Function0;Lkotlin/jvm/functions/Function1;)Ljava/lang/Object;
}

public abstract interface annotation class com/github/avrokotlin/avro4k/AvroEnumDefault : java/lang/annotation/Annotation {
}

Expand Down Expand Up @@ -330,17 +324,6 @@ public final class com/github/avrokotlin/avro4k/UnionDecoder$DefaultImpls {
public static fun decodeSerializableValue (Lcom/github/avrokotlin/avro4k/UnionDecoder;Lkotlinx/serialization/DeserializationStrategy;)Ljava/lang/Object;
}

public abstract interface class com/github/avrokotlin/avro4k/UnionEncoder : com/github/avrokotlin/avro4k/AvroEncoder {
public abstract fun encodeUnionIndex (I)V
}

public final class com/github/avrokotlin/avro4k/UnionEncoder$DefaultImpls {
public static fun beginCollection (Lcom/github/avrokotlin/avro4k/UnionEncoder;Lkotlinx/serialization/descriptors/SerialDescriptor;I)Lkotlinx/serialization/encoding/CompositeEncoder;
public static fun encodeNotNullMark (Lcom/github/avrokotlin/avro4k/UnionEncoder;)V
public static fun encodeNullableSerializableValue (Lcom/github/avrokotlin/avro4k/UnionEncoder;Lkotlinx/serialization/SerializationStrategy;Ljava/lang/Object;)V
public static fun encodeSerializableValue (Lcom/github/avrokotlin/avro4k/UnionEncoder;Lkotlinx/serialization/SerializationStrategy;Ljava/lang/Object;)V
}

public final class com/github/avrokotlin/avro4k/serializer/AvroDuration {
public static final field Companion Lcom/github/avrokotlin/avro4k/serializer/AvroDuration$Companion;
public synthetic fun <init> (IIILkotlin/jvm/internal/DefaultConstructorMarker;)V
Expand Down
16 changes: 8 additions & 8 deletions benchmark/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,21 @@ c.g.a.b.complex.Avro4kBenchmark.read thrpt 5 23
c.g.a.b.complex.ApacheAvroReflectBenchmark.read thrpt 5 21124.413 ± 274.425 ops/s -10.90%
c.g.a.b.complex.Avro4kGenericWithApacheAvroBenchmark.read thrpt 5 14314.182 ± 455.019 ops/s -39.60%
c.g.a.b.complex.Avro4kBenchmark.write thrpt 5 53483.657 ± 1015.416 ops/s 0.00%
c.g.a.b.complex.ApacheAvroReflectBenchmark.write thrpt 5 46724.347 ± 2060.184 ops/s -12.64%
c.g.a.b.complex.JacksonAvroBenchmark.write thrpt 5 36294.736 ± 378.844 ops/s -32.12%
c.g.a.b.complex.Avro4kGenericWithApacheAvroBenchmark.write thrpt 5 27472.078 ± 986.683 ops/s -48.63%
c.g.a.b.complex.Avro4kBenchmark.write thrpt 5 54341.631 ± 1033.605 ops/s 0.00%
c.g.a.b.complex.ApacheAvroReflectBenchmark.write thrpt 5 49805.980 ± 1783.130 ops/s -8.35%
c.g.a.b.complex.JacksonAvroBenchmark.write thrpt 5 34076.802 ± 1358.108 ops/s -37.31%
c.g.a.b.complex.Avro4kGenericWithApacheAvroBenchmark.write thrpt 5 23874.900 ± 7088.413 ops/s -56.06%
c.g.a.b.simple.Avro4kSimpleBenchmark.read thrpt 5 144353.049 ± 3769.344 ops/s 0.00%
c.g.a.b.simple.ApacheAvroReflectSimpleBenchmark.read thrpt 5 138120.480 ± 4272.476 ops/s -4.32%
c.g.a.b.simple.Avro4kGenericWithApacheAvroSimpleBenchmark.read thrpt 5 108761.202 ± 2228.366 ops/s -24.65%
c.g.a.b.simple.JacksonAvroSimpleBenchmark.read thrpt 5 67907.379 ± 1626.214 ops/s -52.98%
c.g.a.b.simple.Avro4kSimpleBenchmark.write thrpt 5 383229.511 ± 8615.022 ops/s 0.00%
c.g.a.b.simple.ApacheAvroReflectSimpleBenchmark.write thrpt 5 241924.179 ± 6148.539 ops/s -36.88%
c.g.a.b.simple.Avro4kGenericWithApacheAvroSimpleBenchmark.write thrpt 5 151438.732 ± 5056.196 ops/s -60.48%
c.g.a.b.simple.JacksonAvroSimpleBenchmark.write thrpt 5 127715.707 ± 3748.254 ops/s -66.69%
c.g.a.b.simple.Avro4kSimpleBenchmark.write thrpt 5 403931.630 ± 5276.622 ops/s 0.00%
c.g.a.b.simple.ApacheAvroReflectSimpleBenchmark.write thrpt 5 244455.414 ± 3681.089 ops/s -39.46%
c.g.a.b.simple.Avro4kGenericWithApacheAvroSimpleBenchmark.write thrpt 5 153565.472 ± 1900.814 ops/s -61.99%
c.g.a.b.simple.JacksonAvroSimpleBenchmark.write thrpt 5 129912.932 ± 2788.534 ops/s -67.84%
```

> [!WARNING]
Expand Down
2 changes: 1 addition & 1 deletion src/main/kotlin/com/github/avrokotlin/avro4k/Avro.kt
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import java.io.ByteArrayInputStream
*/
public sealed class Avro(
public val configuration: AvroConfiguration,
public override val serializersModule: SerializersModule,
public final override val serializersModule: SerializersModule,
) : BinaryFormat {
// We use the identity hash map because we could have multiple descriptors with the same name, especially
// when having 2 different version of the schema for the same name. kotlinx-serialization is instantiating the descriptors
Expand Down
199 changes: 150 additions & 49 deletions src/main/kotlin/com/github/avrokotlin/avro4k/AvroEncoder.kt
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package com.github.avrokotlin.avro4k

import com.github.avrokotlin.avro4k.internal.aliases
import com.github.avrokotlin.avro4k.internal.isNamedSchema
import com.github.avrokotlin.avro4k.internal.nonNullSerialName
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.SerializationException
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.encoding.Encoder
import org.apache.avro.Schema
import org.apache.avro.generic.GenericFixed
import java.nio.ByteBuffer

/**
* Interface to encode Avro values.
Expand All @@ -22,9 +25,6 @@ import java.nio.ByteBuffer
* - [encodeEnum]
* - [encodeBytes]
* - [encodeFixed]
*
* Use the following methods to allow complex encoding using raw values, mainly for logical types:
* - [encodeResolving]
*/
public interface AvroEncoder : Encoder {
/**
Expand All @@ -33,12 +33,6 @@ public interface AvroEncoder : Encoder {
@ExperimentalSerializationApi
public val currentWriterSchema: Schema

/**
* Encodes a [Schema.Type.BYTES] value from a [ByteBuffer].
*/
@ExperimentalSerializationApi
public fun encodeBytes(value: ByteBuffer)

/**
* Encodes a [Schema.Type.BYTES] value from a [ByteArray].
*/
Expand All @@ -47,61 +41,168 @@ public interface AvroEncoder : Encoder {

/**
* Encodes a [Schema.Type.FIXED] value from a [ByteArray]. Its size must match the size of the fixed schema in [currentWriterSchema].
* When many fixed schemas are in a union, the first one that matches the size is selected. To avoid this auto-selection, use [encodeUnionIndex] with the index of the expected fixed schema.
*/
@ExperimentalSerializationApi
public fun encodeFixed(value: ByteArray)

/**
* Encodes a [Schema.Type.FIXED] value from a [GenericFixed]. Its size must match the size of the fixed schema in [currentWriterSchema].
* Selects the index of the union type to encode. Also sets [currentWriterSchema] to the selected type.
*/
@ExperimentalSerializationApi
public fun encodeFixed(value: GenericFixed)
public fun encodeUnionIndex(index: Int)
}

@PublishedApi
internal interface UnionEncoder : AvroEncoder {
/**
* Encode the selected union schema and set the selected type in [currentWriterSchema].
*/
fun encodeUnionIndex(index: Int)
internal fun AvroEncoder.namedSchemaNotFoundInUnionError(
expectedName: String,
possibleAliases: Set<String>,
vararg fallbackTypes: Schema.Type,
): Throwable {
val aliasesStr = if (possibleAliases.isNotEmpty()) " (with aliases ${possibleAliases.joinToString()})" else ""
val fallbacksStr = if (fallbackTypes.isNotEmpty()) " Also no compatible type found (one of ${fallbackTypes.joinToString()})." else ""
return SerializationException("Named schema $expectedName$aliasesStr not found in union.$fallbacksStr Actual schema: $currentWriterSchema")
}

internal fun AvroEncoder.typeNotFoundInUnionError(
mainType: Schema.Type,
vararg fallbackTypes: Schema.Type,
): Throwable {
val fallbacksStr = if (fallbackTypes.isNotEmpty()) " Also no compatible type found (one of ${fallbackTypes.joinToString()})." else ""
return SerializationException("${mainType.getName().replaceFirstChar { it.uppercase() }} type not found in union.$fallbacksStr Actual schema: $currentWriterSchema")
}

internal fun AvroEncoder.unsupportedWriterTypeError(
mainType: Schema.Type,
vararg fallbackTypes: Schema.Type,
): Throwable {
val fallbacksStr = if (fallbackTypes.isNotEmpty()) ", and also not matching to any compatible type (one of ${fallbackTypes.joinToString()})." else ""
return SerializationException(
"Unsupported schema '${currentWriterSchema.fullName}' for encoded type of ${mainType.getName()}$fallbacksStr. Actual schema: $currentWriterSchema"
)
}

internal fun AvroEncoder.ensureFixedSize(byteArray: ByteArray): ByteArray {
if (currentWriterSchema.fixedSize != byteArray.size) {
throw SerializationException("Fixed size mismatch for actual size of ${byteArray.size}. Actual schema: $currentWriterSchema")
}
return byteArray
}

internal fun AvroEncoder.fullNameOrAliasMismatchError(
fullName: String,
aliases: Set<String>,
): Throwable {
val aliasesStr = if (aliases.isNotEmpty()) " (with aliases ${aliases.joinToString()})" else ""
return SerializationException("The descriptor $fullName$aliasesStr doesn't match the schema $currentWriterSchema")
}

internal fun AvroEncoder.logicalTypeMismatchError(
logicalType: String,
type: Schema.Type,
): Throwable {
return SerializationException("Expected schema type of ${type.getName()} with logical type $logicalType but had schema $currentWriterSchema")
}

/**
* Allows you to encode a value differently depending on the schema (generally its name, type, logicalType).
* If the [AvroEncoder.currentWriterSchema] is a union, it takes **the first matching encoder** as the final encoder.
*
* This reduces the need to manually resolve the type in a union **and** not in a union.
*
* For examples, see the [com.github.avrokotlin.avro4k.serializer.BigDecimalSerializer] as it resolves a lot of types and also logical types.
*
* @param resolver A lambda that returns a lambda (the encoding lambda) that contains the logic to encode the value only when the schema matches. The encoding **MUST** be done in the encoder lambda to avoid encoding the value if it is not the right schema. Return null when it is not matching the expected schema.
* @param error A lambda that throws an exception if the encoder cannot be resolved.
* @return true is union is nullable and non-null type was selected, false otherwise
*/
@ExperimentalSerializationApi
public inline fun <T : Any> AvroEncoder.encodeResolving(
error: () -> Throwable,
resolver: (Schema) -> (() -> T)?,
): T {
val schema = currentWriterSchema
return if (schema.isUnion) {
resolveUnion(schema, error, resolver)
internal fun AvroEncoder.trySelectSingleNonNullTypeFromUnion(): Boolean {
return if (currentWriterSchema.types.size == 2) {
// optimization: A nullable union is very common
if (currentWriterSchema.types[0].type == Schema.Type.NULL) {
encodeUnionIndex(1)
true
} else if (currentWriterSchema.types[1].type == Schema.Type.NULL) {
encodeUnionIndex(0)
true
} else {
// we are in case of non-nullable union with only 2 types
false
}
} else {
resolver(schema)?.invoke() ?: throw error()
false
}
}

@PublishedApi
internal inline fun <T> AvroEncoder.resolveUnion(
schema: Schema,
error: () -> Throwable,
resolver: (Schema) -> (() -> T)?,
): T {
for (index in schema.types.indices) {
val subSchema = schema.types[index]
resolver(subSchema)?.let {
(this as UnionEncoder).encodeUnionIndex(index)
return it.invoke()
internal fun AvroEncoder.trySelectTypeFromUnion(vararg oneOf: Schema.Type): Boolean {
val index =
currentWriterSchema.getIndexTyped(*oneOf)
?: return false
encodeUnionIndex(index)
return true
}

internal fun AvroEncoder.trySelectFixedSchemaForSize(fixedSize: Int): Boolean {
currentWriterSchema.types.forEachIndexed { index, schema ->
if (schema.type == Schema.Type.FIXED && schema.fixedSize == fixedSize) {
encodeUnionIndex(index)
return true
}
}
return false
}

internal fun AvroEncoder.trySelectEnumSchemaForSymbol(symbol: String): Boolean {
currentWriterSchema.types.forEachIndexed { index, schema ->
if (schema.type == Schema.Type.ENUM && schema.hasEnumSymbol(symbol)) {
encodeUnionIndex(index)
return true
}
}
return false
}

internal fun AvroEncoder.trySelectNamedSchema(descriptor: SerialDescriptor): Boolean {
return trySelectNamedSchema(descriptor.nonNullSerialName, descriptor::aliases)
}

internal fun AvroEncoder.trySelectNamedSchema(
name: String,
aliases: () -> Set<String> = ::emptySet,
): Boolean {
val index =
currentWriterSchema.getIndexNamedOrAliased(name)
?: aliases().firstNotNullOfOrNull { currentWriterSchema.getIndexNamedOrAliased(it) }
if (index != null) {
encodeUnionIndex(index)
return true
}
return false
}

internal fun AvroEncoder.trySelectLogicalTypeFromUnion(
logicalTypeName: String,
vararg oneOf: Schema.Type,
): Boolean {
val index =
currentWriterSchema.getIndexLogicallyTyped(logicalTypeName, *oneOf)
?: return false
encodeUnionIndex(index)
return true
}

internal fun Schema.getIndexLogicallyTyped(
logicalTypeName: String,
vararg oneOf: Schema.Type,
): Int? {
return oneOf.firstNotNullOfOrNull { expectedType ->
when (expectedType) {
Schema.Type.FIXED, Schema.Type.RECORD, Schema.Type.ENUM -> types.indexOfFirst { it.type == expectedType && it.logicalType?.name == logicalTypeName }
else -> getIndexNamed(expectedType.getName())?.takeIf { types[it].logicalType?.name == logicalTypeName }
}
}
}

internal fun Schema.getIndexNamedOrAliased(expectedName: String): Int? {
return getIndexNamed(expectedName)
?: types.indexOfFirst { it.isNamedSchema() && it.aliases.contains(expectedName) }.takeIf { it >= 0 }
}

internal fun Schema.getIndexTyped(vararg oneOf: Schema.Type): Int? {
return oneOf.firstNotNullOfOrNull { expectedType ->
when (expectedType) {
Schema.Type.FIXED, Schema.Type.RECORD, Schema.Type.ENUM -> types.indexOfFirst { it.type == expectedType }
else -> getIndexNamed(expectedType.getName())
}
}
throw error()
}
Loading

0 comments on commit 6e8d555

Please sign in to comment.