Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Assume kotlin.Pair as a normal data class instead of an union #174

Merged
merged 1 commit into from
Jan 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 10 additions & 37 deletions src/main/kotlin/com/github/avrokotlin/avro4k/schema/SchemaFor.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,14 @@ import com.github.avrokotlin.avro4k.RecordNaming
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.InternalSerializationApi
import kotlinx.serialization.SerializationException
import kotlinx.serialization.descriptors.*
import kotlinx.serialization.descriptors.PolymorphicKind
import kotlinx.serialization.descriptors.PrimitiveKind
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.descriptors.SerialKind
import kotlinx.serialization.descriptors.StructureKind
import kotlinx.serialization.descriptors.capturedKClass
import kotlinx.serialization.descriptors.elementNames
import kotlinx.serialization.descriptors.getContextualDescriptor
import kotlinx.serialization.modules.SerializersModule
import kotlinx.serialization.serializerOrNull
import org.apache.avro.Schema
Expand Down Expand Up @@ -63,36 +70,6 @@ class EnumSchemaFor(
}
}

@ExperimentalSerializationApi
class PairSchemaFor(private val descriptor: SerialDescriptor,
private val configuration: AvroConfiguration,
private val serializersModule: SerializersModule,
private val resolvedSchemas: MutableMap<RecordNaming, Schema>
) : SchemaFor {

override fun schema(): Schema {
val a = schemaFor(
serializersModule,
descriptor.getElementDescriptor(0),
descriptor.getElementAnnotations(0),
configuration,
resolvedSchemas
)
val b = schemaFor(
serializersModule,
descriptor.getElementDescriptor(1),
descriptor.getElementAnnotations(1),
configuration,
resolvedSchemas
)
return SchemaBuilder.unionOf()
.type(a.schema())
.and()
.type(b.schema())
.endUnion()
}
}

@ExperimentalSerializationApi
class ListSchemaFor(private val descriptor: SerialDescriptor,
private val serializersModule: SerializersModule,
Expand Down Expand Up @@ -206,11 +183,7 @@ fun schemaFor(serializersModule: SerializersModule,
resolvedSchemas
)

StructureKind.CLASS, StructureKind.OBJECT -> when (descriptor.serialName) {
"kotlin.Pair" -> PairSchemaFor(descriptor, configuration, serializersModule, resolvedSchemas)
else -> ClassSchemaFor(descriptor, configuration, serializersModule, resolvedSchemas)
}

StructureKind.CLASS, StructureKind.OBJECT -> ClassSchemaFor(descriptor, configuration, serializersModule, resolvedSchemas)
StructureKind.LIST -> ListSchemaFor(descriptor, serializersModule, configuration, resolvedSchemas)
StructureKind.MAP -> MapSchemaFor(descriptor, serializersModule, configuration, resolvedSchemas)
is PolymorphicKind -> UnionSchemaFor(descriptor, configuration, serializersModule, resolvedSchemas)
Expand All @@ -224,4 +197,4 @@ fun schemaFor(serializersModule: SerializersModule,
// copy-paste from kotlinx serialization because it internal
@ExperimentalSerializationApi
internal val SerialDescriptor.unwrapValueClass: SerialDescriptor
get() = if (isInline) getElementDescriptor(0) else this
get() = if (isInline) getElementDescriptor(0) else this
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@ import com.github.avrokotlin.avro4k.schema.Level4
import com.github.avrokotlin.avro4k.schema.RecursiveClass
import com.github.avrokotlin.avro4k.schema.RecursiveListItem
import com.github.avrokotlin.avro4k.schema.RecursiveMapValue
import com.github.avrokotlin.avro4k.schema.RecursivePair
import io.kotest.matchers.shouldBe
import io.kotest.core.spec.style.StringSpec
import io.kotest.matchers.shouldBe
import io.kotest.matchers.types.shouldBeInstanceOf
import org.apache.avro.generic.GenericData
import org.apache.avro.generic.GenericRecord
import org.apache.avro.util.Utf8

Expand Down Expand Up @@ -54,24 +52,6 @@ class RecursiveIoTest : StringSpec({
}
}

"read / write direct recursive pair" {
writeRead(RecursivePair(1, (RecursivePair(2, null) to RecursivePair(3, null))), RecursivePair.serializer())
writeRead(RecursivePair(1, (RecursivePair(2, null) to RecursivePair(3, null))), RecursivePair.serializer()) {
it["payload"] shouldBe 1
it["pair"].shouldBeInstanceOf<GenericData.Record>()
val first = (it["pair"] as GenericData.Record)["first"]
first.shouldBeInstanceOf<GenericRecord>()
first.schema shouldBe Avro.default.schema(RecursivePair.serializer())
first["payload"] shouldBe 2
first["pair"] shouldBe null
val second = (it["pair"] as GenericData.Record)["second"]
second.shouldBeInstanceOf<GenericRecord>()
second.schema shouldBe Avro.default.schema(RecursivePair.serializer())
second["payload"] shouldBe 3
second["pair"] shouldBe null
}
}

"read / write nested recursive classes" {
writeRead(Level1(Level2(Level3(Level4(Level1(null))))), Level1.serializer())
writeRead(Level1(Level2(Level3(Level4(Level1(null))))), Level1.serializer()) {
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package com.github.avrokotlin.avro4k.schema

import com.github.avrokotlin.avro4k.Avro
import io.kotest.matchers.shouldBe
import io.kotest.core.spec.style.FunSpec
import io.kotest.matchers.shouldBe
import kotlinx.serialization.Serializable

@Serializable
Expand All @@ -14,9 +14,6 @@ data class RecursiveListItem(val payload: Int, val list: List<RecursiveListItem>
@Serializable
data class RecursiveMapValue(val payload: Int, val map: Map<String, RecursiveMapValue>?)

@Serializable
data class RecursivePair(val payload: Int, val pair: Pair<RecursivePair, RecursivePair>?)

@Serializable
data class Level4(val level1: Level1)

Expand Down Expand Up @@ -49,12 +46,6 @@ class RecursiveSchemaTest : FunSpec({
schema.toString(true) shouldBe expected.toString(true)
}

test("accept direct recursive pairs") {
val expected = org.apache.avro.Schema.Parser().parse(this::class.java.getResourceAsStream("/recursive_pair.json"))
val schema = Avro.default.schema(RecursivePair.serializer())
schema.toString(true) shouldBe expected.toString(true)
}

test("accept nested recursive classes") {
val expected = org.apache.avro.Schema.Parser().parse(this::class.java.getResourceAsStream("/recursive_nested.json"))
val schema = Avro.default.schema(Level1.serializer())
Expand Down
11 changes: 0 additions & 11 deletions src/test/resources/pair.json

This file was deleted.

14 changes: 0 additions & 14 deletions src/test/resources/pair_records.json

This file was deleted.

23 changes: 0 additions & 23 deletions src/test/resources/recursive_pair.json

This file was deleted.