From 0e9be7adaa973464d0e4815db02b8964daf0136f Mon Sep 17 00:00:00 2001 From: Chuckame Date: Sun, 15 Sep 2024 10:06:31 +0200 Subject: [PATCH] refactor: rework direct decoding for more clear & compact resolving unions --- .../github/avrokotlin/avro4k/AvroDecoder.kt | 11 + .../direct/AbstractAvroDirectDecoder.kt | 414 +++++------------- 2 files changed, 121 insertions(+), 304 deletions(-) diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/AvroDecoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/AvroDecoder.kt index 569bf6a..7305c4e 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/AvroDecoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/AvroDecoder.kt @@ -1,6 +1,7 @@ package com.github.avrokotlin.avro4k import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.SerializationException import kotlinx.serialization.encoding.Decoder import org.apache.avro.Schema import org.apache.avro.generic.GenericFixed @@ -317,4 +318,14 @@ internal inline fun AvroDecoder.findValueDecoder( resolver(schema) } return foundResolver ?: throw error() +} + +internal fun AvroDecoder.unsupportedWriterTypeError( + mainType: Schema.Type, + vararg fallbackTypes: Schema.Type, +): Throwable { + val fallbacksStr = if (fallbackTypes.isNotEmpty()) ", and also not matching to any compatible type (one of ${fallbackTypes.joinToString()})." else "" + return SerializationException( + "Unsupported schema '${currentWriterSchema.fullName}' for decoded type of ${mainType.getName()}$fallbacksStr. Actual schema: $currentWriterSchema" + ) } \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/AbstractAvroDirectDecoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/AbstractAvroDirectDecoder.kt index c8fdde6..7bc3cc5 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/AbstractAvroDirectDecoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/internal/decoder/direct/AbstractAvroDirectDecoder.kt @@ -1,21 +1,7 @@ package com.github.avrokotlin.avro4k.internal.decoder.direct -import com.github.avrokotlin.avro4k.AnyValueDecoder import com.github.avrokotlin.avro4k.Avro -import com.github.avrokotlin.avro4k.BooleanValueDecoder -import com.github.avrokotlin.avro4k.CharValueDecoder -import com.github.avrokotlin.avro4k.DoubleValueDecoder -import com.github.avrokotlin.avro4k.FloatValueDecoder -import com.github.avrokotlin.avro4k.IntValueDecoder -import com.github.avrokotlin.avro4k.LongValueDecoder import com.github.avrokotlin.avro4k.UnionDecoder -import com.github.avrokotlin.avro4k.decodeResolvingAny -import com.github.avrokotlin.avro4k.decodeResolvingBoolean -import com.github.avrokotlin.avro4k.decodeResolvingChar -import com.github.avrokotlin.avro4k.decodeResolvingDouble -import com.github.avrokotlin.avro4k.decodeResolvingFloat -import com.github.avrokotlin.avro4k.decodeResolvingInt -import com.github.avrokotlin.avro4k.decodeResolvingLong import com.github.avrokotlin.avro4k.internal.SerializerLocatorMiddleware import com.github.avrokotlin.avro4k.internal.UnexpectedDecodeSchemaError import com.github.avrokotlin.avro4k.internal.decoder.AbstractPolymorphicDecoder @@ -23,9 +9,9 @@ import com.github.avrokotlin.avro4k.internal.getElementIndexNullable import com.github.avrokotlin.avro4k.internal.isFullNameOrAliasMatch import com.github.avrokotlin.avro4k.internal.nonNullSerialName import com.github.avrokotlin.avro4k.internal.toByteExact -import com.github.avrokotlin.avro4k.internal.toFloatExact import com.github.avrokotlin.avro4k.internal.toIntExact import com.github.avrokotlin.avro4k.internal.toShortExact +import com.github.avrokotlin.avro4k.unsupportedWriterTypeError import kotlinx.serialization.DeserializationStrategy import kotlinx.serialization.SerializationException import kotlinx.serialization.descriptors.PolymorphicKind @@ -59,38 +45,39 @@ internal abstract class AbstractAvroDirectDecoder( } override fun beginStructure(descriptor: SerialDescriptor): CompositeDecoder { + decodeAndResolveUnion() + return when (descriptor.kind) { StructureKind.LIST -> - decodeResolvingAny({ UnexpectedDecodeSchemaError(descriptor.nonNullSerialName, Schema.Type.ARRAY) }) { - when (it.type) { - Schema.Type.ARRAY -> { - AnyValueDecoder { ArrayBlockDirectDecoder(it, decodeFirstBlock = decodedCollectionSize == -1, { decodedCollectionSize = it }, avro, binaryDecoder) } - } - - else -> null - } + when (currentWriterSchema.type) { + Schema.Type.ARRAY -> + ArrayBlockDirectDecoder( + currentWriterSchema, + decodeFirstBlock = decodedCollectionSize == -1, + { decodedCollectionSize = it }, + avro, + binaryDecoder + ) + else -> throw unsupportedWriterTypeError(Schema.Type.ARRAY) } StructureKind.MAP -> - decodeResolvingAny({ UnexpectedDecodeSchemaError(descriptor.nonNullSerialName, Schema.Type.MAP) }) { - when (it.type) { - Schema.Type.MAP -> { - AnyValueDecoder { MapBlockDirectDecoder(it, decodeFirstBlock = decodedCollectionSize == -1, { decodedCollectionSize = it }, avro, binaryDecoder) } - } - - else -> null - } + when (currentWriterSchema.type) { + Schema.Type.MAP -> + MapBlockDirectDecoder( + currentWriterSchema, + decodeFirstBlock = decodedCollectionSize == -1, + { decodedCollectionSize = it }, + avro, + binaryDecoder + ) + else -> throw unsupportedWriterTypeError(Schema.Type.MAP) } StructureKind.CLASS, StructureKind.OBJECT -> - decodeResolvingAny({ UnexpectedDecodeSchemaError(descriptor.nonNullSerialName, Schema.Type.RECORD) }) { - when (it.type) { - Schema.Type.RECORD -> { - AnyValueDecoder { RecordDirectDecoder(it, descriptor, avro, binaryDecoder) } - } - - else -> null - } + when (currentWriterSchema.type) { + Schema.Type.RECORD -> RecordDirectDecoder(currentWriterSchema, descriptor, avro, binaryDecoder) + else -> throw unsupportedWriterTypeError(Schema.Type.RECORD) } is PolymorphicKind -> PolymorphicDecoder(avro, descriptor, currentWriterSchema, binaryDecoder) @@ -106,46 +93,27 @@ internal abstract class AbstractAvroDirectDecoder( override fun decodeNotNullMark(): Boolean { decodeAndResolveUnion() + return currentWriterSchema.type != Schema.Type.NULL } override fun decodeNull(): Nothing? { - decodeResolvingAny({ - UnexpectedDecodeSchemaError( - "null", - Schema.Type.NULL - ) - }) { - when (it.type) { - Schema.Type.NULL -> { - AnyValueDecoder { binaryDecoder.readNull() } - } + decodeAndResolveUnion() - else -> null - } + if (currentWriterSchema.type != Schema.Type.NULL) { + throw unsupportedWriterTypeError(Schema.Type.NULL) } + binaryDecoder.readNull() return null } override fun decodeBoolean(): Boolean { - return decodeResolvingBoolean({ - UnexpectedDecodeSchemaError( - "boolean", - Schema.Type.BOOLEAN, - Schema.Type.STRING - ) - }) { - when (it.type) { - Schema.Type.BOOLEAN -> { - BooleanValueDecoder { binaryDecoder.readBoolean() } - } - - Schema.Type.STRING -> { - BooleanValueDecoder { binaryDecoder.readString().toBooleanStrict() } - } + decodeAndResolveUnion() - else -> null - } + return when (currentWriterSchema.type) { + Schema.Type.BOOLEAN -> binaryDecoder.readBoolean() + Schema.Type.STRING -> binaryDecoder.readString().toBooleanStrict() + else -> throw unsupportedWriterTypeError(Schema.Type.BOOLEAN, Schema.Type.STRING) } } @@ -158,284 +126,122 @@ internal abstract class AbstractAvroDirectDecoder( } override fun decodeInt(): Int { - return decodeResolvingInt({ - UnexpectedDecodeSchemaError( - "int", - Schema.Type.INT, - Schema.Type.LONG, - Schema.Type.FLOAT, - Schema.Type.DOUBLE, - Schema.Type.STRING - ) - }) { - when (it.type) { - Schema.Type.INT -> { - IntValueDecoder { binaryDecoder.readInt() } - } - - Schema.Type.LONG -> { - IntValueDecoder { binaryDecoder.readLong().toIntExact() } - } - - Schema.Type.FLOAT -> { - IntValueDecoder { binaryDecoder.readDouble().toInt() } - } - - Schema.Type.DOUBLE -> { - IntValueDecoder { binaryDecoder.readDouble().toInt() } - } - - Schema.Type.STRING -> { - IntValueDecoder { binaryDecoder.readString().toInt() } - } + decodeAndResolveUnion() - else -> null - } + return when (currentWriterSchema.type) { + Schema.Type.INT -> binaryDecoder.readInt() + Schema.Type.LONG -> binaryDecoder.readLong().toIntExact() + Schema.Type.STRING -> binaryDecoder.readString().toInt() + else -> throw unsupportedWriterTypeError(Schema.Type.INT, Schema.Type.LONG, Schema.Type.STRING) } } override fun decodeLong(): Long { - return decodeResolvingLong({ - UnexpectedDecodeSchemaError( - "long", - Schema.Type.INT, - Schema.Type.LONG, - Schema.Type.FLOAT, - Schema.Type.DOUBLE, - Schema.Type.STRING - ) - }) { - when (it.type) { - Schema.Type.INT -> { - LongValueDecoder { binaryDecoder.readInt().toLong() } - } - - Schema.Type.LONG -> { - LongValueDecoder { binaryDecoder.readLong() } - } - - Schema.Type.FLOAT -> { - LongValueDecoder { binaryDecoder.readFloat().toLong() } - } - - Schema.Type.DOUBLE -> { - LongValueDecoder { binaryDecoder.readDouble().toLong() } - } - - Schema.Type.STRING -> { - LongValueDecoder { binaryDecoder.readString().toLong() } - } + decodeAndResolveUnion() - else -> null - } + return when (currentWriterSchema.type) { + Schema.Type.INT -> binaryDecoder.readInt().toLong() + Schema.Type.LONG -> binaryDecoder.readLong() + Schema.Type.STRING -> binaryDecoder.readString().toLong() + else -> throw unsupportedWriterTypeError(Schema.Type.LONG, Schema.Type.INT, Schema.Type.STRING) } } override fun decodeFloat(): Float { - return decodeResolvingFloat({ - UnexpectedDecodeSchemaError( - "float", - Schema.Type.INT, - Schema.Type.LONG, - Schema.Type.FLOAT, - Schema.Type.DOUBLE, - Schema.Type.STRING - ) - }) { - when (it.type) { - Schema.Type.INT -> { - FloatValueDecoder { binaryDecoder.readInt().toFloat() } - } - - Schema.Type.LONG -> { - FloatValueDecoder { binaryDecoder.readLong().toFloat() } - } - - Schema.Type.FLOAT -> { - FloatValueDecoder { binaryDecoder.readFloat() } - } - - Schema.Type.DOUBLE -> { - FloatValueDecoder { binaryDecoder.readDouble().toFloatExact() } - } - - Schema.Type.STRING -> { - FloatValueDecoder { binaryDecoder.readString().toFloat() } - } + decodeAndResolveUnion() - else -> null - } + return when (currentWriterSchema.type) { + Schema.Type.INT -> binaryDecoder.readInt().toFloat() + Schema.Type.LONG -> binaryDecoder.readLong().toFloat() + Schema.Type.FLOAT -> binaryDecoder.readFloat() + Schema.Type.STRING -> binaryDecoder.readString().toFloat() + else -> throw unsupportedWriterTypeError(Schema.Type.FLOAT, Schema.Type.INT, Schema.Type.LONG, Schema.Type.STRING) } } override fun decodeDouble(): Double { - return decodeResolvingDouble({ - UnexpectedDecodeSchemaError( - "double", - Schema.Type.INT, - Schema.Type.LONG, - Schema.Type.FLOAT, - Schema.Type.DOUBLE, - Schema.Type.STRING - ) - }) { - when (it.type) { - Schema.Type.INT -> { - DoubleValueDecoder { binaryDecoder.readInt().toDouble() } - } - - Schema.Type.LONG -> { - DoubleValueDecoder { binaryDecoder.readLong().toDouble() } - } - - Schema.Type.FLOAT -> { - DoubleValueDecoder { binaryDecoder.readFloat().toDouble() } - } - - Schema.Type.DOUBLE -> { - DoubleValueDecoder { binaryDecoder.readDouble() } - } - - Schema.Type.STRING -> { - DoubleValueDecoder { binaryDecoder.readString().toDouble() } - } + decodeAndResolveUnion() - else -> null - } + return when (currentWriterSchema.type) { + Schema.Type.INT -> binaryDecoder.readInt().toDouble() + Schema.Type.LONG -> binaryDecoder.readLong().toDouble() + Schema.Type.FLOAT -> binaryDecoder.readFloat().toDouble() + Schema.Type.DOUBLE -> binaryDecoder.readDouble() + Schema.Type.STRING -> binaryDecoder.readString().toDouble() + else -> throw unsupportedWriterTypeError(Schema.Type.DOUBLE, Schema.Type.INT, Schema.Type.LONG, Schema.Type.FLOAT, Schema.Type.STRING) } } override fun decodeChar(): Char { - return decodeResolvingChar({ - UnexpectedDecodeSchemaError( - "char", - Schema.Type.INT, - Schema.Type.STRING - ) - }) { - when (it.type) { - Schema.Type.INT -> { - CharValueDecoder { binaryDecoder.readInt().toChar() } - } - - Schema.Type.STRING -> { - CharValueDecoder { binaryDecoder.readString(null).single() } - } + decodeAndResolveUnion() - else -> null - } + return when (currentWriterSchema.type) { + Schema.Type.INT -> binaryDecoder.readInt().toChar() + Schema.Type.STRING -> binaryDecoder.readString(null).single() + else -> throw unsupportedWriterTypeError(Schema.Type.INT, Schema.Type.STRING) } } override fun decodeString(): String { - return decodeResolvingAny({ - UnexpectedDecodeSchemaError( - "string", - Schema.Type.STRING, - Schema.Type.BYTES, - Schema.Type.FIXED - ) - }) { - when (it.type) { - Schema.Type.STRING, - Schema.Type.BYTES, - -> { - AnyValueDecoder { binaryDecoder.readString() } - } - - Schema.Type.FIXED -> { - AnyValueDecoder { ByteArray(it.fixedSize).also { buf -> binaryDecoder.readFixed(buf) }.decodeToString() } - } + decodeAndResolveUnion() - else -> null - } + return when (currentWriterSchema.type) { + Schema.Type.STRING -> binaryDecoder.readString(null).toString() + Schema.Type.BYTES -> binaryDecoder.readBytes(null).array().decodeToString() + Schema.Type.FIXED -> ByteArray(currentWriterSchema.fixedSize).also { buf -> binaryDecoder.readFixed(buf) }.decodeToString() + else -> throw unsupportedWriterTypeError(Schema.Type.STRING, Schema.Type.BYTES, Schema.Type.FIXED) } } override fun decodeEnum(enumDescriptor: SerialDescriptor): Int { - return decodeResolvingInt({ - UnexpectedDecodeSchemaError( - enumDescriptor.nonNullSerialName, - Schema.Type.ENUM, - Schema.Type.STRING - ) - }) { - when (it.type) { - Schema.Type.ENUM -> - if (it.isFullNameOrAliasMatch(enumDescriptor)) { - IntValueDecoder { - val enumName = it.enumSymbols[binaryDecoder.readEnum()] - enumDescriptor.getElementIndexNullable(enumName) - ?: avro.enumResolver.getDefaultValueIndex(enumDescriptor) - ?: throw SerializationException( - "Unknown enum symbol name '$enumName' for Enum '${enumDescriptor.serialName}' for writer schema $currentWriterSchema" - ) - } - } else { - null - } - - Schema.Type.STRING -> { - IntValueDecoder { - val enumSymbol = binaryDecoder.readString() - enumDescriptor.getElementIndex(enumSymbol) - .takeIf { index -> index >= 0 } - ?: avro.enumResolver.getDefaultValueIndex(enumDescriptor) - ?: throw SerializationException("Unknown enum symbol '$enumSymbol' for Enum '${enumDescriptor.serialName}'") - } - } + decodeAndResolveUnion() - else -> null + return when (currentWriterSchema.type) { + Schema.Type.ENUM -> + if (currentWriterSchema.isFullNameOrAliasMatch(enumDescriptor)) { + val enumName = currentWriterSchema.enumSymbols[binaryDecoder.readEnum()] + enumDescriptor.getElementIndexNullable(enumName) + ?: avro.enumResolver.getDefaultValueIndex(enumDescriptor) + ?: throw SerializationException( + "Unknown enum symbol name '$enumName' for Enum '${enumDescriptor.serialName}' for writer schema $currentWriterSchema" + ) + } else { + throw UnexpectedDecodeSchemaError( + enumDescriptor.nonNullSerialName, + Schema.Type.ENUM, + Schema.Type.STRING + ) + } + + Schema.Type.STRING -> { + val enumSymbol = binaryDecoder.readString() + enumDescriptor.getElementIndex(enumSymbol).takeIf { index -> index >= 0 } + ?: avro.enumResolver.getDefaultValueIndex(enumDescriptor) + ?: throw SerializationException("Unknown enum symbol '$enumSymbol' for Enum '${enumDescriptor.serialName}'") } + + else -> throw unsupportedWriterTypeError(Schema.Type.ENUM, Schema.Type.STRING) } } override fun decodeBytes(): ByteArray { - return decodeResolvingAny({ - UnexpectedDecodeSchemaError( - "ByteArray", - Schema.Type.BYTES, - Schema.Type.FIXED, - Schema.Type.STRING - ) - }) { - when (it.type) { - Schema.Type.BYTES -> { - AnyValueDecoder { binaryDecoder.readBytes(null).array() } - } - - Schema.Type.FIXED -> { - AnyValueDecoder { ByteArray(it.fixedSize).also { buf -> binaryDecoder.readFixed(buf) } } - } - - Schema.Type.STRING -> { - AnyValueDecoder { binaryDecoder.readString(null).bytes } - } + decodeAndResolveUnion() - else -> null - } + return when (currentWriterSchema.type) { + Schema.Type.BYTES -> binaryDecoder.readBytes(null).array() + Schema.Type.FIXED -> ByteArray(currentWriterSchema.fixedSize).also { buf -> binaryDecoder.readFixed(buf) } + Schema.Type.STRING -> binaryDecoder.readString(null).bytes + else -> throw unsupportedWriterTypeError(Schema.Type.BYTES, Schema.Type.FIXED, Schema.Type.STRING) } } override fun decodeFixed(): GenericFixed { - return decodeResolvingAny({ - UnexpectedDecodeSchemaError( - "GenericFixed", - Schema.Type.BYTES, - Schema.Type.FIXED - ) - }) { - when (it.type) { - Schema.Type.BYTES -> { - AnyValueDecoder { GenericData.Fixed(it, binaryDecoder.readBytes(null).array()) } - } - - Schema.Type.FIXED -> { - AnyValueDecoder { GenericData.Fixed(it, ByteArray(it.fixedSize).also { buf -> binaryDecoder.readFixed(buf) }) } - } + decodeAndResolveUnion() - else -> null - } + return when (currentWriterSchema.type) { + Schema.Type.BYTES -> GenericData.Fixed(currentWriterSchema, binaryDecoder.readBytes(null).array()) + Schema.Type.FIXED -> GenericData.Fixed(currentWriterSchema, ByteArray(currentWriterSchema.fixedSize).also { buf -> binaryDecoder.readFixed(buf) }) + else -> throw unsupportedWriterTypeError(Schema.Type.BYTES, Schema.Type.FIXED) } } }