diff --git a/src/main/java/org/akhq/controllers/ErrorController.java b/src/main/java/org/akhq/controllers/ErrorController.java index e2afdf9d4..f08b6241d 100644 --- a/src/main/java/org/akhq/controllers/ErrorController.java +++ b/src/main/java/org/akhq/controllers/ErrorController.java @@ -22,6 +22,7 @@ import java.io.PrintWriter; import java.io.StringWriter; import java.net.URISyntaxException; +import java.util.regex.Pattern; @Secured(SecurityRule.IS_ANONYMOUS) @Slf4j @@ -33,6 +34,11 @@ public HttpResponse error(HttpRequest request, ApiException e) { return renderExecption(request, e); } + @Error(global = true) + public HttpResponse error(HttpRequest request, ClassCastException e) { + return extractAndRenderException(request, e); + } + // Registry @Error(global = true) public HttpResponse error(HttpRequest request, RestClientException e) { @@ -109,6 +115,7 @@ public HttpResponse notFound(HttpRequest request) throws URISyntaxExceptio return HttpResponse.notFound() .body(error); } + @Error(global = true) public HttpResponse error(HttpRequest request, InvalidClusterException e) { JsonError error = new JsonError(e.getMessage()) @@ -116,4 +123,34 @@ public HttpResponse error(HttpRequest request, InvalidClusterException e) return HttpResponse.status(HttpStatus.CONFLICT).body(error); } + + private HttpResponse extractAndRenderException(HttpRequest request, Exception e) { + String fieldRegex = "field\\s.*$"; + String expectedTypeRegex = "cannot be cast to\\sclass\\s[A-z.]+"; + String actualTypeRegex = "class\\s[A-z.]+"; + + String actualField = retrievePatternMatch(fieldRegex, e.getMessage()); + String expectedType = retrievePatternMatch(expectedTypeRegex, e.getMessage()).toLowerCase(); + String actualType = retrievePatternMatch(actualTypeRegex, e.getMessage()).toLowerCase(); + + var message = String.format("Field %s required %s but got %s", actualField, expectedType, actualType); + JsonError error = new JsonError(message) + .link(Link.SELF, Link.of(request.getUri())); + + return HttpResponse.status(HttpStatus.CONFLICT) + .body(error); + } + + private String retrievePatternMatch(String regex, String message) { + var compile = Pattern.compile(regex); + var matcher = compile.matcher(message); + if (matcher.find()) { + var group = matcher.group(); + var invalidWordRegex = "field\\s|cannot\\sbe\\scast\\sto\\sclass\\s|class\\s|java.lang."; + + return group.replaceAll(invalidWordRegex, ""); + } + + return ""; + } } diff --git a/src/main/java/org/akhq/utils/AvroSerializer.java b/src/main/java/org/akhq/utils/AvroSerializer.java index d21333ab8..3399564d6 100644 --- a/src/main/java/org/akhq/utils/AvroSerializer.java +++ b/src/main/java/org/akhq/utils/AvroSerializer.java @@ -8,6 +8,7 @@ import org.apache.avro.generic.GenericData; import org.apache.avro.generic.GenericRecord; import org.apache.avro.util.Utf8; +import org.apache.commons.collections.CollectionUtils; import java.math.BigDecimal; import java.math.MathContext; @@ -20,11 +21,7 @@ import java.time.ZoneOffset; import java.time.format.DateTimeFormatter; import java.time.format.DateTimeFormatterBuilder; -import java.util.AbstractMap; -import java.util.Collection; -import java.util.Map; -import java.util.TimeZone; -import java.util.UUID; +import java.util.*; import java.util.stream.Collectors; public class AvroSerializer { @@ -59,6 +56,16 @@ public class AvroSerializer { public static GenericRecord recordSerializer(Map record, Schema schema) { GenericRecord returnValue = new GenericData.Record(schema); + Set schemaFields = schema.getFields().stream() + .map(Schema.Field::name).collect(Collectors.toSet()); + + Set recordFields = record.keySet(); + + if (schemaFields.size() != recordFields.size()) { + Object[] missingFields = CollectionUtils.disjunction(schemaFields, recordFields).stream().toArray(); + throw new IllegalArgumentException(" Record does not contain followings fields ".concat(Arrays.toString(missingFields))); + } + schema .getFields() .forEach(field -> { diff --git a/src/test/java/org/akhq/modules/AvroSchemaSerializerTest.java b/src/test/java/org/akhq/modules/AvroSchemaSerializerTest.java index 695717731..bb63be36e 100644 --- a/src/test/java/org/akhq/modules/AvroSchemaSerializerTest.java +++ b/src/test/java/org/akhq/modules/AvroSchemaSerializerTest.java @@ -1,7 +1,6 @@ package org.akhq.modules; import io.confluent.kafka.schemaregistry.avro.AvroSchema; -import io.confluent.kafka.schemaregistry.client.rest.exceptions.RestClientException; import org.akhq.configs.SchemaRegistryType; import org.akhq.modules.schemaregistry.AvroSerializer; import org.apache.avro.SchemaBuilder; @@ -10,7 +9,6 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.junit.jupiter.MockitoExtension; -import java.io.IOException; import java.nio.ByteBuffer; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -66,5 +64,4 @@ void shouldFailIfDoesntMatchSchemaId() { avroSerializer.serialize(INVALID_JSON); }); } - } diff --git a/src/test/java/org/akhq/utils/AvroDeserializerTest.java b/src/test/java/org/akhq/utils/AvroDeserializerTest.java index cd9a06f83..18d2e0eb4 100644 --- a/src/test/java/org/akhq/utils/AvroDeserializerTest.java +++ b/src/test/java/org/akhq/utils/AvroDeserializerTest.java @@ -177,15 +177,17 @@ void testDefaultValue() { + " {\"name\": \"arrayField\", \"type\": {\"type\": \"array\", \"items\": \"double\"}, \"default\": []}" + " ]" + "}"; + Map defaultValues = new HashMap<>(); + defaultValues.put("stringField", null); + defaultValues.put("arrayField", List.of()); + Schema schema = new Schema.Parser().parse(type); - GenericRecord expectedRecord = AvroSerializer.recordSerializer(Map.of(), schema); + GenericRecord expectedRecord = AvroSerializer.recordSerializer(defaultValues, schema); assert new GenericData().validate(schema, expectedRecord); Map result = AvroDeserializer.recordDeserializer(expectedRecord); - Map defaultValues = new HashMap<>(); - defaultValues.put("stringField", null); - defaultValues.put("arrayField", List.of()); + assertThat(result, is(defaultValues)); } } diff --git a/src/test/java/org/akhq/utils/AvroSerializerTest.java b/src/test/java/org/akhq/utils/AvroSerializerTest.java index 7ebc89548..c83401d1d 100644 --- a/src/test/java/org/akhq/utils/AvroSerializerTest.java +++ b/src/test/java/org/akhq/utils/AvroSerializerTest.java @@ -1,13 +1,16 @@ package org.akhq.utils; +import org.apache.avro.SchemaBuilder; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import java.time.Instant; import java.time.LocalDateTime; import java.time.ZoneId; +import java.util.Map; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; class AvroSerializerTest { @@ -127,4 +130,19 @@ void testParseDateTime_minutes_local() { } + private final org.apache.avro.Schema SCHEMA = SchemaBuilder + .record("schema1").namespace("org.akhq") + .fields() + .name("title").type().stringType().noDefault() + .name("release_year").type().intType().noDefault() + .name("rating").type().doubleType().noDefault() + .endRecord(); + + @Test + void shouldThrowIfSchemaAndRecordFieldsAreNotEqual() { + assertThrows(IllegalArgumentException.class, () -> { + AvroSerializer.recordSerializer(Map.of("title", "akhq"), SCHEMA); + }); + } + }