From 6e8d5556f74acda7e8f0024ed9703d0c3b36d1b9 Mon Sep 17 00:00:00 2001 From: Chuckame Date: Fri, 13 Sep 2024 16:28:12 +0200 Subject: [PATCH] refactor: rework encoding for more clear & compact resolving unions --- api/avro4k-core.api | 21 +- benchmark/README.md | 16 +- .../com/github/avrokotlin/avro4k/Avro.kt | 2 +- .../github/avrokotlin/avro4k/AvroEncoder.kt | 199 ++++++-- .../avro4k/internal/RecordResolver.kt | 155 +++---- .../decoder/direct/RecordDirectDecoder.kt | 4 +- .../decoder/generic/RecordGenericDecoder.kt | 4 +- .../internal/encoder/AbstractAvroEncoder.kt | 372 +++++++++++++++ .../encoder/ReorderingCompositeEncoder.kt | 297 ++++++++++++ .../direct/AbstractAvroDirectEncoder.kt | 414 ++--------------- .../encoder/direct/RecordDirectEncoder.kt | 218 +++------ .../generic/AbstractAvroGenericEncoder.kt | 427 ++---------------- .../encoder/generic/ArrayGenericEncoder.kt | 16 +- .../generic/AvroValueGenericEncoder.kt | 15 +- .../encoder/generic/BytesGenericEncoder.kt | 26 -- .../encoder/generic/FixedGenericEncoder.kt | 37 -- .../encoder/generic/MapGenericEncoder.kt | 2 +- .../encoder/generic/RecordGenericEncoder.kt | 41 +- .../avrokotlin/avro4k/internal/exceptions.kt | 18 - .../avrokotlin/avro4k/internal/helpers.kt | 9 +- .../avro4k/serializer/AvroDuration.kt | 41 +- .../serializer/JavaStdLibSerializers.kt | 122 ++--- .../avro4k/serializer/JavaTimeSerializers.kt | 182 +++----- 23 files changed, 1186 insertions(+), 1452 deletions(-) create mode 100644 src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/AbstractAvroEncoder.kt create mode 100644 src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/ReorderingCompositeEncoder.kt delete mode 100644 src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/BytesGenericEncoder.kt delete mode 100644 src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/FixedGenericEncoder.kt diff --git a/api/avro4k-core.api b/api/avro4k-core.api index d1c52c6..389560b 100644 --- a/api/avro4k-core.api +++ b/api/avro4k-core.api @@ -10,7 +10,7 @@ public abstract class com/github/avrokotlin/avro4k/Avro : kotlinx/serialization/ public fun encodeToByteArray (Lkotlinx/serialization/SerializationStrategy;Ljava/lang/Object;)[B public final fun encodeToByteArray (Lorg/apache/avro/Schema;Lkotlinx/serialization/SerializationStrategy;Ljava/lang/Object;)[B public final fun getConfiguration ()Lcom/github/avrokotlin/avro4k/AvroConfiguration; - public fun getSerializersModule ()Lkotlinx/serialization/modules/SerializersModule; + public final fun getSerializersModule ()Lkotlinx/serialization/modules/SerializersModule; public final fun schema (Lkotlinx/serialization/descriptors/SerialDescriptor;)Lorg/apache/avro/Schema; } @@ -113,10 +113,9 @@ public synthetic class com/github/avrokotlin/avro4k/AvroDoc$Impl : com/github/av } public abstract interface class com/github/avrokotlin/avro4k/AvroEncoder : kotlinx/serialization/encoding/Encoder { - public abstract fun encodeBytes (Ljava/nio/ByteBuffer;)V public abstract fun encodeBytes ([B)V - public abstract fun encodeFixed (Lorg/apache/avro/generic/GenericFixed;)V public abstract fun encodeFixed ([B)V + public abstract fun encodeUnionIndex (I)V public abstract fun getCurrentWriterSchema ()Lorg/apache/avro/Schema; } @@ -127,11 +126,6 @@ public final class com/github/avrokotlin/avro4k/AvroEncoder$DefaultImpls { public static fun encodeSerializableValue (Lcom/github/avrokotlin/avro4k/AvroEncoder;Lkotlinx/serialization/SerializationStrategy;Ljava/lang/Object;)V } -public final class com/github/avrokotlin/avro4k/AvroEncoderKt { - public static final fun encodeResolving (Lcom/github/avrokotlin/avro4k/AvroEncoder;Lkotlin/jvm/functions/Function0;Lkotlin/jvm/functions/Function1;)Ljava/lang/Object; - public static final fun resolveUnion (Lcom/github/avrokotlin/avro4k/AvroEncoder;Lorg/apache/avro/Schema;Lkotlin/jvm/functions/Function0;Lkotlin/jvm/functions/Function1;)Ljava/lang/Object; -} - public abstract interface annotation class com/github/avrokotlin/avro4k/AvroEnumDefault : java/lang/annotation/Annotation { } @@ -330,17 +324,6 @@ public final class com/github/avrokotlin/avro4k/UnionDecoder$DefaultImpls { public static fun decodeSerializableValue (Lcom/github/avrokotlin/avro4k/UnionDecoder;Lkotlinx/serialization/DeserializationStrategy;)Ljava/lang/Object; } -public abstract interface class com/github/avrokotlin/avro4k/UnionEncoder : com/github/avrokotlin/avro4k/AvroEncoder { - public abstract fun encodeUnionIndex (I)V -} - -public final class com/github/avrokotlin/avro4k/UnionEncoder$DefaultImpls { - public static fun beginCollection (Lcom/github/avrokotlin/avro4k/UnionEncoder;Lkotlinx/serialization/descriptors/SerialDescriptor;I)Lkotlinx/serialization/encoding/CompositeEncoder; - public static fun encodeNotNullMark (Lcom/github/avrokotlin/avro4k/UnionEncoder;)V - public static fun encodeNullableSerializableValue (Lcom/github/avrokotlin/avro4k/UnionEncoder;Lkotlinx/serialization/SerializationStrategy;Ljava/lang/Object;)V - public static fun encodeSerializableValue (Lcom/github/avrokotlin/avro4k/UnionEncoder;Lkotlinx/serialization/SerializationStrategy;Ljava/lang/Object;)V -} - public final class com/github/avrokotlin/avro4k/serializer/AvroDuration { public static final field Companion Lcom/github/avrokotlin/avro4k/serializer/AvroDuration$Companion; public synthetic fun (IIILkotlin/jvm/internal/DefaultConstructorMarker;)V diff --git a/benchmark/README.md b/benchmark/README.md index 12b91d5..211f15d 100644 --- a/benchmark/README.md +++ b/benchmark/README.md @@ -30,10 +30,10 @@ c.g.a.b.complex.Avro4kBenchmark.read thrpt 5 23 c.g.a.b.complex.ApacheAvroReflectBenchmark.read thrpt 5 21124.413 ± 274.425 ops/s -10.90% c.g.a.b.complex.Avro4kGenericWithApacheAvroBenchmark.read thrpt 5 14314.182 ± 455.019 ops/s -39.60% -c.g.a.b.complex.Avro4kBenchmark.write thrpt 5 53483.657 ± 1015.416 ops/s 0.00% -c.g.a.b.complex.ApacheAvroReflectBenchmark.write thrpt 5 46724.347 ± 2060.184 ops/s -12.64% -c.g.a.b.complex.JacksonAvroBenchmark.write thrpt 5 36294.736 ± 378.844 ops/s -32.12% -c.g.a.b.complex.Avro4kGenericWithApacheAvroBenchmark.write thrpt 5 27472.078 ± 986.683 ops/s -48.63% +c.g.a.b.complex.Avro4kBenchmark.write thrpt 5 54341.631 ± 1033.605 ops/s 0.00% +c.g.a.b.complex.ApacheAvroReflectBenchmark.write thrpt 5 49805.980 ± 1783.130 ops/s -8.35% +c.g.a.b.complex.JacksonAvroBenchmark.write thrpt 5 34076.802 ± 1358.108 ops/s -37.31% +c.g.a.b.complex.Avro4kGenericWithApacheAvroBenchmark.write thrpt 5 23874.900 ± 7088.413 ops/s -56.06% c.g.a.b.simple.Avro4kSimpleBenchmark.read thrpt 5 144353.049 ± 3769.344 ops/s 0.00% @@ -41,10 +41,10 @@ c.g.a.b.simple.ApacheAvroReflectSimpleBenchmark.read thrpt 5 138 c.g.a.b.simple.Avro4kGenericWithApacheAvroSimpleBenchmark.read thrpt 5 108761.202 ± 2228.366 ops/s -24.65% c.g.a.b.simple.JacksonAvroSimpleBenchmark.read thrpt 5 67907.379 ± 1626.214 ops/s -52.98% -c.g.a.b.simple.Avro4kSimpleBenchmark.write thrpt 5 383229.511 ± 8615.022 ops/s 0.00% -c.g.a.b.simple.ApacheAvroReflectSimpleBenchmark.write thrpt 5 241924.179 ± 6148.539 ops/s -36.88% -c.g.a.b.simple.Avro4kGenericWithApacheAvroSimpleBenchmark.write thrpt 5 151438.732 ± 5056.196 ops/s -60.48% -c.g.a.b.simple.JacksonAvroSimpleBenchmark.write thrpt 5 127715.707 ± 3748.254 ops/s -66.69% +c.g.a.b.simple.Avro4kSimpleBenchmark.write thrpt 5 403931.630 ± 5276.622 ops/s 0.00% +c.g.a.b.simple.ApacheAvroReflectSimpleBenchmark.write thrpt 5 244455.414 ± 3681.089 ops/s -39.46% +c.g.a.b.simple.Avro4kGenericWithApacheAvroSimpleBenchmark.write thrpt 5 153565.472 ± 1900.814 ops/s -61.99% +c.g.a.b.simple.JacksonAvroSimpleBenchmark.write thrpt 5 129912.932 ± 2788.534 ops/s -67.84% ``` > [!WARNING] diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/Avro.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/Avro.kt index 0112f21..4906972 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/Avro.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/Avro.kt @@ -28,7 +28,7 @@ import java.io.ByteArrayInputStream */ public sealed class Avro( public val configuration: AvroConfiguration, - public override val serializersModule: SerializersModule, + public final override val serializersModule: SerializersModule, ) : BinaryFormat { // We use the identity hash map because we could have multiple descriptors with the same name, especially // when having 2 different version of the schema for the same name. kotlinx-serialization is instantiating the descriptors diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/AvroEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/AvroEncoder.kt index 1e791ec..609b0eb 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/AvroEncoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/AvroEncoder.kt @@ -1,10 +1,13 @@ package com.github.avrokotlin.avro4k +import com.github.avrokotlin.avro4k.internal.aliases +import com.github.avrokotlin.avro4k.internal.isNamedSchema +import com.github.avrokotlin.avro4k.internal.nonNullSerialName import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.SerializationException +import kotlinx.serialization.descriptors.SerialDescriptor import kotlinx.serialization.encoding.Encoder import org.apache.avro.Schema -import org.apache.avro.generic.GenericFixed -import java.nio.ByteBuffer /** * Interface to encode Avro values. @@ -22,9 +25,6 @@ import java.nio.ByteBuffer * - [encodeEnum] * - [encodeBytes] * - [encodeFixed] - * - * Use the following methods to allow complex encoding using raw values, mainly for logical types: - * - [encodeResolving] */ public interface AvroEncoder : Encoder { /** @@ -33,12 +33,6 @@ public interface AvroEncoder : Encoder { @ExperimentalSerializationApi public val currentWriterSchema: Schema - /** - * Encodes a [Schema.Type.BYTES] value from a [ByteBuffer]. - */ - @ExperimentalSerializationApi - public fun encodeBytes(value: ByteBuffer) - /** * Encodes a [Schema.Type.BYTES] value from a [ByteArray]. */ @@ -47,61 +41,168 @@ public interface AvroEncoder : Encoder { /** * Encodes a [Schema.Type.FIXED] value from a [ByteArray]. Its size must match the size of the fixed schema in [currentWriterSchema]. + * When many fixed schemas are in a union, the first one that matches the size is selected. To avoid this auto-selection, use [encodeUnionIndex] with the index of the expected fixed schema. */ @ExperimentalSerializationApi public fun encodeFixed(value: ByteArray) /** - * Encodes a [Schema.Type.FIXED] value from a [GenericFixed]. Its size must match the size of the fixed schema in [currentWriterSchema]. + * Selects the index of the union type to encode. Also sets [currentWriterSchema] to the selected type. */ @ExperimentalSerializationApi - public fun encodeFixed(value: GenericFixed) + public fun encodeUnionIndex(index: Int) } -@PublishedApi -internal interface UnionEncoder : AvroEncoder { - /** - * Encode the selected union schema and set the selected type in [currentWriterSchema]. - */ - fun encodeUnionIndex(index: Int) +internal fun AvroEncoder.namedSchemaNotFoundInUnionError( + expectedName: String, + possibleAliases: Set, + vararg fallbackTypes: Schema.Type, +): Throwable { + val aliasesStr = if (possibleAliases.isNotEmpty()) " (with aliases ${possibleAliases.joinToString()})" else "" + val fallbacksStr = if (fallbackTypes.isNotEmpty()) " Also no compatible type found (one of ${fallbackTypes.joinToString()})." else "" + return SerializationException("Named schema $expectedName$aliasesStr not found in union.$fallbacksStr Actual schema: $currentWriterSchema") +} + +internal fun AvroEncoder.typeNotFoundInUnionError( + mainType: Schema.Type, + vararg fallbackTypes: Schema.Type, +): Throwable { + val fallbacksStr = if (fallbackTypes.isNotEmpty()) " Also no compatible type found (one of ${fallbackTypes.joinToString()})." else "" + return SerializationException("${mainType.getName().replaceFirstChar { it.uppercase() }} type not found in union.$fallbacksStr Actual schema: $currentWriterSchema") +} + +internal fun AvroEncoder.unsupportedWriterTypeError( + mainType: Schema.Type, + vararg fallbackTypes: Schema.Type, +): Throwable { + val fallbacksStr = if (fallbackTypes.isNotEmpty()) ", and also not matching to any compatible type (one of ${fallbackTypes.joinToString()})." else "" + return SerializationException( + "Unsupported schema '${currentWriterSchema.fullName}' for encoded type of ${mainType.getName()}$fallbacksStr. Actual schema: $currentWriterSchema" + ) +} + +internal fun AvroEncoder.ensureFixedSize(byteArray: ByteArray): ByteArray { + if (currentWriterSchema.fixedSize != byteArray.size) { + throw SerializationException("Fixed size mismatch for actual size of ${byteArray.size}. Actual schema: $currentWriterSchema") + } + return byteArray +} + +internal fun AvroEncoder.fullNameOrAliasMismatchError( + fullName: String, + aliases: Set, +): Throwable { + val aliasesStr = if (aliases.isNotEmpty()) " (with aliases ${aliases.joinToString()})" else "" + return SerializationException("The descriptor $fullName$aliasesStr doesn't match the schema $currentWriterSchema") +} + +internal fun AvroEncoder.logicalTypeMismatchError( + logicalType: String, + type: Schema.Type, +): Throwable { + return SerializationException("Expected schema type of ${type.getName()} with logical type $logicalType but had schema $currentWriterSchema") } /** - * Allows you to encode a value differently depending on the schema (generally its name, type, logicalType). - * If the [AvroEncoder.currentWriterSchema] is a union, it takes **the first matching encoder** as the final encoder. - * - * This reduces the need to manually resolve the type in a union **and** not in a union. - * - * For examples, see the [com.github.avrokotlin.avro4k.serializer.BigDecimalSerializer] as it resolves a lot of types and also logical types. - * - * @param resolver A lambda that returns a lambda (the encoding lambda) that contains the logic to encode the value only when the schema matches. The encoding **MUST** be done in the encoder lambda to avoid encoding the value if it is not the right schema. Return null when it is not matching the expected schema. - * @param error A lambda that throws an exception if the encoder cannot be resolved. + * @return true is union is nullable and non-null type was selected, false otherwise */ -@ExperimentalSerializationApi -public inline fun AvroEncoder.encodeResolving( - error: () -> Throwable, - resolver: (Schema) -> (() -> T)?, -): T { - val schema = currentWriterSchema - return if (schema.isUnion) { - resolveUnion(schema, error, resolver) +internal fun AvroEncoder.trySelectSingleNonNullTypeFromUnion(): Boolean { + return if (currentWriterSchema.types.size == 2) { + // optimization: A nullable union is very common + if (currentWriterSchema.types[0].type == Schema.Type.NULL) { + encodeUnionIndex(1) + true + } else if (currentWriterSchema.types[1].type == Schema.Type.NULL) { + encodeUnionIndex(0) + true + } else { + // we are in case of non-nullable union with only 2 types + false + } } else { - resolver(schema)?.invoke() ?: throw error() + false } } -@PublishedApi -internal inline fun AvroEncoder.resolveUnion( - schema: Schema, - error: () -> Throwable, - resolver: (Schema) -> (() -> T)?, -): T { - for (index in schema.types.indices) { - val subSchema = schema.types[index] - resolver(subSchema)?.let { - (this as UnionEncoder).encodeUnionIndex(index) - return it.invoke() +internal fun AvroEncoder.trySelectTypeFromUnion(vararg oneOf: Schema.Type): Boolean { + val index = + currentWriterSchema.getIndexTyped(*oneOf) + ?: return false + encodeUnionIndex(index) + return true +} + +internal fun AvroEncoder.trySelectFixedSchemaForSize(fixedSize: Int): Boolean { + currentWriterSchema.types.forEachIndexed { index, schema -> + if (schema.type == Schema.Type.FIXED && schema.fixedSize == fixedSize) { + encodeUnionIndex(index) + return true + } + } + return false +} + +internal fun AvroEncoder.trySelectEnumSchemaForSymbol(symbol: String): Boolean { + currentWriterSchema.types.forEachIndexed { index, schema -> + if (schema.type == Schema.Type.ENUM && schema.hasEnumSymbol(symbol)) { + encodeUnionIndex(index) + return true + } + } + return false +} + +internal fun AvroEncoder.trySelectNamedSchema(descriptor: SerialDescriptor): Boolean { + return trySelectNamedSchema(descriptor.nonNullSerialName, descriptor::aliases) +} + +internal fun AvroEncoder.trySelectNamedSchema( + name: String, + aliases: () -> Set = ::emptySet, +): Boolean { + val index = + currentWriterSchema.getIndexNamedOrAliased(name) + ?: aliases().firstNotNullOfOrNull { currentWriterSchema.getIndexNamedOrAliased(it) } + if (index != null) { + encodeUnionIndex(index) + return true + } + return false +} + +internal fun AvroEncoder.trySelectLogicalTypeFromUnion( + logicalTypeName: String, + vararg oneOf: Schema.Type, +): Boolean { + val index = + currentWriterSchema.getIndexLogicallyTyped(logicalTypeName, *oneOf) + ?: return false + encodeUnionIndex(index) + return true +} + +internal fun Schema.getIndexLogicallyTyped( + logicalTypeName: String, + vararg oneOf: Schema.Type, +): Int? { + return oneOf.firstNotNullOfOrNull { expectedType -> + when (expectedType) { + Schema.Type.FIXED, Schema.Type.RECORD, Schema.Type.ENUM -> types.indexOfFirst { it.type == expectedType && it.logicalType?.name == logicalTypeName } + else -> getIndexNamed(expectedType.getName())?.takeIf { types[it].logicalType?.name == logicalTypeName } + } + } +} + +internal fun Schema.getIndexNamedOrAliased(expectedName: String): Int? { + return getIndexNamed(expectedName) + ?: types.indexOfFirst { it.isNamedSchema() && it.aliases.contains(expectedName) }.takeIf { it >= 0 } +} + +internal fun Schema.getIndexTyped(vararg oneOf: Schema.Type): Int? { + return oneOf.firstNotNullOfOrNull { expectedType -> + when (expectedType) { + Schema.Type.FIXED, Schema.Type.RECORD, Schema.Type.ENUM -> types.indexOfFirst { it.type == expectedType } + else -> getIndexNamed(expectedType.getName()) } } - throw error() } \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/RecordResolver.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/RecordResolver.kt index 1614fb4..9526941 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/RecordResolver.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/RecordResolver.kt @@ -3,6 +3,7 @@ package com.github.avrokotlin.avro4k.internal import com.github.avrokotlin.avro4k.Avro import com.github.avrokotlin.avro4k.AvroAlias import com.github.avrokotlin.avro4k.AvroDefault +import com.github.avrokotlin.avro4k.internal.encoder.ReorderingCompositeEncoder import com.github.avrokotlin.avro4k.internal.schema.CHAR_LOGICAL_TYPE_NAME import kotlinx.serialization.SerializationException import kotlinx.serialization.descriptors.SerialDescriptor @@ -31,7 +32,7 @@ internal class RecordResolver( * * Note: We use the descriptor in the key as we could have multiple descriptors for the same record schema, and multiple record schemas for the same descriptor. */ - private val fieldCache: MutableMap> = WeakIdentityHashMap() + private val fieldCache: MutableMap> = WeakIdentityHashMap() /** * Maps the class fields to the schema fields. @@ -48,9 +49,9 @@ internal class RecordResolver( fun resolveFields( writerSchema: Schema, classDescriptor: SerialDescriptor, - ): ClassDescriptorForWriterSchema { - if (classDescriptor.elementsCount == 0) { - return ClassDescriptorForWriterSchema.EMPTY + ): SerializationWorkflow { + if (classDescriptor.elementsCount == 0 && writerSchema.fields.isEmpty()) { + return SerializationWorkflow.EMPTY } return fieldCache.getOrPut(classDescriptor) { WeakHashMap() }.getOrPut(writerSchema) { loadCache(classDescriptor, writerSchema) @@ -69,45 +70,15 @@ internal class RecordResolver( private fun loadCache( classDescriptor: SerialDescriptor, writerSchema: Schema, - ): ClassDescriptorForWriterSchema { + ): SerializationWorkflow { val readerSchema = avro.schema(classDescriptor) - val encodingSteps = computeEncodingSteps(classDescriptor, writerSchema) - return ClassDescriptorForWriterSchema( - sequentialEncoding = encodingSteps.areWriterFieldsSequentiallyOrdered(), + return SerializationWorkflow( computeDecodingSteps(classDescriptor, writerSchema, readerSchema), - encodingSteps + computeEncodingWorkflow(classDescriptor, writerSchema) ) } - private fun Array.areWriterFieldsSequentiallyOrdered(): Boolean { - var lastWriterFieldIndex = -1 - forEach { step -> - when (step) { - is EncodingStep.SerializeWriterField -> { - if (step.writerFieldIndex > lastWriterFieldIndex) { - lastWriterFieldIndex = step.writerFieldIndex - } else { - return false - } - } - - is EncodingStep.MissingWriterFieldFailure -> { - if (step.writerFieldIndex > lastWriterFieldIndex) { - lastWriterFieldIndex = step.writerFieldIndex - } else { - return false - } - } - - is EncodingStep.IgnoreElement -> { - // nothing to check - } - } - } - return true - } - private fun computeDecodingSteps( classDescriptor: SerialDescriptor, writerSchema: Schema, @@ -175,18 +146,19 @@ internal class RecordResolver( return decodingSteps.toTypedArray() } - private fun Schema.isTypeOf(expectedType: Schema.Type): Boolean { - return asSchemaList().any { it.type === expectedType } - } + private fun Schema.isTypeOf(expectedType: Schema.Type): Boolean = asSchemaList().any { it.type === expectedType } - private fun computeEncodingSteps( + private fun computeEncodingWorkflow( classDescriptor: SerialDescriptor, writerSchema: Schema, - ): Array { + ): EncodingWorkflow { // Encoding steps are ordered regarding the class descriptor and not the writer schema. // Because kotlinx-serialization doesn't provide a way to encode non-sequentially elements. - val encodingSteps = mutableListOf() + val missingWriterFieldsIndexes = mutableListOf() val visitedWriterFields = BooleanArray(writerSchema.fields.size) { false } + val descriptorToWriterFieldIndex = IntArray(classDescriptor.elementsCount) { ReorderingCompositeEncoder.SKIP_ELEMENT_INDEX } + + var expectedNextWriterIndex = 0 classDescriptor.elementNames.forEachIndexed { elementIndex, _ -> val avroFieldName = avro.configuration.fieldNamingStrategy.resolve(classDescriptor, elementIndex) @@ -194,24 +166,32 @@ internal class RecordResolver( if (writerField != null) { visitedWriterFields[writerField.pos()] = true - encodingSteps += - EncodingStep.SerializeWriterField( - elementIndex = elementIndex, - writerFieldIndex = writerField.pos(), - schema = writerField.schema() - ) - } else { - encodingSteps += EncodingStep.IgnoreElement(elementIndex) + descriptorToWriterFieldIndex[elementIndex] = writerField.pos() + if (expectedNextWriterIndex != -1) { + if (writerField.pos() != expectedNextWriterIndex) { + expectedNextWriterIndex = -1 + } else { + expectedNextWriterIndex++ + } + } } } visitedWriterFields.forEachIndexed { writerFieldIndex, visited -> if (!visited) { - encodingSteps += EncodingStep.MissingWriterFieldFailure(writerFieldIndex) + missingWriterFieldsIndexes += writerFieldIndex } } - return encodingSteps.toTypedArray() + return if (missingWriterFieldsIndexes.isNotEmpty()) { + EncodingWorkflow.MissingWriterFields(missingWriterFieldsIndexes) + } else if (expectedNextWriterIndex == -1) { + EncodingWorkflow.NonContiguous(descriptorToWriterFieldIndex) + } else if (classDescriptor.elementsCount != writerSchema.fields.size) { + EncodingWorkflow.ContiguousWithSkips(descriptorToWriterFieldIndex.map { it == ReorderingCompositeEncoder.SKIP_ELEMENT_INDEX }.toBooleanArray()) + } else { + EncodingWorkflow.ExactMatch + } } private fun Schema.tryGetField( @@ -228,33 +208,44 @@ internal class RecordResolver( } } -internal class ClassDescriptorForWriterSchema( - /** - * If true, indicates that the encoding steps are ordered the same as the writer schema fields. - * If false, indicates that the encoding steps are **NOT** ordered the same as the writer schema fields. - */ - val sequentialEncoding: Boolean, +internal class SerializationWorkflow( /** * Decoding steps are ordered regarding the writer schema and not the class descriptor. */ - val decodingSteps: Array, + val decoding: Array, /** * Encoding steps are ordered regarding the class descriptor and not the writer schema. */ - val encodingSteps: Array, + val encoding: EncodingWorkflow, ) { - val hasMissingWriterField by lazy { encodingSteps.any { it is EncodingStep.MissingWriterFieldFailure } } - companion object { val EMPTY = - ClassDescriptorForWriterSchema( - sequentialEncoding = true, - decodingSteps = emptyArray(), - encodingSteps = emptyArray() + SerializationWorkflow( + decoding = emptyArray(), + encoding = EncodingWorkflow.ExactMatch ) } } +internal sealed interface EncodingWorkflow { + /** + * The descriptor elements exactly matches the writer schema fields as a 1-to-1 mapping. + */ + data object ExactMatch : EncodingWorkflow + + class ContiguousWithSkips( + val fieldsToSkip: BooleanArray, + ) : EncodingWorkflow + + class NonContiguous( + val descriptorToWriterFieldIndex: IntArray, + ) : EncodingWorkflow + + class MissingWriterFields( + val missingWriterFields: List, + ) : EncodingWorkflow +} + internal sealed interface DecodingStep { /** * This is a flag indicating that the element is deserializable. @@ -310,31 +301,6 @@ internal sealed interface DecodingStep { ) : DecodingStep } -internal sealed interface EncodingStep { - /** - * The element is present in the writer schema and the class descriptor. - */ - data class SerializeWriterField( - val elementIndex: Int, - val writerFieldIndex: Int, - val schema: Schema, - ) : EncodingStep - - /** - * The element is present in the class descriptor but not in the writer schema, so the element is ignored as nothing has to be serialized. - */ - data class IgnoreElement( - val elementIndex: Int, - ) : EncodingStep - - /** - * The writer field doesn't have a corresponding element in the class descriptor, so we aren't able to serialize a value. - */ - data class MissingWriterFieldFailure( - val writerFieldIndex: Int, - ) : EncodingStep -} - private fun AvroDefault.parseValueToGenericData(schema: Schema): Any? { if (value.isStartingAsJson()) { return Json.parseToJsonElement(value).convertDefaultToObject(schema) @@ -342,8 +308,8 @@ private fun AvroDefault.parseValueToGenericData(schema: Schema): Any? { return JsonPrimitive(value).convertDefaultToObject(schema) } -private fun JsonElement.convertDefaultToObject(schema: Schema): Any? { - return when (this) { +private fun JsonElement.convertDefaultToObject(schema: Schema): Any? = + when (this) { is JsonArray -> when (schema.type) { Schema.Type.ARRAY -> this.map { it.convertDefaultToObject(schema.elementType) } @@ -405,7 +371,6 @@ private fun JsonElement.convertDefaultToObject(schema: Schema): Any? { else -> throw SerializationException("Not a valid primitive value for schema $schema: $this") } } -} private fun Schema.resolveUnion( value: JsonElement?, diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/RecordDirectDecoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/RecordDirectDecoder.kt index b5bd2cf..13bf266 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/RecordDirectDecoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/RecordDirectDecoder.kt @@ -30,10 +30,10 @@ internal class RecordDirectDecoder( override fun decodeElementIndex(descriptor: SerialDescriptor): Int { var field: DecodingStep while (true) { - if (nextDecodingStepIndex == classDescriptor.decodingSteps.size) { + if (nextDecodingStepIndex == classDescriptor.decoding.size) { return CompositeDecoder.DECODE_DONE } - field = classDescriptor.decodingSteps[nextDecodingStepIndex++] + field = classDescriptor.decoding[nextDecodingStepIndex++] when (field) { is DecodingStep.IgnoreOptionalElement -> { // loop again to ignore the optional element diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/generic/RecordGenericDecoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/generic/RecordGenericDecoder.kt index 11d7740..84f06c8 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/generic/RecordGenericDecoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/generic/RecordGenericDecoder.kt @@ -37,10 +37,10 @@ internal class RecordGenericDecoder( override fun decodeElementIndex(descriptor: SerialDescriptor): Int { var field: DecodingStep do { - if (nextDecodingStep == classDescriptor.decodingSteps.size) { + if (nextDecodingStep == classDescriptor.decoding.size) { return CompositeDecoder.DECODE_DONE } - field = classDescriptor.decodingSteps[nextDecodingStep++] + field = classDescriptor.decoding[nextDecodingStep++] } while (field !is DecodingStep.ValidatedDecodingStep) currentElement = field return field.elementIndex diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/AbstractAvroEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/AbstractAvroEncoder.kt new file mode 100644 index 0000000..75bb1b8 --- /dev/null +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/AbstractAvroEncoder.kt @@ -0,0 +1,372 @@ +package com.github.avrokotlin.avro4k.internal.encoder + +import com.github.avrokotlin.avro4k.AvroEncoder +import com.github.avrokotlin.avro4k.ensureFixedSize +import com.github.avrokotlin.avro4k.fullNameOrAliasMismatchError +import com.github.avrokotlin.avro4k.getIndexTyped +import com.github.avrokotlin.avro4k.internal.SerializerLocatorMiddleware +import com.github.avrokotlin.avro4k.internal.aliases +import com.github.avrokotlin.avro4k.internal.isFullNameOrAliasMatch +import com.github.avrokotlin.avro4k.internal.nonNullSerialName +import com.github.avrokotlin.avro4k.namedSchemaNotFoundInUnionError +import com.github.avrokotlin.avro4k.trySelectEnumSchemaForSymbol +import com.github.avrokotlin.avro4k.trySelectFixedSchemaForSize +import com.github.avrokotlin.avro4k.trySelectNamedSchema +import com.github.avrokotlin.avro4k.trySelectSingleNonNullTypeFromUnion +import com.github.avrokotlin.avro4k.trySelectTypeFromUnion +import com.github.avrokotlin.avro4k.typeNotFoundInUnionError +import com.github.avrokotlin.avro4k.unsupportedWriterTypeError +import kotlinx.serialization.SerializationException +import kotlinx.serialization.SerializationStrategy +import kotlinx.serialization.descriptors.PolymorphicKind +import kotlinx.serialization.descriptors.SerialDescriptor +import kotlinx.serialization.descriptors.StructureKind +import kotlinx.serialization.encoding.AbstractEncoder +import kotlinx.serialization.encoding.CompositeEncoder +import org.apache.avro.Schema +import org.apache.avro.util.Utf8 + +internal abstract class AbstractAvroEncoder : AbstractEncoder(), AvroEncoder { + private var selectedUnionIndex: Int = -1 + + abstract override var currentWriterSchema: Schema + + abstract fun getRecordEncoder(descriptor: SerialDescriptor): CompositeEncoder + + abstract fun getPolymorphicEncoder(descriptor: SerialDescriptor): CompositeEncoder + + abstract fun getMapEncoder( + descriptor: SerialDescriptor, + collectionSize: Int, + ): CompositeEncoder + + abstract fun getArrayEncoder( + descriptor: SerialDescriptor, + collectionSize: Int, + ): CompositeEncoder + + abstract fun encodeUnionIndexInternal(index: Int) + + abstract fun encodeNullUnchecked() + + abstract fun encodeBooleanUnchecked(value: Boolean) + + abstract fun encodeIntUnchecked(value: Int) + + abstract fun encodeLongUnchecked(value: Long) + + abstract fun encodeFloatUnchecked(value: Float) + + abstract fun encodeDoubleUnchecked(value: Double) + + abstract fun encodeStringUnchecked(value: Utf8) + + abstract fun encodeBytesUnchecked(value: ByteArray) + + abstract fun encodeFixedUnchecked(value: ByteArray) + + abstract fun encodeEnumUnchecked(symbol: String) + + override fun encodeSerializableValue( + serializer: SerializationStrategy, + value: T, + ) { + SerializerLocatorMiddleware.apply(serializer) + .serialize(this, value) + } + + override fun beginStructure(descriptor: SerialDescriptor): CompositeEncoder { + return when (descriptor.kind) { + StructureKind.CLASS, + StructureKind.OBJECT, + -> { + val nameChecked: Boolean + if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { + trySelectNamedSchema(descriptor).also { nameChecked = it } || + throw namedSchemaNotFoundInUnionError(descriptor.nonNullSerialName, descriptor.aliases) + } else { + nameChecked = false + } + when (currentWriterSchema.type) { + Schema.Type.RECORD -> { + if (nameChecked || currentWriterSchema.isFullNameOrAliasMatch(descriptor)) { + getRecordEncoder(descriptor) + } else { + throw fullNameOrAliasMismatchError(descriptor.nonNullSerialName, descriptor.aliases) + } + } + + else -> throw unsupportedWriterTypeError(Schema.Type.RECORD) + } + } + + is PolymorphicKind -> getPolymorphicEncoder(descriptor) + else -> throw SerializationException("Unsupported structure kind: $descriptor") + } + } + + override fun beginCollection( + descriptor: SerialDescriptor, + collectionSize: Int, + ): CompositeEncoder { + return when (descriptor.kind) { + StructureKind.LIST -> { + if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { + trySelectTypeFromUnion(Schema.Type.ARRAY) || throw typeNotFoundInUnionError(Schema.Type.ARRAY) + } + when (currentWriterSchema.type) { + Schema.Type.ARRAY -> getArrayEncoder(descriptor, collectionSize) + else -> throw unsupportedWriterTypeError(Schema.Type.ARRAY) + } + } + + StructureKind.MAP -> { + if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { + trySelectTypeFromUnion(Schema.Type.MAP) || throw typeNotFoundInUnionError(Schema.Type.MAP) + } + when (currentWriterSchema.type) { + Schema.Type.MAP -> getMapEncoder(descriptor, collectionSize) + else -> throw unsupportedWriterTypeError(Schema.Type.MAP) + } + } + + else -> throw SerializationException("Unsupported collection kind: $descriptor") + } + } + + override fun encodeUnionIndex(index: Int) { + if (selectedUnionIndex > -1) { + throw SerializationException("Already selected union index: $selectedUnionIndex, got $index, for selected schema $currentWriterSchema") + } + currentWriterSchema = currentWriterSchema.types[index] + encodeUnionIndexInternal(index) + selectedUnionIndex = index + } + + override fun encodeElement( + descriptor: SerialDescriptor, + index: Int, + ): Boolean { + selectedUnionIndex = -1 + return true + } + + override fun encodeNull() { + if (currentWriterSchema.isUnion) { + // Generally, null types are the first or last in the union + if (currentWriterSchema.types.first().type == Schema.Type.NULL) { + encodeUnionIndex(0) + } else if (currentWriterSchema.types.last().type == Schema.Type.NULL) { + encodeUnionIndex(currentWriterSchema.types.size - 1) + } else { + val nullIndex = + currentWriterSchema.getIndexTyped(Schema.Type.NULL) + ?: throw SerializationException("Cannot encode null value for non-nullable schema: $currentWriterSchema") + encodeUnionIndex(nullIndex) + } + } else if (currentWriterSchema.type != Schema.Type.NULL) { + throw SerializationException("Cannot encode null value for non-null schema: $currentWriterSchema") + } + encodeNullUnchecked() + } + + override fun encodeBytes(value: ByteArray) { + if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { + trySelectTypeFromUnion(Schema.Type.BYTES, Schema.Type.STRING) || + trySelectFixedSchemaForSize(value.size) || + throw typeNotFoundInUnionError(Schema.Type.FIXED, Schema.Type.BYTES, Schema.Type.STRING) + } + when (currentWriterSchema.type) { + Schema.Type.BYTES -> encodeBytesUnchecked(value) + Schema.Type.STRING -> encodeStringUnchecked(Utf8(value)) + Schema.Type.FIXED -> encodeFixedUnchecked(ensureFixedSize(value)) + else -> throw unsupportedWriterTypeError(Schema.Type.BYTES, Schema.Type.STRING, Schema.Type.FIXED) + } + } + + override fun encodeFixed(value: ByteArray) { + if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { + trySelectFixedSchemaForSize(value.size) || + trySelectTypeFromUnion(Schema.Type.BYTES, Schema.Type.STRING) || + throw typeNotFoundInUnionError(Schema.Type.FIXED, Schema.Type.BYTES, Schema.Type.STRING) + } + when (currentWriterSchema.type) { + Schema.Type.FIXED -> encodeFixedUnchecked(ensureFixedSize(value)) + Schema.Type.BYTES -> encodeBytesUnchecked(value) + Schema.Type.STRING -> encodeStringUnchecked(Utf8(value)) + else -> throw unsupportedWriterTypeError(Schema.Type.FIXED, Schema.Type.BYTES, Schema.Type.STRING) + } + } + + override fun encodeEnum( + enumDescriptor: SerialDescriptor, + index: Int, + ) { + val nameChecked: Boolean + if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { + trySelectNamedSchema(enumDescriptor).also { nameChecked = it } || + trySelectTypeFromUnion(Schema.Type.STRING) || + throw namedSchemaNotFoundInUnionError( + enumDescriptor.nonNullSerialName, + enumDescriptor.aliases, + Schema.Type.STRING + ) + } else { + nameChecked = false + } + val enumName = enumDescriptor.getElementName(index) + when (currentWriterSchema.type) { + Schema.Type.ENUM -> + if (nameChecked || currentWriterSchema.isFullNameOrAliasMatch(enumDescriptor)) { + encodeEnumUnchecked(enumName) + } else { + throw fullNameOrAliasMismatchError(enumDescriptor.nonNullSerialName, enumDescriptor.aliases) + } + + Schema.Type.STRING -> encodeStringUnchecked(Utf8(enumName)) + else -> throw unsupportedWriterTypeError(Schema.Type.ENUM, Schema.Type.STRING) + } + } + + override fun encodeBoolean(value: Boolean) { + if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { + trySelectTypeFromUnion( + Schema.Type.BOOLEAN, + Schema.Type.STRING + ) || throw typeNotFoundInUnionError(Schema.Type.BOOLEAN, Schema.Type.STRING) + } + when (currentWriterSchema.type) { + Schema.Type.BOOLEAN -> encodeBooleanUnchecked(value) + Schema.Type.STRING -> encodeStringUnchecked(Utf8(value.toString())) + else -> throw unsupportedWriterTypeError(Schema.Type.BOOLEAN, Schema.Type.STRING) + } + } + + override fun encodeByte(value: Byte) { + encodeInt(value.toInt()) + } + + override fun encodeShort(value: Short) { + encodeInt(value.toInt()) + } + + override fun encodeInt(value: Int) { + if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { + trySelectTypeFromUnion( + Schema.Type.INT, + Schema.Type.LONG, + Schema.Type.FLOAT, + Schema.Type.DOUBLE, + Schema.Type.STRING + ) || + throw typeNotFoundInUnionError( + Schema.Type.INT, + Schema.Type.LONG, + Schema.Type.FLOAT, + Schema.Type.DOUBLE, + Schema.Type.STRING + ) + } + when (currentWriterSchema.type) { + Schema.Type.INT -> encodeIntUnchecked(value) + Schema.Type.LONG -> encodeLongUnchecked(value.toLong()) + Schema.Type.FLOAT -> encodeFloatUnchecked(value.toFloat()) + Schema.Type.DOUBLE -> encodeDoubleUnchecked(value.toDouble()) + Schema.Type.STRING -> encodeStringUnchecked(Utf8(value.toString())) + else -> throw unsupportedWriterTypeError( + Schema.Type.INT, + Schema.Type.LONG, + Schema.Type.FLOAT, + Schema.Type.DOUBLE, + Schema.Type.STRING + ) + } + } + + override fun encodeLong(value: Long) { + if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { + trySelectTypeFromUnion(Schema.Type.LONG, Schema.Type.FLOAT, Schema.Type.DOUBLE, Schema.Type.STRING) || + throw typeNotFoundInUnionError( + Schema.Type.LONG, + Schema.Type.FLOAT, + Schema.Type.DOUBLE, + Schema.Type.STRING + ) + } + when (currentWriterSchema.type) { + Schema.Type.LONG -> encodeLongUnchecked(value) + Schema.Type.FLOAT -> encodeFloatUnchecked(value.toFloat()) + Schema.Type.DOUBLE -> encodeDoubleUnchecked(value.toDouble()) + Schema.Type.STRING -> encodeStringUnchecked(Utf8(value.toString())) + else -> throw unsupportedWriterTypeError( + Schema.Type.LONG, + Schema.Type.FLOAT, + Schema.Type.DOUBLE, + Schema.Type.STRING + ) + } + } + + override fun encodeFloat(value: Float) { + if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { + trySelectTypeFromUnion(Schema.Type.FLOAT, Schema.Type.DOUBLE, Schema.Type.STRING) || + throw typeNotFoundInUnionError(Schema.Type.FLOAT, Schema.Type.DOUBLE, Schema.Type.STRING) + } + when (currentWriterSchema.type) { + Schema.Type.FLOAT -> encodeFloatUnchecked(value) + Schema.Type.DOUBLE -> encodeDoubleUnchecked(value.toDouble()) + Schema.Type.STRING -> encodeStringUnchecked(Utf8(value.toString())) + else -> throw unsupportedWriterTypeError(Schema.Type.FLOAT, Schema.Type.DOUBLE, Schema.Type.STRING) + } + } + + override fun encodeDouble(value: Double) { + if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { + trySelectTypeFromUnion(Schema.Type.DOUBLE, Schema.Type.STRING) || + throw typeNotFoundInUnionError(Schema.Type.DOUBLE, Schema.Type.STRING) + } + when (currentWriterSchema.type) { + Schema.Type.DOUBLE -> encodeDoubleUnchecked(value) + Schema.Type.STRING -> encodeStringUnchecked(Utf8(value.toString())) + else -> throw unsupportedWriterTypeError(Schema.Type.DOUBLE, Schema.Type.STRING) + } + } + + override fun encodeChar(value: Char) { + if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { + trySelectTypeFromUnion(Schema.Type.INT, Schema.Type.STRING) || + throw typeNotFoundInUnionError(Schema.Type.INT, Schema.Type.STRING) + } + when (currentWriterSchema.type) { + Schema.Type.INT -> encodeIntUnchecked(value.code) + Schema.Type.STRING -> encodeStringUnchecked(Utf8(value.toString())) + else -> throw unsupportedWriterTypeError(Schema.Type.INT, Schema.Type.STRING) + } + } + + override fun encodeString(value: String) { + if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { + trySelectTypeFromUnion(Schema.Type.STRING, Schema.Type.BYTES) || + trySelectFixedSchemaForSize(value.length) || + trySelectEnumSchemaForSymbol(value) || + throw typeNotFoundInUnionError( + Schema.Type.STRING, + Schema.Type.BYTES, + Schema.Type.FIXED, + Schema.Type.ENUM + ) + } + when (currentWriterSchema.type) { + Schema.Type.STRING -> encodeStringUnchecked(Utf8(value)) + Schema.Type.BYTES -> encodeBytesUnchecked(value.encodeToByteArray()) + Schema.Type.FIXED -> encodeFixedUnchecked(ensureFixedSize(value.encodeToByteArray())) + Schema.Type.ENUM -> encodeEnumUnchecked(value) + else -> throw unsupportedWriterTypeError( + Schema.Type.BYTES, + Schema.Type.STRING, + Schema.Type.FIXED, + Schema.Type.ENUM + ) + } + } +} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/ReorderingCompositeEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/ReorderingCompositeEncoder.kt new file mode 100644 index 0000000..0f84226 --- /dev/null +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/ReorderingCompositeEncoder.kt @@ -0,0 +1,297 @@ +package com.github.avrokotlin.avro4k.internal.encoder + +import com.github.avrokotlin.avro4k.internal.encoder.ReorderingCompositeEncoder.Companion.SKIP_ELEMENT_INDEX +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.SerializationStrategy +import kotlinx.serialization.descriptors.SerialDescriptor +import kotlinx.serialization.encoding.CompositeEncoder +import kotlinx.serialization.encoding.Encoder +import kotlinx.serialization.modules.EmptySerializersModule +import kotlinx.serialization.modules.SerializersModule + +/** + * Encodes composite elements in a specific order managed by [mapElementIndex]. + * + * This encoder will replicate the behavior of a standard encoding, but calling the `encode*Element` methods in + * the order defined by [mapElementIndex]. It first buffers each `encode*Element` calls in an array following + * the given indexes using [mapElementIndex], then when [endStructure] is called, it encodes the buffered calls + * in the expected order by replaying the previous calls on the given [compositeEncoderDelegate]. + * + * When [mapElementIndex] returns [SKIP_ELEMENT_INDEX], the element will be ignored and not encoded. + * + * This encoder is stateful and not designed to be reused. + * + * @param compositeEncoderDelegate the [CompositeEncoder] to be used to encode the given descriptor's elements in the expected order. + * @param encodedElementsCount The final number of elements to encode. If the mapper provides a smaller number of elements, an error will be thrown indicating the missing index. + * @param mapElementIndex maps the element index to a new positional zero-based index. If this mapper provides the same index for multiple elements, only the last one will be encoded as the previous ones will be overridden. The mapped index just helps to reorder the elements, but the reordered `encode*Element` method calls will still pass the original element index. + */ +@ExperimentalSerializationApi +internal class ReorderingCompositeEncoder( + encodedElementsCount: Int, + private val compositeEncoderDelegate: CompositeEncoder, + private val mapElementIndex: (SerialDescriptor, Int) -> Int, +) : CompositeEncoder { + private val bufferedCalls = Array(encodedElementsCount) { null } + + companion object { + @ExperimentalSerializationApi + const val SKIP_ELEMENT_INDEX: Int = -1 + } + + override val serializersModule: SerializersModule + // No need to return a serializers module as it's not used during buffering + get() = EmptySerializersModule() + + private data class BufferedCall( + val originalElementIndex: Int, + val encoder: () -> Unit, + ) + + private fun bufferEncoding( + descriptor: SerialDescriptor, + index: Int, + encoder: () -> Unit, + ) { + val newIndex = mapElementIndex(descriptor, index) + if (newIndex != SKIP_ELEMENT_INDEX) { + bufferedCalls[newIndex] = BufferedCall(index, encoder) + } + } + + override fun endStructure(descriptor: SerialDescriptor) { + bufferedCalls.forEach { fieldToEncode -> + // In case of skipped fields, overridden fields (mapped to same index) or too big [encodedElementsCount], + // the fieldToEncode may be null as no element was encoded for that index + fieldToEncode?.encoder?.invoke() + } + compositeEncoderDelegate.endStructure(descriptor) + } + + override fun encodeBooleanElement( + descriptor: SerialDescriptor, + index: Int, + value: Boolean, + ) { + bufferEncoding(descriptor, index) { + compositeEncoderDelegate.encodeBooleanElement(descriptor, index, value) + } + } + + override fun encodeByteElement( + descriptor: SerialDescriptor, + index: Int, + value: Byte, + ) { + bufferEncoding(descriptor, index) { + compositeEncoderDelegate.encodeByteElement(descriptor, index, value) + } + } + + override fun encodeCharElement( + descriptor: SerialDescriptor, + index: Int, + value: Char, + ) { + bufferEncoding(descriptor, index) { + compositeEncoderDelegate.encodeCharElement(descriptor, index, value) + } + } + + override fun encodeDoubleElement( + descriptor: SerialDescriptor, + index: Int, + value: Double, + ) { + bufferEncoding(descriptor, index) { + compositeEncoderDelegate.encodeDoubleElement(descriptor, index, value) + } + } + + override fun encodeFloatElement( + descriptor: SerialDescriptor, + index: Int, + value: Float, + ) { + bufferEncoding(descriptor, index) { + compositeEncoderDelegate.encodeFloatElement(descriptor, index, value) + } + } + + override fun encodeIntElement( + descriptor: SerialDescriptor, + index: Int, + value: Int, + ) { + bufferEncoding(descriptor, index) { + compositeEncoderDelegate.encodeIntElement(descriptor, index, value) + } + } + + override fun encodeLongElement( + descriptor: SerialDescriptor, + index: Int, + value: Long, + ) { + bufferEncoding(descriptor, index) { + compositeEncoderDelegate.encodeLongElement(descriptor, index, value) + } + } + + override fun encodeNullableSerializableElement( + descriptor: SerialDescriptor, + index: Int, + serializer: SerializationStrategy, + value: T?, + ) { + bufferEncoding(descriptor, index) { + compositeEncoderDelegate.encodeNullableSerializableElement(descriptor, index, serializer, value) + } + } + + override fun encodeSerializableElement( + descriptor: SerialDescriptor, + index: Int, + serializer: SerializationStrategy, + value: T, + ) { + bufferEncoding(descriptor, index) { + compositeEncoderDelegate.encodeSerializableElement(descriptor, index, serializer, value) + } + } + + override fun encodeShortElement( + descriptor: SerialDescriptor, + index: Int, + value: Short, + ) { + bufferEncoding(descriptor, index) { + compositeEncoderDelegate.encodeShortElement(descriptor, index, value) + } + } + + override fun encodeStringElement( + descriptor: SerialDescriptor, + index: Int, + value: String, + ) { + bufferEncoding(descriptor, index) { + compositeEncoderDelegate.encodeStringElement(descriptor, index, value) + } + } + + override fun encodeInlineElement( + descriptor: SerialDescriptor, + index: Int, + ): Encoder { + return BufferingInlineEncoder(descriptor, index) + } + + override fun shouldEncodeElementDefault( + descriptor: SerialDescriptor, + index: Int, + ): Boolean { + return compositeEncoderDelegate.shouldEncodeElementDefault(descriptor, index) + } + + private inner class BufferingInlineEncoder( + private val descriptor: SerialDescriptor, + private val elementIndex: Int, + ) : Encoder { + private var encodeNotNullMarkCalled = false + + override val serializersModule: SerializersModule + get() = this@ReorderingCompositeEncoder.serializersModule + + private fun bufferEncoding(encoder: Encoder.() -> Unit) { + bufferEncoding(descriptor, elementIndex) { + compositeEncoderDelegate.encodeInlineElement(descriptor, elementIndex).apply { + if (encodeNotNullMarkCalled) { + encodeNotNullMark() + } + encoder() + } + } + } + + override fun encodeNotNullMark() { + encodeNotNullMarkCalled = true + } + + override fun encodeNullableSerializableValue( + serializer: SerializationStrategy, + value: T?, + ) { + bufferEncoding { encodeNullableSerializableValue(serializer, value) } + } + + override fun encodeSerializableValue( + serializer: SerializationStrategy, + value: T, + ) { + bufferEncoding { encodeSerializableValue(serializer, value) } + } + + override fun encodeBoolean(value: Boolean) { + bufferEncoding { encodeBoolean(value) } + } + + override fun encodeByte(value: Byte) { + bufferEncoding { encodeByte(value) } + } + + override fun encodeChar(value: Char) { + bufferEncoding { encodeChar(value) } + } + + override fun encodeDouble(value: Double) { + bufferEncoding { encodeDouble(value) } + } + + override fun encodeEnum( + enumDescriptor: SerialDescriptor, + index: Int, + ) { + bufferEncoding { encodeEnum(enumDescriptor, index) } + } + + override fun encodeFloat(value: Float) { + bufferEncoding { encodeFloat(value) } + } + + override fun encodeInt(value: Int) { + bufferEncoding { encodeInt(value) } + } + + override fun encodeLong(value: Long) { + bufferEncoding { encodeLong(value) } + } + + @ExperimentalSerializationApi + override fun encodeNull() { + bufferEncoding { encodeNull() } + } + + override fun encodeShort(value: Short) { + bufferEncoding { encodeShort(value) } + } + + override fun encodeString(value: String) { + bufferEncoding { encodeString(value) } + } + + override fun beginStructure(descriptor: SerialDescriptor): CompositeEncoder { + unexpectedCall(::beginStructure.name) + } + + override fun encodeInline(descriptor: SerialDescriptor): Encoder { + unexpectedCall(::encodeInline.name) + } + + private fun unexpectedCall(methodName: String): Nothing { + // This method is normally called from within encodeSerializableValue or encodeNullableSerializableValue which is buffered, so we should never go here during buffering as it will be delegated to the concrete CompositeEncoder + throw UnsupportedOperationException( + "Non-standard usage of ${CompositeEncoder::class.simpleName}: $methodName should be called from within encodeSerializableValue or encodeNullableSerializableValue" + ) + } + } +} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/direct/AbstractAvroDirectEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/direct/AbstractAvroDirectEncoder.kt index 87d36be..57e42aa 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/direct/AbstractAvroDirectEncoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/direct/AbstractAvroDirectEncoder.kt @@ -1,24 +1,15 @@ package com.github.avrokotlin.avro4k.internal.encoder.direct import com.github.avrokotlin.avro4k.Avro -import com.github.avrokotlin.avro4k.AvroEncoder -import com.github.avrokotlin.avro4k.UnionEncoder -import com.github.avrokotlin.avro4k.encodeResolving -import com.github.avrokotlin.avro4k.internal.BadEncodedValueError -import com.github.avrokotlin.avro4k.internal.SerializerLocatorMiddleware -import com.github.avrokotlin.avro4k.internal.isFullNameOrAliasMatch +import com.github.avrokotlin.avro4k.internal.encoder.AbstractAvroEncoder import kotlinx.serialization.SerializationException import kotlinx.serialization.SerializationStrategy -import kotlinx.serialization.descriptors.PolymorphicKind import kotlinx.serialization.descriptors.SerialDescriptor -import kotlinx.serialization.descriptors.StructureKind import kotlinx.serialization.encoding.AbstractEncoder import kotlinx.serialization.encoding.CompositeEncoder import kotlinx.serialization.modules.SerializersModule import org.apache.avro.Schema -import org.apache.avro.generic.GenericFixed import org.apache.avro.util.Utf8 -import java.nio.ByteBuffer internal class AvroValueDirectEncoder( override var currentWriterSchema: Schema, @@ -29,407 +20,82 @@ internal class AvroValueDirectEncoder( internal sealed class AbstractAvroDirectEncoder( protected val avro: Avro, protected val binaryEncoder: org.apache.avro.io.Encoder, -) : AbstractEncoder(), AvroEncoder, UnionEncoder { - private var selectedUnionIndex: Int = -1 - - abstract override var currentWriterSchema: Schema - +) : AbstractAvroEncoder() { override val serializersModule: SerializersModule get() = avro.serializersModule - override fun encodeSerializableValue( - serializer: SerializationStrategy, - value: T, - ) { - SerializerLocatorMiddleware.apply(serializer) - .serialize(this, value) + override fun getRecordEncoder(descriptor: SerialDescriptor): CompositeEncoder { + return RecordDirectEncoder(descriptor, currentWriterSchema, avro, binaryEncoder) } - override fun beginStructure(descriptor: SerialDescriptor): CompositeEncoder { - return when (descriptor.kind) { - StructureKind.CLASS, - StructureKind.OBJECT, - -> - encodeResolving( - { BadEncodedValueError(null, currentWriterSchema, Schema.Type.RECORD) } - ) { schema -> - if (schema.type == Schema.Type.RECORD && schema.isFullNameOrAliasMatch(descriptor)) { - { - val elementDescriptors = avro.recordResolver.resolveFields(schema, descriptor) - RecordDirectEncoder(elementDescriptors, schema, avro, binaryEncoder) - } - } else { - null - } - } - - is PolymorphicKind -> PolymorphicDirectEncoder(avro, currentWriterSchema, binaryEncoder) - else -> throw SerializationException("Unsupported structure kind: $descriptor") - } + override fun getPolymorphicEncoder(descriptor: SerialDescriptor): CompositeEncoder { + return PolymorphicDirectEncoder(avro, currentWriterSchema, binaryEncoder) } - override fun beginCollection( + override fun getArrayEncoder( descriptor: SerialDescriptor, collectionSize: Int, ): CompositeEncoder { - return when (descriptor.kind) { - StructureKind.LIST -> - encodeResolving({ BadEncodedValueError(emptyList(), currentWriterSchema, Schema.Type.ARRAY) }) { schema -> - when (schema.type) { - Schema.Type.ARRAY -> { - { ArrayDirectEncoder(schema, collectionSize, avro, binaryEncoder) } - } - - else -> null - } - } - - StructureKind.MAP -> - encodeResolving({ BadEncodedValueError(emptyMap(), currentWriterSchema, Schema.Type.MAP) }) { schema -> - when (schema.type) { - Schema.Type.MAP -> { - { MapDirectEncoder(schema, collectionSize, avro, binaryEncoder) } - } - - else -> null - } - } - - else -> throw SerializationException("Unsupported collection kind: $descriptor") - } - } - - override fun encodeUnionIndex(index: Int) { - if (selectedUnionIndex > -1) { - throw SerializationException("Already selected union index: $selectedUnionIndex, got $index, for selected schema $currentWriterSchema") - } - if (currentWriterSchema.isUnion) { - binaryEncoder.writeIndex(index) - selectedUnionIndex = index - currentWriterSchema = currentWriterSchema.types[index] - } else { - throw SerializationException("Cannot select union index for non-union schema: $currentWriterSchema") - } + return ArrayDirectEncoder(currentWriterSchema, collectionSize, avro, binaryEncoder) } - override fun encodeElement( + override fun getMapEncoder( descriptor: SerialDescriptor, - index: Int, - ): Boolean { - selectedUnionIndex = -1 - return true - } - - override fun encodeNull() { - encodeResolving( - { BadEncodedValueError(null, currentWriterSchema, Schema.Type.NULL) } - ) { - when (it.type) { - Schema.Type.NULL -> { - { binaryEncoder.writeNull() } - } - - else -> null - } - } - } - - override fun encodeBytes(value: ByteBuffer) { - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.BYTES, Schema.Type.STRING, Schema.Type.FIXED) } - ) { - when (it.type) { - Schema.Type.BYTES -> { - { binaryEncoder.writeBytes(value) } - } - - Schema.Type.STRING -> { - { binaryEncoder.writeString(Utf8(value.array())) } - } - - Schema.Type.FIXED -> { - if (value.remaining() == it.fixedSize) { - { binaryEncoder.writeFixed(value.array()) } - } else { - null - } - } - - else -> null - } - } - } - - override fun encodeBytes(value: ByteArray) { - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.BYTES, Schema.Type.STRING, Schema.Type.FIXED) } - ) { - when (it.type) { - Schema.Type.BYTES -> { - { binaryEncoder.writeBytes(value) } - } - - Schema.Type.STRING -> { - { binaryEncoder.writeString(Utf8(value)) } - } - - Schema.Type.FIXED -> { - if (value.size == it.fixedSize) { - { binaryEncoder.writeFixed(value) } - } else { - null - } - } - - else -> null - } - } + collectionSize: Int, + ): CompositeEncoder { + return MapDirectEncoder(currentWriterSchema, collectionSize, avro, binaryEncoder) } - override fun encodeFixed(value: GenericFixed) { - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.FIXED, Schema.Type.STRING, Schema.Type.BYTES) } - ) { - when (it.type) { - Schema.Type.FIXED -> { - if (it.fullName == value.schema.fullName && it.fixedSize == value.bytes().size) { - { binaryEncoder.writeFixed(value.bytes()) } - } else { - null - } - } - - Schema.Type.BYTES -> { - { binaryEncoder.writeBytes(value.bytes()) } - } - - Schema.Type.STRING -> { - { binaryEncoder.writeString(Utf8(value.bytes())) } - } - - else -> null - } - } + override fun encodeUnionIndexInternal(index: Int) { + binaryEncoder.writeIndex(index) } - override fun encodeFixed(value: ByteArray) { - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.FIXED, Schema.Type.STRING, Schema.Type.BYTES) } - ) { - when (it.type) { - Schema.Type.FIXED -> - if (it.fixedSize == value.size) { - { binaryEncoder.writeFixed(value) } - } else { - null - } - - Schema.Type.BYTES -> { - { binaryEncoder.writeBytes(value) } - } - - Schema.Type.STRING -> { - { binaryEncoder.writeString(Utf8(value)) } - } - - else -> null - } - } + override fun encodeNullUnchecked() { + binaryEncoder.writeNull() } - override fun encodeEnum( - enumDescriptor: SerialDescriptor, - index: Int, - ) { - val enumName = enumDescriptor.getElementName(index) - encodeResolving( - { BadEncodedValueError(index, currentWriterSchema, Schema.Type.ENUM, Schema.Type.STRING) } - ) { - when (it.type) { - Schema.Type.ENUM -> - if (it.isFullNameOrAliasMatch(enumDescriptor)) { - { binaryEncoder.writeEnum(it.getEnumOrdinal(enumName)) } - } else { - null - } - - Schema.Type.STRING -> { - { binaryEncoder.writeString(enumName) } - } - - else -> null - } - } + override fun encodeBytesUnchecked(value: ByteArray) { + binaryEncoder.writeBytes(value) } - override fun encodeBoolean(value: Boolean) { - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.BOOLEAN, Schema.Type.STRING) } - ) { - when (it.type) { - Schema.Type.BOOLEAN -> { - { binaryEncoder.writeBoolean(value) } - } - - Schema.Type.STRING -> { - { binaryEncoder.writeString(value.toString()) } - } - - else -> null - } - } + override fun encodeBooleanUnchecked(value: Boolean) { + binaryEncoder.writeBoolean(value) } - override fun encodeByte(value: Byte) { - encodeInt(value.toInt()) + override fun encodeIntUnchecked(value: Int) { + binaryEncoder.writeInt(value) } - override fun encodeShort(value: Short) { - encodeInt(value.toInt()) + override fun encodeLongUnchecked(value: Long) { + binaryEncoder.writeLong(value) } - override fun encodeInt(value: Int) { - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.INT, Schema.Type.LONG, Schema.Type.FLOAT, Schema.Type.DOUBLE, Schema.Type.STRING) } - ) { - when (it.type) { - Schema.Type.INT -> { - { binaryEncoder.writeInt(value) } - } - - Schema.Type.LONG -> { - { binaryEncoder.writeLong(value.toLong()) } - } - - Schema.Type.FLOAT -> { - { binaryEncoder.writeFloat(value.toFloat()) } - } - - Schema.Type.DOUBLE -> { - { binaryEncoder.writeDouble(value.toDouble()) } - } - - Schema.Type.STRING -> { - { binaryEncoder.writeString(value.toString()) } - } - - else -> null - } - } + override fun encodeFloatUnchecked(value: Float) { + binaryEncoder.writeFloat(value) } - override fun encodeLong(value: Long) { - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.LONG, Schema.Type.FLOAT, Schema.Type.DOUBLE, Schema.Type.STRING) } - ) { - when (it.type) { - Schema.Type.LONG -> { - { binaryEncoder.writeLong(value) } - } - - Schema.Type.FLOAT -> { - { binaryEncoder.writeFloat(value.toFloat()) } - } - - Schema.Type.DOUBLE -> { - { binaryEncoder.writeDouble(value.toDouble()) } - } - - Schema.Type.STRING -> { - { binaryEncoder.writeString(value.toString()) } - } - - else -> null - } - } + override fun encodeDoubleUnchecked(value: Double) { + binaryEncoder.writeDouble(value) } - override fun encodeFloat(value: Float) { - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.FLOAT, Schema.Type.DOUBLE, Schema.Type.STRING) } - ) { - when (it.type) { - Schema.Type.FLOAT -> { - { binaryEncoder.writeFloat(value) } - } - - Schema.Type.DOUBLE -> { - { binaryEncoder.writeDouble(value.toDouble()) } - } - - Schema.Type.STRING -> { - { binaryEncoder.writeString(value.toString()) } - } - - else -> null - } - } + override fun encodeStringUnchecked(value: Utf8) { + binaryEncoder.writeString(value) } - override fun encodeDouble(value: Double) { - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.DOUBLE, Schema.Type.STRING) } - ) { - when (it.type) { - Schema.Type.DOUBLE -> { - { binaryEncoder.writeDouble(value) } - } - - Schema.Type.STRING -> { - { binaryEncoder.writeString(value.toString()) } - } - - else -> null - } - } + override fun encodeEnumUnchecked(symbol: String) { + binaryEncoder.writeEnum(currentWriterSchema.getEnumOrdinalChecked(symbol)) } - override fun encodeChar(value: Char) { - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.INT, Schema.Type.STRING) } - ) { - when (it.type) { - Schema.Type.INT -> { - { binaryEncoder.writeInt(value.code) } - } - - Schema.Type.STRING -> { - { binaryEncoder.writeString(value.toString()) } - } - - else -> null - } - } + override fun encodeFixedUnchecked(value: ByteArray) { + binaryEncoder.writeFixed(value) } +} - override fun encodeString(value: String) { - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.STRING, Schema.Type.BYTES, Schema.Type.FIXED, Schema.Type.ENUM) } - ) { - when (it.type) { - Schema.Type.STRING -> { - { binaryEncoder.writeString(value) } - } - - Schema.Type.BYTES -> { - { binaryEncoder.writeBytes(value.encodeToByteArray()) } - } - - Schema.Type.FIXED -> { - if (value.length == it.fixedSize) { - { binaryEncoder.writeFixed(value.encodeToByteArray()) } - } else { - null - } - } - - Schema.Type.ENUM -> { - { binaryEncoder.writeEnum(it.getEnumOrdinal(value)) } - } - - else -> null - } - } +private fun Schema.getEnumOrdinalChecked(symbol: String): Int { + return try { + getEnumOrdinal(symbol) + } catch (e: NullPointerException) { + throw SerializationException("Enum symbol $symbol not found in schema $this", e) } } diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/direct/RecordDirectEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/direct/RecordDirectEncoder.kt index 5c0e1fc..5c1f67e 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/direct/RecordDirectEncoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/direct/RecordDirectEncoder.kt @@ -1,39 +1,42 @@ package com.github.avrokotlin.avro4k.internal.encoder.direct import com.github.avrokotlin.avro4k.Avro -import com.github.avrokotlin.avro4k.AvroEncoder -import com.github.avrokotlin.avro4k.UnionEncoder -import com.github.avrokotlin.avro4k.internal.ClassDescriptorForWriterSchema -import com.github.avrokotlin.avro4k.internal.EncodingStep +import com.github.avrokotlin.avro4k.internal.EncodingWorkflow +import com.github.avrokotlin.avro4k.internal.encoder.ReorderingCompositeEncoder import kotlinx.serialization.SerializationException -import kotlinx.serialization.SerializationStrategy import kotlinx.serialization.descriptors.SerialDescriptor -import kotlinx.serialization.encoding.AbstractEncoder import kotlinx.serialization.encoding.CompositeEncoder -import kotlinx.serialization.modules.SerializersModule import org.apache.avro.Schema -import org.apache.avro.generic.GenericFixed -import java.nio.ByteBuffer @Suppress("FunctionName") internal fun RecordDirectEncoder( - classDescriptor: ClassDescriptorForWriterSchema, + descriptor: SerialDescriptor, schema: Schema, avro: Avro, binaryEncoder: org.apache.avro.io.Encoder, ): CompositeEncoder { - return if (classDescriptor.sequentialEncoding) { - RecordSequentialDirectEncoder(classDescriptor, schema, avro, binaryEncoder) - } else { - RecordBadOrderDirectEncoder(classDescriptor, schema, avro, binaryEncoder) + val encodingWorkflow = avro.recordResolver.resolveFields(schema, descriptor).encoding + when (encodingWorkflow) { + is EncodingWorkflow.ExactMatch -> return RecordExactDirectEncoder(schema, avro, binaryEncoder) + is EncodingWorkflow.ContiguousWithSkips -> return RecordSkippingDirectEncoder(encodingWorkflow.fieldsToSkip, schema, avro, binaryEncoder) + is EncodingWorkflow.NonContiguous -> return ReorderingCompositeEncoder( + schema.fields.size, + RecordNonContiguousDirectEncoder( + encodingWorkflow.descriptorToWriterFieldIndex, + schema, + avro, + binaryEncoder + ) + ) { _, index -> + encodingWorkflow.descriptorToWriterFieldIndex[index] + } + + is EncodingWorkflow.MissingWriterFields -> throw SerializationException("Invalid encoding workflow") } } -/** - * Consider that the descriptor elements are in the same order as the schema fields, and all the fields are represented by an element. - */ -private class RecordSequentialDirectEncoder( - private val classDescriptor: ClassDescriptorForWriterSchema, +private class RecordNonContiguousDirectEncoder( + private val descriptorToWriterFieldIndex: IntArray, private val schema: Schema, avro: Avro, binaryEncoder: org.apache.avro.io.Encoder, @@ -44,167 +47,50 @@ private class RecordSequentialDirectEncoder( descriptor: SerialDescriptor, index: Int, ): Boolean { - super.encodeElement(descriptor, index) - // index == elementIndex == writerFieldIndex, so the written field is already in the good order - return when (val step = classDescriptor.encodingSteps[index]) { - is EncodingStep.SerializeWriterField -> { - currentWriterSchema = schema.fields[step.writerFieldIndex].schema() - true - } - - is EncodingStep.IgnoreElement -> { - false - } - - is EncodingStep.MissingWriterFieldFailure -> { - throw SerializationException("No serializable element found for writer field ${step.writerFieldIndex} in schema $schema") - } - } - } - - override fun endStructure(descriptor: SerialDescriptor) { - if (classDescriptor.hasMissingWriterField) { - throw SerializationException("The descriptor is not writing all the expected fields of writer schema. Schema: $schema, descriptor: $descriptor") + val writerFieldIndex = descriptorToWriterFieldIndex[index] + if (writerFieldIndex == -1) { + return false } + super.encodeElement(descriptor, index) + currentWriterSchema = schema.fields[writerFieldIndex].schema() + return true } } -/** - * This handles the case where the descriptor elements are not in the same order as the schema fields. - * - * First we buffer all the element encodings to the corresponding field indexes, then we encode them for real in the correct order using [RecordSequentialDirectEncoder]. - * - * Not implementing [UnionEncoder] as all the encoding is delegated to the [RecordSequentialDirectEncoder] which already handles union encoding. - */ -private class RecordBadOrderDirectEncoder( - private val classDescriptor: ClassDescriptorForWriterSchema, +private class RecordSkippingDirectEncoder( + private val skippedElements: BooleanArray, private val schema: Schema, - private val avro: Avro, - private val binaryEncoder: org.apache.avro.io.Encoder, -) : AbstractEncoder(), AvroEncoder { - // Each time we encode a field, if the next expected schema field index is not the good one, it is buffered until it's the time to encode it - private var bufferedFields = Array(schema.fields.size) { null } - private lateinit var encodingStepToBuffer: EncodingStep.SerializeWriterField - - data class BufferedField( - val step: EncodingStep.SerializeWriterField, - val encoder: AvroEncoder.() -> Unit, - ) - - override val currentWriterSchema: Schema - get() = encodingStepToBuffer.schema - - override val serializersModule: SerializersModule - get() = avro.serializersModule + avro: Avro, + binaryEncoder: org.apache.avro.io.Encoder, +) : AbstractAvroDirectEncoder(avro, binaryEncoder) { + override lateinit var currentWriterSchema: Schema override fun encodeElement( descriptor: SerialDescriptor, index: Int, ): Boolean { - return when (val step = classDescriptor.encodingSteps[index]) { - is EncodingStep.SerializeWriterField -> { - encodingStepToBuffer = step - true - } - - is EncodingStep.IgnoreElement -> { - false - } - - is EncodingStep.MissingWriterFieldFailure -> { - throw SerializationException("No serializable element found for writer field ${step.writerFieldIndex} in schema $schema") - } - } - } - - private inline fun bufferEncoding(crossinline encoder: AvroEncoder.() -> Unit) { - bufferedFields[encodingStepToBuffer.writerFieldIndex] = BufferedField(encodingStepToBuffer) { encoder() } - } - - override fun endStructure(descriptor: SerialDescriptor) { - encodeBufferedFields(descriptor) - } - - private fun encodeBufferedFields(descriptor: SerialDescriptor) { - val recordEncoder = RecordSequentialDirectEncoder(classDescriptor, schema, avro, binaryEncoder) - bufferedFields.forEach { fieldToEncode -> - if (fieldToEncode == null) { - throw SerializationException("The writer field is missing in the buffered fields, it hasn't been encoded yet") - } - // To simulate the behavior of regular element encoding - // We don't use the return of encodeElement because we know it's always true - recordEncoder.encodeElement(descriptor, fieldToEncode.step.elementIndex) - fieldToEncode.encoder(recordEncoder) + if (skippedElements[index]) { + return false } + super.encodeElement(descriptor, index) + currentWriterSchema = schema.fields[index].schema() + return true } +} - override fun encodeSerializableValue( - serializer: SerializationStrategy, - value: T, - ) { - bufferEncoding { encodeSerializableValue(serializer, value) } - } - - override fun encodeNull() { - bufferEncoding { encodeNull() } - } - - override fun encodeBytes(value: ByteArray) { - bufferEncoding { encodeBytes(value) } - } - - override fun encodeBytes(value: ByteBuffer) { - bufferEncoding { encodeBytes(value) } - } - - override fun encodeFixed(value: GenericFixed) { - bufferEncoding { encodeFixed(value) } - } - - override fun encodeFixed(value: ByteArray) { - bufferEncoding { encodeFixed(value) } - } - - override fun encodeBoolean(value: Boolean) { - bufferEncoding { encodeBoolean(value) } - } - - override fun encodeByte(value: Byte) { - bufferEncoding { encodeByte(value) } - } - - override fun encodeShort(value: Short) { - bufferEncoding { encodeShort(value) } - } - - override fun encodeInt(value: Int) { - bufferEncoding { encodeInt(value) } - } - - override fun encodeLong(value: Long) { - bufferEncoding { encodeLong(value) } - } - - override fun encodeFloat(value: Float) { - bufferEncoding { encodeFloat(value) } - } - - override fun encodeDouble(value: Double) { - bufferEncoding { encodeDouble(value) } - } - - override fun encodeChar(value: Char) { - bufferEncoding { encodeChar(value) } - } - - override fun encodeString(value: String) { - bufferEncoding { encodeString(value) } - } +private class RecordExactDirectEncoder( + private val schema: Schema, + avro: Avro, + binaryEncoder: org.apache.avro.io.Encoder, +) : AbstractAvroDirectEncoder(avro, binaryEncoder) { + override lateinit var currentWriterSchema: Schema - override fun encodeEnum( - enumDescriptor: SerialDescriptor, + override fun encodeElement( + descriptor: SerialDescriptor, index: Int, - ) { - bufferEncoding { encodeEnum(enumDescriptor, index) } + ): Boolean { + super.encodeElement(descriptor, index) + currentWriterSchema = schema.fields[index].schema() + return true } } \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/AbstractAvroGenericEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/AbstractAvroGenericEncoder.kt index a064891..06a5f41 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/AbstractAvroGenericEncoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/AbstractAvroGenericEncoder.kt @@ -1,440 +1,81 @@ package com.github.avrokotlin.avro4k.internal.encoder.generic import com.github.avrokotlin.avro4k.Avro -import com.github.avrokotlin.avro4k.AvroEncoder -import com.github.avrokotlin.avro4k.UnionEncoder -import com.github.avrokotlin.avro4k.encodeResolving -import com.github.avrokotlin.avro4k.internal.BadEncodedValueError -import com.github.avrokotlin.avro4k.internal.SerializerLocatorMiddleware -import com.github.avrokotlin.avro4k.internal.isFullNameOrAliasMatch -import com.github.avrokotlin.avro4k.internal.toIntExact -import kotlinx.serialization.SerializationException -import kotlinx.serialization.SerializationStrategy -import kotlinx.serialization.descriptors.PolymorphicKind +import com.github.avrokotlin.avro4k.internal.encoder.AbstractAvroEncoder import kotlinx.serialization.descriptors.SerialDescriptor -import kotlinx.serialization.descriptors.StructureKind -import kotlinx.serialization.encoding.AbstractEncoder import kotlinx.serialization.encoding.CompositeEncoder import kotlinx.serialization.modules.SerializersModule -import org.apache.avro.Schema import org.apache.avro.generic.GenericData -import org.apache.avro.generic.GenericFixed +import org.apache.avro.util.Utf8 import java.nio.ByteBuffer -internal abstract class AbstractAvroGenericEncoder : AbstractEncoder(), AvroEncoder, UnionEncoder { +internal abstract class AbstractAvroGenericEncoder : AbstractAvroEncoder() { abstract val avro: Avro - abstract override var currentWriterSchema: Schema - abstract override fun encodeValue(value: Any) - abstract override fun encodeNull() - - override fun encodeElement( - descriptor: SerialDescriptor, - index: Int, - ): Boolean { - selectedUnionIndex = -1 - return true - } - - private var selectedUnionIndex: Int = -1 - - override fun encodeUnionIndex(index: Int) { - if (selectedUnionIndex > -1) { - throw SerializationException("Already selected union index: $selectedUnionIndex, got $index, for selected schema $currentWriterSchema") - } - if (currentWriterSchema.isUnion) { - selectedUnionIndex = index - currentWriterSchema = currentWriterSchema.types[index] - } else { - throw SerializationException("Cannot select union index for non-union schema: $currentWriterSchema") - } - } - override val serializersModule: SerializersModule get() = avro.serializersModule - override fun encodeSerializableValue( - serializer: SerializationStrategy, - value: T, - ) { - SerializerLocatorMiddleware.apply(serializer) - .serialize(this, value) + override fun getRecordEncoder(descriptor: SerialDescriptor): CompositeEncoder { + return RecordGenericEncoder(avro, descriptor, currentWriterSchema) { encodeValue(it) } } - override fun beginStructure(descriptor: SerialDescriptor): CompositeEncoder { - return when (descriptor.kind) { - StructureKind.CLASS, - StructureKind.OBJECT, - -> - encodeResolving( - { BadEncodedValueError(null, currentWriterSchema, Schema.Type.RECORD) } - ) { schema -> - if (schema.type == Schema.Type.RECORD && schema.isFullNameOrAliasMatch(descriptor)) { - { RecordGenericEncoder(avro, descriptor, schema) { encodeValue(it) } } - } else { - null - } - } - - is PolymorphicKind -> - PolymorphicEncoder(avro, currentWriterSchema) { - encodeValue(it) - } - - else -> throw SerializationException("Unsupported structure kind: $descriptor") - } + override fun getPolymorphicEncoder(descriptor: SerialDescriptor): CompositeEncoder { + return PolymorphicEncoder(avro, currentWriterSchema) { encodeValue(it) } } - override fun beginCollection( + override fun getArrayEncoder( descriptor: SerialDescriptor, collectionSize: Int, ): CompositeEncoder { - return when (descriptor.kind) { - StructureKind.LIST -> - encodeResolving( - { BadEncodedValueError(emptyList(), currentWriterSchema, Schema.Type.ARRAY, Schema.Type.BYTES, Schema.Type.FIXED) } - ) { schema -> - when (schema.type) { - Schema.Type.ARRAY -> { - { ArrayGenericEncoder(avro, collectionSize, schema) { encodeValue(it) } } - } - - Schema.Type.BYTES -> { - { BytesGenericEncoder(avro, collectionSize) { encodeValue(it) } } - } - - Schema.Type.FIXED -> { - { FixedGenericEncoder(avro, collectionSize, schema) { encodeValue(it) } } - } - - else -> null - } - } - - StructureKind.MAP -> - encodeResolving( - { BadEncodedValueError(emptyMap(), currentWriterSchema, Schema.Type.MAP) } - ) { schema -> - when (schema.type) { - Schema.Type.MAP -> { - { MapGenericEncoder(avro, collectionSize, schema) { encodeValue(it) } } - } - - else -> null - } - } - - else -> throw SerializationException("Unsupported collection kind: $descriptor") - } + return ArrayGenericEncoder(avro, collectionSize, currentWriterSchema) { encodeValue(it) } } - override fun encodeBytes(value: ByteBuffer) { - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.STRING, Schema.Type.BYTES, Schema.Type.FIXED) } - ) { schema -> - when (schema.type) { - Schema.Type.BYTES -> { - { encodeValue(value) } - } - - Schema.Type.FIXED -> { - if (value.remaining() == schema.fixedSize) { - { encodeValue(value.array()) } - } else { - null - } - } - - Schema.Type.STRING -> { - { encodeValue(value.array().decodeToString()) } - } - - else -> null - } - } - } - - override fun encodeBytes(value: ByteArray) { - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.STRING, Schema.Type.BYTES, Schema.Type.FIXED) } - ) { schema -> - when (schema.type) { - Schema.Type.BYTES -> { - { encodeValue(ByteBuffer.wrap(value)) } - } - - Schema.Type.FIXED -> { - if (value.size == schema.fixedSize) { - { encodeValue(value) } - } else { - null - } - } - - Schema.Type.STRING -> { - { encodeValue(value.decodeToString()) } - } - - else -> null - } - } - } - - override fun encodeFixed(value: GenericFixed) { - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.STRING, Schema.Type.BYTES, Schema.Type.FIXED) } - ) { schema -> - when (schema.type) { - Schema.Type.FIXED -> - if (schema.fullName == value.schema.fullName && schema.fixedSize == value.bytes().size) { - { encodeValue(value) } - } else { - null - } - - Schema.Type.BYTES -> { - { encodeValue(ByteBuffer.wrap(value.bytes())) } - } - - Schema.Type.STRING -> { - { encodeValue(value.bytes().decodeToString()) } - } - - else -> null - } - } - } - - override fun encodeFixed(value: ByteArray) { - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.STRING, Schema.Type.BYTES, Schema.Type.FIXED) } - ) { schema -> - when (schema.type) { - Schema.Type.FIXED -> { - if (value.size == schema.fixedSize) { - { encodeValue(value) } - } else { - null - } - } - - Schema.Type.BYTES -> { - { encodeValue(ByteBuffer.wrap(value)) } - } - - Schema.Type.STRING -> { - { encodeValue(value.decodeToString()) } - } - - else -> null - } - } + override fun getMapEncoder( + descriptor: SerialDescriptor, + collectionSize: Int, + ): CompositeEncoder { + return MapGenericEncoder(avro, collectionSize, currentWriterSchema) { encodeValue(it) } } - override fun encodeBoolean(value: Boolean) { - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.BOOLEAN, Schema.Type.STRING) } - ) { schema -> - when (schema.type) { - Schema.Type.BOOLEAN -> { - { encodeValue(value) } - } - - Schema.Type.STRING -> { - { encodeValue(value.toString()) } - } - - else -> null - } - } + override fun encodeBytesUnchecked(value: ByteArray) { + encodeValue(ByteBuffer.wrap(value)) } - override fun encodeByte(value: Byte) { - encodeInt(value.toInt()) + override fun encodeBooleanUnchecked(value: Boolean) { + encodeValue(value) } - override fun encodeShort(value: Short) { - encodeInt(value.toInt()) + override fun encodeStringUnchecked(value: Utf8) { + encodeValue(value) } - override fun encodeInt(value: Int) { - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.LONG, Schema.Type.INT, Schema.Type.FLOAT, Schema.Type.DOUBLE, Schema.Type.STRING) } - ) { schema -> - when (schema.type) { - Schema.Type.INT -> { - { encodeValue(value) } - } - - Schema.Type.LONG -> { - { encodeValue(value.toLong()) } - } - - Schema.Type.FLOAT -> { - { encodeValue(value.toFloat()) } - } - - Schema.Type.DOUBLE -> { - { encodeValue(value.toDouble()) } - } - - Schema.Type.STRING -> { - { encodeValue(value.toString()) } - } - - else -> null - } - } + override fun encodeUnionIndexInternal(index: Int) { + // nothing to do } - override fun encodeLong(value: Long) { - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.LONG, Schema.Type.INT, Schema.Type.FLOAT, Schema.Type.DOUBLE, Schema.Type.STRING) } - ) { schema -> - when (schema.type) { - Schema.Type.LONG -> { - { encodeValue(value) } - } - - Schema.Type.INT -> { - { encodeValue(value.toIntExact()) } - } - - Schema.Type.FLOAT -> { - { encodeValue(value.toFloat()) } - } - - Schema.Type.DOUBLE -> { - { encodeValue(value.toDouble()) } - } - - Schema.Type.STRING -> { - { encodeValue(value.toString()) } - } - - else -> null - } - } + override fun encodeFixedUnchecked(value: ByteArray) { + encodeValue(GenericData.Fixed(currentWriterSchema, value)) } - override fun encodeFloat(value: Float) { - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.STRING, Schema.Type.DOUBLE, Schema.Type.FLOAT) } - ) { schema -> - when (schema.type) { - Schema.Type.FLOAT -> { - { encodeValue(value) } - } - - Schema.Type.DOUBLE -> { - { encodeValue(value.toDouble()) } - } - - Schema.Type.STRING -> { - { encodeValue(value.toString()) } - } - - else -> null - } - } + override fun encodeIntUnchecked(value: Int) { + encodeValue(value) } - override fun encodeDouble(value: Double) { - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.STRING, Schema.Type.DOUBLE) } - ) { schema -> - when (schema.type) { - Schema.Type.DOUBLE -> { - { encodeValue(value) } - } - - Schema.Type.STRING -> { - { encodeValue(value.toString()) } - } - - else -> null - } - } + override fun encodeLongUnchecked(value: Long) { + encodeValue(value) } - override fun encodeChar(value: Char) { - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.INT, Schema.Type.STRING) } - ) { schema -> - when (schema.type) { - Schema.Type.INT -> { - { encodeValue(value.code) } - } - - Schema.Type.STRING -> { - { encodeValue(value.toString()) } - } - - else -> null - } - } + override fun encodeFloatUnchecked(value: Float) { + encodeValue(value) } - override fun encodeString(value: String) { - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.STRING, Schema.Type.BYTES, Schema.Type.FIXED, Schema.Type.ENUM) } - ) { schema -> - when (schema.type) { - Schema.Type.STRING -> { - { encodeValue(value) } - } - - Schema.Type.BYTES -> { - { encodeValue(value.encodeToByteArray()) } - } - - Schema.Type.FIXED -> { - if (value.length == schema.fixedSize) { - { encodeValue(value.encodeToByteArray()) } - } else { - null - } - } - - Schema.Type.ENUM -> { - { encodeValue(GenericData.EnumSymbol(schema, value)) } - } - - else -> null - } - } + override fun encodeDoubleUnchecked(value: Double) { + encodeValue(value) } - override fun encodeEnum( - enumDescriptor: SerialDescriptor, - index: Int, - ) { - /* - We allow enums as ENUM (must match the descriptor's full name), STRING or UNION. - For UNION, we look for an enum with the descriptor's full name, otherwise a string. - */ - val value = enumDescriptor.getElementName(index) - - encodeResolving( - { BadEncodedValueError(value, currentWriterSchema, Schema.Type.STRING, Schema.Type.ENUM) } - ) { schema -> - when (schema.type) { - Schema.Type.STRING -> { - { encodeValue(value) } - } - - Schema.Type.ENUM -> { - if (schema.isFullNameOrAliasMatch(enumDescriptor)) { - { encodeValue(GenericData.EnumSymbol(schema, value)) } - } else { - null - } - } - - else -> null - } - } + override fun encodeEnumUnchecked(symbol: String) { + encodeValue(GenericData.EnumSymbol(currentWriterSchema, symbol)) } } \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/ArrayGenericEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/ArrayGenericEncoder.kt index 925d252..b00a241 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/ArrayGenericEncoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/ArrayGenericEncoder.kt @@ -1,8 +1,6 @@ package com.github.avrokotlin.avro4k.internal.encoder.generic import com.github.avrokotlin.avro4k.Avro -import com.github.avrokotlin.avro4k.encodeResolving -import com.github.avrokotlin.avro4k.internal.BadEncodedValueError import kotlinx.serialization.descriptors.SerialDescriptor import org.apache.avro.Schema import org.apache.avro.generic.GenericArray @@ -36,17 +34,7 @@ internal class ArrayGenericEncoder( values[index++] = value } - override fun encodeNull() { - encodeResolving( - { BadEncodedValueError(null, currentWriterSchema, Schema.Type.NULL) } - ) { - when (it.type) { - Schema.Type.NULL -> { - { values[index++] = null } - } - - else -> null - } - } + override fun encodeNullUnchecked() { + values[index++] = null } } \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/AvroValueGenericEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/AvroValueGenericEncoder.kt index eea2670..20901c9 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/AvroValueGenericEncoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/AvroValueGenericEncoder.kt @@ -1,8 +1,6 @@ package com.github.avrokotlin.avro4k.internal.encoder.generic import com.github.avrokotlin.avro4k.Avro -import com.github.avrokotlin.avro4k.encodeResolving -import com.github.avrokotlin.avro4k.internal.BadEncodedValueError import org.apache.avro.Schema internal class AvroValueGenericEncoder( @@ -14,16 +12,7 @@ internal class AvroValueGenericEncoder( onEncoded(value) } - override fun encodeNull() { - encodeResolving( - { BadEncodedValueError(null, currentWriterSchema, Schema.Type.NULL) } - ) { - when (it.type) { - Schema.Type.NULL -> { - { onEncoded(null) } - } - else -> null - } - } + override fun encodeNullUnchecked() { + onEncoded(null) } } \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/BytesGenericEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/BytesGenericEncoder.kt deleted file mode 100644 index 8a0f916..0000000 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/BytesGenericEncoder.kt +++ /dev/null @@ -1,26 +0,0 @@ -package com.github.avrokotlin.avro4k.internal.encoder.generic - -import com.github.avrokotlin.avro4k.Avro -import kotlinx.serialization.descriptors.SerialDescriptor -import kotlinx.serialization.encoding.AbstractEncoder -import kotlinx.serialization.modules.SerializersModule -import java.nio.ByteBuffer - -internal class BytesGenericEncoder( - private val avro: Avro, - arraySize: Int, - private val onEncoded: (ByteBuffer) -> Unit, -) : AbstractEncoder() { - private val output: ByteBuffer = ByteBuffer.allocate(arraySize) - - override val serializersModule: SerializersModule - get() = avro.serializersModule - - override fun endStructure(descriptor: SerialDescriptor) { - onEncoded(output.rewind()) - } - - override fun encodeByte(value: Byte) { - output.put(value) - } -} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/FixedGenericEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/FixedGenericEncoder.kt deleted file mode 100644 index e6f02b1..0000000 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/FixedGenericEncoder.kt +++ /dev/null @@ -1,37 +0,0 @@ -package com.github.avrokotlin.avro4k.internal.encoder.generic - -import com.github.avrokotlin.avro4k.Avro -import kotlinx.serialization.SerializationException -import kotlinx.serialization.descriptors.SerialDescriptor -import kotlinx.serialization.encoding.AbstractEncoder -import kotlinx.serialization.modules.SerializersModule -import org.apache.avro.Schema -import org.apache.avro.generic.GenericData -import org.apache.avro.generic.GenericFixed - -internal class FixedGenericEncoder( - private val avro: Avro, - arraySize: Int, - private val schema: Schema, - private val onEncoded: (GenericFixed) -> Unit, -) : AbstractEncoder() { - private val buffer = ByteArray(schema.fixedSize) - private var pos = 0 - - init { - if (arraySize != schema.fixedSize) { - throw SerializationException("Actual collection size $arraySize is greater than schema fixed size $schema") - } - } - - override val serializersModule: SerializersModule - get() = avro.serializersModule - - override fun endStructure(descriptor: SerialDescriptor) { - onEncoded(GenericData.Fixed(schema, buffer)) - } - - override fun encodeByte(value: Byte) { - buffer[pos++] = value - } -} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/MapGenericEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/MapGenericEncoder.kt index c5f98f4..3954d49 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/MapGenericEncoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/MapGenericEncoder.kt @@ -46,7 +46,7 @@ internal class MapGenericEncoder( } } - override fun encodeNull() { + override fun encodeNullUnchecked() { val key = currentKey ?: throw SerializationException("Map key cannot be null") entries.add(key to null) } diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/RecordGenericEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/RecordGenericEncoder.kt index 1ae05fa..5f7706a 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/RecordGenericEncoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/encoder/generic/RecordGenericEncoder.kt @@ -2,7 +2,7 @@ package com.github.avrokotlin.avro4k.internal.encoder.generic import com.github.avrokotlin.avro4k.Avro import com.github.avrokotlin.avro4k.ListRecord -import com.github.avrokotlin.avro4k.internal.EncodingStep +import com.github.avrokotlin.avro4k.internal.EncodingWorkflow import kotlinx.serialization.SerializationException import kotlinx.serialization.descriptors.SerialDescriptor import org.apache.avro.Schema @@ -16,7 +16,7 @@ internal class RecordGenericEncoder( ) : AbstractAvroGenericEncoder() { private val fieldValues: Array = Array(schema.fields.size) { null } - private val classDescriptor = avro.recordResolver.resolveFields(schema, descriptor) + private val encodingWorkflow = avro.recordResolver.resolveFields(schema, descriptor).encoding private lateinit var currentField: Schema.Field override lateinit var currentWriterSchema: Schema @@ -26,22 +26,31 @@ internal class RecordGenericEncoder( index: Int, ): Boolean { super.encodeElement(descriptor, index) - return when (val step = classDescriptor.encodingSteps[index]) { - is EncodingStep.SerializeWriterField -> { - val field = schema.fields[step.writerFieldIndex] - currentField = field - currentWriterSchema = field.schema() - true - } + val writerFieldIndex = + when (encodingWorkflow) { + EncodingWorkflow.ExactMatch -> index - is EncodingStep.IgnoreElement -> { - false - } + is EncodingWorkflow.ContiguousWithSkips -> { + if (encodingWorkflow.fieldsToSkip[index]) { + return false + } + index + } + + is EncodingWorkflow.NonContiguous -> { + val writerFieldIndex = encodingWorkflow.descriptorToWriterFieldIndex[index] + if (writerFieldIndex == -1) { + return false + } + writerFieldIndex + } - is EncodingStep.MissingWriterFieldFailure -> { - throw SerializationException("No serializable element found for writer field ${step.writerFieldIndex} in schema $schema") + is EncodingWorkflow.MissingWriterFields -> throw SerializationException("Invalid encoding workflow") } - } + val field = schema.fields[writerFieldIndex] + currentField = field + currentWriterSchema = field.schema() + return true } override fun endStructure(descriptor: SerialDescriptor) { @@ -52,7 +61,7 @@ internal class RecordGenericEncoder( fieldValues[currentField.pos()] = value } - override fun encodeNull() { + override fun encodeNullUnchecked() { fieldValues[currentField.pos()] = null } } \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/exceptions.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/exceptions.kt index 1455fd3..aac2364 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/exceptions.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/exceptions.kt @@ -7,7 +7,6 @@ import kotlinx.serialization.SerializationException import kotlinx.serialization.descriptors.SerialDescriptor import kotlinx.serialization.descriptors.SerialKind import kotlinx.serialization.encoding.Decoder -import kotlinx.serialization.encoding.Encoder import org.apache.avro.Schema import kotlin.reflect.KClass @@ -73,21 +72,4 @@ internal fun AvroDecoder.UnexpectedDecodeSchemaError( return SerializationException( "For $actualType, expected type one of $allExpectedTypes, but had writer schema $currentWriterSchema" ) -} - -context(Encoder) -internal fun BadEncodedValueError( - value: Any?, - writerSchema: Schema, - firstExpectedType: Schema.Type, - vararg expectedTypes: Schema.Type, -): SerializationException { - val allExpectedTypes = listOf(firstExpectedType) + expectedTypes - return if (value == null) { - SerializationException("Encoded null value, expected one of $allExpectedTypes, actual writer schema $writerSchema") - } else { - SerializationException( - "Encoded value '$value' of type ${value::class.qualifiedName}, expected one of $allExpectedTypes, actual writer schema $writerSchema" - ) - } } \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/helpers.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/helpers.kt index 3ecfdd3..079cc23 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/helpers.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/helpers.kt @@ -52,7 +52,14 @@ internal fun Schema.isNamedSchema(): Boolean { } internal fun Schema.isFullNameOrAliasMatch(descriptor: SerialDescriptor): Boolean { - return isFullNameMatch(descriptor.nonNullSerialName) || descriptor.aliases.any { isFullNameMatch(it) } + return isFullNameOrAliasMatch(descriptor.nonNullSerialName, descriptor::aliases) +} + +internal fun Schema.isFullNameOrAliasMatch( + fullName: String, + aliases: () -> Set, +): Boolean { + return isFullNameMatch(fullName) || aliases().any { isFullNameMatch(it) } } internal fun Schema.isFullNameMatch(fullNameToMatch: String): Boolean { diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/AvroDuration.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/AvroDuration.kt index 21c8e43..7761a3c 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/AvroDuration.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/AvroDuration.kt @@ -4,9 +4,16 @@ import com.github.avrokotlin.avro4k.AnyValueDecoder import com.github.avrokotlin.avro4k.AvroDecoder import com.github.avrokotlin.avro4k.AvroEncoder import com.github.avrokotlin.avro4k.decodeResolvingAny -import com.github.avrokotlin.avro4k.encodeResolving -import com.github.avrokotlin.avro4k.internal.BadEncodedValueError +import com.github.avrokotlin.avro4k.ensureFixedSize +import com.github.avrokotlin.avro4k.fullNameOrAliasMismatchError import com.github.avrokotlin.avro4k.internal.UnexpectedDecodeSchemaError +import com.github.avrokotlin.avro4k.internal.isFullNameOrAliasMatch +import com.github.avrokotlin.avro4k.trySelectLogicalTypeFromUnion +import com.github.avrokotlin.avro4k.trySelectNamedSchema +import com.github.avrokotlin.avro4k.trySelectSingleNonNullTypeFromUnion +import com.github.avrokotlin.avro4k.trySelectTypeFromUnion +import com.github.avrokotlin.avro4k.typeNotFoundInUnionError +import com.github.avrokotlin.avro4k.unsupportedWriterTypeError import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.Serializable import kotlinx.serialization.SerializationException @@ -116,8 +123,9 @@ public class AvroDurationParseException(value: String) : SerializationException( internal object AvroDurationSerializer : AvroSerializer(AvroDuration::class.qualifiedName!!) { private const val LOGICAL_TYPE_NAME = "duration" private const val DURATION_BYTES = 12 + private const val DEFAULT_DURATION_FULL_NAME = "time.Duration" internal val DURATION_SCHEMA = - Schema.createFixed("time.Duration", "A 12-byte byte array encoding a duration in months, days and milliseconds.", null, DURATION_BYTES).also { + Schema.createFixed(DEFAULT_DURATION_FULL_NAME, "A 12-byte byte array encoding a duration in months, days and milliseconds.", null, DURATION_BYTES).also { LogicalType(LOGICAL_TYPE_NAME).addToSchema(it) } @@ -132,21 +140,22 @@ internal object AvroDurationSerializer : AvroSerializer(AvroDurati value: AvroDuration, ) { with(encoder) { - encodeResolving({ BadEncodedValueError(value, currentWriterSchema, Schema.Type.FIXED, Schema.Type.STRING) }) { - when (it.type) { - Schema.Type.FIXED -> - if (it.logicalType?.name == LOGICAL_TYPE_NAME && it.fixedSize == DURATION_BYTES) { - { encodeFixed(encodeDuration(value)) } - } else { - null - } - - Schema.Type.STRING -> { - { encoder.encodeString(value.toString()) } + if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { + trySelectNamedSchema(DEFAULT_DURATION_FULL_NAME) || + trySelectLogicalTypeFromUnion(LOGICAL_TYPE_NAME, Schema.Type.FIXED) || + trySelectTypeFromUnion(Schema.Type.STRING) || + throw typeNotFoundInUnionError(Schema.Type.FIXED, Schema.Type.STRING) + } + when (currentWriterSchema.type) { + Schema.Type.FIXED -> + if (currentWriterSchema.logicalType?.name == LOGICAL_TYPE_NAME || currentWriterSchema.isFullNameOrAliasMatch(DEFAULT_DURATION_FULL_NAME, ::emptySet)) { + encodeFixed(ensureFixedSize(encodeDuration(value))) + } else { + throw fullNameOrAliasMismatchError(DEFAULT_DURATION_FULL_NAME, emptySet()) } - else -> null - } + Schema.Type.STRING -> encodeString(value.toString()) + else -> throw unsupportedWriterTypeError(Schema.Type.FIXED, Schema.Type.STRING) } } } diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/JavaStdLibSerializers.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/JavaStdLibSerializers.kt index 5eea0d1..f4cf796 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/JavaStdLibSerializers.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/JavaStdLibSerializers.kt @@ -5,11 +5,14 @@ import com.github.avrokotlin.avro4k.AvroDecimal import com.github.avrokotlin.avro4k.AvroDecoder import com.github.avrokotlin.avro4k.AvroEncoder import com.github.avrokotlin.avro4k.decodeResolvingAny -import com.github.avrokotlin.avro4k.encodeResolving import com.github.avrokotlin.avro4k.internal.AvroSchemaGenerationException -import com.github.avrokotlin.avro4k.internal.BadEncodedValueError import com.github.avrokotlin.avro4k.internal.UnexpectedDecodeSchemaError import com.github.avrokotlin.avro4k.internal.copy +import com.github.avrokotlin.avro4k.trySelectLogicalTypeFromUnion +import com.github.avrokotlin.avro4k.trySelectSingleNonNullTypeFromUnion +import com.github.avrokotlin.avro4k.trySelectTypeFromUnion +import com.github.avrokotlin.avro4k.typeNotFoundInUnionError +import com.github.avrokotlin.avro4k.unsupportedWriterTypeError import kotlinx.serialization.KSerializer import kotlinx.serialization.descriptors.PrimitiveKind import kotlinx.serialization.descriptors.PrimitiveSerialDescriptor @@ -91,41 +94,18 @@ public object BigIntegerSerializer : AvroSerializer(BigInteger::clas encoder: AvroEncoder, value: BigInteger, ) { - encoder.encodeResolving({ - with(encoder) { - BadEncodedValueError( - value, - encoder.currentWriterSchema, - Schema.Type.STRING, - Schema.Type.INT, - Schema.Type.LONG, - Schema.Type.FLOAT, - Schema.Type.DOUBLE - ) + with(encoder) { + if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { + trySelectTypeFromUnion(Schema.Type.STRING, Schema.Type.INT, Schema.Type.LONG, Schema.Type.FLOAT, Schema.Type.DOUBLE) || + throw typeNotFoundInUnionError(Schema.Type.STRING, Schema.Type.INT, Schema.Type.LONG, Schema.Type.FLOAT, Schema.Type.DOUBLE) } - }) { schema -> - when (schema.type) { - Schema.Type.STRING -> { - { encoder.encodeString(value.toString()) } - } - - Schema.Type.INT -> { - { encoder.encodeInt(value.intValueExact()) } - } - - Schema.Type.LONG -> { - { encoder.encodeLong(value.longValueExact()) } - } - - Schema.Type.FLOAT -> { - { encoder.encodeFloat(value.toFloat()) } - } - - Schema.Type.DOUBLE -> { - { encoder.encodeDouble(value.toDouble()) } - } - - else -> null + when (currentWriterSchema.type) { + Schema.Type.STRING -> encodeString(value.toString()) + Schema.Type.INT -> encodeInt(value.intValueExact()) + Schema.Type.LONG -> encodeLong(value.longValueExact()) + Schema.Type.FLOAT -> encodeFloat(value.toFloat()) + Schema.Type.DOUBLE -> encodeDouble(value.toDouble()) + else -> throw unsupportedWriterTypeError(Schema.Type.STRING, Schema.Type.INT, Schema.Type.LONG, Schema.Type.FLOAT, Schema.Type.DOUBLE) } } } @@ -202,11 +182,29 @@ public object BigDecimalSerializer : AvroSerializer(BigDecimal::clas encoder: AvroEncoder, value: BigDecimal, ) { - encoder.encodeResolving({ - with(encoder) { - BadEncodedValueError( - value, - encoder.currentWriterSchema, + with(encoder) { + if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { + trySelectLogicalTypeFromUnion(converter.logicalTypeName, Schema.Type.BYTES, Schema.Type.FIXED) || + trySelectTypeFromUnion(Schema.Type.STRING, Schema.Type.INT, Schema.Type.LONG, Schema.Type.FLOAT, Schema.Type.DOUBLE) || + throw typeNotFoundInUnionError( + Schema.Type.BYTES, + Schema.Type.FIXED, + Schema.Type.STRING, + Schema.Type.INT, + Schema.Type.LONG, + Schema.Type.FLOAT, + Schema.Type.DOUBLE + ) + } + when (currentWriterSchema.type) { + Schema.Type.BYTES -> encodeBytes(converter.toBytes(value, currentWriterSchema, currentWriterSchema.logicalType).array()) + Schema.Type.FIXED -> encodeFixed(converter.toFixed(value, currentWriterSchema, currentWriterSchema.logicalType).bytes()) + Schema.Type.STRING -> encodeString(value.toString()) + Schema.Type.INT -> encodeInt(value.intValueExact()) + Schema.Type.LONG -> encodeLong(value.longValueExact()) + Schema.Type.FLOAT -> encodeFloat(value.toFloat()) + Schema.Type.DOUBLE -> encodeDouble(value.toDouble()) + else -> throw unsupportedWriterTypeError( Schema.Type.BYTES, Schema.Type.FIXED, Schema.Type.STRING, @@ -216,48 +214,6 @@ public object BigDecimalSerializer : AvroSerializer(BigDecimal::clas Schema.Type.DOUBLE ) } - }) { schema -> - when (schema.type) { - Schema.Type.BYTES -> - when (schema.logicalType) { - is LogicalTypes.Decimal -> { - { encoder.encodeBytes(converter.toBytes(value, schema, schema.logicalType)) } - } - - else -> null - } - - Schema.Type.FIXED -> - when (schema.logicalType) { - is LogicalTypes.Decimal -> { - { encoder.encodeFixed(converter.toFixed(value, schema, schema.logicalType)) } - } - - else -> null - } - - Schema.Type.STRING -> { - { encoder.encodeString(value.toString()) } - } - - Schema.Type.INT -> { - { encoder.encodeInt(value.intValueExact()) } - } - - Schema.Type.LONG -> { - { encoder.encodeLong(value.longValueExact()) } - } - - Schema.Type.FLOAT -> { - { encoder.encodeFloat(value.toFloat()) } - } - - Schema.Type.DOUBLE -> { - { encoder.encodeDouble(value.toDouble()) } - } - - else -> null - } } } diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/JavaTimeSerializers.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/JavaTimeSerializers.kt index 9d7d711..4d369a4 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/JavaTimeSerializers.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/JavaTimeSerializers.kt @@ -4,10 +4,13 @@ import com.github.avrokotlin.avro4k.AnyValueDecoder import com.github.avrokotlin.avro4k.AvroDecoder import com.github.avrokotlin.avro4k.AvroEncoder import com.github.avrokotlin.avro4k.decodeResolvingAny -import com.github.avrokotlin.avro4k.encodeResolving -import com.github.avrokotlin.avro4k.internal.BadEncodedValueError import com.github.avrokotlin.avro4k.internal.UnexpectedDecodeSchemaError import com.github.avrokotlin.avro4k.internal.copy +import com.github.avrokotlin.avro4k.logicalTypeMismatchError +import com.github.avrokotlin.avro4k.trySelectSingleNonNullTypeFromUnion +import com.github.avrokotlin.avro4k.trySelectTypeFromUnion +import com.github.avrokotlin.avro4k.typeNotFoundInUnionError +import com.github.avrokotlin.avro4k.unsupportedWriterTypeError import kotlinx.serialization.SerializationException import kotlinx.serialization.encoding.Decoder import kotlinx.serialization.encoding.Encoder @@ -51,36 +54,20 @@ public object LocalDateSerializer : AvroSerializer(LocalDate::class.q encoder: AvroEncoder, value: LocalDate, ) { - encoder.encodeResolving({ - with(encoder) { - BadEncodedValueError(value, encoder.currentWriterSchema, Schema.Type.INT, Schema.Type.LONG) + with(encoder) { + if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { + trySelectTypeFromUnion(Schema.Type.INT, Schema.Type.STRING) || + throw typeNotFoundInUnionError(Schema.Type.INT, Schema.Type.STRING) } - }) { schema -> - when (schema.type) { + when (currentWriterSchema.type) { Schema.Type.INT -> - when (schema.logicalType?.name) { - LOGICAL_TYPE_NAME_DATE, null -> { - { encoder.encodeInt(value.toEpochDay().toInt()) } - } - - else -> null - } - - Schema.Type.LONG -> - when (schema.logicalType) { - // Date is not compatible with LONG, so we require a null logical type to encode the timestamp - null -> { - { encoder.encodeLong(value.toEpochDay()) } - } - - else -> null + when (currentWriterSchema.logicalType?.name) { + LOGICAL_TYPE_NAME_DATE -> encodeInt(value.toEpochDay().toInt()) + else -> throw logicalTypeMismatchError(LOGICAL_TYPE_NAME_DATE, Schema.Type.INT) } - Schema.Type.STRING -> { - { encoder.encodeString(value.toString()) } - } - - else -> null + Schema.Type.STRING -> encodeString(value.toString()) + else -> throw unsupportedWriterTypeError(Schema.Type.INT, Schema.Type.STRING) } } } @@ -148,39 +135,25 @@ public object LocalTimeSerializer : AvroSerializer(LocalTime::class.q value: LocalTime, ) { with(encoder) { - encodeResolving({ - BadEncodedValueError(value, encoder.currentWriterSchema, Schema.Type.INT, Schema.Type.LONG, Schema.Type.STRING) - }) { schema -> - when (schema.type) { - Schema.Type.INT -> - when (schema.logicalType?.name) { - LOGICAL_TYPE_NAME_TIME_MILLIS, null -> { - { encoder.encodeInt(value.toMillisOfDay()) } - } - - else -> null - } - - Schema.Type.LONG -> - when (schema.logicalType?.name) { - // TimeMillis is not compatible with LONG, so we require a null logical type to encode the timestamp - null -> { - { encoder.encodeLong(value.toMillisOfDay().toLong()) } - } - - LOGICAL_TYPE_NAME_TIME_MICROS -> { - { encoder.encodeLong(value.toMicroOfDay()) } - } - - else -> null - } + if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { + trySelectTypeFromUnion(Schema.Type.INT, Schema.Type.LONG, Schema.Type.STRING) || + throw typeNotFoundInUnionError(Schema.Type.INT, Schema.Type.LONG, Schema.Type.STRING) + } + when (currentWriterSchema.type) { + Schema.Type.INT -> + when (currentWriterSchema.logicalType?.name) { + LOGICAL_TYPE_NAME_TIME_MILLIS -> encodeInt(value.toMillisOfDay()) + else -> throw logicalTypeMismatchError(LOGICAL_TYPE_NAME_TIME_MILLIS, Schema.Type.INT) + } - Schema.Type.STRING -> { - { encoder.encodeString(value.toString()) } + Schema.Type.LONG -> + when (currentWriterSchema.logicalType?.name) { + LOGICAL_TYPE_NAME_TIME_MICROS -> encodeLong(value.toMicroOfDay()) + else -> throw logicalTypeMismatchError(LOGICAL_TYPE_NAME_TIME_MICROS, Schema.Type.LONG) } - else -> null - } + Schema.Type.STRING -> encodeString(value.toString()) + else -> throw unsupportedWriterTypeError(Schema.Type.INT, Schema.Type.LONG, Schema.Type.STRING) } } } @@ -257,30 +230,21 @@ public object LocalDateTimeSerializer : AvroSerializer(LocalDateT encoder: AvroEncoder, value: LocalDateTime, ) { - encoder.encodeResolving({ - with(encoder) { - BadEncodedValueError(value, encoder.currentWriterSchema, Schema.Type.LONG, Schema.Type.STRING) + with(encoder) { + if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { + trySelectTypeFromUnion(Schema.Type.LONG, Schema.Type.STRING) || + throw typeNotFoundInUnionError(Schema.Type.LONG, Schema.Type.STRING) } - }) { - when (it.type) { + when (currentWriterSchema.type) { Schema.Type.LONG -> - when (it.logicalType?.name) { - LOGICAL_TYPE_NAME_TIMESTAMP_MILLIS, null -> { - { encoder.encodeLong(value.toInstant(ZoneOffset.UTC).toEpochMilli()) } - } - - LOGICAL_TYPE_NAME_TIMESTAMP_MICROS -> { - { encoder.encodeLong(value.toInstant(ZoneOffset.UTC).toEpochMicros()) } - } - - else -> null + when (currentWriterSchema.logicalType?.name) { + LOGICAL_TYPE_NAME_TIMESTAMP_MICROS -> encodeLong(value.toInstant(ZoneOffset.UTC).toEpochMicros()) + LOGICAL_TYPE_NAME_TIMESTAMP_MILLIS -> encodeLong(value.toInstant(ZoneOffset.UTC).toEpochMilli()) + else -> throw logicalTypeMismatchError(LOGICAL_TYPE_NAME_TIMESTAMP_MILLIS, Schema.Type.LONG) } - Schema.Type.STRING -> { - { encoder.encodeString(value.toString()) } - } - - else -> null + Schema.Type.STRING -> encodeString(value.toString()) + else -> throw unsupportedWriterTypeError(Schema.Type.LONG, Schema.Type.STRING) } } } @@ -335,30 +299,21 @@ public object InstantSerializer : AvroSerializer(Instant::class.qualifi encoder: AvroEncoder, value: Instant, ) { - encoder.encodeResolving({ - with(encoder) { - BadEncodedValueError(value, encoder.currentWriterSchema, Schema.Type.LONG, Schema.Type.STRING) + with(encoder) { + if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { + trySelectTypeFromUnion(Schema.Type.LONG, Schema.Type.STRING) || + throw typeNotFoundInUnionError(Schema.Type.LONG, Schema.Type.STRING) } - }) { - when (it.type) { + when (currentWriterSchema.type) { Schema.Type.LONG -> - when (it.logicalType?.name) { - LOGICAL_TYPE_NAME_TIMESTAMP_MILLIS, null -> { - { encoder.encodeLong(value.toEpochMilli()) } - } - - LOGICAL_TYPE_NAME_TIMESTAMP_MICROS -> { - { encoder.encodeLong(value.toEpochMicros()) } - } - - else -> null + when (currentWriterSchema.logicalType?.name) { + LOGICAL_TYPE_NAME_TIMESTAMP_MICROS -> encodeLong(value.toEpochMicros()) + LOGICAL_TYPE_NAME_TIMESTAMP_MILLIS -> encodeLong(value.toEpochMilli()) + else -> throw logicalTypeMismatchError(LOGICAL_TYPE_NAME_TIMESTAMP_MILLIS, Schema.Type.LONG) } - Schema.Type.STRING -> { - { encoder.encodeString(value.toString()) } - } - - else -> null + Schema.Type.STRING -> encodeString(value.toString()) + else -> throw unsupportedWriterTypeError(Schema.Type.LONG, Schema.Type.STRING) } } } @@ -412,30 +367,21 @@ public object InstantToMicroSerializer : AvroSerializer(Instant::class. encoder: AvroEncoder, value: Instant, ) { - encoder.encodeResolving({ - with(encoder) { - BadEncodedValueError(value, encoder.currentWriterSchema, Schema.Type.LONG, Schema.Type.STRING) + with(encoder) { + if (currentWriterSchema.isUnion && !trySelectSingleNonNullTypeFromUnion()) { + trySelectTypeFromUnion(Schema.Type.LONG, Schema.Type.STRING) || + throw typeNotFoundInUnionError(Schema.Type.LONG, Schema.Type.STRING) } - }) { - when (it.type) { + when (currentWriterSchema.type) { Schema.Type.LONG -> - when (it.logicalType?.name) { - LOGICAL_TYPE_NAME_TIMESTAMP_MICROS, null -> { - { encoder.encodeLong(value.toEpochMicros()) } - } - - LOGICAL_TYPE_NAME_TIMESTAMP_MILLIS -> { - { encoder.encodeLong(value.toEpochMilli()) } - } - - else -> null + when (currentWriterSchema.logicalType?.name) { + LOGICAL_TYPE_NAME_TIMESTAMP_MICROS -> encodeLong(value.toEpochMicros()) + LOGICAL_TYPE_NAME_TIMESTAMP_MILLIS -> encodeLong(value.toEpochMilli()) + else -> throw logicalTypeMismatchError(LOGICAL_TYPE_NAME_TIMESTAMP_MICROS, Schema.Type.LONG) } - Schema.Type.STRING -> { - { encoder.encodeString(value.toString()) } - } - - else -> null + Schema.Type.STRING -> encodeString(value.toString()) + else -> throw unsupportedWriterTypeError(Schema.Type.LONG, Schema.Type.STRING) } } }