Skip to content

Commit

Permalink
Merge pull request #503 from microsoft/fix_sink_bug_when_using_struct…
Browse files Browse the repository at this point in the history
…_record_value

Fixed sink bug when using struct type for record value
  • Loading branch information
kushagraThapar authored Mar 24, 2023
2 parents 20227d1 + 4c7a3be commit 37dc369
Show file tree
Hide file tree
Showing 8 changed files with 169 additions and 64 deletions.
6 changes: 0 additions & 6 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,6 @@
<version>4.0.3</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<version>29.0-jre</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>io.confluent</groupId>
<artifactId>kafka-avro-serializer</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,26 @@ public void put(Collection<SinkRecord> records) {
Object recordValue;
if (record.value() instanceof Struct) {
recordValue = StructToJsonMap.toJsonMap((Struct) record.value());
// TODO: Do we need to update the value schema to map or keep it struct?
} else {
recordValue = record.value();
}

maybeInsertId(recordValue, record);
toBeWrittenRecordList.add(record);

// Create an updated record with from the current record and the updated record value
final SinkRecord updatedRecord = new SinkRecord(record.topic(),
record.kafkaPartition(),
record.keySchema(),
record.key(),
record.valueSchema(),
recordValue,
record.kafkaOffset(),
record.timestamp(),
record.timestampType(),
record.headers());

toBeWrittenRecordList.add(updatedRecord);
}

try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import com.azure.cosmos.models.CosmosItemOperation;
import com.azure.cosmos.models.PartitionKey;
import com.azure.cosmos.models.PartitionKeyDefinition;
import com.google.common.collect.Iterators;
import org.apache.kafka.connect.data.ConnectSchema;
import org.apache.kafka.connect.data.Schema;
import org.apache.kafka.connect.sink.SinkRecord;
Expand All @@ -32,6 +31,7 @@
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicInteger;

import static junit.framework.TestCase.assertEquals;
import static junit.framework.TestCase.assertTrue;
Expand Down Expand Up @@ -109,8 +109,12 @@ public void testBulkWriteWithNonTransientException() {
ArgumentCaptor<Iterable<CosmosItemOperation>> parameters = ArgumentCaptor.forClass(Iterable.class);
verify(container, times(1)).executeBulkOperations(parameters.capture());

AtomicInteger count = new AtomicInteger();
parameters.getValue().forEach(cosmosItemOperation -> {
count.incrementAndGet();
});
Iterator<CosmosItemOperation> bulkExecutionParameters = parameters.getValue().iterator();
assertEquals(2, Iterators.size(bulkExecutionParameters));
assertEquals(2, getIteratorSize(bulkExecutionParameters));
}

@Test
Expand Down Expand Up @@ -141,9 +145,9 @@ public void testBulkWriteSucceededWithTransientException() {

List<Iterable<CosmosItemOperation>> allParameters = parameters.getAllValues();
assertEquals(3, allParameters.size());
assertEquals(2, Iterators.size(allParameters.get(0).iterator()));
assertEquals(1, Iterators.size(allParameters.get(1).iterator()));
assertEquals(1, Iterators.size(allParameters.get(2).iterator()));
assertEquals(2, getIteratorSize(allParameters.get(0).iterator()));
assertEquals(1, getIteratorSize(allParameters.get(1).iterator()));
assertEquals(1, getIteratorSize(allParameters.get(2).iterator()));
}


Expand Down Expand Up @@ -210,4 +214,13 @@ private CosmosBulkOperationResponse mockFailedBulkOperationResponse(SinkRecord s

return mockedBulkOptionResponse;
}

