Skip to content

Commit

Permalink
Add custom enum deserializer to improve error messaging, improve byte…
Browse files Browse the repository at this point in the history
… count error mesages (opensearch-project#5076)

Signed-off-by: Taylor Gray <tylgry@amazon.com>
  • Loading branch information
graytaylor0 authored Oct 23, 2024
1 parent f10867f commit e9bffee
Show file tree
Hide file tree
Showing 9 changed files with 355 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,13 @@ public static ByteCount parse(final String string) {
final String unitString = matcher.group("unit");

if(unitString == null) {
throw new ByteCountInvalidInputException("Byte counts must have a unit.");
throw new ByteCountInvalidInputException("Byte counts must have a unit. Valid byte units include: " +
Arrays.stream(Unit.values()).map(unitValue -> unitValue.unitString).collect(Collectors.toList()));
}

final Unit unit = Unit.fromString(unitString)
.orElseThrow(() -> new ByteCountInvalidInputException("Invalid byte unit: '" + unitString + "'"));
.orElseThrow(() -> new ByteCountInvalidInputException("Invalid byte unit: '" + unitString + "'. Valid byte units include: "
+ Arrays.stream(Unit.values()).map(unitValue -> unitValue.unitString).collect(Collectors.toList())));

final BigDecimal valueBigDecimal = new BigDecimal(valueString);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public ByteCount deserialize(final JsonParser parser, final DeserializationConte
try {
return ByteCount.parse(byteString);
} catch (final Exception ex) {
throw new IllegalArgumentException(ex);
throw new IllegalArgumentException(ex.getMessage());
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package org.opensearch.dataprepper.pipeline.parser;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonValue;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.databind.BeanProperty;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JavaType;
import com.fasterxml.jackson.databind.JsonDeserializer;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.deser.ContextualDeserializer;

import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;


/**
* This deserializer is used for any Enum classes when converting the pipeline configuration file into the plugin model classes
* @since 2.11
*/
public class EnumDeserializer extends JsonDeserializer<Enum<?>> implements ContextualDeserializer {

static final String INVALID_ENUM_VALUE_ERROR_FORMAT = "Invalid value \"%s\". Valid options include %s.";

private Class<?> enumClass;

public EnumDeserializer() {}

public EnumDeserializer(final Class<?> enumClass) {
if (!enumClass.isEnum()) {
throw new IllegalArgumentException("The provided class is not an enum: " + enumClass.getName());
}

this.enumClass = enumClass;
}
@Override
public Enum<?> deserialize(final JsonParser p, final DeserializationContext ctxt) throws IOException {
final JsonNode node = p.getCodec().readTree(p);
final String enumValue = node.asText();

final Optional<Method> jsonCreator = findJsonCreatorMethod();

try {
jsonCreator.ifPresent(method -> method.setAccessible(true));

for (Object enumConstant : enumClass.getEnumConstants()) {
try {
if (jsonCreator.isPresent() && enumConstant.equals(jsonCreator.get().invoke(null, enumValue))) {
return (Enum<?>) enumConstant;
} else if (jsonCreator.isEmpty() && enumConstant.toString().toLowerCase().equals(enumValue)) {
return (Enum<?>) enumConstant;
}
} catch (IllegalAccessException | InvocationTargetException e) {
throw new RuntimeException(e);
}
}
} finally {
jsonCreator.ifPresent(method -> method.setAccessible(false));
}



final Optional<Method> jsonValueMethod = findJsonValueMethodForClass();
final List<Object> listOfEnums = jsonValueMethod.map(method -> Arrays.stream(enumClass.getEnumConstants())
.map(valueEnum -> {
try {
return method.invoke(valueEnum);
} catch (IllegalAccessException | InvocationTargetException e) {
throw new RuntimeException(e);
}
})
.collect(Collectors.toList())).orElseGet(() -> Arrays.stream(enumClass.getEnumConstants())
.map(valueEnum -> valueEnum.toString().toLowerCase())
.collect(Collectors.toList()));

throw new IllegalArgumentException(String.format(INVALID_ENUM_VALUE_ERROR_FORMAT, enumValue, listOfEnums));
}

@Override
public JsonDeserializer<?> createContextual(final DeserializationContext ctxt, final BeanProperty property) {
final JavaType javaType = property.getType();
final Class<?> rawClass = javaType.getRawClass();

return new EnumDeserializer(rawClass);
}

private Optional<Method> findJsonValueMethodForClass() {
for (final Method method : enumClass.getDeclaredMethods()) {
if (method.isAnnotationPresent(JsonValue.class)) {
return Optional.of(method);
}
}

return Optional.empty();
}

private Optional<Method> findJsonCreatorMethod() {
for (final Method method : enumClass.getDeclaredMethods()) {
if (method.isAnnotationPresent(JsonCreator.class)) {
return Optional.of(method);
}
}

return Optional.empty();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.junit.jupiter.params.provider.ValueSource;
import org.opensearch.dataprepper.model.types.ByteCount;

import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.notNullValue;
import static org.hamcrest.MatcherAssert.assertThat;
Expand All @@ -31,9 +32,28 @@ void setUp() {
}

@ParameterizedTest
@ValueSource(strings = {"1", "1b 2b", "1vb", "bad"})
void convert_with_invalid_values_throws(final String invalidByteString) {
assertThrows(IllegalArgumentException.class, () -> objectMapper.convertValue(invalidByteString, ByteCount.class));
@ValueSource(strings = {"1", "10"})
void convert_with_no_byte_unit_throws_expected_exception(final String invalidByteString) {
final IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> objectMapper.convertValue(invalidByteString, ByteCount.class));
assertThat(exception.getMessage(), containsString("Byte counts must have a unit. Valid byte units include: [b, kb, mb, gb]"));
}

@ParameterizedTest
@ValueSource(strings = {"10 2b", "bad"})
void convert_with_non_parseable_values_throws(final String invalidByteString) {
final IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> objectMapper.convertValue(invalidByteString, ByteCount.class));
assertThat(exception.getMessage(), containsString("Unable to parse bytes"));
}

@ParameterizedTest
@CsvSource({
"10f, f",
"1vb, vb",
"3g, g"
})
void convert_with_invalid_byte_units_throws(final String invalidByteString, final String invalidUnit) {
final IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> objectMapper.convertValue(invalidByteString, ByteCount.class));
assertThat(exception.getMessage(), containsString("Invalid byte unit: '" + invalidUnit + "'. Valid byte units include: [b, kb, mb, gb]"));
}

@ParameterizedTest
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
package org.opensearch.dataprepper.pipeline.parser;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonValue;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.databind.BeanProperty;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JavaType;
import com.fasterxml.jackson.databind.JsonDeserializer;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.TextNode;
import org.hamcrest.Matchers;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
import org.opensearch.dataprepper.model.event.HandleFailedEventsOption;

import java.io.IOException;
import java.time.Duration;
import java.util.Arrays;
import java.util.Map;
import java.util.UUID;
import java.util.function.Function;
import java.util.stream.Collectors;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.notNullValue;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class EnumDeserializerTest {

private ObjectMapper objectMapper;

@BeforeEach
void setup() {
objectMapper = mock(ObjectMapper.class);
}

private EnumDeserializer createObjectUnderTest(final Class<?> enumClass) {
return new EnumDeserializer(enumClass);
}

@Test
void non_enum_class_throws_IllegalArgumentException() {
assertThrows(IllegalArgumentException.class, () -> new EnumDeserializer(Duration.class));
}

@ParameterizedTest
@EnumSource(TestEnum.class)
void enum_class_with_json_creator_annotation_returns_expected_enum_constant(final TestEnum testEnumOption) throws IOException {
final EnumDeserializer objectUnderTest = createObjectUnderTest(TestEnum.class);
final JsonParser jsonParser = mock(JsonParser.class);
final DeserializationContext deserializationContext = mock(DeserializationContext.class);
when(jsonParser.getCodec()).thenReturn(objectMapper);

when(objectMapper.readTree(jsonParser)).thenReturn(new TextNode(testEnumOption.toString()));

Enum<?> result = objectUnderTest.deserialize(jsonParser, deserializationContext);

assertThat(result, equalTo(testEnumOption));
}

@ParameterizedTest
@EnumSource(TestEnumWithoutJsonCreator.class)
void enum_class_without_json_creator_annotation_returns_expected_enum_constant(final TestEnumWithoutJsonCreator enumWithoutJsonCreator) throws IOException {
final EnumDeserializer objectUnderTest = createObjectUnderTest(TestEnumWithoutJsonCreator.class);
final JsonParser jsonParser = mock(JsonParser.class);
final DeserializationContext deserializationContext = mock(DeserializationContext.class);
when(jsonParser.getCodec()).thenReturn(objectMapper);

when(objectMapper.readTree(jsonParser)).thenReturn(new TextNode(enumWithoutJsonCreator.toString()));

Enum<?> result = objectUnderTest.deserialize(jsonParser, deserializationContext);

assertThat(result, equalTo(enumWithoutJsonCreator));
}

@Test
void enum_class_with_invalid_value_and_jsonValue_annotation_throws_IllegalArgumentException() throws IOException {
final EnumDeserializer objectUnderTest = createObjectUnderTest(TestEnum.class);
final JsonParser jsonParser = mock(JsonParser.class);
final DeserializationContext deserializationContext = mock(DeserializationContext.class);
when(jsonParser.getCodec()).thenReturn(objectMapper);

final String invalidValue = UUID.randomUUID().toString();
when(objectMapper.readTree(jsonParser)).thenReturn(new TextNode(invalidValue));

final IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () ->
objectUnderTest.deserialize(jsonParser, deserializationContext));

assertThat(exception, notNullValue());
final String expectedErrorMessage = "Invalid value \"" + invalidValue + "\". Valid options include";
assertThat(exception.getMessage(), Matchers.startsWith(expectedErrorMessage));
assertThat(exception.getMessage(), containsString("[test_display_one, test_display_two, test_display_three]"));
}

@Test
void enum_class_with_invalid_value_and_no_jsonValue_annotation_throws_IllegalArgumentException() throws IOException {
final EnumDeserializer objectUnderTest = createObjectUnderTest(TestEnumWithoutJsonCreator.class);
final JsonParser jsonParser = mock(JsonParser.class);
final DeserializationContext deserializationContext = mock(DeserializationContext.class);
when(jsonParser.getCodec()).thenReturn(objectMapper);

final String invalidValue = UUID.randomUUID().toString();
when(objectMapper.readTree(jsonParser)).thenReturn(new TextNode(invalidValue));

final IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () ->
objectUnderTest.deserialize(jsonParser, deserializationContext));

assertThat(exception, notNullValue());
final String expectedErrorMessage = "Invalid value \"" + invalidValue + "\". Valid options include";
assertThat(exception.getMessage(), Matchers.startsWith(expectedErrorMessage));

}

@Test
void create_contextual_returns_expected_enum_deserializer() {
final DeserializationContext context = mock(DeserializationContext.class);
final BeanProperty property = mock(BeanProperty.class);

final ObjectMapper mapper = new ObjectMapper();
final JavaType javaType = mapper.constructType(HandleFailedEventsOption.class);
when(property.getType()).thenReturn(javaType);

final EnumDeserializer objectUnderTest = new EnumDeserializer();
JsonDeserializer<?> result = objectUnderTest.createContextual(context, property);

assertThat(result, instanceOf(EnumDeserializer.class));
}

private enum TestEnum {
TEST_ONE("test_display_one"),
TEST_TWO("test_display_two"),
TEST_THREE("test_display_three");
private static final Map<String, TestEnum> NAMES_MAP = Arrays.stream(TestEnum.values())
.collect(Collectors.toMap(TestEnum::toString, Function.identity()));
private final String name;
TestEnum(final String name) {
this.name = name;
}

@JsonValue
public String toString() {
return this.name;
}
@JsonCreator
static TestEnum fromOptionValue(final String option) {
return NAMES_MAP.get(option);
}
}

private enum TestEnumWithoutJsonCreator {
TEST("test");
private static final Map<String, TestEnumWithoutJsonCreator> NAMES_MAP = Arrays.stream(TestEnumWithoutJsonCreator.values())
.collect(Collectors.toMap(TestEnumWithoutJsonCreator::toString, Function.identity()));
private final String name;
TestEnumWithoutJsonCreator(final String name) {
this.name = name;
}
public String toString() {
return this.name;
}

static TestEnumWithoutJsonCreator fromOptionValue(final String option) {
return NAMES_MAP.get(option);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.opensearch.dataprepper.model.types.ByteCount;
import org.opensearch.dataprepper.pipeline.parser.ByteCountDeserializer;
import org.opensearch.dataprepper.pipeline.parser.DataPrepperDurationDeserializer;
import org.opensearch.dataprepper.pipeline.parser.EnumDeserializer;
import org.opensearch.dataprepper.pipeline.parser.EventKeyDeserializer;
import org.springframework.context.annotation.Bean;

Expand All @@ -33,6 +34,7 @@ public class ObjectMapperConfiguration {
ObjectMapper extensionPluginConfigObjectMapper() {
final SimpleModule simpleModule = new SimpleModule();
simpleModule.addDeserializer(Duration.class, new DataPrepperDurationDeserializer());
simpleModule.addDeserializer(Enum.class, new EnumDeserializer());
simpleModule.addDeserializer(ByteCount.class, new ByteCountDeserializer());

return new ObjectMapper()
Expand All @@ -47,6 +49,7 @@ ObjectMapper pluginConfigObjectMapper(
final SimpleModule simpleModule = new SimpleModule();
simpleModule.addDeserializer(Duration.class, new DataPrepperDurationDeserializer());
simpleModule.addDeserializer(ByteCount.class, new ByteCountDeserializer());
simpleModule.addDeserializer(Enum.class, new EnumDeserializer());
simpleModule.addDeserializer(EventKey.class, new EventKeyDeserializer(eventKeyFactory));
TRANSLATE_VALUE_SUPPORTED_JAVA_TYPES.stream().forEach(clazz -> simpleModule.addDeserializer(
clazz, new DataPrepperScalarTypeDeserializer<>(variableExpander, clazz)));
Expand Down
Loading

0 comments on commit e9bffee

Please sign in to comment.