diff --git a/runners/flink/1.5/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java b/runners/flink/1.5/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java index edc44f6543d7d..9924581d2133a 100644 --- a/runners/flink/1.5/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java +++ b/runners/flink/1.5/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java @@ -27,6 +27,7 @@ import org.apache.beam.runners.core.StateNamespaces; import org.apache.beam.runners.core.StateTag; import org.apache.beam.runners.core.StateTags; +import org.apache.beam.runners.flink.translation.wrappers.streaming.FlinkKeyUtils; import org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkStateInternals; import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.coders.StringUtf8Coder; @@ -63,6 +64,17 @@ protected StateInternals createStateInternals() { } } + @Test + public void testGetKey() throws Exception { + KeyedStateBackend keyedStateBackend = createStateBackend(); + StringUtf8Coder coder = StringUtf8Coder.of(); + String key = "key"; + + keyedStateBackend.setCurrentKey(FlinkKeyUtils.encodeKey(key, coder)); + FlinkStateInternals stateInternals = new FlinkStateInternals<>(keyedStateBackend, coder); + assertThat(stateInternals.getKey(), is(key)); + } + @Test public void testWatermarkHoldsPersistence() throws Exception { KeyedStateBackend keyedStateBackend = createStateBackend(); diff --git a/runners/flink/1.8/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java b/runners/flink/1.8/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java index 82d2c918b93e2..b5a485981793b 100644 --- a/runners/flink/1.8/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java +++ b/runners/flink/1.8/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java @@ -28,6 +28,7 @@ import org.apache.beam.runners.core.StateNamespaces; import org.apache.beam.runners.core.StateTag; import org.apache.beam.runners.core.StateTags; +import org.apache.beam.runners.flink.translation.wrappers.streaming.FlinkKeyUtils; import org.apache.beam.runners.flink.translation.wrappers.streaming.state.FlinkStateInternals; import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.coders.StringUtf8Coder; @@ -66,6 +67,17 @@ protected StateInternals createStateInternals() { } } + @Test + public void testGetKey() throws Exception { + KeyedStateBackend keyedStateBackend = createStateBackend(); + StringUtf8Coder coder = StringUtf8Coder.of(); + String key = "key"; + + keyedStateBackend.setCurrentKey(FlinkKeyUtils.encodeKey(key, coder)); + FlinkStateInternals stateInternals = new FlinkStateInternals<>(keyedStateBackend, coder); + assertThat(stateInternals.getKey(), is(key)); + } + @Test public void testWatermarkHoldsPersistence() throws Exception { KeyedStateBackend keyedStateBackend = createStateBackend(); diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkKeyUtils.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkKeyUtils.java index 61eaae8f22785..0e22ac9eeb418 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkKeyUtils.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkKeyUtils.java @@ -44,7 +44,10 @@ public static ByteBuffer encodeKey(K key, Coder keyCoder) { checkNotNull(keyCoder, "Provided coder must not be null"); final byte[] keyBytes; try { - keyBytes = CoderUtils.encodeToByteArray(keyCoder, key); + // We use nested here to make sure the same logic with the get key logic at the portability + // layer. We currently use NESTED everywhere in the Beam portability APIs since NESTED vs + // OUTER becomes quite complicated and extremely error prone. + keyBytes = CoderUtils.encodeToByteArray(keyCoder, key, Coder.Context.NESTED); } catch (Exception e) { throw new RuntimeException(String.format(Locale.ENGLISH, "Failed to encode key: %s", key), e); } @@ -59,7 +62,7 @@ static K decodeKey(ByteBuffer byteBuffer, Coder keyCoder) { @SuppressWarnings("ByteBufferBackingArray") final byte[] keyBytes = byteBuffer.array(); try { - return CoderUtils.decodeFromByteArray(keyCoder, keyBytes); + return CoderUtils.decodeFromByteArray(keyCoder, keyBytes, Coder.Context.NESTED); } catch (Exception e) { throw new RuntimeException( String.format( diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java index 2eb450852925f..5b82a0b386b7c 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java @@ -109,7 +109,7 @@ public K getKey() { keyBytes.get(bytes); keyBytes.position(keyBytes.position() - bytes.length); try { - return CoderUtils.decodeFromByteArray(keyCoder, bytes); + return CoderUtils.decodeFromByteArray(keyCoder, bytes, Coder.Context.NESTED); } catch (CoderException e) { throw new RuntimeException("Error decoding key.", e); } diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java index ec4a2b90d34d0..57f7694d95ece 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java @@ -1631,10 +1631,10 @@ public void finishBundle(FinishBundleContext context) { assertThat( stripStreamRecordFromWindowedValue(testHarness.getOutput()), contains( - WindowedValue.valueInGlobalWindow(KV.of("key2", "c")), - WindowedValue.valueInGlobalWindow(KV.of("key2", "d")), WindowedValue.valueInGlobalWindow(KV.of("key", "a")), WindowedValue.valueInGlobalWindow(KV.of("key", "b")), + WindowedValue.valueInGlobalWindow(KV.of("key2", "c")), + WindowedValue.valueInGlobalWindow(KV.of("key2", "d")), WindowedValue.valueInGlobalWindow(KV.of("key3", "finishBundle")))); doFnOperator = doFnOperatorSupplier.get(); @@ -1652,10 +1652,10 @@ public void finishBundle(FinishBundleContext context) { assertThat( stripStreamRecordFromWindowedValue(testHarness.getOutput()), contains( - WindowedValue.valueInGlobalWindow(KV.of("key2", "c")), - WindowedValue.valueInGlobalWindow(KV.of("key2", "d")), WindowedValue.valueInGlobalWindow(KV.of("key", "a")), WindowedValue.valueInGlobalWindow(KV.of("key", "b")), + WindowedValue.valueInGlobalWindow(KV.of("key2", "c")), + WindowedValue.valueInGlobalWindow(KV.of("key2", "d")), WindowedValue.valueInGlobalWindow(KV.of("key3", "finishBundle")))); // repeat to see if elements are evicted @@ -1665,10 +1665,10 @@ public void finishBundle(FinishBundleContext context) { assertThat( stripStreamRecordFromWindowedValue(testHarness.getOutput()), contains( - WindowedValue.valueInGlobalWindow(KV.of("key2", "c")), - WindowedValue.valueInGlobalWindow(KV.of("key2", "d")), WindowedValue.valueInGlobalWindow(KV.of("key", "a")), WindowedValue.valueInGlobalWindow(KV.of("key", "b")), + WindowedValue.valueInGlobalWindow(KV.of("key2", "c")), + WindowedValue.valueInGlobalWindow(KV.of("key2", "d")), WindowedValue.valueInGlobalWindow(KV.of("key3", "finishBundle")))); } diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkKeyUtilsTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkKeyUtilsTest.java index 06b5d0131233c..61d7ed141fb31 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkKeyUtilsTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkKeyUtilsTest.java @@ -53,11 +53,12 @@ public void testNullKey() { @SuppressWarnings("ByteBufferBackingArray") public void testCoderContext() throws Exception { byte[] bytes = {1, 1, 1}; + byte[] encodedBytes = {3, 1, 1, 1}; ByteString key = ByteString.copyFrom(bytes); ByteStringCoder coder = ByteStringCoder.of(); ByteBuffer encoded = FlinkKeyUtils.encodeKey(key, coder); - // Ensure outer context is used where no length encoding is used. - assertThat(encoded.array(), is(bytes)); + // Ensure nested context is used. + assertThat(encoded.array(), is(encodedBytes)); } }