private int getIteratorSize(Iterator<?> iterator) {
int count = 0;
while (iterator.hasNext()) {
iterator.next();
count++;
}
return count;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
package com.azure.cosmos.kafka.connect.sink;

import com.azure.cosmos.kafka.connect.CosmosDBConfig.CosmosClientBuilder;
import com.google.common.collect.ImmutableMap;
import org.apache.kafka.common.config.Config;
import org.apache.kafka.common.config.ConfigValue;
import org.junit.Test;
import org.mockito.MockedStatic;

import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
Expand All @@ -25,7 +25,7 @@ public class CosmosDBSinkConnectorTest {

@Test
public void testValidateEmptyConfigFailsRequiredFields() {
Config config = new CosmosDBSinkConnector().validate(ImmutableMap.of());
Config config = new CosmosDBSinkConnector().validate(Collections.emptyMap());

Map<String, List<String>> errorMessages = config.configValues().stream()
.collect(Collectors.toMap(ConfigValue::name, ConfigValue::errorMessages));
Expand All @@ -46,7 +46,7 @@ public void testValidateCannotConnectToCosmos() {
.when(() -> CosmosClientBuilder.createClient(anyString(), anyString()))
.thenThrow(IllegalArgumentException.class);

Config config = connector.validate(ImmutableMap.of(
Config config = connector.validate(Map.of(
CosmosDBSinkConfig.COSMOS_CONN_ENDPOINT_CONF, "https://endpoint:port/",
CosmosDBSinkConfig.COSMOS_CONN_KEY_CONF, "superSecretPassword",
CosmosDBSinkConfig.COSMOS_DATABASE_NAME_CONF, "superAwesomeDatabase",
Expand All @@ -71,7 +71,7 @@ public void testValidateHappyPath() {
.then(answerVoid((s1, s2) -> {
}));

Config config = connector.validate(ImmutableMap.of(
Config config = connector.validate(Map.of(
CosmosDBSinkConfig.COSMOS_CONN_ENDPOINT_CONF,
"https://cosmos-instance.documents.azure.com:443/",
CosmosDBSinkConfig.COSMOS_CONN_KEY_CONF, "superSecretPassword",
Expand Down Expand Up @@ -106,7 +106,7 @@ public void testValidateTopicMapValidFormat() {
}

private void invalidTopicMapString(CosmosDBSinkConnector connector, String topicMapConfig) {
Config config = connector.validate(ImmutableMap.of(
Config config = connector.validate(Map.of(
CosmosDBSinkConfig.COSMOS_CONN_ENDPOINT_CONF, "https://endpoint:port/",
CosmosDBSinkConfig.COSMOS_CONN_KEY_CONF, "superSecretPassword",
CosmosDBSinkConfig.COSMOS_DATABASE_NAME_CONF, "superAwesomeDatabase",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
import com.azure.cosmos.CosmosContainer;
import com.azure.cosmos.CosmosDatabase;
import com.azure.cosmos.implementation.BadRequestException;
import com.azure.cosmos.kafka.connect.source.JsonToStruct;
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.fasterxml.jackson.databind.node.ObjectNode;
import org.apache.commons.lang3.reflect.FieldUtils;
import org.apache.kafka.connect.data.ConnectSchema;
import org.apache.kafka.connect.data.Schema;
Expand All @@ -23,9 +26,11 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;

import static junit.framework.TestCase.assertEquals;
import static junit.framework.TestCase.assertNotNull;
import static junit.framework.TestCase.assertTrue;
import static junit.framework.TestCase.fail;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.anyString;
Expand Down Expand Up @@ -108,7 +113,7 @@ public void sinkWriteFailed() {
}

try {
testTask.put(Arrays.asList(record));
testTask.put(List.of(record));
fail("Expected ConnectException on bad message");
} catch (ConnectException ce) {

Expand Down Expand Up @@ -158,7 +163,7 @@ public void sinkWriteSucceeded() {
}

try {
testTask.put(Arrays.asList(record));
testTask.put(List.of(record));
} catch (ConnectException ce) {
fail("Expected sink write succeeded. but got: " + ce.getMessage());
} catch (Throwable t) {
Expand All @@ -172,5 +177,83 @@ public void sinkWriteSucceeded() {
}
}
}

@Test
public void sinkWriteSucceededWithStructRecordValue() {
Schema stringSchema = new ConnectSchema(Schema.Type.STRING);
Schema structSchema = new ConnectSchema(Schema.Type.STRUCT);
ObjectNode objectNode = new ObjectNode(new JsonNodeFactory(false))
.put("foo", "fooz")
.put("bar", "baaz");
JsonToStruct jsonToStruct = new JsonToStruct();

Object recordValue = jsonToStruct.recordToSchemaAndValue(objectNode).value();


SinkRecord record = new SinkRecord(topicName, 1, stringSchema, "nokey", structSchema, recordValue, 0L);
assertNotNull(record.value());

SinkWriteResponse sinkWriteResponse = new SinkWriteResponse();
sinkWriteResponse.getSucceededRecords().add(record);

MockedConstruction<? extends SinkWriterBase> mockedWriterConstruction = null;
AtomicReference<List<SinkRecord>> sinkRecords = new AtomicReference<>();
try {
if (this.isBulkModeEnabled) {
mockedWriterConstruction = mockConstructionWithAnswer(BulkWriter.class, invocation -> {
if (invocation.getMethod().equals(BulkWriter.class.getMethod("write", List.class))) {
sinkRecords.set(invocation.getArgument(0));
return sinkWriteResponse;
}

throw new IllegalStateException("Not implemented for method " + invocation.getMethod().getName());
});
} else {
mockedWriterConstruction = mockConstructionWithAnswer(PointWriter.class, invocation -> {
if (invocation.getMethod().equals(PointWriter.class.getMethod("write", List.class))) {
sinkRecords.set(invocation.getArgument(0));
return sinkWriteResponse;
}

throw new IllegalStateException("Not implemented for method " + invocation.getMethod().getName());
});
}

try {
testTask.put(List.of(record));
} catch (ConnectException ce) {
fail("Expected sink write succeeded. but got: " + ce.getMessage());
} catch (Throwable t) {
fail("Expected sink write succeeded, but got: " + t.getClass().getName());
}

assertEquals(1, mockedWriterConstruction.constructed().size());

SinkRecord sinkRecord = sinkRecords.get().get(0);
assertRecordEquals(record, sinkRecord);

Object value = sinkRecord.value();
assertTrue(value instanceof Map);

assertEquals("fooz", ((Map<?, ?>) value).get("foo"));
assertEquals("baaz", ((Map<?, ?>) value).get("bar"));
} finally {
if (mockedWriterConstruction != null) {
mockedWriterConstruction.close();
}
}
}

private void assertRecordEquals(SinkRecord record, SinkRecord updatedRecord) {
assertEquals(record.kafkaOffset(), updatedRecord.kafkaOffset());
assertEquals(record.timestamp(), updatedRecord.timestamp());
assertEquals(record.timestampType(), updatedRecord.timestampType());
assertEquals(record.headers(), updatedRecord.headers());
assertEquals(record.keySchema(), updatedRecord.keySchema());
assertEquals(record.valueSchema(), updatedRecord.valueSchema());
assertEquals(record.key(), updatedRecord.key());
assertEquals(record.topic(), updatedRecord.topic());
assertEquals(record.kafkaPartition(), updatedRecord.kafkaPartition());
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

package com.azure.cosmos.kafka.connect.sink;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.apache.kafka.connect.data.Date;
import org.apache.kafka.connect.data.Decimal;
import org.apache.kafka.connect.data.Schema;
Expand All @@ -28,7 +26,7 @@ public void emptyStructToEmptyMap() {
Schema schema = SchemaBuilder.struct()
.build();
Struct struct = new Struct(schema);
assertEquals(ImmutableMap.of(), StructToJsonMap.toJsonMap(struct));
assertEquals(Map.of(), StructToJsonMap.toJsonMap(struct));
}


Expand All @@ -38,10 +36,10 @@ public void structWithEmptyArrayToMap() {
.field("array_of_boolean", SchemaBuilder.array(Schema.BOOLEAN_SCHEMA).build());

Struct struct = new Struct(schema)
.put("array_of_boolean", ImmutableList.of());
.put("array_of_boolean", List.of());

Map<String, Object> converted = StructToJsonMap.toJsonMap(struct);
assertEquals(ImmutableList.of(), ((List<Boolean>) converted.get("array_of_boolean")));
assertEquals(List.of(), converted.get("array_of_boolean"));
}

@Test
Expand Down Expand Up @@ -86,8 +84,8 @@ public void complexStructToMap() {
.put("string", quickBrownFox)
.put("struct", new Struct(embeddedSchema)
.put("embedded_string", quickBrownFox))
.put("array_of_boolean", ImmutableList.of(false))
.put("array_of_struct", ImmutableList.of(
.put("array_of_boolean", List.of(false))
.put("array_of_struct", List.of(
new Struct(embeddedSchema).put("embedded_string", quickBrownFox)));

Map<String, Object> converted = StructToJsonMap.toJsonMap(struct);
Expand All @@ -105,7 +103,7 @@ public void complexStructToMap() {
assertEquals(quickBrownFox, converted.get("string"));
assertEquals(quickBrownFox, ((Map<String, Object>) converted.get("struct")).get("embedded_string"));
assertEquals(false, ((List<Boolean>) converted.get("array_of_boolean")).get(0));
assertEquals(ImmutableMap.of("embedded_string", quickBrownFox), ((List<Struct>) converted.get("array_of_struct")).get(0));
assertEquals(Map.of("embedded_string", quickBrownFox), ((List<Struct>) converted.get("array_of_struct")).get(0));
assertNull(converted.get("optional_string"));
}

Expand Down
Loading

0 comments on commit 37dc369

Please sign in to comment.