Skip to content

Commit

Permalink
feat: Handle contextual map keys #114
Browse files Browse the repository at this point in the history
  • Loading branch information
Chuckame committed Jan 28, 2024
1 parent 563c20f commit 91dea8e
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 15 deletions.
3 changes: 1 addition & 2 deletions .editorconfig
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ end_of_line = lf
indent_size = 4
indent_style = space
insert_final_newline = false
max_line_length = 120
max_line_length = 180
tab_width = 4
ij_continuation_indent_size = 8
ij_formatter_off_tag = @formatter:off
Expand All @@ -23,7 +23,6 @@ ij_editorconfig_spaces_around_assignment_operators = true

[{*.kt,*.kts}]
ktlint_standard_filename = disabled
max_line_length = 180
ij_kotlin_align_in_columns_case_branch = false
ij_kotlin_align_multiline_binary_operation = false
ij_kotlin_align_multiline_extends_list = false
Expand Down
30 changes: 17 additions & 13 deletions src/main/kotlin/com/github/avrokotlin/avro4k/schema/SchemaFor.kt
Original file line number Diff line number Diff line change
Expand Up @@ -111,20 +111,24 @@ class MapSchemaFor(
private val resolvedSchemas: MutableMap<RecordNaming, Schema>,
) : SchemaFor {
override fun schema(): Schema {
val keyType = descriptor.getElementDescriptor(0).unwrapValueClass
if (keyType.kind !is PrimitiveKind && keyType.kind != SerialKind.ENUM) {
throw RuntimeException("Avro4k only supports primitive and enum kinds as the map key. Actual: ${descriptor.getElementDescriptor(0)}")
val keyType =
descriptor.getElementDescriptor(0).unwrapValueClass.let {
if (it.kind == SerialKind.CONTEXTUAL) serializersModule.getContextualDescriptor(it)?.unwrapValueClass else it
}
if (keyType != null) {
if (keyType.kind is PrimitiveKind || keyType.kind == SerialKind.ENUM) {
val valueSchema =
schemaFor(
serializersModule,
descriptor.getElementDescriptor(1),
descriptor.getElementAnnotations(1),
configuration,
resolvedSchemas
).schema()
return Schema.createMap(valueSchema)
}
}

val valueSchema =
schemaFor(
serializersModule,
descriptor.getElementDescriptor(1),
descriptor.getElementAnnotations(1),
configuration,
resolvedSchemas
).schema()
return Schema.createMap(valueSchema)
throw RuntimeException("Avro4k only supports primitive and enum kinds as the map key. Actual: ${descriptor.getElementDescriptor(0)}")
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
package com.github.avrokotlin.avro4k.endecode

import com.github.avrokotlin.avro4k.Avro
import com.github.avrokotlin.avro4k.record
import io.kotest.core.factory.TestFactory
import io.kotest.core.spec.style.StringSpec
import io.kotest.core.spec.style.stringSpec
import kotlinx.serialization.Contextual
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.InternalSerializationApi
import kotlinx.serialization.KSerializer
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlinx.serialization.descriptors.PrimitiveKind
import kotlinx.serialization.descriptors.buildSerialDescriptor
import kotlinx.serialization.encoding.Decoder
import kotlinx.serialization.encoding.Encoder
import kotlinx.serialization.modules.serializersModuleOf
import java.nio.ByteBuffer

class MapEncoderTest : StringSpec({
Expand Down Expand Up @@ -41,6 +51,16 @@ fun mapEncoderTests(enDecoder: EnDecoder): TestFactory {
)
}

"encode/decode a Map<non serializable key, String>" {
@Serializable
data class StringStringTest(val a: Map<@Contextual NonSerializableKey, String>)
enDecoder.avro = Avro(serializersModule = serializersModuleOf(NonSerializableKey::class, NonSerializableKeyKSerializer()))
enDecoder.testEncodeDecode(
StringStringTest(mapOf(NonSerializableKey("a") to "x", NonSerializableKey("b") to "y", NonSerializableKey("c") to "z")),
record(mapOf("a" to "x", "b" to "y", "c" to "z"))
)
}

"encode/decode a Map<int value class, String>" {
@Serializable
data class StringStringTest(val a: Map<MapIntKey, String>)
Expand Down Expand Up @@ -114,4 +134,21 @@ private enum class MyEnum {

@SerialName("z")
C,
}

data class NonSerializableKey(val value: String)

@OptIn(ExperimentalSerializationApi::class)
class NonSerializableKeyKSerializer : KSerializer<NonSerializableKey> {
@OptIn(InternalSerializationApi::class)
override val descriptor = buildSerialDescriptor("NonSerializableKey", PrimitiveKind.STRING)

override fun deserialize(decoder: Decoder) = NonSerializableKey(decoder.decodeString())

override fun serialize(
encoder: Encoder,
value: NonSerializableKey,
) {
encoder.encodeString(value.value)
}
}

0 comments on commit 91dea8e

Please sign in to comment.