diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/WithKeys.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/WithKeys.java index a6f66f6208dfd..96072d8ec29b9 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/WithKeys.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/WithKeys.java @@ -24,6 +24,9 @@ import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderRegistry; import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.schemas.NoSuchSchemaException; +import org.apache.beam.sdk.schemas.SchemaCoder; +import org.apache.beam.sdk.schemas.SchemaRegistry; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.TypeDescriptor; @@ -129,6 +132,20 @@ public KV apply(V element) { // TODO: Remove when we can set the coder inference context. result.setCoder(KvCoder.of(keyCoder, in.getCoder())); } catch (CannotProvideCoderException exc) { + if (keyType != null) { + try { + SchemaRegistry schemaRegistry = SchemaRegistry.createDefault(); + SchemaCoder schemaCoder = + SchemaCoder.of( + schemaRegistry.getSchema(keyType), + keyType, + schemaRegistry.getToRowFunction(keyType), + schemaRegistry.getFromRowFunction(keyType)); + result.setCoder(KvCoder.of(schemaCoder, in.getCoder())); + } catch (NoSuchSchemaException exception) { + // No Schema. + } + } // let lazy coder inference have a try } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WithKeysTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WithKeysTest.java index 5a8da194f4e72..296a53f48e800 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WithKeysTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WithKeysTest.java @@ -22,7 +22,15 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.schemas.JavaBeanSchema; +import org.apache.beam.sdk.schemas.NoSuchSchemaException; +import org.apache.beam.sdk.schemas.SchemaCoder; +import org.apache.beam.sdk.schemas.SchemaRegistry; +import org.apache.beam.sdk.schemas.annotations.DefaultSchema; +import org.apache.beam.sdk.schemas.annotations.SchemaCreate; import org.apache.beam.sdk.testing.NeedsRunner; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; @@ -30,6 +38,7 @@ import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.sdk.values.TypeDescriptors; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.junit.Rule; import org.junit.Test; import org.junit.experimental.categories.Category; @@ -175,4 +184,47 @@ public void withLambdaAndNoTypeDescriptorShouldThrow() { p.run(); } + + @Test + @Category(NeedsRunner.class) + public void testKeySchemaCoderSet() throws NoSuchSchemaException { + PCollection> pCollection = + p.apply(Create.of(Lists.newArrayList("1", "2", "3")).withType(TypeDescriptors.strings())) + .apply( + WithKeys.of(v -> new Pojo(1, v)) + .withKeyType(TypeDescriptor.of(Pojo.class))); + + TypeDescriptor keyType = TypeDescriptor.of(Pojo.class); + SchemaRegistry schemaRegistry = SchemaRegistry.createDefault(); + SchemaCoder schemaCoder = + SchemaCoder.of( + schemaRegistry.getSchema(keyType), + keyType, + schemaRegistry.getToRowFunction(keyType), + schemaRegistry.getFromRowFunction(keyType)); + Coder> expectedCoder = KvCoder.of(schemaCoder, StringUtf8Coder.of()); + assertEquals(expectedCoder, pCollection.getCoder()); + + p.run(); + } + + @DefaultSchema(JavaBeanSchema.class) + private static class Pojo { + private final long num; + private final String str; + + @SchemaCreate + public Pojo(long num, String str) { + this.num = num; + this.str = str; + } + + public long getNum() { + return this.num; + } + + public String getStr() { + return this.str; + } + } }