diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/options/PipelineOptionsFactory.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/options/PipelineOptionsFactory.java index 4cb7eb4eac45f..d014f08ac2b8c 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/options/PipelineOptionsFactory.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/options/PipelineOptionsFactory.java @@ -36,11 +36,13 @@ import com.fasterxml.jackson.databind.annotation.JsonSerialize; import com.fasterxml.jackson.databind.deser.DefaultDeserializationContext; import com.fasterxml.jackson.databind.deser.impl.MethodProperty; +import com.fasterxml.jackson.databind.deser.impl.TypeWrappedDeserializer; import com.fasterxml.jackson.databind.introspect.AnnotatedMember; import com.fasterxml.jackson.databind.introspect.AnnotatedMethod; import com.fasterxml.jackson.databind.introspect.AnnotationCollector; import com.fasterxml.jackson.databind.introspect.BeanPropertyDefinition; import com.fasterxml.jackson.databind.introspect.TypeResolutionContext; +import com.fasterxml.jackson.databind.jsontype.TypeDeserializer; import com.fasterxml.jackson.databind.node.TreeTraversingParser; import com.fasterxml.jackson.databind.ser.DefaultSerializerProvider; import com.fasterxml.jackson.databind.type.TypeBindings; @@ -1730,21 +1732,23 @@ private static JsonDeserializer computeDeserializerForMethod(Method meth BeanProperty prop = createBeanProperty(method); AnnotatedMember annotatedMethod = prop.getMember(); + DefaultDeserializationContext context = DESERIALIZATION_CONTEXT.get(); Object maybeDeserializerClass = - DESERIALIZATION_CONTEXT - .get() - .getAnnotationIntrospector() - .findDeserializer(annotatedMethod); + context.getAnnotationIntrospector().findDeserializer(annotatedMethod); JsonDeserializer jsonDeserializer = - DESERIALIZATION_CONTEXT - .get() - .deserializerInstance(annotatedMethod, maybeDeserializerClass); + context.deserializerInstance(annotatedMethod, maybeDeserializerClass); if (jsonDeserializer == null) { - jsonDeserializer = - DESERIALIZATION_CONTEXT.get().findContextualValueDeserializer(prop.getType(), prop); + jsonDeserializer = context.findContextualValueDeserializer(prop.getType(), prop); } + + TypeDeserializer typeDeserializer = + context.getFactory().findTypeDeserializer(context.getConfig(), prop.getType()); + if (typeDeserializer != null) { + jsonDeserializer = new TypeWrappedDeserializer(typeDeserializer, jsonDeserializer); + } + return jsonDeserializer; } catch (JsonMappingException e) { throw new RuntimeException(e); diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/options/PipelineOptionsFactoryTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/options/PipelineOptionsFactoryTest.java index 94fd3f41faacd..ffdfbc8681a13 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/options/PipelineOptionsFactoryTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/options/PipelineOptionsFactoryTest.java @@ -39,6 +39,8 @@ import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonSubTypes; +import com.fasterxml.jackson.annotation.JsonTypeInfo; import com.fasterxml.jackson.core.JsonGenerator; import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.core.JsonProcessingException; @@ -60,7 +62,6 @@ import java.util.Collection; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.Set; import org.apache.beam.model.jobmanagement.v1.JobApi.PipelineOptionDescriptor; import org.apache.beam.model.jobmanagement.v1.JobApi.PipelineOptionType; @@ -1070,6 +1071,53 @@ public void testComplexTypes() { assertEquals("value2", options.getObjectValue().get().value2); } + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME) + @JsonSubTypes({ + @JsonSubTypes.Type(value = PolymorphicTypeOne.class, name = "one"), + @JsonSubTypes.Type(value = PolymorphicTypeTwo.class, name = "two") + }) + public abstract static class PolymorphicType { + String key; + + @JsonProperty("key") + public String getKey() { + return key; + } + + public void setKey(String key) { + this.key = key; + } + } + + public static class PolymorphicTypeOne extends PolymorphicType {} + + public static class PolymorphicTypeTwo extends PolymorphicType {} + + public interface PolymorphicTypes extends PipelineOptions { + PolymorphicType getObject(); + + void setObject(PolymorphicType value); + + ValueProvider getObjectValue(); + + void setObjectValue(ValueProvider value); + } + + @Test + public void testPolymorphicType() { + String[] args = + new String[] { + "--object={\"key\":\"value\",\"@type\":\"one\"}", + "--objectValue={\"key\":\"value\",\"@type\":\"two\"}" + }; + PolymorphicTypes options = PipelineOptionsFactory.fromArgs(args).as(PolymorphicTypes.class); + assertEquals("value", options.getObject().key); + assertEquals(PolymorphicTypeOne.class, options.getObject().getClass()); + + assertEquals("value", options.getObjectValue().get().key); + assertEquals(PolymorphicTypeTwo.class, options.getObjectValue().get().getClass()); + } + @Test public void testMissingArgument() { String[] args = new String[] {};