Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Only handle ByteArrays as bytes or fixed, and collection of Byte as arrays of int #234

Merged
merged 1 commit into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@ import kotlinx.serialization.SerializationException
import kotlinx.serialization.SerializationStrategy
import kotlinx.serialization.builtins.ByteArraySerializer
import kotlinx.serialization.builtins.serializer
import kotlinx.serialization.descriptors.PrimitiveKind
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.descriptors.StructureKind
import kotlinx.serialization.encoding.Decoder
import kotlinx.serialization.encoding.Encoder
import kotlinx.serialization.internal.AbstractCollectionSerializer
Expand Down Expand Up @@ -56,14 +54,12 @@ internal object SerializerLocatorMiddleware {

fun apply(descriptor: SerialDescriptor): SerialDescriptor {
return when {
descriptor.isCollectionOfBytes() -> SerialDescriptorWithAvroSchemaDelegate(descriptor, AvroByteArraySerializer)
descriptor == String.serializer().descriptor -> AvroStringSerialDescriptor
descriptor == Duration.serializer().descriptor -> KotlinDurationSerializer.descriptor
descriptor === ByteArraySerializer().descriptor -> AvroByteArraySerializer.descriptor
descriptor === String.serializer().descriptor -> AvroStringSerialDescriptor
descriptor === Duration.serializer().descriptor -> KotlinDurationSerializer.descriptor
else -> descriptor
}
}

private fun SerialDescriptor.isCollectionOfBytes() = kind === StructureKind.LIST && elementsCount == 1 && getElementDescriptor(0).kind === PrimitiveKind.BYTE
}

