diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 3c2a9791..aed41016 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -16,7 +16,7 @@ jobs: - name: Checkout the repo uses: actions/checkout@v4 - name: Setup Gradle - uses: gradle/gradle-build-action@v3 + uses: gradle/actions/setup-gradle@v3 - name: Run tests run: ./gradlew check @@ -38,7 +38,7 @@ jobs: - name: Checkout the repo uses: actions/checkout@v4 - name: Setup Gradle - uses: gradle/gradle-build-action@v3 + uses: gradle/actions/setup-gradle@v3 - name: deploy to sonatype snapshots run: ./gradlew publish diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index 2ca72284..468d86fe 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -13,7 +13,7 @@ jobs: - name: Checkout the repo uses: actions/checkout@v4 - name: Setup Gradle - uses: gradle/gradle-build-action@v3 + uses: gradle/actions/setup-gradle@v3 - name: Run tests run: ./gradlew check diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index c366f09b..c41cf817 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -14,7 +14,7 @@ jobs: - name: Checkout the repo uses: actions/checkout@v4 - name: Setup Gradle - uses: gradle/gradle-build-action@v3 + uses: gradle/actions/setup-gradle@v3 - name: publish release run: ./gradlew publishToSonatype closeAndReleaseSonatypeStagingRepository diff --git a/build.gradle.kts b/build.gradle.kts index 160ccf30..e23426d3 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -17,6 +17,7 @@ plugins { id("maven-publish") signing alias(libs.plugins.dokka) + alias(libs.plugins.kover) alias(libs.plugins.kotest) alias(libs.plugins.github.versions) alias(libs.plugins.nexus.publish) @@ -35,7 +36,7 @@ dependencies { api(libs.apache.avro) api(libs.kotlinx.serialization.core) implementation(libs.kotlinx.serialization.json) - implementation(libs.xerial.snappy) + implementation(kotlin("reflect")) testImplementation(libs.kotest.junit5) testImplementation(libs.kotest.core) testImplementation(libs.kotest.json) @@ -46,7 +47,7 @@ tasks.withType().configureEach { kotlinOptions.jvmTarget = "1.8" kotlinOptions.apiVersion = "1.6" kotlinOptions.languageVersion = "1.6" - kotlinOptions.freeCompilerArgs += "-opt-in=kotlin.RequiresOptIn" + kotlinOptions.freeCompilerArgs += listOf("-opt-in=kotlinx.serialization.ExperimentalSerializationApi", "-opt-in=kotlin.RequiresOptIn", "-Xcontext-receivers") } java { sourceCompatibility = JavaVersion.VERSION_1_8 diff --git a/settings.gradle.kts b/settings.gradle.kts index ccedc97f..2aa66f79 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -13,7 +13,6 @@ dependencyResolutionManagement { version("kotlin", "1.9.22") version("jvm", "18") - library("xerial-snappy", "org.xerial.snappy", "snappy-java").version("1.1.10.1") library("apache-avro", "org.apache.avro", "avro").version("1.11.3") val kotlinxSerialization = "1.6.2" @@ -31,6 +30,7 @@ dependencyResolutionManagement { plugin("github-versions", "com.github.ben-manes.versions").version("0.46.0") plugin("nexus-publish", "io.github.gradle-nexus.publish-plugin").version("1.3.0") plugin("spotless", "com.diffplug.spotless").version("6.25.0") + plugin("kover", "org.jetbrains.kotlinx.kover").version("0.7.6") } } @Suppress("UnstableApiUsage") diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/Avro.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/Avro.kt index 0e405a7f..53a9af85 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/Avro.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/Avro.kt @@ -9,11 +9,13 @@ import com.github.avrokotlin.avro4k.io.AvroEncodeFormat import com.github.avrokotlin.avro4k.io.AvroFormat import com.github.avrokotlin.avro4k.io.AvroInputStream import com.github.avrokotlin.avro4k.io.AvroOutputStream -import com.github.avrokotlin.avro4k.schema.schemaFor +import com.github.avrokotlin.avro4k.schema.ValueVisitor +import com.github.avrokotlin.avro4k.serializer.BigDecimalAsStringSerializer +import com.github.avrokotlin.avro4k.serializer.BigIntegerSerializer +import com.github.avrokotlin.avro4k.serializer.URLSerializer import com.github.avrokotlin.avro4k.serializer.UUIDSerializer import kotlinx.serialization.BinaryFormat import kotlinx.serialization.DeserializationStrategy -import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.SerialFormat import kotlinx.serialization.SerializationStrategy import kotlinx.serialization.descriptors.SerialDescriptor @@ -31,6 +33,7 @@ import java.nio.ByteBuffer import java.nio.file.Files import java.nio.file.Path import java.nio.file.Paths +import java.util.concurrent.ConcurrentHashMap open class AvroInputStreamBuilder( private val converter: (Any) -> T, @@ -60,10 +63,12 @@ open class AvroInputStreamBuilder( val wschema = writerSchema ?: error("Writer schema needs to be supplied for Json format") AvroDecodeFormat.Json(wschema, readerSchema ?: wschema) } + is AvroFormat.BinaryFormat -> { val wschema = writerSchema ?: error("Writer schema needs to be supplied for Binary format") AvroDecodeFormat.Binary(wschema, readerSchema ?: wschema) } + is AvroFormat.DataFormat -> AvroDecodeFormat.Data(writerSchema, readerSchema) } } @@ -152,29 +157,28 @@ class AvroOutputStreamBuilder( } } -@OptIn(ExperimentalSerializationApi::class) class Avro internal constructor( - internal val configuration: AvroInternalConfiguration, - override val serializersModule: SerializersModule, + override val serializersModule: SerializersModule = defaultModule, + internal val configuration: AvroConfiguration = AvroConfiguration(), ) : SerialFormat, BinaryFormat { - constructor( - serializersModule: SerializersModule = defaultModule, - configuration: AvroConfiguration = AvroConfiguration(), - ) : this(AvroInternalConfiguration(configuration), serializersModule) + internal val schemaCache: MutableMap = ConcurrentHashMap() constructor(configuration: AvroConfiguration) : this(defaultModule, configuration) companion object { val defaultModule = SerializersModule { - contextual(UUIDSerializer()) + contextual(UUIDSerializer) + contextual(BigDecimalAsStringSerializer) + contextual(BigIntegerSerializer) + contextual(URLSerializer) } val default = Avro(defaultModule) /** * Use this constant if you want to explicitly set a default value of a field to avro null */ - const val NULL = "com.github.avrokotlin.avro4k.Avro.AVRO_NULL_DEFAULT" + const val NULL = "null" } /** @@ -303,14 +307,13 @@ class Avro internal constructor( ) } - fun schema(descriptor: SerialDescriptor): Schema = - schemaFor( - serializersModule, - descriptor, - descriptor.annotations, - configuration, - mutableMapOf() - ).schema() + fun schema(descriptor: SerialDescriptor): Schema { + return schemaCache.getOrPut(descriptor) { + lateinit var output: Schema + ValueVisitor(this) { output = it }.visitValue(descriptor) + return output + } + } fun schema(serializer: SerializationStrategy): Schema { return schema(serializer.descriptor) diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/AvroConfiguration.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/AvroConfiguration.kt index ffaaa984..5d9ffbda 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/AvroConfiguration.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/AvroConfiguration.kt @@ -1,10 +1,7 @@ package com.github.avrokotlin.avro4k import com.github.avrokotlin.avro4k.schema.FieldNamingStrategy -import com.github.avrokotlin.avro4k.schema.RecordName import com.github.avrokotlin.avro4k.schema.RecordNamingStrategy -import kotlinx.serialization.descriptors.SerialDescriptor -import java.util.concurrent.ConcurrentHashMap data class AvroConfiguration( /** @@ -24,57 +21,4 @@ data class AvroConfiguration( * When set to `true`, the nullable fields that haven't any default value are set as null if the value is missing. It also adds `"default": null` to those fields when generating schema using avro4k. */ val implicitNulls: Boolean = false, - /** - * Enable caching of resolved names. - * - * Default: `true` - */ - val namingCacheEnabled: Boolean = true, -) - -class AvroInternalConfiguration private constructor( - val recordNamingStrategy: RecordNamingStrategy, - val fieldNamingStrategy: FieldNamingStrategy, - val implicitNulls: Boolean, -) { - constructor(configuration: AvroConfiguration) : this( - recordNamingStrategy = configuration.recordNamingStrategy.cachedIfNecessary(configuration.namingCacheEnabled), - fieldNamingStrategy = configuration.fieldNamingStrategy.cachedIfNecessary(configuration.namingCacheEnabled), - implicitNulls = configuration.implicitNulls - ) -} - -internal fun RecordNamingStrategy.cachedIfNecessary(cacheEnabled: Boolean): RecordNamingStrategy = - if (!cacheEnabled) { - this - } else { - object : RecordNamingStrategy { - private val cache = ConcurrentHashMap() - - override fun resolve( - descriptor: SerialDescriptor, - serialName: String, - ): RecordName = - cache.getOrPut(descriptor) { - this@cachedIfNecessary.resolve(descriptor, serialName) - } - } - } - -internal fun FieldNamingStrategy.cachedIfNecessary(cacheEnabled: Boolean): FieldNamingStrategy = - if (!cacheEnabled) { - this - } else { - object : FieldNamingStrategy { - private val cache = ConcurrentHashMap, String>() - - override fun resolve( - descriptor: SerialDescriptor, - elementIndex: Int, - serialName: String, - ): String = - cache.getOrPut(descriptor to elementIndex) { - this@cachedIfNecessary.resolve(descriptor, elementIndex, serialName) - } - } - } \ No newline at end of file +) \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/annotations.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/annotations.kt index 544f75c3..d0442f68 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/annotations.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/annotations.kt @@ -5,24 +5,40 @@ package com.github.avrokotlin.avro4k import com.github.avrokotlin.avro4k.serializer.BigDecimalSerializer import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.SerialInfo -import kotlinx.serialization.descriptors.PrimitiveKind -import org.apache.avro.LogicalTypes +import kotlinx.serialization.descriptors.SerialDescriptor +import org.apache.avro.LogicalType import org.apache.avro.Schema -import org.apache.avro.SchemaBuilder import org.intellij.lang.annotations.Language +import kotlin.reflect.KClass /** - * When annotated on a property, overrides the namespace for the nested record. + * When annotated on a property, deeply overrides the namespace for all the nested named types (records, enums and fixed). + * + * Works with standard classes and inline classes. */ @SerialInfo @Target(AnnotationTarget.PROPERTY) -annotation class AvroNamespaceOverride(val value: String) +annotation class AvroNamespaceOverride( + val value: String, +) +/** + * Adds a property to the Avro schema or field. + * + * Ignored in inline classes. + */ @SerialInfo +@Repeatable @Target(AnnotationTarget.PROPERTY, AnnotationTarget.CLASS) annotation class AvroProp(val key: String, val value: String) +/** + * Adds a json property to the Avro schema or field. + * + * Ignored in inline classes. + */ @SerialInfo +@Repeatable @Target(AnnotationTarget.PROPERTY, AnnotationTarget.CLASS) annotation class AvroJsonProp( val key: String, @@ -31,67 +47,33 @@ annotation class AvroJsonProp( /** * To be used with [BigDecimalSerializer] to specify the scale, precision, type and rounding mode of the decimal value. + * + * Can be used with [AvroFixed] to serialize value as a fixed type. */ @SerialInfo @Target(AnnotationTarget.PROPERTY) -annotation class AvroDecimalLogicalType( +annotation class AvroDecimal( val scale: Int = 2, val precision: Int = 8, - val schema: LogicalDecimalTypeEnum = LogicalDecimalTypeEnum.BYTES, ) -enum class LogicalDecimalTypeEnum { - BYTES, - STRING, - - /** - * Fixed requires the field annotated with [AvroFixed] - */ - FIXED, -} - -@SerialInfo -@Target(AnnotationTarget.PROPERTY) -annotation class AvroUuidLogicalType - -@SerialInfo -@Target(AnnotationTarget.PROPERTY) -annotation class AvroTimeLogicalType(val type: LogicalTimeTypeEnum) - -enum class LogicalTimeTypeEnum(val kind: PrimitiveKind, val schemaFor: () -> Schema) { - DATE(PrimitiveKind.INT, { LogicalTypes.date().addToSchema(SchemaBuilder.builder().intType()) }), - TIME_MILLIS( - PrimitiveKind.INT, - { LogicalTypes.timeMillis().addToSchema(SchemaBuilder.builder().intType()) } - ), - TIME_MICROS( - PrimitiveKind.LONG, - { LogicalTypes.timeMicros().addToSchema(SchemaBuilder.builder().longType()) } - ), - TIMESTAMP_MILLIS( - PrimitiveKind.LONG, - { LogicalTypes.timestampMillis().addToSchema(SchemaBuilder.builder().longType()) } - ), - TIMESTAMP_MICROS( - PrimitiveKind.LONG, - { LogicalTypes.timestampMicros().addToSchema(SchemaBuilder.builder().longType()) } - ), - LOCAL_TIMESTAMP_MILLIS( - PrimitiveKind.LONG, - { LogicalTypes.localTimestampMillis().addToSchema(SchemaBuilder.builder().longType()) } - ), - LOCAL_TIMESTAMP_MICROS( - PrimitiveKind.LONG, - { LogicalTypes.localTimestampMicros().addToSchema(SchemaBuilder.builder().longType()) } - ), -} - +/** + * Adds documentation to: + * - a record's field + * - a record + * - an enum + * + * Ignored in inline classes. + */ @SerialInfo @Target(AnnotationTarget.PROPERTY, AnnotationTarget.CLASS) annotation class AvroDoc(val value: String) /** * Adds aliases to a field of a record. It helps to allow having different names for the same field for better compatibility when changing a schema. + * + * Ignored in inline classes. + * * @param value The aliases for the annotated property. Note that the given aliases won't be changed by the configured [AvroConfiguration.fieldNamingStrategy]. */ @SerialInfo @@ -106,12 +88,49 @@ annotation class AvroAlias(vararg val value: String) @Target(AnnotationTarget.PROPERTY) annotation class AvroFixed(val size: Int) +/** + * Sets the default avro value for a record's field. + * + * Ignored in inline classes. + */ @SerialInfo @Target(AnnotationTarget.PROPERTY) annotation class AvroDefault( @Language("JSON") val value: String, ) +/** + * This annotation indicates that the annotated enum class should be serialized as an Avro enum with the given default value. + * + * It must be annotated on an enum class. Otherwise, it will be ignored. + */ @SerialInfo @Target(AnnotationTarget.CLASS) -annotation class AvroEnumDefault(val value: String) \ No newline at end of file +annotation class AvroEnumDefault(val value: String) + +/** + * Allows to specify the schema of a property. + */ +@SerialInfo +@Target(AnnotationTarget.PROPERTY) +annotation class AvroSchema(val value: KClass) + +interface AvroSchemaSupplier { + fun getSchema(stack: List): Schema +} + +/** + * Allows to specify the logical type applied on the generated schema of a property. + */ +@SerialInfo +@Target(AnnotationTarget.PROPERTY) +annotation class AvroLogicalType(val value: KClass) + +interface AvroLogicalTypeSupplier { + fun getLogicalType(inlinedStack: List): LogicalType +} + +interface AnnotatedLocation { + val descriptor: SerialDescriptor + val elementIndex: Int? +} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/FromAvroValue.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/FromAvroValue.kt index 8c227266..f06e4c74 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/FromAvroValue.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/FromAvroValue.kt @@ -2,35 +2,17 @@ package com.github.avrokotlin.avro4k.decoder import kotlinx.serialization.SerializationException import org.apache.avro.generic.GenericData -import org.apache.avro.generic.GenericEnumSymbol -import org.apache.avro.util.Utf8 import java.nio.ByteBuffer -interface FromAvroValue { - fun fromValue(value: T): R -} - -object StringFromAvroValue : FromAvroValue { - override fun fromValue(value: Any?): String { +object StringFromAvroValue { + fun fromValue(value: Any?): String { return when (value) { - is String -> value - is Utf8 -> value.toString() + is CharSequence -> value.toString() is GenericData.Fixed -> String(value.bytes()) is ByteArray -> String(value) - is CharSequence -> value.toString() is ByteBuffer -> String(value.array()) null -> throw SerializationException("Cannot decode as a string") else -> throw SerializationException("Unsupported type for String [is ${value::class.qualifiedName}]") } } -} - -object EnumFromAvroValue : FromAvroValue { - override fun fromValue(value: Any): String { - return when (value) { - is GenericEnumSymbol<*> -> value.toString() - is String -> value - else -> value.toString() - } - } } \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/ListDecoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/ListDecoder.kt index 42bed488..c25d640e 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/ListDecoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/ListDecoder.kt @@ -1,6 +1,6 @@ package com.github.avrokotlin.avro4k.decoder -import com.github.avrokotlin.avro4k.AvroInternalConfiguration +import com.github.avrokotlin.avro4k.AvroConfiguration import kotlinx.serialization.DeserializationStrategy import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.descriptors.PolymorphicKind @@ -19,7 +19,7 @@ class ListDecoder( private val schema: Schema, private val array: List, override val serializersModule: SerializersModule, - private val configuration: AvroInternalConfiguration, + private val configuration: AvroConfiguration, ) : AbstractDecoder(), FieldDecoder { init { require(schema.type == Schema.Type.ARRAY) @@ -72,7 +72,7 @@ class ListDecoder( override fun fieldSchema(): Schema = schema.elementType override fun decodeEnum(enumDescriptor: SerialDescriptor): Int { - val symbol = EnumFromAvroValue.fromValue(array[index]!!) + val symbol = array[index]!!.toString() return (0 until enumDescriptor.elementsCount).find { enumDescriptor.getElementName(it) == symbol } ?: -1 } diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/MapDecoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/MapDecoder.kt index 4f6e7202..37f1b271 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/MapDecoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/MapDecoder.kt @@ -1,6 +1,6 @@ package com.github.avrokotlin.avro4k.decoder -import com.github.avrokotlin.avro4k.AvroInternalConfiguration +import com.github.avrokotlin.avro4k.AvroConfiguration import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.SerializationException import kotlinx.serialization.descriptors.PolymorphicKind @@ -20,7 +20,7 @@ class MapDecoder( private val schema: Schema, map: Map<*, *>, override val serializersModule: SerializersModule, - private val configuration: AvroInternalConfiguration, + private val configuration: AvroConfiguration, ) : AbstractDecoder(), CompositeDecoder { init { require(schema.type == Schema.Type.MAP) diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/RecordDecoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/RecordDecoder.kt index c85ef2f4..2e74a857 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/RecordDecoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/RecordDecoder.kt @@ -1,7 +1,7 @@ package com.github.avrokotlin.avro4k.decoder import com.github.avrokotlin.avro4k.AnnotationExtractor -import com.github.avrokotlin.avro4k.AvroInternalConfiguration +import com.github.avrokotlin.avro4k.AvroConfiguration import com.github.avrokotlin.avro4k.schema.extractNonNull import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.SerializationException @@ -31,7 +31,7 @@ class RecordDecoder( private val desc: SerialDescriptor, private val record: GenericRecord, override val serializersModule: SerializersModule, - private val configuration: AvroInternalConfiguration, + private val configuration: AvroConfiguration, ) : AbstractDecoder(), FieldDecoder { private var currentIndex = -1 @@ -135,7 +135,7 @@ class RecordDecoder( } override fun decodeEnum(enumDescriptor: SerialDescriptor): Int { - val symbol = EnumFromAvroValue.fromValue(fieldValue()!!) + val symbol = fieldValue()!!.toString() val enumValueByEnumName = (0 until enumDescriptor.elementsCount).associateBy { enumDescriptor.getElementName(it) } diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/RootRecordDecoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/RootRecordDecoder.kt index a14cce62..9c53b9c4 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/RootRecordDecoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/RootRecordDecoder.kt @@ -1,6 +1,6 @@ package com.github.avrokotlin.avro4k.decoder -import com.github.avrokotlin.avro4k.AvroInternalConfiguration +import com.github.avrokotlin.avro4k.AvroConfiguration import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.SerializationException import kotlinx.serialization.descriptors.PolymorphicKind @@ -16,7 +16,7 @@ import org.apache.avro.generic.GenericRecord class RootRecordDecoder( private val record: GenericRecord, override val serializersModule: SerializersModule, - private val configuration: AvroInternalConfiguration, + private val configuration: AvroConfiguration, ) : AbstractDecoder() { var decoded = false diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/UnionDecoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/UnionDecoder.kt index 2253ef96..a8bdcb7d 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/UnionDecoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/decoder/UnionDecoder.kt @@ -1,6 +1,6 @@ package com.github.avrokotlin.avro4k.decoder -import com.github.avrokotlin.avro4k.AvroInternalConfiguration +import com.github.avrokotlin.avro4k.AvroConfiguration import com.github.avrokotlin.avro4k.possibleSerializationSubclasses import com.github.avrokotlin.avro4k.schema.RecordName import kotlinx.serialization.DeserializationStrategy @@ -18,7 +18,7 @@ class UnionDecoder( descriptor: SerialDescriptor, private val value: GenericRecord, override val serializersModule: SerializersModule, - private val configuration: AvroInternalConfiguration, + private val configuration: AvroConfiguration, ) : AbstractDecoder(), FieldDecoder { private enum class DecoderState(val index: Int) { BEFORE(0), diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/ListEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/ListEncoder.kt index ef2a9b8d..3cac4e51 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/ListEncoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/ListEncoder.kt @@ -1,6 +1,6 @@ package com.github.avrokotlin.avro4k.encoder -import com.github.avrokotlin.avro4k.AvroInternalConfiguration +import com.github.avrokotlin.avro4k.AvroConfiguration import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.descriptors.SerialDescriptor import kotlinx.serialization.encoding.AbstractEncoder @@ -15,7 +15,7 @@ import java.nio.ByteBuffer class ListEncoder( private val schema: Schema, override val serializersModule: SerializersModule, - override val configuration: AvroInternalConfiguration, + override val configuration: AvroConfiguration, private val callback: (GenericData.Array) -> Unit, ) : AbstractEncoder(), StructureEncoder { private val list = mutableListOf() diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/MapEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/MapEncoder.kt index 76ec1743..c6599179 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/MapEncoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/MapEncoder.kt @@ -1,6 +1,6 @@ package com.github.avrokotlin.avro4k.encoder -import com.github.avrokotlin.avro4k.AvroInternalConfiguration +import com.github.avrokotlin.avro4k.AvroConfiguration import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.SerializationException import kotlinx.serialization.descriptors.SerialDescriptor @@ -16,7 +16,7 @@ import java.nio.ByteBuffer class MapEncoder( schema: Schema, override val serializersModule: SerializersModule, - override val configuration: AvroInternalConfiguration, + override val configuration: AvroConfiguration, private val callback: (Map) -> Unit, ) : AbstractEncoder(), CompositeEncoder, diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/RecordEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/RecordEncoder.kt index f49a94d8..df763b6a 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/RecordEncoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/RecordEncoder.kt @@ -1,10 +1,9 @@ package com.github.avrokotlin.avro4k.encoder -import com.github.avrokotlin.avro4k.AvroInternalConfiguration +import com.github.avrokotlin.avro4k.AvroConfiguration import com.github.avrokotlin.avro4k.ListRecord import com.github.avrokotlin.avro4k.Record import com.github.avrokotlin.avro4k.schema.extractNonNull -import com.github.avrokotlin.avro4k.schema.unwrapValueClass import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.SerializationException import kotlinx.serialization.descriptors.PolymorphicKind @@ -20,7 +19,7 @@ import java.nio.ByteBuffer @ExperimentalSerializationApi interface StructureEncoder : FieldEncoder { - val configuration: AvroInternalConfiguration + val configuration: AvroConfiguration override fun beginStructure(descriptor: SerialDescriptor): CompositeEncoder { return when (descriptor.kind) { @@ -38,11 +37,15 @@ interface StructureEncoder : FieldEncoder { } } +@ExperimentalSerializationApi +internal val SerialDescriptor.unwrapValueClass: SerialDescriptor + get() = if (isInline) getElementDescriptor(0) else this + @ExperimentalSerializationApi class RecordEncoder( private val schema: Schema, override val serializersModule: SerializersModule, - override val configuration: AvroInternalConfiguration, + override val configuration: AvroConfiguration, val callback: (Record) -> Unit, ) : AbstractEncoder(), StructureEncoder { private val builder = RecordBuilder(schema) diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/RootRecordEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/RootRecordEncoder.kt index 5bacb035..05600355 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/RootRecordEncoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/RootRecordEncoder.kt @@ -1,6 +1,6 @@ package com.github.avrokotlin.avro4k.encoder -import com.github.avrokotlin.avro4k.AvroInternalConfiguration +import com.github.avrokotlin.avro4k.AvroConfiguration import com.github.avrokotlin.avro4k.Record import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.SerializationException @@ -16,7 +16,7 @@ import org.apache.avro.Schema class RootRecordEncoder( private val schema: Schema, override val serializersModule: SerializersModule, - private val configuration: AvroInternalConfiguration, + private val configuration: AvroConfiguration, private val callback: (Record) -> Unit, ) : AbstractEncoder() { override fun beginStructure(descriptor: SerialDescriptor): CompositeEncoder { diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/UnionEncoder.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/UnionEncoder.kt index 2e5908fd..4397451e 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/UnionEncoder.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/encoder/UnionEncoder.kt @@ -1,6 +1,6 @@ package com.github.avrokotlin.avro4k.encoder -import com.github.avrokotlin.avro4k.AvroInternalConfiguration +import com.github.avrokotlin.avro4k.AvroConfiguration import com.github.avrokotlin.avro4k.Record import com.github.avrokotlin.avro4k.schema.RecordName import kotlinx.serialization.ExperimentalSerializationApi @@ -16,7 +16,7 @@ import org.apache.avro.Schema class UnionEncoder( private val unionSchema: Schema, override val serializersModule: SerializersModule, - private val configuration: AvroInternalConfiguration, + private val configuration: AvroConfiguration, private val callback: (Record) -> Unit, ) : AbstractEncoder() { override fun encodeString(value: String) { diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/io/DefaultAvroOutputStream.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/io/DefaultAvroOutputStream.kt index 463e87c0..e36d2a28 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/io/DefaultAvroOutputStream.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/io/DefaultAvroOutputStream.kt @@ -1,7 +1,6 @@ package com.github.avrokotlin.avro4k.io import com.github.avrokotlin.avro4k.Avro -import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.SerializationStrategy import org.apache.avro.Schema import org.apache.avro.generic.GenericDatumWriter @@ -36,7 +35,6 @@ abstract class DefaultAvroOutputStream( } } -@OptIn(ExperimentalSerializationApi::class) class AvroBinaryOutputStream( output: OutputStream, converter: (T) -> GenericRecord, @@ -52,7 +50,6 @@ class AvroBinaryOutputStream( override val encoder: BinaryEncoder = EncoderFactory.get().binaryEncoder(output, null) } -@OptIn(ExperimentalSerializationApi::class) class AvroJsonOutputStream( output: OutputStream, converter: (T) -> GenericRecord, diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/schema/AvroSchemaGenerationException.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/AvroSchemaGenerationException.kt new file mode 100644 index 00000000..7e5f8b6f --- /dev/null +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/AvroSchemaGenerationException.kt @@ -0,0 +1,5 @@ +package com.github.avrokotlin.avro4k.schema + +import kotlinx.serialization.SerializationException + +class AvroSchemaGenerationException(message: String) : SerializationException(message) \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/schema/ClassSchemaFor.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/ClassSchemaFor.kt deleted file mode 100644 index f1dd4cb0..00000000 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/schema/ClassSchemaFor.kt +++ /dev/null @@ -1,157 +0,0 @@ -package com.github.avrokotlin.avro4k.schema - -import com.github.avrokotlin.avro4k.AnnotationExtractor -import com.github.avrokotlin.avro4k.Avro -import com.github.avrokotlin.avro4k.AvroAlias -import com.github.avrokotlin.avro4k.AvroInternalConfiguration -import com.github.avrokotlin.avro4k.AvroJsonProp -import com.github.avrokotlin.avro4k.AvroNamespaceOverride -import com.github.avrokotlin.avro4k.AvroProp -import kotlinx.serialization.ExperimentalSerializationApi -import kotlinx.serialization.descriptors.SerialDescriptor -import kotlinx.serialization.json.Json -import kotlinx.serialization.json.JsonArray -import kotlinx.serialization.json.JsonElement -import kotlinx.serialization.json.JsonNull -import kotlinx.serialization.json.JsonObject -import kotlinx.serialization.json.JsonPrimitive -import kotlinx.serialization.json.boolean -import kotlinx.serialization.json.booleanOrNull -import kotlinx.serialization.modules.SerializersModule -import org.apache.avro.JsonProperties -import org.apache.avro.Schema -import org.apache.avro.SchemaBuilder - -@ExperimentalSerializationApi -class ClassSchemaFor( - private val descriptor: SerialDescriptor, - private val configuration: AvroInternalConfiguration, - private val serializersModule: SerializersModule, - private val resolvedSchemas: MutableMap, -) : SchemaFor { - private val entityAnnotations = AnnotationExtractor(descriptor.annotations) - private val naming = configuration.recordNamingStrategy.resolve(descriptor, descriptor.serialName) - private val json by lazy { - Json { - serializersModule = this@ClassSchemaFor.serializersModule - } - } - - override fun schema(): Schema = - if (descriptor.isInline) { - buildField(0).schema() - } else { - dataClassSchema() - } - - private fun dataClassSchema(): Schema { - // return schema if already resolved - recursive circuit breaker - resolvedSchemas[naming]?.let { return it } - - // create new schema without fields - val record = Schema.createRecord(naming.name, entityAnnotations.doc(), naming.namespace, false) - - // add schema without fields right now, so that fields could recursively use it - resolvedSchemas[naming] = record - - val fields = - (0 until descriptor.elementsCount) - .map { index -> buildField(index) } - - record.fields = fields - entityAnnotations.aliases().forEach { record.addAlias(it) } - entityAnnotations.props().forEach { (k, v) -> record.addProp(k, v) } - entityAnnotations.jsonProps().forEach { (k, v) -> record.addProp(k, json.parseToJsonElement(v).convertToAvroDefault()) } - - return record - } - - private fun buildField(index: Int): Schema.Field { - val fieldTypeDescriptor = descriptor.getElementDescriptor(index) - val annos = AnnotationExtractor(descriptor.getElementAnnotations(index)) - val fieldSpecificNamespace: String? = descriptor.getElementAnnotations(index).filterIsInstance().firstOrNull()?.value - val fieldName = configuration.fieldNamingStrategy.resolve(descriptor, index, descriptor.getElementName(index)) - val schema = - getFixedSchema(fieldName, annos) ?: schemaFor( - serializersModule, - fieldTypeDescriptor, - descriptor.getElementAnnotations(index), - configuration, - resolvedSchemas - ).schema() - - // If the field is annotated with a specific namespace, then we need to override the namespace of the field's schema - val schemaWithResolvedNamespace = fieldSpecificNamespace?.let { schema.overrideNamespace(it) } ?: schema - - val default: Any? = getDefaultValue(annos, schemaWithResolvedNamespace, fieldTypeDescriptor) - - val field = Schema.Field(fieldName, schemaWithResolvedNamespace, annos.doc(), default) - field.mutateFieldFromAnnotations(this.descriptor.getElementAnnotations(index)) - return field - } - - private fun getFixedSchema( - fieldName: String, - annos: AnnotationExtractor, - ): Schema? { - val size = annos.fixed() ?: return null - return SchemaBuilder.fixed(fieldName) - .doc(annos.doc()) - .namespace(naming.namespace) - .size(size) - } - - private fun Schema.Field.mutateFieldFromAnnotations(annotations: List) = - annotations.forEach { - when (it) { - is AvroProp -> this.addProp(it.key, it.value) - is AvroJsonProp -> this.addProp(it.key, json.parseToJsonElement(it.jsonValue).convertToAvroDefault()) - is AvroAlias -> it.value.forEach { this.addAlias(it) } - } - } - - private fun getDefaultValue( - annos: AnnotationExtractor, - schemaWithResolvedNamespace: Schema, - fieldTypeDescriptor: SerialDescriptor, - ) = annos.default()?.let { annotationDefaultValue -> - when { - annotationDefaultValue == Avro.NULL -> Schema.Field.NULL_DEFAULT_VALUE - schemaWithResolvedNamespace.extractNonNull().type in - listOf( - Schema.Type.FIXED, - Schema.Type.BYTES, - Schema.Type.STRING, - Schema.Type.ENUM - ) - -> annotationDefaultValue - - else -> json.parseToJsonElement(annotationDefaultValue).convertToAvroDefault() - } - } ?: if (configuration.implicitNulls && fieldTypeDescriptor.isNullable) { - Schema.Field.NULL_DEFAULT_VALUE - } else { - null - } - - private fun JsonElement.convertToAvroDefault(): Any { - return when (this) { - is JsonNull -> JsonProperties.NULL_VALUE - is JsonObject -> this.map { Pair(it.key, it.value.convertToAvroDefault()) }.toMap() - is JsonArray -> this.map { it.convertToAvroDefault() }.toList() - is JsonPrimitive -> - when { - this.isString -> this.content - this.booleanOrNull != null -> this.boolean - else -> { - val number = this.content.toBigDecimal() - if (number.scale() <= 0) { - number.toBigInteger() - } else { - number - } - } - } - } - } -} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/schema/ClassVisitor.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/ClassVisitor.kt new file mode 100644 index 00000000..71d367f4 --- /dev/null +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/ClassVisitor.kt @@ -0,0 +1,149 @@ +package com.github.avrokotlin.avro4k.schema + +import com.github.avrokotlin.avro4k.AvroDefault +import kotlinx.serialization.descriptors.SerialDescriptor +import org.apache.avro.JsonProperties +import org.apache.avro.Schema + +internal class ClassVisitor( + descriptor: SerialDescriptor, + override val context: VisitorContext, + private val onSchemaBuilt: (Schema) -> Unit, +) : SerialDescriptorClassVisitor, AvroVisitorContextAware { + private val fields = mutableListOf() + private val schemaAlreadyResolved: Boolean + private val schema: Schema + + init { + val recordName = descriptor.getAvroName() + var schemaAlreadyResolved = true + schema = + context.resolvedSchemas.getOrPut(recordName) { + schemaAlreadyResolved = false + + val annotations = TypeAnnotations(descriptor) + val schema = + Schema.createRecord( + // name = + recordName.name, + // doc = + annotations.doc?.value, + // namespace = + recordName.namespace, + // isError = + false + ) + annotations.aliases?.value?.forEach { schema.addAlias(it) } + annotations.props.forEach { schema.addProp(it.key, it.value) } + annotations.jsonProps.forEach { schema.addProp(it.key, it.jsonNode) } + schema + } + this.schemaAlreadyResolved = schemaAlreadyResolved + } + + override fun visitClassElement( + descriptor: SerialDescriptor, + elementIndex: Int, + ): SerialDescriptorValueVisitor? { + if (schemaAlreadyResolved) { + return null + } + return ValueVisitor( + context.copy( + inlinedAnnotations = ValueAnnotations(descriptor, elementIndex) + ) + ) { + fields.add( + createField( + descriptor.getElementAvroName(elementIndex), + FieldAnnotations(descriptor, elementIndex), + it + ) + ) + } + } + + override fun endClassVisit(descriptor: SerialDescriptor) { + if (!schemaAlreadyResolved) { + schema.fields = fields + } + onSchemaBuilt(schema) + } + + /** + * Create a field with the given annotations. + * Here are managed the generic field level annotations: + * - namespaceOverride + * - default (also sort unions according to the default value) + * - aliases + * - doc + * - props & json props + */ + private fun createField( + fieldName: String, + annotations: FieldAnnotations, + elementSchema: Schema, + ): Schema.Field { + var finalSchema: Schema = annotations.namespaceOverride?.value?.let { elementSchema.overrideNamespace(it) } ?: elementSchema + + val fieldDefault = getFieldDefault(annotations.default, finalSchema) + + if (fieldDefault != null) { + reorderUnionIfNeeded(fieldDefault, finalSchema)?.let { + finalSchema = it + } + } + + val field = + Schema.Field( + // name = + fieldName, + // schema = + finalSchema, + // doc = + annotations.doc?.value, + // defaultValue = + fieldDefault + ) + annotations.aliases.flatMap { it.value.asSequence() }.forEach { field.addAlias(it) } + annotations.props.forEach { field.addProp(it.key, it.value) } + annotations.jsonProps.forEach { field.addProp(it.key, it.jsonNode) } + return field + } + + /** + * Reorder the union to put the non-null first if the default value is non-null. + */ + private fun reorderUnionIfNeeded( + fieldDefault: Any, + finalSchema: Schema, + ): Schema? { + if (finalSchema.isUnion && finalSchema.isNullable) { + var nullNotFirst = false + if (fieldDefault is Collection<*>) { + nullNotFirst = fieldDefault.any { it != JsonProperties.NULL_VALUE } + } else if (fieldDefault != JsonProperties.NULL_VALUE) { + nullNotFirst = true + } + if (nullNotFirst) { + val nullIndex = finalSchema.types.indexOfFirst { it.type == Schema.Type.NULL } + val nonNullTypes = finalSchema.types.toMutableList() + val nullType = nonNullTypes.removeAt(nullIndex) + return Schema.createUnion(nonNullTypes + nullType) + } + } + return null + } + + private fun getFieldDefault( + default: AvroDefault?, + fieldSchema: Schema, + ): Any? { + val defaultValue = default?.jsonValue + + if (defaultValue == null && context.avro.configuration.implicitNulls && fieldSchema.isNullable) { + return JsonProperties.NULL_VALUE + } + return defaultValue + } +} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/schema/namingStrategy.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/FieldNamingStrategy.kt similarity index 62% rename from src/main/kotlin/com/github/avrokotlin/avro4k/schema/namingStrategy.kt rename to src/main/kotlin/com/github/avrokotlin/avro4k/schema/FieldNamingStrategy.kt index e1e9e93e..851424bd 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/schema/namingStrategy.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/FieldNamingStrategy.kt @@ -2,36 +2,6 @@ package com.github.avrokotlin.avro4k.schema import kotlinx.serialization.descriptors.SerialDescriptor -interface RecordNamingStrategy { - fun resolve( - descriptor: SerialDescriptor, - serialName: String, - ): RecordName - - companion object Builtins { - /** - * Extract the record name from the fully qualified class name by taking the last part of the class name as the record name and the rest as the namespace. - * - * If there is no dot, then the namespace is null. - */ - object FullyQualified : RecordNamingStrategy { - override fun resolve( - descriptor: SerialDescriptor, - serialName: String, - ): RecordName { - val lastDot = serialName.lastIndexOf('.').takeIf { it >= 0 && it + 1 < serialName.length } - val lastIndex = if (serialName.endsWith('?')) serialName.length - 1 else serialName.length - return RecordName( - name = lastDot?.let { serialName.substring(lastDot + 1, lastIndex) } ?: serialName, - namespace = lastDot?.let { serialName.substring(0, lastDot) }?.takeIf { it.isNotEmpty() } - ) - } - } - } -} - -data class RecordName(val name: String, val namespace: String?) - interface FieldNamingStrategy { fun resolve( descriptor: SerialDescriptor, @@ -103,16 +73,5 @@ interface FieldNamingStrategy { serialName: String, ): String = serialName.replaceFirstChar { it.uppercaseChar() } } - - /** - * Enforce camelCase naming strategy by lower-casing the first field name letter. - */ - object CamelCase : FieldNamingStrategy { - override fun resolve( - descriptor: SerialDescriptor, - elementIndex: Int, - serialName: String, - ): String = serialName.replaceFirstChar { it.lowercaseChar() } - } } } \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/schema/InlineClassVisitor.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/InlineClassVisitor.kt new file mode 100644 index 00000000..a66b74b7 --- /dev/null +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/InlineClassVisitor.kt @@ -0,0 +1,30 @@ +package com.github.avrokotlin.avro4k.schema + +import kotlinx.serialization.descriptors.SerialDescriptor +import org.apache.avro.Schema + +internal class InlineClassVisitor( + override val context: VisitorContext, + private val onSchemaBuilt: (Schema) -> Unit, +) : SerialDescriptorInlineClassVisitor, AvroVisitorContextAware { + override fun visitInlineClassElement( + inlineClassDescriptor: SerialDescriptor, + inlineElementIndex: Int, + ): SerialDescriptorValueVisitor { + val inlinedAnnotations = + context.inlinedAnnotations.appendAnnotations( + ValueAnnotations( + inlineClassDescriptor, + inlineElementIndex + ) + ) + return ValueVisitor(context.copy(inlinedAnnotations = inlinedAnnotations)) { + val annotations = InlineClassFieldAnnotations(inlineClassDescriptor, inlineElementIndex) + if (annotations.namespaceOverride != null) { + onSchemaBuilt(it.overrideNamespace(annotations.namespaceOverride.value)) + } else { + onSchemaBuilt(it) + } + } + } +} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/schema/ListVisitor.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/ListVisitor.kt new file mode 100644 index 00000000..2cc0d87f --- /dev/null +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/ListVisitor.kt @@ -0,0 +1,24 @@ +package com.github.avrokotlin.avro4k.schema + +import kotlinx.serialization.descriptors.SerialDescriptor +import org.apache.avro.Schema + +internal class ListVisitor( + override val context: VisitorContext, + private val onSchemaBuilt: (Schema) -> Unit, +) : SerialDescriptorListVisitor, AvroVisitorContextAware { + private lateinit var itemSchema: Schema + + override fun visitListItem( + listDescriptor: SerialDescriptor, + itemElementIndex: Int, + ): SerialDescriptorValueVisitor { + return ValueVisitor(context) { + itemSchema = it + } + } + + override fun endListVisit(descriptor: SerialDescriptor) { + onSchemaBuilt(Schema.createArray(itemSchema)) + } +} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/schema/MapVisitor.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/MapVisitor.kt new file mode 100644 index 00000000..cb3af283 --- /dev/null +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/MapVisitor.kt @@ -0,0 +1,63 @@ +package com.github.avrokotlin.avro4k.schema + +import kotlinx.serialization.descriptors.SerialDescriptor +import org.apache.avro.Schema + +internal class MapVisitor( + override val context: VisitorContext, + private val onSchemaBuilt: (Schema) -> Unit, +) : SerialDescriptorMapVisitor, AvroVisitorContextAware { + private lateinit var valueSchema: Schema + + override fun visitMapKey( + mapDescriptor: SerialDescriptor, + keyElementIndex: Int, + ) = ValueVisitor(context) { + // In avro, the map key must be a string. + // Here we just delegate the schema building to the value visitor + // and then check if the output schema is about a type that we can + // stringify (e.g. when .toString() makes sense). + // Here we are just checking if the schema is string-compatible. We don't need to + // store the schema as it is a string. + if (it.isNullable()) { + throw AvroSchemaGenerationException("Map key cannot be nullable. Actual generated map key schema: $it") + } + if (!it.isStringable()) { + throw AvroSchemaGenerationException("Map key must be string-able (boolean, number, enum, or string). Actual generated map key schema: $it") + } + } + + override fun visitMapValue( + mapDescriptor: SerialDescriptor, + valueElementIndex: Int, + ) = ValueVisitor(context) { + valueSchema = it + } + + override fun endMapVisit(descriptor: SerialDescriptor) { + onSchemaBuilt(Schema.createMap(valueSchema)) + } +} + +private fun Schema.isStringable(): Boolean = + when (type) { + Schema.Type.BOOLEAN, + Schema.Type.INT, + Schema.Type.LONG, + Schema.Type.FLOAT, + Schema.Type.DOUBLE, + Schema.Type.STRING, + Schema.Type.ENUM, + -> true + + Schema.Type.NULL, + Schema.Type.BYTES, // bytes could be stringified, but it's not a good idea as it can produce unreadable strings. + Schema.Type.FIXED, // same, just bytes. Btw, if the user wants to stringify it, he can use @Contextual or custom @Serializable serializer. + Schema.Type.ARRAY, + Schema.Type.MAP, + Schema.Type.RECORD, + null, + -> false + + Schema.Type.UNION -> types.all { it.isStringable() } + } \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/schema/PolymorphicVisitor.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/PolymorphicVisitor.kt new file mode 100644 index 00000000..553559de --- /dev/null +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/PolymorphicVisitor.kt @@ -0,0 +1,29 @@ +package com.github.avrokotlin.avro4k.schema + +import kotlinx.serialization.descriptors.SerialDescriptor +import org.apache.avro.Schema + +internal class PolymorphicVisitor( + override val context: VisitorContext, + private val onSchemaBuilt: (Schema) -> Unit, +) : SerialDescriptorPolymorphicVisitor, AvroVisitorContextAware { + private val possibleSchemas = mutableListOf() + + override fun visitPolymorphicFoundDescriptor(descriptor: SerialDescriptor): SerialDescriptorValueVisitor { + return ValueVisitor(context) { + possibleSchemas += it + } + } + + override fun endPolymorphicVisit(descriptor: SerialDescriptor) { + if (possibleSchemas.isEmpty()) { + throw AvroSchemaGenerationException("Polymorphic descriptor must have at least one possible schema") + } + if (possibleSchemas.size == 1) { + // flatten the useless union schema + onSchemaBuilt(possibleSchemas.first()) + } else { + onSchemaBuilt(Schema.createUnion(possibleSchemas)) + } + } +} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/schema/RecordNamingStrategy.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/RecordNamingStrategy.kt new file mode 100644 index 00000000..ec95ee90 --- /dev/null +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/RecordNamingStrategy.kt @@ -0,0 +1,33 @@ +package com.github.avrokotlin.avro4k.schema + +import kotlinx.serialization.descriptors.SerialDescriptor + +interface RecordNamingStrategy { + fun resolve( + descriptor: SerialDescriptor, + serialName: String, + ): RecordName + + companion object Builtins { + /** + * Extract the record name from the fully qualified class name by taking the last part of the class name as the record name and the rest as the namespace. + * + * If there is no dot, then the namespace is null. + */ + object FullyQualified : RecordNamingStrategy { + override fun resolve( + descriptor: SerialDescriptor, + serialName: String, + ): RecordName { + val lastDot = serialName.lastIndexOf('.').takeIf { it >= 0 && it + 1 < serialName.length } + val lastIndex = if (serialName.endsWith('?')) serialName.length - 1 else serialName.length + return RecordName( + name = lastDot?.let { serialName.substring(lastDot + 1, lastIndex) } ?: serialName.substring(0, lastIndex), + namespace = lastDot?.let { serialName.substring(0, lastDot) }?.takeIf { it.isNotEmpty() } + ) + } + } + } +} + +data class RecordName(val name: String, val namespace: String?) \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/schema/SchemaFor.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/SchemaFor.kt deleted file mode 100644 index 21a8eaee..00000000 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/schema/SchemaFor.kt +++ /dev/null @@ -1,247 +0,0 @@ -package com.github.avrokotlin.avro4k.schema - -import com.github.avrokotlin.avro4k.AnnotationExtractor -import com.github.avrokotlin.avro4k.Avro -import com.github.avrokotlin.avro4k.AvroDecimalLogicalType -import com.github.avrokotlin.avro4k.AvroFixed -import com.github.avrokotlin.avro4k.AvroInternalConfiguration -import com.github.avrokotlin.avro4k.AvroTimeLogicalType -import com.github.avrokotlin.avro4k.AvroUuidLogicalType -import com.github.avrokotlin.avro4k.LogicalDecimalTypeEnum -import kotlinx.serialization.ExperimentalSerializationApi -import kotlinx.serialization.InternalSerializationApi -import kotlinx.serialization.SerializationException -import kotlinx.serialization.descriptors.PolymorphicKind -import kotlinx.serialization.descriptors.PrimitiveKind -import kotlinx.serialization.descriptors.SerialDescriptor -import kotlinx.serialization.descriptors.SerialKind -import kotlinx.serialization.descriptors.StructureKind -import kotlinx.serialization.descriptors.capturedKClass -import kotlinx.serialization.descriptors.elementNames -import kotlinx.serialization.descriptors.getContextualDescriptor -import kotlinx.serialization.modules.SerializersModule -import kotlinx.serialization.serializerOrNull -import org.apache.avro.LogicalTypes -import org.apache.avro.Schema -import org.apache.avro.SchemaBuilder - -interface SchemaFor { - fun schema(): Schema - - companion object { - /** - * Creates a [SchemaFor] that always returns the given constant schema. - */ - fun const(schema: Schema) = - object : SchemaFor { - override fun schema() = schema - } - - val StringSchemaFor: SchemaFor = const(SchemaBuilder.builder().stringType()) - val LongSchemaFor: SchemaFor = const(SchemaBuilder.builder().longType()) - val IntSchemaFor: SchemaFor = const(SchemaBuilder.builder().intType()) - val ShortSchemaFor: SchemaFor = const(SchemaBuilder.builder().intType()) - val ByteSchemaFor: SchemaFor = const(SchemaBuilder.builder().intType()) - val DoubleSchemaFor: SchemaFor = const(SchemaBuilder.builder().doubleType()) - val FloatSchemaFor: SchemaFor = const(SchemaBuilder.builder().floatType()) - val BooleanSchemaFor: SchemaFor = const(SchemaBuilder.builder().booleanType()) - } -} - -@ExperimentalSerializationApi -class EnumSchemaFor( - private val descriptor: SerialDescriptor, - private val configuration: AvroInternalConfiguration, -) : SchemaFor { - override fun schema(): Schema { - val naming = configuration.recordNamingStrategy.resolve(descriptor, descriptor.serialName) - val entityAnnotations = AnnotationExtractor(descriptor.annotations) - val symbols = (0 until descriptor.elementsCount).map { descriptor.getElementName(it) } - - val defaultSymbol = - entityAnnotations.enumDefault()?.let { enumDefault -> - descriptor.elementNames.firstOrNull { it == enumDefault } ?: error( - "Could not use: $enumDefault to resolve the enum class ${descriptor.serialName}" - ) - } - - val enumSchema = - SchemaBuilder.enumeration(naming.name).doc(entityAnnotations.doc()) - .namespace(naming.namespace) - .defaultSymbol(defaultSymbol) - .symbols(*symbols.toTypedArray()) - - entityAnnotations.aliases().forEach { enumSchema.addAlias(it) } - - return enumSchema - } -} - -@ExperimentalSerializationApi -class ListSchemaFor( - private val descriptor: SerialDescriptor, - private val serializersModule: SerializersModule, - private val configuration: AvroInternalConfiguration, - private val resolvedSchemas: MutableMap, -) : SchemaFor { - override fun schema(): Schema { - val elementType = descriptor.getElementDescriptor(0) // don't use unwrapValueClass to prevent losing serial annotations - return when (descriptor.unwrapValueClass.getElementDescriptor(0).kind) { - PrimitiveKind.BYTE -> SchemaBuilder.builder().bytesType() - else -> { - val elementSchema = - schemaFor( - serializersModule, - elementType, - descriptor.getElementAnnotations(0), - configuration, - resolvedSchemas - ).schema() - return Schema.createArray(elementSchema) - } - } - } -} - -@ExperimentalSerializationApi -class MapSchemaFor( - private val descriptor: SerialDescriptor, - private val serializersModule: SerializersModule, - private val configuration: AvroInternalConfiguration, - private val resolvedSchemas: MutableMap, -) : SchemaFor { - override fun schema(): Schema { - val keyType = - descriptor.getElementDescriptor(0).unwrapValueClass.let { - if (it.kind == SerialKind.CONTEXTUAL) serializersModule.getContextualDescriptor(it)?.unwrapValueClass else it - } - if (keyType != null) { - if (keyType.kind is PrimitiveKind || keyType.kind == SerialKind.ENUM) { - val valueSchema = - schemaFor( - serializersModule, - descriptor.getElementDescriptor(1), - descriptor.getElementAnnotations(1), - configuration, - resolvedSchemas - ).schema() - return Schema.createMap(valueSchema) - } - } - throw SerializationException("Avro4k only supports primitive and enum kinds as the map key. Actual: ${descriptor.getElementDescriptor(0)}") - } -} - -@ExperimentalSerializationApi -class NullableSchemaFor( - private val schemaFor: SchemaFor, - private val annotations: List, -) : SchemaFor { - private val nullFirst by lazy { - // The default value can only be of the first type in the union definition. - // Therefore we have to check the default value in order to decide the order of types within the union. - // If no default is set, or if the default value is of type "null", nulls will be first. - val default = AnnotationExtractor(annotations).default() - default == null || default == Avro.NULL - } - - override fun schema(): Schema { - val elementSchema = schemaFor.schema() - val nullSchema = SchemaBuilder.builder().nullType() - return createSafeUnion(nullFirst, elementSchema, nullSchema) - } -} - -@OptIn(InternalSerializationApi::class) -@ExperimentalSerializationApi -fun schemaFor( - serializersModule: SerializersModule, - descriptor: SerialDescriptor, - annos: List, - configuration: AvroInternalConfiguration, - resolvedSchemas: MutableMap, -): SchemaFor { - val schemaFor: SchemaFor = - schemaForLogicalTypes(descriptor, annos, configuration)?.let(SchemaFor::const) - ?: when (descriptor.unwrapValueClass.kind) { - PrimitiveKind.STRING -> SchemaFor.StringSchemaFor - PrimitiveKind.LONG -> SchemaFor.LongSchemaFor - PrimitiveKind.INT -> SchemaFor.IntSchemaFor - PrimitiveKind.SHORT -> SchemaFor.ShortSchemaFor - PrimitiveKind.BYTE -> SchemaFor.ByteSchemaFor - PrimitiveKind.DOUBLE -> SchemaFor.DoubleSchemaFor - PrimitiveKind.FLOAT -> SchemaFor.FloatSchemaFor - PrimitiveKind.BOOLEAN -> SchemaFor.BooleanSchemaFor - SerialKind.ENUM -> EnumSchemaFor(descriptor, configuration) - SerialKind.CONTEXTUAL -> - schemaFor( - serializersModule, - requireNotNull( - serializersModule.getContextualDescriptor(descriptor.unwrapValueClass) - ?: descriptor.capturedKClass?.serializerOrNull()?.descriptor - ) { - "Contextual or default serializer not found for $descriptor " - }, - annos, - configuration, - resolvedSchemas - ) - - StructureKind.CLASS, StructureKind.OBJECT -> ClassSchemaFor(descriptor, configuration, serializersModule, resolvedSchemas) - StructureKind.LIST -> ListSchemaFor(descriptor, serializersModule, configuration, resolvedSchemas) - StructureKind.MAP -> MapSchemaFor(descriptor, serializersModule, configuration, resolvedSchemas) - is PolymorphicKind -> UnionSchemaFor(descriptor, configuration, serializersModule, resolvedSchemas) - else -> throw SerializationException("Unsupported type ${descriptor.serialName} of ${descriptor.kind}") - } - - return if (descriptor.isNullable) NullableSchemaFor(schemaFor, annos) else schemaFor -} - -@ExperimentalSerializationApi -private fun schemaForLogicalTypes( - descriptor: SerialDescriptor, - annos: List, - configuration: AvroInternalConfiguration, -): Schema? { - val annotations = - annos + descriptor.annotations + (if (descriptor.isInline) descriptor.unwrapValueClass.annotations else emptyList()) - - for (annotation in annotations) { - when (annotation) { - is AvroDecimalLogicalType -> { - val schema = - when (annotation.schema) { - LogicalDecimalTypeEnum.BYTES -> SchemaBuilder.builder().bytesType() - LogicalDecimalTypeEnum.STRING -> SchemaBuilder.builder().stringType() - LogicalDecimalTypeEnum.FIXED -> { - val fixedSize = - annotations.filterIsInstance().firstOrNull()?.size - ?: throw UnsupportedOperationException("Fixed size must be specified for FIXED decimal type with @AvroFixed annotation") - createFixedSchema(descriptor, fixedSize, configuration) - } - } - return LogicalTypes.decimal(annotation.precision, annotation.scale).addToSchema(schema) - } - is AvroUuidLogicalType -> return LogicalTypes.uuid().addToSchema(SchemaBuilder.builder().stringType()) - is AvroTimeLogicalType -> return annotation.type.schemaFor() - is AvroFixed -> return createFixedSchema(descriptor, annotation.size, configuration) - } - } - return null -} - -@OptIn(ExperimentalSerializationApi::class) -private fun createFixedSchema( - descriptor: SerialDescriptor, - fixedSize: Int, - configuration: AvroInternalConfiguration, -): Schema { - return configuration.recordNamingStrategy.resolve(descriptor, descriptor.serialName).let { - SchemaBuilder.fixed(it.name).namespace(it.namespace).size(fixedSize) - } -} - -// copy-paste from kotlinx serialization because it internal -@ExperimentalSerializationApi -internal val SerialDescriptor.unwrapValueClass: SerialDescriptor - get() = if (isInline) getElementDescriptor(0) else this \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/schema/SerialDescriptorVisitor.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/SerialDescriptorVisitor.kt new file mode 100644 index 00000000..60551f7f --- /dev/null +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/SerialDescriptorVisitor.kt @@ -0,0 +1,206 @@ +package com.github.avrokotlin.avro4k.schema + +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.InternalSerializationApi +import kotlinx.serialization.descriptors.PolymorphicKind +import kotlinx.serialization.descriptors.PrimitiveKind +import kotlinx.serialization.descriptors.SerialDescriptor +import kotlinx.serialization.descriptors.SerialKind +import kotlinx.serialization.descriptors.StructureKind +import kotlinx.serialization.descriptors.capturedKClass +import kotlinx.serialization.descriptors.elementDescriptors +import kotlinx.serialization.descriptors.getContextualDescriptor +import kotlinx.serialization.descriptors.getPolymorphicDescriptors +import kotlinx.serialization.modules.SerializersModule +import kotlinx.serialization.serializerOrNull + +@ExperimentalSerializationApi +interface SerialDescriptorValueVisitor { + val serializersModule: SerializersModule + + /** + * Called when the [descriptor]'s kind is a [PrimitiveKind]. + */ + fun visitPrimitive( + descriptor: SerialDescriptor, + kind: PrimitiveKind, + ) + + /** + * Called when the [descriptor]'s kind is an [SerialKind.ENUM]. + */ + fun visitEnum(descriptor: SerialDescriptor) + + /** + * Called when the [descriptor]'s kind is an [StructureKind.OBJECT]. + */ + fun visitObject(descriptor: SerialDescriptor) + + /** + * Called when the [descriptor]'s kind is a [PolymorphicKind]. + * @return null if we don't want to visit the polymorphic type + */ + fun visitPolymorphic( + descriptor: SerialDescriptor, + kind: PolymorphicKind, + ): SerialDescriptorPolymorphicVisitor? + + /** + * Called when the [descriptor]'s kind is a [StructureKind.CLASS]. + * Note that when the [descriptor] is an inline class, [visitInlineClass] is called instead. + * @return null if we don't want to visit the class + */ + fun visitClass(descriptor: SerialDescriptor): SerialDescriptorClassVisitor? + + /** + * Called when the [descriptor]'s kind is a [StructureKind.LIST]. + * @return null if we don't want to visit the list + */ + fun visitList(descriptor: SerialDescriptor): SerialDescriptorListVisitor? + + /** + * Called when the [descriptor]'s kind is a [StructureKind.MAP]. + * @return null if we don't want to visit the map + */ + fun visitMap(descriptor: SerialDescriptor): SerialDescriptorMapVisitor? + + /** + * Called when the [descriptor] is about a value class (e.g. its kind is a [StructureKind.CLASS] and [SerialDescriptor.isInline] is true). + * @return null if we don't want to visit the inline class + */ + fun visitInlineClass(descriptor: SerialDescriptor): SerialDescriptorInlineClassVisitor? + + fun visitValue(descriptor: SerialDescriptor) { + if (descriptor.isInline) { + visitInlineClass(descriptor)?.apply { + visitInlineClassElement(descriptor, 0)?.visitValue(descriptor.getElementDescriptor(0)) + } + } else { + when (descriptor.kind) { + is PrimitiveKind -> visitPrimitive(descriptor, descriptor.kind as PrimitiveKind) + SerialKind.ENUM -> visitEnum(descriptor) + SerialKind.CONTEXTUAL -> visitValue(descriptor.getNonNullContextualDescriptor(serializersModule)) + StructureKind.CLASS -> + visitClass(descriptor)?.apply { + for (elementIndex in (0 until descriptor.elementsCount)) { + visitClassElement(descriptor, elementIndex)?.visitValue(descriptor.getElementDescriptor(elementIndex)) + } + }?.endClassVisit(descriptor) + + StructureKind.LIST -> + visitList(descriptor)?.apply { + visitListItem(descriptor, 0)?.visitValue(descriptor.getElementDescriptor(0)) + }?.endListVisit(descriptor) + + StructureKind.MAP -> + visitMap(descriptor)?.apply { + visitMapKey(descriptor, 0)?.visitValue(descriptor.getElementDescriptor(0)) + visitMapValue(descriptor, 1)?.visitValue(descriptor.getElementDescriptor(1)) + }?.endMapVisit(descriptor) + + is PolymorphicKind -> + visitPolymorphic(descriptor, descriptor.kind as PolymorphicKind)?.apply { + descriptor.possibleSerializationSubclasses(serializersModule).sortedBy { it.serialName }.forEach { implementationDescriptor -> + visitPolymorphicFoundDescriptor(implementationDescriptor)?.visitValue(implementationDescriptor) + } + }?.endPolymorphicVisit(descriptor) + + StructureKind.OBJECT -> visitObject(descriptor) + } + } + } +} + +@ExperimentalSerializationApi +interface SerialDescriptorMapVisitor { + /** + * @return null if we don't want to visit the map key + */ + fun visitMapKey( + mapDescriptor: SerialDescriptor, + keyElementIndex: Int, + ): SerialDescriptorValueVisitor? + + /** + * @return null if we don't want to visit the map value + */ + fun visitMapValue( + mapDescriptor: SerialDescriptor, + valueElementIndex: Int, + ): SerialDescriptorValueVisitor? + + fun endMapVisit(descriptor: SerialDescriptor) +} + +@ExperimentalSerializationApi +interface SerialDescriptorListVisitor { + /** + * @return null if we don't want to visit the list item + */ + fun visitListItem( + listDescriptor: SerialDescriptor, + itemElementIndex: Int, + ): SerialDescriptorValueVisitor? + + fun endListVisit(descriptor: SerialDescriptor) +} + +@ExperimentalSerializationApi +interface SerialDescriptorPolymorphicVisitor { + /** + * @return null if we don't want to visit the found polymorphic descriptor + */ + fun visitPolymorphicFoundDescriptor(descriptor: SerialDescriptor): SerialDescriptorValueVisitor? + + fun endPolymorphicVisit(descriptor: SerialDescriptor) +} + +@ExperimentalSerializationApi +interface SerialDescriptorClassVisitor { + /** + * @return null if we don't want to visit the class element + */ + fun visitClassElement( + descriptor: SerialDescriptor, + elementIndex: Int, + ): SerialDescriptorValueVisitor? + + fun endClassVisit(descriptor: SerialDescriptor) +} + +@ExperimentalSerializationApi +interface SerialDescriptorInlineClassVisitor { + /** + * @return null if we don't want to visit the inline class element + */ + fun visitInlineClassElement( + inlineClassDescriptor: SerialDescriptor, + inlineElementIndex: Int, + ): SerialDescriptorValueVisitor? +} + +@ExperimentalSerializationApi +@OptIn(InternalSerializationApi::class) +private fun SerialDescriptor.getNonNullContextualDescriptor(serializersModule: SerializersModule) = + requireNotNull(serializersModule.getContextualDescriptor(this) ?: this.capturedKClass?.serializerOrNull()?.descriptor) { + "No descriptor found in serialization context for $this" + } + +@ExperimentalSerializationApi +private fun SerialDescriptor.possibleSerializationSubclasses(serializersModule: SerializersModule): Sequence { + return when (this.kind) { + PolymorphicKind.SEALED -> + elementDescriptors.asSequence() + .filter { it.kind == SerialKind.CONTEXTUAL } + .flatMap { it.elementDescriptors } + .flatMap { it.possibleSerializationSubclasses(serializersModule) } + + PolymorphicKind.OPEN -> + serializersModule.getPolymorphicDescriptors(this@possibleSerializationSubclasses).asSequence() + .flatMap { it.possibleSerializationSubclasses(serializersModule) } + + SerialKind.CONTEXTUAL -> sequenceOf(getNonNullContextualDescriptor(serializersModule)) + + else -> sequenceOf(this) + } +} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/schema/UnionSchemaFor.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/UnionSchemaFor.kt deleted file mode 100644 index 130635c1..00000000 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/schema/UnionSchemaFor.kt +++ /dev/null @@ -1,26 +0,0 @@ -package com.github.avrokotlin.avro4k.schema - -import com.github.avrokotlin.avro4k.AvroInternalConfiguration -import com.github.avrokotlin.avro4k.possibleSerializationSubclasses -import kotlinx.serialization.ExperimentalSerializationApi -import kotlinx.serialization.descriptors.SerialDescriptor -import kotlinx.serialization.modules.SerializersModule -import org.apache.avro.Schema - -@ExperimentalSerializationApi -class UnionSchemaFor( - private val descriptor: SerialDescriptor, - private val configuration: AvroInternalConfiguration, - private val serializersModule: SerializersModule, - private val resolvedSchemas: MutableMap, -) : SchemaFor { - override fun schema(): Schema { - val leafSerialDescriptors = - descriptor.possibleSerializationSubclasses(serializersModule).sortedBy { it.serialName } - return Schema.createUnion( - leafSerialDescriptors.map { - ClassSchemaFor(it, configuration, serializersModule, resolvedSchemas).schema() - } - ) - } -} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/schema/ValueVisitor.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/ValueVisitor.kt new file mode 100644 index 00000000..13f2e1a0 --- /dev/null +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/ValueVisitor.kt @@ -0,0 +1,147 @@ +package com.github.avrokotlin.avro4k.schema + +import com.github.avrokotlin.avro4k.Avro +import com.github.avrokotlin.avro4k.AvroFixed +import com.github.avrokotlin.avro4k.AvroLogicalType +import com.github.avrokotlin.avro4k.AvroSchema +import kotlinx.serialization.descriptors.PolymorphicKind +import kotlinx.serialization.descriptors.PrimitiveKind +import kotlinx.serialization.descriptors.SerialDescriptor +import kotlinx.serialization.descriptors.SerialKind +import kotlinx.serialization.json.Json +import kotlinx.serialization.modules.SerializersModule +import org.apache.avro.LogicalType +import org.apache.avro.Schema +import org.apache.avro.SchemaBuilder +import kotlin.reflect.KClass + +internal class ValueVisitor internal constructor( + override val context: VisitorContext, + private val onSchemaBuilt: (Schema) -> Unit, +) : SerialDescriptorValueVisitor, AvroVisitorContextAware { + private var isNullable: Boolean = false + private var logicalType: LogicalType? = null + + override val serializersModule: SerializersModule + get() = context.avro.serializersModule + + constructor(avro: Avro, onSchemaBuilt: (Schema) -> Unit) : this( + VisitorContext( + avro, + mutableMapOf(), + Json { serializersModule = avro.serializersModule } + ), + onSchemaBuilt = onSchemaBuilt + ) + + override fun visitPrimitive( + descriptor: SerialDescriptor, + kind: PrimitiveKind, + ) = setSchema(Schema.create(kind.toAvroType())) + + override fun visitEnum(descriptor: SerialDescriptor) { + val enumName = descriptor.getAvroName() + + val annotations = TypeAnnotations(descriptor) + val schema = + SchemaBuilder.enumeration(enumName.name) + .namespace(enumName.namespace) + .doc(annotations.doc?.value) + .defaultSymbol(annotations.enumDefault?.value) + .symbols(*descriptor.elementNamesArray) + + annotations.aliases?.value?.forEach { schema.addAlias(it) } + annotations.props.forEach { schema.addProp(it.key, it.value) } + annotations.jsonProps.forEach { schema.addProp(it.key, it.jsonNode) } + + setSchema(schema) + } + + private val SerialDescriptor.elementNamesArray: Array + get() = Array(elementsCount) { getElementName(it) } + + override fun visitObject(descriptor: SerialDescriptor) { + // we consider objects as records without fields since the beginning. Is it really a good idea ? + visitClass(descriptor).endClassVisit(descriptor) + } + + override fun visitClass(descriptor: SerialDescriptor) = ClassVisitor(descriptor, context.resetNesting()) { setSchema(it) } + + override fun visitPolymorphic( + descriptor: SerialDescriptor, + kind: PolymorphicKind, + ) = PolymorphicVisitor(context) { setSchema(it) } + + override fun visitList(descriptor: SerialDescriptor) = ListVisitor(context.copy(inlinedAnnotations = null)) { setSchema(it) } + + override fun visitMap(descriptor: SerialDescriptor) = MapVisitor(context.copy(inlinedAnnotations = null)) { setSchema(it) } + + override fun visitInlineClass(descriptor: SerialDescriptor) = InlineClassVisitor(context) { setSchema(it) } + + private fun setSchema(schema: Schema) { + val finalSchema = logicalType?.addToSchema(schema) ?: schema + if (isNullable && !finalSchema.isNullable) { + onSchemaBuilt(finalSchema.toNullableSchema()) + } else { + onSchemaBuilt(finalSchema) + } + } + + private fun visitByteArray() { + setSchema(Schema.create(Schema.Type.BYTES)) + } + + private fun visitFixed(fixed: AnnotatedElementOrType) { + val parentFieldName = + fixed.elementIndex?.let { fixed.descriptor.getElementName(it) } + ?: throw AvroSchemaGenerationException("@AvroFixed must be used on a field") + val parentNamespace = fixed.descriptor.getAvroName().namespace + + setSchema( + SchemaBuilder.fixed(parentFieldName) + .namespace(parentNamespace) + .size(fixed.annotation.size) + ) + } + + override fun visitValue(descriptor: SerialDescriptor) { + if (descriptor.isNullable) { + isNullable = true + } + if (descriptor.kind == SerialKind.CONTEXTUAL) { + super.visitValue(descriptor) + return + } + val annotations = context.inlinedAnnotations.appendAnnotations(ValueAnnotations(descriptor)) + + if (annotations.logicalType != null) { + logicalType = annotations.logicalType.getLogicalType(annotations) + } + when { + annotations.customSchema != null -> setSchema(annotations.customSchema.getSchema(annotations)) + annotations.fixed != null -> visitFixed(annotations.fixed) + descriptor.isByteArray() -> visitByteArray() + else -> super.visitValue(descriptor) + } + } + + private fun AnnotatedElementOrType.getLogicalType(valueAnnotations: ValueAnnotations): LogicalType { + return this.annotation.value.newObjectInstance().getLogicalType(valueAnnotations.stack) + } + + private fun AnnotatedElementOrType.getSchema(valueAnnotations: ValueAnnotations): Schema { + return this.annotation.value.newObjectInstance().getSchema(valueAnnotations.stack) + } +} + +private fun KClass.newObjectInstance(): T { + return this.objectInstance ?: throw AvroSchemaGenerationException("${this.qualifiedName} must be an object") +} + +private fun Schema.toNullableSchema(): Schema { + return if (this.type == Schema.Type.UNION) { + Schema.createUnion(listOf(Schema.create(Schema.Type.NULL)) + this.types) + } else { + Schema.createUnion(Schema.create(Schema.Type.NULL), this) + } +} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/schema/VisitorContext.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/VisitorContext.kt new file mode 100644 index 00000000..e74695e8 --- /dev/null +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/VisitorContext.kt @@ -0,0 +1,247 @@ +package com.github.avrokotlin.avro4k.schema + +import com.fasterxml.jackson.databind.JsonNode +import com.fasterxml.jackson.databind.node.ArrayNode +import com.fasterxml.jackson.databind.node.JsonNodeFactory +import com.fasterxml.jackson.databind.node.NullNode +import com.fasterxml.jackson.databind.node.ObjectNode +import com.fasterxml.jackson.databind.node.TextNode +import com.github.avrokotlin.avro4k.AnnotatedLocation +import com.github.avrokotlin.avro4k.Avro +import com.github.avrokotlin.avro4k.AvroAlias +import com.github.avrokotlin.avro4k.AvroDefault +import com.github.avrokotlin.avro4k.AvroDoc +import com.github.avrokotlin.avro4k.AvroEnumDefault +import com.github.avrokotlin.avro4k.AvroFixed +import com.github.avrokotlin.avro4k.AvroJsonProp +import com.github.avrokotlin.avro4k.AvroLogicalType +import com.github.avrokotlin.avro4k.AvroNamespaceOverride +import com.github.avrokotlin.avro4k.AvroProp +import com.github.avrokotlin.avro4k.AvroSchema +import kotlinx.serialization.descriptors.SerialDescriptor +import kotlinx.serialization.descriptors.SerialKind +import kotlinx.serialization.descriptors.StructureKind +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonArray +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonNull +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.boolean +import kotlinx.serialization.json.booleanOrNull +import org.apache.avro.JsonProperties +import org.apache.avro.Schema + +internal data class VisitorContext( + val avro: Avro, + val resolvedSchemas: MutableMap, + val json: Json, + val inlinedAnnotations: ValueAnnotations? = null, +) + +internal fun VisitorContext.resetNesting() = copy(inlinedAnnotations = null) + +internal interface AvroVisitorContextAware { + val context: VisitorContext +} + +/** + * Contains all the annotations for a field of a class (kind == CLASS && isInline == true). + */ +internal data class InlineClassFieldAnnotations( + val namespaceOverride: AvroNamespaceOverride?, +) { + constructor(descriptor: SerialDescriptor, elementIndex: Int) : this( + descriptor.findElementAnnotation(elementIndex) + ) { + require(descriptor.isInline) { + "${InlineClassFieldAnnotations::class.qualifiedName} is only for inline classes, but trying at element index $elementIndex with non-inline class descriptor $descriptor" + } + } +} + +/** + * Contains all the annotations for a field of a class (kind == CLASS && isInline == false). + */ +internal data class FieldAnnotations( + val props: Sequence, + val jsonProps: Sequence, + val aliases: Sequence, + val doc: AvroDoc?, + val default: AvroDefault?, + val namespaceOverride: AvroNamespaceOverride?, +) { + constructor(descriptor: SerialDescriptor, elementIndex: Int) : this( + descriptor.findElementAnnotations(elementIndex).asSequence(), + descriptor.findElementAnnotations(elementIndex).asSequence(), + descriptor.findElementAnnotations(elementIndex).asSequence(), + descriptor.findElementAnnotation(elementIndex), + descriptor.findElementAnnotation(elementIndex), + descriptor.findElementAnnotation(elementIndex) + ) { + require(descriptor.kind == StructureKind.CLASS) { + "${FieldAnnotations::class.qualifiedName} is only for classes, but trying at element index $elementIndex with non class descriptor $descriptor" + } + } +} + +/** + * Contains all the annotations for a field of a class, inline or not (kind == CLASS). + * Helpful when nesting multiple inline classes to get the first annotation. + */ +internal data class ValueAnnotations( + val stack: List, + val fixed: AnnotatedElementOrType?, + val customSchema: AnnotatedElementOrType?, + val logicalType: AnnotatedElementOrType?, +) { + constructor(descriptor: SerialDescriptor, elementIndex: Int) : this( + listOf(SimpleAnnotatedLocation(descriptor, elementIndex)), + AnnotatedElementOrType(descriptor, elementIndex), + AnnotatedElementOrType(descriptor, elementIndex), + AnnotatedElementOrType(descriptor, elementIndex) + ) + + constructor(descriptor: SerialDescriptor) : this( + listOf(SimpleAnnotatedLocation(descriptor)), + AnnotatedElementOrType(descriptor), + AnnotatedElementOrType(descriptor), + AnnotatedElementOrType(descriptor) + ) +} + +internal data class AnnotatedElementOrType( + override val descriptor: SerialDescriptor, + override val elementIndex: Int?, + val annotation: T, +) : AnnotatedLocation { + companion object { + inline operator fun invoke( + descriptor: SerialDescriptor, + elementIndex: Int, + ) = descriptor.findElementAnnotation(elementIndex)?.let { AnnotatedElementOrType(descriptor, elementIndex, it) } + + inline operator fun invoke(descriptor: SerialDescriptor) = descriptor.findAnnotation()?.let { AnnotatedElementOrType(descriptor, null, it) } + } +} + +internal data class SimpleAnnotatedLocation( + override val descriptor: SerialDescriptor, + override val elementIndex: Int? = null, +) : AnnotatedLocation + +/** + * Contains all the annotations for a class, object or enum (kind == CLASS || kind == OBJECT || kind == ENUM). + */ +internal data class TypeAnnotations( + val props: Sequence, + val jsonProps: Sequence, + val aliases: AvroAlias?, + val doc: AvroDoc?, + val enumDefault: AvroEnumDefault?, +) { + constructor(descriptor: SerialDescriptor) : this( + descriptor.findAnnotations().asSequence(), + descriptor.findAnnotations().asSequence(), + descriptor.findAnnotation(), + descriptor.findAnnotation(), + descriptor.findAnnotation() + ) { + if (enumDefault != null) { + require(descriptor.kind == SerialKind.ENUM) { "@AvroEnumDefault can only be used on enums. Actual: $descriptor" } + } else { + require(descriptor.kind == StructureKind.CLASS || descriptor.kind == StructureKind.OBJECT || descriptor.kind == SerialKind.ENUM) { + "TypeAnnotations are only for classes, objects and enums. Actual: $descriptor" + } + } + } +} + +/** + * Keep the top-est annotation. If the current element details annotation is null, it will be replaced by the new annotation. + * If the current element details annotation is not null, it will be kept. + */ +internal fun ValueAnnotations?.appendAnnotations(other: ValueAnnotations) = + ValueAnnotations( + fixed = this?.fixed ?: other.fixed, + logicalType = this?.logicalType ?: other.logicalType, + customSchema = this?.customSchema ?: other.customSchema, + stack = (this?.stack ?: emptyList()) + other.stack + ) + +context(AvroVisitorContextAware) +internal val AvroJsonProp.jsonNode: JsonNode + get() { + if (jsonValue.isStartingAsJson()) { + return context.json.parseToJsonElement(jsonValue).toJacksonNode() + } + return TextNode.valueOf(jsonValue) + } + +context(AvroVisitorContextAware) +internal val AvroDefault.jsonValue: Any + get() { + if (value.isStartingAsJson()) { + return context.json.parseToJsonElement(value).toAvroObject() + } + return value + } + +/** + * Returns true if the given content is starting with `"`, {`, `[`, a digit or equals to `null`. + * It doesn't check if the content is valid json. + * It skips the whitespaces at the beginning of the content. + */ +internal fun String.isStartingAsJson(): Boolean { + val i = this.indexOfFirst { !it.isWhitespace() } + if (i == -1) { + return false + } + val c = this[i] + return c == '{' || c == '"' || c.isDigit() || c == '[' || this == "null" || this == "true" || this == "false" +} + +private fun JsonElement.toAvroObject(): Any = + when (this) { + is JsonNull -> JsonProperties.NULL_VALUE + is JsonObject -> this.entries.associate { it.key to it.value.toAvroObject() } + is JsonArray -> this.map { it.toAvroObject() } + is JsonPrimitive -> + when { + this.isString -> this.content + this.booleanOrNull != null -> this.boolean + else -> { + this.content.toBigDecimal().stripTrailingZeros().let { + if (it.scale() <= 0) it.toBigInteger() else it + } + } + } + } + +private fun JsonElement.toJacksonNode(): JsonNode = + when (this) { + is JsonNull -> NullNode.instance + is JsonObject -> ObjectNode(JsonNodeFactory.instance, this.entries.associate { it.key to it.value.toJacksonNode() }) + is JsonArray -> ArrayNode(JsonNodeFactory.instance, this.map { it.toJacksonNode() }) + is JsonPrimitive -> + when { + this.isString -> JsonNodeFactory.instance.textNode(this.content) + this.booleanOrNull != null -> JsonNodeFactory.instance.booleanNode(this.boolean) + else -> + this.content.toBigDecimal().let { + if (it.scale() <= 0) JsonNodeFactory.instance.numberNode(it.toBigInteger()) else JsonNodeFactory.instance.numberNode(it) + } + } + } + +/** + * Get the record/enum name using the configured record naming strategy. + */ +context(AvroVisitorContextAware) +internal fun SerialDescriptor.getAvroName() = context.avro.configuration.recordNamingStrategy.resolve(this, serialName) + +/** + * Get the field name using the configured field naming strategy. + */ +context(AvroVisitorContextAware) +internal fun SerialDescriptor.getElementAvroName(elementIndex: Int) = context.avro.configuration.fieldNamingStrategy.resolve(this, elementIndex, getElementName(elementIndex)) \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/schema/helpers.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/helpers.kt new file mode 100644 index 00000000..1a641704 --- /dev/null +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/helpers.kt @@ -0,0 +1,65 @@ +package com.github.avrokotlin.avro4k.schema + +import kotlinx.serialization.descriptors.PrimitiveKind +import kotlinx.serialization.descriptors.SerialDescriptor +import kotlinx.serialization.descriptors.StructureKind +import org.apache.avro.Schema + +inline fun SerialDescriptor.findAnnotation() = annotations.asSequence().filterIsInstance().firstOrNull() + +inline fun SerialDescriptor.findAnnotations() = annotations.filterIsInstance() + +inline fun SerialDescriptor.findElementAnnotation(elementIndex: Int) = getElementAnnotations(elementIndex).asSequence().filterIsInstance().firstOrNull() + +inline fun SerialDescriptor.findElementAnnotations(elementIndex: Int) = getElementAnnotations(elementIndex).filterIsInstance() + +internal fun Schema.extractNonNull(): Schema = + when (this.type) { + Schema.Type.UNION -> this.types.filter { it.type != Schema.Type.NULL }.let { if (it.size > 1) Schema.createUnion(it) else it[0] } + else -> this + } + +/** + * Overrides the namespace of a [Schema] with the given namespace. + */ +internal fun Schema.overrideNamespace(namespaceOverride: String): Schema { + return when (type) { + Schema.Type.RECORD -> { + val fields = + fields.map { field -> + Schema.Field( + field.name(), + field.schema().overrideNamespace(namespaceOverride), + field.doc(), + field.defaultVal(), + field.order() + ) + } + val copy = Schema.createRecord(name, doc, namespaceOverride, isError, fields) + aliases.forEach { copy.addAlias(it) } + this.objectProps.forEach { copy.addProp(it.key, it.value) } + copy + } + Schema.Type.UNION -> Schema.createUnion(types.map { it.overrideNamespace(namespaceOverride) }) + Schema.Type.ENUM -> Schema.createEnum(name, doc, namespaceOverride, enumSymbols, enumDefault) + Schema.Type.FIXED -> Schema.createFixed(name, doc, namespaceOverride, fixedSize) + Schema.Type.MAP -> Schema.createMap(valueType.overrideNamespace(namespaceOverride)) + Schema.Type.ARRAY -> Schema.createArray(elementType.overrideNamespace(namespaceOverride)) + else -> this + } +} + +internal fun SerialDescriptor.isByteArray(): Boolean = kind == StructureKind.LIST && getElementDescriptor(0).let { !it.isNullable && it.kind == PrimitiveKind.BYTE } + +internal fun PrimitiveKind.toAvroType() = + when (this) { + PrimitiveKind.BOOLEAN -> Schema.Type.BOOLEAN + PrimitiveKind.CHAR -> Schema.Type.INT + PrimitiveKind.BYTE -> Schema.Type.INT + PrimitiveKind.SHORT -> Schema.Type.INT + PrimitiveKind.INT -> Schema.Type.INT + PrimitiveKind.LONG -> Schema.Type.LONG + PrimitiveKind.FLOAT -> Schema.Type.FLOAT + PrimitiveKind.DOUBLE -> Schema.Type.DOUBLE + PrimitiveKind.STRING -> Schema.Type.STRING + } \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/schema/schemas.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/schema/schemas.kt deleted file mode 100644 index c584d89d..00000000 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/schema/schemas.kt +++ /dev/null @@ -1,52 +0,0 @@ -package com.github.avrokotlin.avro4k.schema - -import org.apache.avro.Schema - -// creates a union schema type, with nested unions extracted, and duplicate nulls stripped -// union schemas can't contain other union schemas as a direct -// child, so whenever we create a union, we need to check if our -// children are unions and flatten -fun createSafeUnion( - nullFirst: Boolean, - vararg schemas: Schema, -): Schema { - val flattened = schemas.flatMap { schema -> runCatching { schema.types }.getOrElse { listOf(schema) } } - val (nulls, rest) = flattened.partition { it.type == Schema.Type.NULL } - return Schema.createUnion(if (nullFirst) nulls + rest else rest + nulls) -} - -fun Schema.extractNonNull(): Schema = - when (this.type) { - Schema.Type.UNION -> this.types.filter { it.type != Schema.Type.NULL }.let { if (it.size > 1) Schema.createUnion(it) else it[0] } - else -> this - } - -/** - * Overrides the namespace of a [Schema] with the given namespace. - */ -fun Schema.overrideNamespace(namespace: String): Schema { - return when (type) { - Schema.Type.RECORD -> { - val fields = - fields.map { field -> - Schema.Field( - field.name(), - field.schema().overrideNamespace(namespace), - field.doc(), - field.defaultVal(), - field.order() - ) - } - val copy = Schema.createRecord(name, doc, namespace, isError, fields) - aliases.forEach { copy.addAlias(it) } - this.objectProps.forEach { copy.addProp(it.key, it.value) } - copy - } - Schema.Type.UNION -> Schema.createUnion(types.map { it.overrideNamespace(namespace) }) - Schema.Type.ENUM -> Schema.createEnum(name, doc, namespace, enumSymbols, enumDefault) - Schema.Type.FIXED -> Schema.createFixed(name, doc, namespace, fixedSize) - Schema.Type.MAP -> Schema.createMap(valueType.overrideNamespace(namespace)) - Schema.Type.ARRAY -> Schema.createArray(elementType.overrideNamespace(namespace)) - else -> this - } -} \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/AvroSerializer.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/AvroSerializer.kt index 1aed0a65..58a5469e 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/AvroSerializer.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/AvroSerializer.kt @@ -15,23 +15,26 @@ abstract class AvroSerializer : KSerializer { encoder: Encoder, value: T, ) { - val schema = (encoder as FieldEncoder).fieldSchema() - // we may be encoding a nullable schema - val subschema = - when (schema.type) { - Schema.Type.UNION -> schema.extractNonNull() - else -> schema + val schema = + (encoder as FieldEncoder).fieldSchema().let { + if (!this.descriptor.isNullable && it.isNullable) { + it.extractNonNull() + } else { + it + } } - encodeAvroValue(subschema, encoder, value) + encodeAvroValue(schema, encoder, value) } final override fun deserialize(decoder: Decoder): T { - val schema = (decoder as FieldDecoder).fieldSchema() -// // we may be coming from a nullable schema aka a union -// val subschema = when (schema.type) { -// Schema.Type.UNION -> schema.extractNonNull() -// else -> schema -// } + val schema = + (decoder as FieldDecoder).fieldSchema().let { + if (!this.descriptor.isNullable && it.isNullable) { + it.extractNonNull() + } else { + it + } + } return decodeAvroValue(schema, decoder) } diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/BigDecimalSerializer.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/BigDecimalSerializer.kt index 04be45f9..209e42c9 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/BigDecimalSerializer.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/BigDecimalSerializer.kt @@ -1,74 +1,112 @@ package com.github.avrokotlin.avro4k.serializer -import com.github.avrokotlin.avro4k.AvroDecimalLogicalType +import com.github.avrokotlin.avro4k.AnnotatedLocation +import com.github.avrokotlin.avro4k.AvroDecimal +import com.github.avrokotlin.avro4k.AvroLogicalType +import com.github.avrokotlin.avro4k.AvroLogicalTypeSupplier import com.github.avrokotlin.avro4k.decoder.ExtendedDecoder import com.github.avrokotlin.avro4k.encoder.ExtendedEncoder -import kotlinx.serialization.ExperimentalSerializationApi -import kotlinx.serialization.InternalSerializationApi +import com.github.avrokotlin.avro4k.schema.findElementAnnotation import kotlinx.serialization.SerializationException -import kotlinx.serialization.descriptors.StructureKind -import kotlinx.serialization.descriptors.buildSerialDescriptor +import kotlinx.serialization.descriptors.PrimitiveKind +import kotlinx.serialization.descriptors.PrimitiveSerialDescriptor import org.apache.avro.Conversions +import org.apache.avro.LogicalType import org.apache.avro.LogicalTypes import org.apache.avro.Schema import org.apache.avro.generic.GenericFixed import java.math.BigDecimal import java.nio.ByteBuffer -@OptIn(ExperimentalSerializationApi::class) -class BigDecimalSerializer : AvroSerializer() { - private val converter = Conversions.DecimalConversion() +private val converter = Conversions.DecimalConversion() +private val defaultAnnotation = AvroDecimal() + +object BigDecimalSerializer : AvroSerializer(), AvroLogicalTypeSupplier { + override fun getLogicalType(inlinedStack: List): LogicalType { + return inlinedStack.firstNotNullOfOrNull { + it.descriptor.findElementAnnotation(it.elementIndex ?: return@firstNotNullOfOrNull null)?.logicalType + } ?: defaultAnnotation.logicalType + } - @OptIn(InternalSerializationApi::class) override val descriptor = - buildSerialDescriptor(BigDecimal::class.qualifiedName!!, StructureKind.OBJECT) { - annotations = listOf(AvroDecimalLogicalType()) - } + buildByteArraySerialDescriptor( + BigDecimal::class.qualifiedName!!, + AvroLogicalType(BigDecimalSerializer::class) + ) override fun encodeAvroValue( schema: Schema, encoder: ExtendedEncoder, obj: BigDecimal, - ) { - // we support encoding big decimals in three ways - fixed, bytes or as a String, depending on the schema passed in - // the scale and precision should come from the schema and the rounding mode from the implicit - - return when (schema.type) { - Schema.Type.STRING -> encoder.encodeString(obj.toString()) - Schema.Type.BYTES -> { - when (val logical = schema.logicalType) { - is LogicalTypes.Decimal -> encoder.encodeByteArray(converter.toBytes(obj, schema, logical)) - else -> throw SerializationException("Cannot encode BigDecimal to FIXED for logical type $logical") - } - } - - Schema.Type.FIXED -> { - when (val logical = schema.logicalType) { - is LogicalTypes.Decimal -> encoder.encodeFixed(converter.toFixed(obj, schema, logical)) - else -> throw SerializationException("Cannot encode BigDecimal to FIXED for logical type $logical") - } - } - - else -> throw SerializationException("Cannot encode BigDecimal as ${schema.type}") + ) = encodeBigDecimal(schema, encoder, obj) + + override fun decodeAvroValue( + schema: Schema, + decoder: ExtendedDecoder, + ) = decodeBigDecimal(decoder, schema) + + private val AvroDecimal.logicalType: LogicalType + get() { + return LogicalTypes.decimal(precision, scale) } - } +} + +object BigDecimalAsStringSerializer : AvroSerializer() { + override val descriptor = PrimitiveSerialDescriptor(BigDecimal::class.qualifiedName!!, PrimitiveKind.STRING) + + override fun encodeAvroValue( + schema: Schema, + encoder: ExtendedEncoder, + obj: BigDecimal, + ) = encodeBigDecimal(schema, encoder, obj) override fun decodeAvroValue( schema: Schema, decoder: ExtendedDecoder, - ): BigDecimal { - fun logical() = - when (val l = schema.logicalType) { - is LogicalTypes.Decimal -> l - else -> throw SerializationException("Cannot decode to BigDecimal when field schema [$schema] does not define Decimal logical type [$l]") - } - - return when (val v = decoder.decodeAny()) { - is CharSequence -> BigDecimal(v.toString()) - is ByteArray -> converter.fromBytes(ByteBuffer.wrap(v), schema, logical()) - is ByteBuffer -> converter.fromBytes(v, schema, logical()) - is GenericFixed -> converter.fromFixed(v, schema, logical()) - else -> throw SerializationException("Unsupported BigDecimal type [$v]") + ) = decodeBigDecimal(decoder, schema) +} + +private fun encodeBigDecimal( + schema: Schema, + encoder: ExtendedEncoder, + value: BigDecimal, +) { + when (schema.type) { + Schema.Type.STRING -> encoder.encodeString(value.toString()) + Schema.Type.BYTES -> { + encoder.encodeByteArray(converter.toBytes(value, schema, schema.getDecimalLogicalType())) + } + + Schema.Type.FIXED -> { + encoder.encodeFixed(converter.toFixed(value, schema, schema.getDecimalLogicalType())) } + + Schema.Type.INT -> encoder.encodeInt(value.intValueExact()) + Schema.Type.LONG -> encoder.encodeLong(value.longValueExact()) + Schema.Type.FLOAT -> encoder.encodeFloat(value.toFloat()) + Schema.Type.DOUBLE -> encoder.encodeDouble(value.toDouble()) + + else -> throw SerializationException("Cannot encode BigDecimal as ${schema.type}") + } +} + +private fun decodeBigDecimal( + decoder: ExtendedDecoder, + schema: Schema, +): BigDecimal = + // TODO we should use the schema instead of this generic decodeAny() + when (val v = decoder.decodeAny()) { + is CharSequence -> BigDecimal(v.toString()) + is ByteArray -> converter.fromBytes(ByteBuffer.wrap(v), schema, schema.getDecimalLogicalType()) + is ByteBuffer -> converter.fromBytes(v, schema, schema.getDecimalLogicalType()) + is GenericFixed -> converter.fromFixed(v, schema, schema.getDecimalLogicalType()) + else -> throw SerializationException("Unsupported BigDecimal type [$v]") + } + +private fun Schema.getDecimalLogicalType(): LogicalTypes.Decimal { + val l = logicalType + return when (l) { + is LogicalTypes.Decimal -> l + else -> throw SerializationException("Expected to find a decimal logical type for BigDecimal but found $l") } } \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/BigIntegerSerializer.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/BigIntegerSerializer.kt index f35c010e..15d3ade5 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/BigIntegerSerializer.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/BigIntegerSerializer.kt @@ -2,30 +2,39 @@ package com.github.avrokotlin.avro4k.serializer import com.github.avrokotlin.avro4k.decoder.ExtendedDecoder import com.github.avrokotlin.avro4k.encoder.ExtendedEncoder -import kotlinx.serialization.ExperimentalSerializationApi -import kotlinx.serialization.InternalSerializationApi -import kotlinx.serialization.Serializer import kotlinx.serialization.descriptors.PrimitiveKind -import kotlinx.serialization.descriptors.buildSerialDescriptor +import kotlinx.serialization.descriptors.PrimitiveSerialDescriptor import org.apache.avro.Schema import java.math.BigInteger -@OptIn(ExperimentalSerializationApi::class) -@Serializer(forClass = BigInteger::class) -class BigIntegerSerializer : AvroSerializer() { - @OptIn(InternalSerializationApi::class) - override val descriptor = buildSerialDescriptor(BigInteger::class.qualifiedName!!, PrimitiveKind.STRING) +object BigIntegerSerializer : AvroSerializer() { + override val descriptor = PrimitiveSerialDescriptor(BigInteger::class.qualifiedName!!, PrimitiveKind.STRING) override fun encodeAvroValue( schema: Schema, encoder: ExtendedEncoder, obj: BigInteger, - ) = encoder.encodeString(obj.toString()) + ) = when (schema.type) { + Schema.Type.STRING -> encoder.encodeString(obj.toString()) + Schema.Type.INT -> encoder.encodeInt(obj.intValueExact()) + Schema.Type.LONG -> encoder.encodeLong(obj.longValueExact()) + Schema.Type.FLOAT -> encoder.encodeFloat(obj.toFloat()) + Schema.Type.DOUBLE -> encoder.encodeDouble(obj.toDouble()) + + else -> throw UnsupportedOperationException("Unsupported schema type: $schema") + } override fun decodeAvroValue( schema: Schema, decoder: ExtendedDecoder, - ): BigInteger { - return BigInteger(decoder.decodeString()) - } + ): BigInteger = + when (schema.type) { + Schema.Type.STRING -> BigInteger(decoder.decodeString()) + Schema.Type.INT -> BigInteger.valueOf(decoder.decodeInt().toLong()) + Schema.Type.LONG -> BigInteger.valueOf(decoder.decodeLong()) + Schema.Type.FLOAT -> BigInteger.valueOf(decoder.decodeFloat().toLong()) + Schema.Type.DOUBLE -> BigInteger.valueOf(decoder.decodeDouble().toLong()) + + else -> throw UnsupportedOperationException("Unsupported schema type for BigInteger: $schema") + } } \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/URLSerializer.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/URLSerializer.kt index 29b28e41..d9941623 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/URLSerializer.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/URLSerializer.kt @@ -2,21 +2,14 @@ package com.github.avrokotlin.avro4k.serializer import com.github.avrokotlin.avro4k.decoder.ExtendedDecoder import com.github.avrokotlin.avro4k.encoder.ExtendedEncoder -import kotlinx.serialization.ExperimentalSerializationApi -import kotlinx.serialization.InternalSerializationApi import kotlinx.serialization.SerializationException -import kotlinx.serialization.Serializer import kotlinx.serialization.descriptors.PrimitiveKind -import kotlinx.serialization.descriptors.buildSerialDescriptor +import kotlinx.serialization.descriptors.PrimitiveSerialDescriptor import org.apache.avro.Schema -import org.apache.avro.util.Utf8 import java.net.URL -@OptIn(ExperimentalSerializationApi::class) -@Serializer(forClass = URL::class) -class URLSerializer : AvroSerializer() { - @OptIn(InternalSerializationApi::class) - override val descriptor = buildSerialDescriptor(URL::class.qualifiedName!!, PrimitiveKind.STRING) +object URLSerializer : AvroSerializer() { + override val descriptor = PrimitiveSerialDescriptor(URL::class.qualifiedName!!, PrimitiveKind.STRING) override fun encodeAvroValue( schema: Schema, @@ -31,8 +24,7 @@ class URLSerializer : AvroSerializer() { decoder: ExtendedDecoder, ): URL { return when (val v = decoder.decodeAny()) { - is Utf8 -> URL(v.toString()) - is String -> URL(v) + is CharSequence -> URL(v.toString()) null -> throw SerializationException("Cannot decode as URL") else -> throw SerializationException("Unsupported URL type [$v : ${v::class.qualifiedName}]") } diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/UUIDSerializer.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/UUIDSerializer.kt index 6c822f8c..0dd3bb9c 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/UUIDSerializer.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/UUIDSerializer.kt @@ -1,27 +1,29 @@ package com.github.avrokotlin.avro4k.serializer -import com.github.avrokotlin.avro4k.AvroUuidLogicalType +import com.github.avrokotlin.avro4k.AnnotatedLocation +import com.github.avrokotlin.avro4k.AvroLogicalType +import com.github.avrokotlin.avro4k.AvroLogicalTypeSupplier import com.github.avrokotlin.avro4k.decoder.ExtendedDecoder import com.github.avrokotlin.avro4k.encoder.ExtendedEncoder -import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.InternalSerializationApi -import kotlinx.serialization.Serializer import kotlinx.serialization.descriptors.PrimitiveKind import kotlinx.serialization.descriptors.buildSerialDescriptor +import org.apache.avro.LogicalType +import org.apache.avro.LogicalTypes import org.apache.avro.Schema import java.util.UUID -@OptIn(ExperimentalSerializationApi::class) -@Serializer(forClass = UUID::class) -class UUIDSerializer : AvroSerializer() { - private val avroUuidLogicalTypeAnnotation = AvroUuidLogicalType() - +object UUIDSerializer : AvroSerializer(), AvroLogicalTypeSupplier { @OptIn(InternalSerializationApi::class) override val descriptor = buildSerialDescriptor("uuid", PrimitiveKind.STRING) { - annotations = listOf(avroUuidLogicalTypeAnnotation) + annotations = listOf(AvroLogicalType(UUIDSerializer::class)) } + override fun getLogicalType(inlinedStack: List): LogicalType { + return LogicalTypes.uuid() + } + override fun encodeAvroValue( schema: Schema, encoder: ExtendedEncoder, diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/date.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/date.kt index 57188621..c7348496 100644 --- a/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/date.kt +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/date.kt @@ -1,16 +1,16 @@ -@file:OptIn(ExperimentalSerializationApi::class) - package com.github.avrokotlin.avro4k.serializer -import com.github.avrokotlin.avro4k.AvroTimeLogicalType -import com.github.avrokotlin.avro4k.LogicalTimeTypeEnum +import com.github.avrokotlin.avro4k.AnnotatedLocation +import com.github.avrokotlin.avro4k.AvroLogicalType +import com.github.avrokotlin.avro4k.AvroLogicalTypeSupplier import com.github.avrokotlin.avro4k.decoder.ExtendedDecoder import com.github.avrokotlin.avro4k.encoder.ExtendedEncoder import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.InternalSerializationApi import kotlinx.serialization.SerializationException -import kotlinx.serialization.Serializer +import kotlinx.serialization.descriptors.PrimitiveKind import kotlinx.serialization.descriptors.buildSerialDescriptor +import org.apache.avro.LogicalType import org.apache.avro.LogicalTypes import org.apache.avro.Schema import java.sql.Timestamp @@ -22,17 +22,10 @@ import java.time.ZoneOffset import java.time.temporal.ChronoUnit import kotlin.reflect.KClass -@OptIn(InternalSerializationApi::class) -private fun buildTimeSerialDescriptor( - clazz: KClass<*>, - type: LogicalTimeTypeEnum, -) = buildSerialDescriptor(clazz.qualifiedName!!, type.kind) { - annotations = listOf(AvroTimeLogicalType(type)) -} - -@Serializer(forClass = LocalDate::class) -class LocalDateSerializer : AvroSerializer() { - override val descriptor = buildTimeSerialDescriptor(LocalDate::class, LogicalTimeTypeEnum.DATE) +object LocalDateSerializer : AvroTimeSerializer(LocalDate::class, PrimitiveKind.INT) { + override fun getLogicalType(inlinedStack: List): LogicalType { + return LogicalTypes.date() + } override fun encodeAvroValue( schema: Schema, @@ -43,12 +36,13 @@ class LocalDateSerializer : AvroSerializer() { override fun decodeAvroValue( schema: Schema, decoder: ExtendedDecoder, - ): LocalDate = LocalDate.ofEpochDay(decoder.decodeLong()) + ): LocalDate = LocalDate.ofEpochDay(decoder.decodeInt().toLong()) } -@Serializer(forClass = LocalTime::class) -class LocalTimeSerializer : AvroSerializer() { - override val descriptor = buildTimeSerialDescriptor(LocalTime::class, LogicalTimeTypeEnum.TIME_MILLIS) +object LocalTimeSerializer : AvroTimeSerializer(LocalTime::class, PrimitiveKind.INT) { + override fun getLogicalType(inlinedStack: List): LogicalType { + return LogicalTypes.timeMillis() + } override fun encodeAvroValue( schema: Schema, @@ -69,15 +63,16 @@ class LocalTimeSerializer : AvroSerializer() { } } -@Serializer(forClass = LocalDateTime::class) -class LocalDateTimeSerializer : AvroSerializer() { - override val descriptor = buildTimeSerialDescriptor(LocalDateTime::class, LogicalTimeTypeEnum.TIMESTAMP_MILLIS) +object LocalDateTimeSerializer : AvroTimeSerializer(LocalDateTime::class, PrimitiveKind.LONG) { + override fun getLogicalType(inlinedStack: List): LogicalType { + return LogicalTypes.timestampMillis() + } override fun encodeAvroValue( schema: Schema, encoder: ExtendedEncoder, obj: LocalDateTime, - ) = InstantSerializer().encodeAvroValue(schema, encoder, obj.toInstant(ZoneOffset.UTC)) + ) = InstantSerializer.encodeAvroValue(schema, encoder, obj.toInstant(ZoneOffset.UTC)) override fun decodeAvroValue( schema: Schema, @@ -85,15 +80,16 @@ class LocalDateTimeSerializer : AvroSerializer() { ): LocalDateTime = LocalDateTime.ofInstant(Instant.ofEpochMilli(decoder.decodeLong()), ZoneOffset.UTC) } -@Serializer(forClass = Timestamp::class) -class TimestampSerializer : AvroSerializer() { - override val descriptor = buildTimeSerialDescriptor(Timestamp::class, LogicalTimeTypeEnum.TIMESTAMP_MILLIS) +object TimestampSerializer : AvroTimeSerializer(Timestamp::class, PrimitiveKind.LONG) { + override fun getLogicalType(inlinedStack: List): LogicalType { + return LogicalTypes.timestampMillis() + } override fun encodeAvroValue( schema: Schema, encoder: ExtendedEncoder, obj: Timestamp, - ) = InstantSerializer().encodeAvroValue(schema, encoder, obj.toInstant()) + ) = InstantSerializer.encodeAvroValue(schema, encoder, obj.toInstant()) override fun decodeAvroValue( schema: Schema, @@ -101,9 +97,10 @@ class TimestampSerializer : AvroSerializer() { ): Timestamp = Timestamp(decoder.decodeLong()) } -@Serializer(forClass = Instant::class) -class InstantSerializer : AvroSerializer() { - override val descriptor = buildTimeSerialDescriptor(Instant::class, LogicalTimeTypeEnum.TIMESTAMP_MILLIS) +object InstantSerializer : AvroTimeSerializer(Instant::class, PrimitiveKind.LONG) { + override fun getLogicalType(inlinedStack: List): LogicalType { + return LogicalTypes.timestampMillis() + } override fun encodeAvroValue( schema: Schema, @@ -117,9 +114,10 @@ class InstantSerializer : AvroSerializer() { ): Instant = Instant.ofEpochMilli(decoder.decodeLong()) } -@Serializer(forClass = Instant::class) -class InstantToMicroSerializer : AvroSerializer() { - override val descriptor = buildTimeSerialDescriptor(Instant::class, LogicalTimeTypeEnum.TIMESTAMP_MICROS) +object InstantToMicroSerializer : AvroTimeSerializer(Instant::class, PrimitiveKind.LONG) { + override fun getLogicalType(inlinedStack: List): LogicalType { + return LogicalTypes.timestampMicros() + } override fun encodeAvroValue( schema: Schema, @@ -131,4 +129,15 @@ class InstantToMicroSerializer : AvroSerializer() { schema: Schema, decoder: ExtendedDecoder, ): Instant = Instant.EPOCH.plus(decoder.decodeLong(), ChronoUnit.MICROS) +} + +@OptIn(InternalSerializationApi::class, ExperimentalSerializationApi::class) +abstract class AvroTimeSerializer( + klass: KClass, + kind: PrimitiveKind, +) : AvroSerializer(), AvroLogicalTypeSupplier { + override val descriptor = + buildSerialDescriptor(klass.qualifiedName!!, kind) { + annotations = listOf(AvroLogicalType(this@AvroTimeSerializer::class)) + } } \ No newline at end of file diff --git a/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/helpers.kt b/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/helpers.kt new file mode 100644 index 00000000..982db7af --- /dev/null +++ b/src/main/kotlin/com/github/avrokotlin/avro4k/serializer/helpers.kt @@ -0,0 +1,20 @@ +package com.github.avrokotlin.avro4k.serializer + +import kotlinx.serialization.ExperimentalSerializationApi +import kotlinx.serialization.InternalSerializationApi +import kotlinx.serialization.descriptors.PrimitiveKind +import kotlinx.serialization.descriptors.StructureKind +import kotlinx.serialization.descriptors.buildSerialDescriptor + +@OptIn(InternalSerializationApi::class, ExperimentalSerializationApi::class) +fun buildByteArraySerialDescriptor( + serialName: String, + vararg annotations: Annotation, +) = buildSerialDescriptor(serialName, StructureKind.LIST) { + element("item", buildSerialDescriptor("item", PrimitiveKind.BYTE)) + this.annotations = listOf(*annotations) +} + +fun Long.toIntExact(): Int { + return Math.toIntExact(this) +} \ No newline at end of file diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/AvroSerializationAssertThat.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/AvroSerializationAssertThat.kt index 6f0f78dc..911b9ab5 100644 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/AvroSerializationAssertThat.kt +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/AvroSerializationAssertThat.kt @@ -24,7 +24,7 @@ import java.nio.file.Path class AvroSerializationAssertThat(private val valueToEncode: T, private val serializer: KSerializer) { private var serializersModule: SerializersModule = Avro.default.serializersModule - private var config: AvroInternalConfiguration = Avro.default.configuration + private var config: AvroConfiguration = Avro.default.configuration private var readerSchema: Schema = avro.schema(serializer) private lateinit var writerSchema: Schema private val avro: Avro @@ -36,7 +36,7 @@ class AvroSerializationAssertThat(private val valueToEncode: T, private val s } fun withConfig(config: AvroConfiguration): AvroSerializationAssertThat { - this.config = AvroInternalConfiguration(config) + this.config = config return this } diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/decoder/AvroDefaultValuesDecoderTest.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/decoder/AvroDefaultValuesDecoderTest.kt index 9bc4eeab..30bf7128 100644 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/decoder/AvroDefaultValuesDecoderTest.kt +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/decoder/AvroDefaultValuesDecoderTest.kt @@ -1,7 +1,7 @@ package com.github.avrokotlin.avro4k.decoder import com.github.avrokotlin.avro4k.Avro -import com.github.avrokotlin.avro4k.AvroDecimalLogicalType +import com.github.avrokotlin.avro4k.AvroDecimal import com.github.avrokotlin.avro4k.AvroDefault import com.github.avrokotlin.avro4k.AvroEnumDefault import com.github.avrokotlin.avro4k.io.AvroDecodeFormat @@ -49,7 +49,7 @@ data class ContainerWithDefaultFields( @AvroDefault("""[{"content":"bar"}]""") val filledFooList: List, @AvroDefault("\u0000") - @AvroDecimalLogicalType(0, 10) + @AvroDecimal(0, 10) @Serializable(BigDecimalSerializer::class) val bigDecimal: BigDecimal, ) diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/endecode/MapEncoderTest.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/endecode/MapEncoderTest.kt index 3ffd9f35..9402c5aa 100644 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/endecode/MapEncoderTest.kt +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/endecode/MapEncoderTest.kt @@ -6,7 +6,6 @@ import io.kotest.core.factory.TestFactory import io.kotest.core.spec.style.StringSpec import io.kotest.core.spec.style.stringSpec import kotlinx.serialization.Contextual -import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.InternalSerializationApi import kotlinx.serialization.KSerializer import kotlinx.serialization.SerialName @@ -138,7 +137,6 @@ private enum class MyEnum { data class NonSerializableKey(val value: String) -@OptIn(ExperimentalSerializationApi::class) class NonSerializableKeyKSerializer : KSerializer { @OptIn(InternalSerializationApi::class) override val descriptor = buildSerialDescriptor("NonSerializableKey", PrimitiveKind.STRING) diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/io/AvroDataOutputStreamCodecTest.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/io/AvroDataOutputStreamCodecTest.kt index 482de4ed..678d4208 100644 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/io/AvroDataOutputStreamCodecTest.kt +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/io/AvroDataOutputStreamCodecTest.kt @@ -22,17 +22,6 @@ class AvroDataOutputStreamCodecTest : StringSpec({ String(baos.toByteArray()) should contain("compositions") } - "include snappy coded in metadata when serialized with snappy" { - - val baos = ByteArrayOutputStream() - Avro.default.openOutputStream(Composer.serializer()) { - encodeFormat = AvroEncodeFormat.Data(CodecFactory.snappyCodec()) - }.to(baos).write(ennio).close() - String(baos.toByteArray()) should contain("snappy") - String(baos.toByteArray()) shouldNot contain("bzip2") - String(baos.toByteArray()) shouldNot contain("deflate") - } - "include deflate coded in metadata when serialized with deflate" { val baos = ByteArrayOutputStream() Avro.default.openOutputStream(Composer.serializer()) { diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/schema/AvroAliasSchemaTest.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/schema/AvroAliasSchemaTest.kt index 2e1b291a..eb3f993c 100644 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/schema/AvroAliasSchemaTest.kt +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/schema/AvroAliasSchemaTest.kt @@ -47,7 +47,13 @@ class AvroAliasSchemaTest : WordSpec({ data class FieldAnnotated( @AvroAlias("cold") val str: String, @AvroAlias("kate") val long: Long, - val int: Int, + val int: IntValue, + ) + + @Serializable + @JvmInline + value class IntValue( + @AvroAlias("ignoredAlias") val value: Int, ) @Serializable diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/schema/AvroDefaultSchemaTest.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/schema/AvroDefaultSchemaTest.kt index b0999bbb..f6e55158 100644 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/schema/AvroDefaultSchemaTest.kt +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/schema/AvroDefaultSchemaTest.kt @@ -12,7 +12,6 @@ import kotlinx.serialization.Serializable import org.apache.avro.AvroTypeException import java.math.BigDecimal -@Suppress("BlockingMethodInNonBlockingContext") class AvroDefaultSchemaTest : FunSpec() { init { test("schema for data class with @AvroDefault should include default value as a string") { diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/schema/BigDecimalSchemaTest.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/schema/BigDecimalSchemaTest.kt index a9485852..15333bb9 100644 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/schema/BigDecimalSchemaTest.kt +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/schema/BigDecimalSchemaTest.kt @@ -3,7 +3,7 @@ package com.github.avrokotlin.avro4k.schema import com.github.avrokotlin.avro4k.Avro -import com.github.avrokotlin.avro4k.AvroDecimalLogicalType +import com.github.avrokotlin.avro4k.AvroDecimal import com.github.avrokotlin.avro4k.serializer.BigDecimalSerializer import io.kotest.core.spec.style.FunSpec import io.kotest.matchers.shouldBe @@ -37,7 +37,7 @@ class BigDecimalSchemaTest : FunSpec({ @Serializable data class BigDecimalPrecisionTest( - @AvroDecimalLogicalType(1, 4) val decimal: BigDecimal, + @AvroDecimal(1, 4) val decimal: BigDecimal, ) @Serializable diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/schema/ContextualSchemaTest.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/schema/ContextualSchemaTest.kt index 5699331a..70da6325 100644 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/schema/ContextualSchemaTest.kt +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/schema/ContextualSchemaTest.kt @@ -11,7 +11,6 @@ import kotlinx.serialization.Serializable import kotlinx.serialization.modules.SerializersModule import kotlinx.serialization.modules.contextual import org.apache.avro.Schema -import java.lang.IllegalArgumentException import java.time.Instant class ContextualSchemaTest : StringSpec({ @@ -22,7 +21,7 @@ class ContextualSchemaTest : StringSpec({ Avro( serializersModule = SerializersModule { - contextual(InstantSerializer()) + contextual(InstantSerializer) } ) @@ -30,7 +29,7 @@ class ContextualSchemaTest : StringSpec({ Avro( serializersModule = SerializersModule { - contextual(InstantToMicroSerializer()) + contextual(InstantToMicroSerializer) } ) diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/schema/EnumSchemaTest.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/schema/EnumSchemaTest.kt index 0f59ef16..d41c2cd0 100644 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/schema/EnumSchemaTest.kt +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/schema/EnumSchemaTest.kt @@ -9,6 +9,7 @@ import io.kotest.assertions.throwables.shouldThrow import io.kotest.core.spec.style.WordSpec import io.kotest.matchers.shouldBe import kotlinx.serialization.Serializable +import org.apache.avro.SchemaParseException class EnumSchemaTest : WordSpec({ @@ -57,7 +58,7 @@ class EnumSchemaTest : WordSpec({ schemaWithNewNameSpace.toString(true) shouldBe expected.toString(true) } "fail with unknown values" { - shouldThrow { + shouldThrow { Avro.default.schema(EnumWithUnknownDefaultTest.serializer()) } } diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/schema/PolymorphicClassSchemaTest.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/schema/PolymorphicClassSchemaTest.kt index 2873f906..495e67c4 100644 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/schema/PolymorphicClassSchemaTest.kt +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/schema/PolymorphicClassSchemaTest.kt @@ -46,19 +46,6 @@ val polymorphicModule = } class PolymorphicClassSchemaTest : StringSpec({ - "schema for polymorphic hierarchy" { - val module = - SerializersModule { - polymorphic(UnsealedPolymorphicRoot::class) { - subclass(UnsealedChildOne::class) - subclass(SealedChildTwo::class) - } - } - val schema = Avro(serializersModule = module).schema(UnsealedPolymorphicRoot.serializer()) - val expected = Schema.Parser().parse(javaClass.getResourceAsStream("/polymorphic.json")) - schema shouldBe expected - } - "supports polymorphic references / nested fields" { val schema = Avro(serializersModule = polymorphicModule).schema(ReferencingPolymorphicRoot.serializer()) val expected = Schema.Parser().parse(javaClass.getResourceAsStream("/polymorphic_reference.json")) diff --git a/src/test/kotlin/com/github/avrokotlin/avro4k/schema/ValueClassSchemaTest.kt b/src/test/kotlin/com/github/avrokotlin/avro4k/schema/ValueClassSchemaTest.kt index 3a636cc0..b03c4e4b 100644 --- a/src/test/kotlin/com/github/avrokotlin/avro4k/schema/ValueClassSchemaTest.kt +++ b/src/test/kotlin/com/github/avrokotlin/avro4k/schema/ValueClassSchemaTest.kt @@ -1,12 +1,10 @@ -@file:UseContextualSerialization(forClasses = [UUID::class]) - package com.github.avrokotlin.avro4k.schema import com.github.avrokotlin.avro4k.Avro import io.kotest.core.spec.style.StringSpec import io.kotest.matchers.shouldBe +import kotlinx.serialization.Contextual import kotlinx.serialization.Serializable -import kotlinx.serialization.UseContextualSerialization import java.util.UUID class ValueClassSchemaTest : StringSpec({ @@ -24,7 +22,7 @@ class ValueClassSchemaTest : StringSpec({ @Serializable @JvmInline - value class UuidWrapper(val uuid: UUID) + value class UuidWrapper(val uuid: @Contextual UUID) @Serializable data class ContainsInlineTest(val id: StringWrapper, val uuid: UuidWrapper)