private val AvroStringSerialDescriptor: SerialDescriptor =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,38 +61,18 @@ internal abstract class AbstractAvroDirectDecoder(
override fun beginStructure(descriptor: SerialDescriptor): CompositeDecoder {
return when (descriptor.kind) {
StructureKind.LIST ->
decodeResolvingAny({
UnexpectedDecodeSchemaError(
descriptor.nonNullSerialName,
Schema.Type.ARRAY,
Schema.Type.BYTES,
Schema.Type.FIXED
)
}) {
decodeResolvingAny({ UnexpectedDecodeSchemaError(descriptor.nonNullSerialName, Schema.Type.ARRAY) }) {
when (it.type) {
Schema.Type.ARRAY -> {
AnyValueDecoder { ArrayBlockDirectDecoder(it, decodeFirstBlock = decodedCollectionSize == -1, { decodedCollectionSize = it }, avro, binaryDecoder) }
}

Schema.Type.BYTES -> {
AnyValueDecoder { BytesDirectDecoder(avro, binaryDecoder) }
}

Schema.Type.FIXED -> {
AnyValueDecoder { FixedDirectDecoder(avro, it.fixedSize, binaryDecoder) }
}

else -> null
}
}

StructureKind.MAP ->
decodeResolvingAny({
UnexpectedDecodeSchemaError(
descriptor.nonNullSerialName,
Schema.Type.MAP
)
}) {
decodeResolvingAny({ UnexpectedDecodeSchemaError(descriptor.nonNullSerialName, Schema.Type.MAP) }) {
when (it.type) {
Schema.Type.MAP -> {
AnyValueDecoder { MapBlockDirectDecoder(it, decodeFirstBlock = decodedCollectionSize == -1, { decodedCollectionSize = it }, avro, binaryDecoder) }
Expand All @@ -103,12 +83,7 @@ internal abstract class AbstractAvroDirectDecoder(
}

StructureKind.CLASS, StructureKind.OBJECT ->
decodeResolvingAny({
UnexpectedDecodeSchemaError(
descriptor.nonNullSerialName,
Schema.Type.RECORD
)
}) {
decodeResolvingAny({ UnexpectedDecodeSchemaError(descriptor.nonNullSerialName, Schema.Type.RECORD) }) {
when (it.type) {
Schema.Type.RECORD -> {
AnyValueDecoder { RecordDirectDecoder(it, descriptor, avro, binaryDecoder) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,60 +3,8 @@ package com.github.avrokotlin.avro4k.internal.decoder.direct
import com.github.avrokotlin.avro4k.Avro
import com.github.avrokotlin.avro4k.internal.IllegalIndexedAccessError
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.encoding.AbstractDecoder
import kotlinx.serialization.modules.SerializersModule
import org.apache.avro.Schema

internal class BytesDirectDecoder(
private val avro: Avro,
binaryDecoder: org.apache.avro.io.Decoder,
) : AbstractDecoder() {
override val serializersModule: SerializersModule
get() = avro.serializersModule

private val bytes = binaryDecoder.readBytes(null)

override fun decodeByte(): Byte {
return bytes.get()
}

override fun decodeCollectionSize(descriptor: SerialDescriptor): Int {
return bytes.remaining()
}

override fun decodeSequentially() = true

override fun decodeElementIndex(descriptor: SerialDescriptor): Int {
throw IllegalIndexedAccessError()
}
}

internal class FixedDirectDecoder(
private val avro: Avro,
fixedSize: Int,
binaryDecoder: org.apache.avro.io.Decoder,
) : AbstractDecoder() {
override val serializersModule: SerializersModule
get() = avro.serializersModule

private val bytes = ByteArray(fixedSize).also { binaryDecoder.readFixed(it) }
private var nextPosition = 0

override fun decodeByte(): Byte {
return bytes[nextPosition++]
}

override fun decodeCollectionSize(descriptor: SerialDescriptor): Int {
return bytes.size
}

override fun decodeSequentially() = true

override fun decodeElementIndex(descriptor: SerialDescriptor): Int {
throw IllegalIndexedAccessError()
}
}

internal class ArrayBlockDirectDecoder(
private val arraySchema: Schema,
private val decodeFirstBlock: Boolean,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ import com.github.avrokotlin.avro4k.Avro
import com.github.avrokotlin.avro4k.internal.DecodedNullError
import com.github.avrokotlin.avro4k.internal.IllegalIndexedAccessError
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.encoding.AbstractDecoder
import kotlinx.serialization.modules.SerializersModule
import org.apache.avro.Schema

internal class MapGenericDecoder(
Expand Down Expand Up @@ -90,26 +88,4 @@ internal class ArrayGenericDecoder(
override fun decodeCollectionSize(descriptor: SerialDescriptor) = collection.size

override fun decodeSequentially() = true
}

internal class ByteArrayGenericDecoder(
private val avro: Avro,
private val bytes: ByteArray,
) : AbstractDecoder() {
override val serializersModule: SerializersModule
get() = avro.serializersModule

private val iterator = bytes.iterator()

override fun decodeByte() = iterator.nextByte()

override fun decodeCollectionSize(descriptor: SerialDescriptor): Int {
return bytes.size
}

override fun decodeSequentially() = true

override fun decodeElementIndex(descriptor: SerialDescriptor): Int {
throw IllegalIndexedAccessError()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,30 +74,18 @@ internal sealed class AbstractAvroDirectEncoder(
): CompositeEncoder {
return when (descriptor.kind) {
StructureKind.LIST ->
encodeResolving(
{ BadEncodedValueError(emptyList<Any?>(), currentWriterSchema, Schema.Type.ARRAY, Schema.Type.BYTES, Schema.Type.FIXED) }
) { schema ->
encodeResolving({ BadEncodedValueError(emptyList<Any?>(), currentWriterSchema, Schema.Type.ARRAY) }) { schema ->
when (schema.type) {
Schema.Type.ARRAY -> {
{ ArrayDirectEncoder(schema, collectionSize, avro, binaryEncoder) }
}

Schema.Type.BYTES -> {
{ BytesDirectEncoder(avro, binaryEncoder, collectionSize) }
}

Schema.Type.FIXED -> {
{ FixedDirectEncoder(schema, collectionSize, avro, binaryEncoder) }
}

else -> null
}
}

StructureKind.MAP ->
encodeResolving(
{ BadEncodedValueError(emptyMap<String, Any?>(), currentWriterSchema, Schema.Type.MAP) }
) { schema ->
encodeResolving({ BadEncodedValueError(emptyMap<String, Any?>(), currentWriterSchema, Schema.Type.MAP) }) { schema ->
when (schema.type) {
Schema.Type.MAP -> {
{ MapDirectEncoder(schema, collectionSize, avro, binaryEncoder) }
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
package com.github.avrokotlin.avro4k.internal.encoder.direct

import com.github.avrokotlin.avro4k.Avro
import kotlinx.serialization.SerializationException
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.encoding.AbstractEncoder
import kotlinx.serialization.modules.SerializersModule
import org.apache.avro.Schema

internal class MapDirectEncoder(private val schema: Schema, mapSize: Int, avro: Avro, binaryEncoder: org.apache.avro.io.Encoder) :
Expand Down Expand Up @@ -69,42 +66,4 @@ internal class ArrayDirectEncoder(
override fun endStructure(descriptor: SerialDescriptor) {
binaryEncoder.writeArrayEnd()
}
}

internal class FixedDirectEncoder(schema: Schema, arraySize: Int, private val avro: Avro, private val binaryEncoder: org.apache.avro.io.Encoder) : AbstractEncoder() {
private val buffer = ByteArray(schema.fixedSize)
private var pos = schema.fixedSize - arraySize

override val serializersModule: SerializersModule
get() = avro.serializersModule

init {
if (arraySize > schema.fixedSize) {
throw SerializationException("Actual collection size $arraySize is greater than schema fixed size $schema")
}
}

override fun encodeByte(value: Byte) {
buffer[pos++] = value
}

override fun endStructure(descriptor: SerialDescriptor) {
binaryEncoder.writeFixed(buffer)
}
}

internal class BytesDirectEncoder(private val avro: Avro, private val binaryEncoder: org.apache.avro.io.Encoder, collectionSize: Int) : AbstractEncoder() {
private val buffer = ByteArray(collectionSize)
private var pos = 0

override val serializersModule: SerializersModule
get() = avro.serializersModule

override fun encodeByte(value: Byte) {
buffer[pos++] = value
}

override fun endStructure(descriptor: SerialDescriptor) {
binaryEncoder.writeBytes(buffer)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,8 @@ internal class ValueVisitor internal constructor(
val finalDescriptor = SerializerLocatorMiddleware.apply(unwrapNullable(descriptor))

(finalDescriptor.nonNullOriginal as? AvroSchemaSupplier)
?.getSchema(context)
?.let {
setSchema(it)
return
}
super.visitValue(finalDescriptor)
?.getSchema(context)?.let { setSchema(it) }
?: super.visitValue(finalDescriptor)
}

private fun unwrapNullable(descriptor: SerialDescriptor): SerialDescriptor {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ internal class AvroEncodingAssertions<T>(
return this
}

fun generatesSchema(
expectedSchemaResourcePath: Path,
schemaTransformer: (Schema) -> Schema = { it },
): AvroEncodingAssertions<T> {
generatesSchema(Schema.Parser().parse(javaClass.getResourceAsStream(expectedSchemaResourcePath.toString())).let(schemaTransformer))
return this
}

fun isEncodedAs(
expectedEncodedGenericValue: Any?,
expectedDecodedValue: T = valueToEncode,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ internal class AvroFixedEncodingTest : StringSpec({
.isEncodedAs(record(GenericData.Fixed(schema, "1234567".toByteArray())))
}

"support fixed on value classes" {
"support fixed on string value classes" {
AvroAssertions.assertThat<FixedNestedStringField>()
.generatesSchema(Path("/fixed_string.json"))

Expand All @@ -36,6 +36,12 @@ internal class AvroFixedEncodingTest : StringSpec({
.isEncodedAs(GenericData.Fixed(Avro.schema<FixedStringValueClass>(), "1234567".toByteArray()))
}

"support @AvroFixed on ByteArray" {
AvroAssertions.assertThat(FixedByteArrayField("1234567".toByteArray()))
.generatesSchema(Path("/fixed_string.json"))
.isEncodedAs(record(GenericData.Fixed(Avro.schema<FixedByteArrayField>().fields[0].schema(), "1234567".toByteArray())))
}

"top-est @AvroFixed annotation takes precedence over nested @AvroFixed annotations" {
AvroAssertions.assertThat<FieldPriorToValueClass>()
.generatesSchema(Path("/fixed_string_5.json"))
Expand Down Expand Up @@ -83,6 +89,25 @@ internal class AvroFixedEncodingTest : StringSpec({
@AvroFixed(7) val mystring: String,
)

@Serializable
@SerialName("Fixed")
private data class FixedByteArrayField(
@AvroFixed(7) val mystring: ByteArray,
) {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (javaClass != other?.javaClass) return false

other as FixedByteArrayField

return mystring.contentEquals(other.mystring)
}

override fun hashCode(): Int {
return mystring.contentHashCode()
}
}

@Serializable
@SerialName("Fixed")
private data class FixedNestedStringField(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import kotlinx.serialization.Serializable
import org.apache.avro.Schema

internal class BytesEncodingTest : StringSpec({
"encode/decode nullable ByteArray" {
"encode/decode nullable ByteArray to BYTES" {
AvroAssertions.assertThat(NullableByteArrayTest(byteArrayOf(1, 4, 9)))
.isEncodedAs(record(byteArrayOf(1, 4, 9)))
AvroAssertions.assertThat(NullableByteArrayTest(null))
Expand All @@ -22,7 +22,7 @@ internal class BytesEncodingTest : StringSpec({
.isEncodedAs(null)
}

"encode/decode ByteArray" {
"encode/decode ByteArray to BYTES" {
AvroAssertions.assertThat(ByteArrayTest(byteArrayOf(1, 4, 9)))
.isEncodedAs(record(byteArrayOf(1, 4, 9)))

Expand All @@ -32,24 +32,24 @@ internal class BytesEncodingTest : StringSpec({
.isEncodedAs(byteArrayOf(1, 4, 9))
}

"encode/decode List<Byte>" {
"encode/decode List<Byte> to ARRAY[INT]" {
AvroAssertions.assertThat(ListByteTest(listOf(1, 4, 9)))
.isEncodedAs(record(byteArrayOf(1, 4, 9)))
.isEncodedAs(record(listOf(1, 4, 9)))

AvroAssertions.assertThat<List<Byte>>()
.generatesSchema(Schema.create(Schema.Type.BYTES))
.generatesSchema(Schema.createArray(Schema.create(Schema.Type.INT)))
AvroAssertions.assertThat(listOf<Byte>(1, 4, 9))
.isEncodedAs(byteArrayOf(1, 4, 9))
.isEncodedAs(listOf(1, 4, 9))
}

"encode/decode Array<Byte> to ByteBuffer" {
"encode/decode Array<Byte> to ARRAY[INT]" {
AvroAssertions.assertThat(ArrayByteTest(arrayOf(1, 4, 9)))
.isEncodedAs(record(byteArrayOf(1, 4, 9)))
.isEncodedAs(record(listOf(1, 4, 9)))

AvroAssertions.assertThat<Array<Byte>>()
.generatesSchema(Schema.create(Schema.Type.BYTES))
.generatesSchema(Schema.createArray(Schema.create(Schema.Type.INT)))
AvroAssertions.assertThat(arrayOf<Byte>(1, 4, 9))
.isEncodedAs(byteArrayOf(1, 4, 9))
.isEncodedAs(listOf(1, 4, 9))
}
}) {
@Serializable
Expand Down
Loading