From 5b1d73fee3994784fd826aea1a7fc81c261d093f Mon Sep 17 00:00:00 2001 From: mdshalda Date: Wed, 5 Apr 2017 22:23:00 -0700 Subject: [PATCH 01/15] BEAM-1851: Update documentation for for fixedSizeGlobally to be clear about memory constraints. --- .../src/main/java/org/apache/beam/sdk/transforms/Sample.java | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Sample.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Sample.java index 3d35c808265d5..f3bd07a27acb2 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Sample.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Sample.java @@ -81,6 +81,10 @@ public static PTransform, PCollection> any(long limit) { * selected elements. If the input {@code PCollection} has fewer than {@code sampleSize} elements, * then the output {@code Iterable} will be all the input's elements. * + *

All of the elements of the output {@code PCollection} should fit into + * main memory of a single worker machine. This operation does not + * run in parallel. + * *

Example of use: * *

{@code

From aac8b03289a296011a976b1edc0f66991e13013a Mon Sep 17 00:00:00 2001
From: Robert Bradshaw 
Date: Wed, 5 Apr 2017 11:55:20 -0700
Subject: [PATCH 02/15] Use absolute import for dataflow iobase test.

---
 .../apache_beam/runners/dataflow/native_io/iobase_test.py       | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/sdks/python/apache_beam/runners/dataflow/native_io/iobase_test.py b/sdks/python/apache_beam/runners/dataflow/native_io/iobase_test.py
index 1f82fdf1478a1..7610baff6b479 100644
--- a/sdks/python/apache_beam/runners/dataflow/native_io/iobase_test.py
+++ b/sdks/python/apache_beam/runners/dataflow/native_io/iobase_test.py
@@ -21,7 +21,7 @@
 import unittest
 
 from apache_beam import error, pvalue
-from iobase import (
+from apache_beam.runners.dataflow.native_io.iobase import (
     _dict_printable_fields,
     _NativeWrite,
     ConcatPosition,

From aaae9d776e958d6d891c2d2c635d0164f07132a1 Mon Sep 17 00:00:00 2001
From: Sourabh Bajaj 
Date: Thu, 6 Apr 2017 16:23:33 -0700
Subject: [PATCH 03/15] [BEAM-778] Fix the Compressed file seek tests on
 windows

---
 sdks/python/apache_beam/io/filesystem.py      |   2 +-
 sdks/python/apache_beam/io/filesystem_test.py | 242 +++++++++---------
 2 files changed, 118 insertions(+), 126 deletions(-)

diff --git a/sdks/python/apache_beam/io/filesystem.py b/sdks/python/apache_beam/io/filesystem.py
index e6c3c298fa825..85c7f066e6f2b 100644
--- a/sdks/python/apache_beam/io/filesystem.py
+++ b/sdks/python/apache_beam/io/filesystem.py
@@ -259,7 +259,7 @@ def flush(self):
 
   @property
   def seekable(self):
-    return self._file.mode == 'r'
+    return 'r' in self._file.mode
 
   def _clear_read_buffer(self):
     """Clears the read buffer by removing all the contents and
diff --git a/sdks/python/apache_beam/io/filesystem_test.py b/sdks/python/apache_beam/io/filesystem_test.py
index 168925d999214..607393d3a5552 100644
--- a/sdks/python/apache_beam/io/filesystem_test.py
+++ b/sdks/python/apache_beam/io/filesystem_test.py
@@ -17,48 +17,23 @@
 #
 
 """Unit tests for filesystem module."""
-import shutil
+import bz2
+import gzip
 import os
 import unittest
 import tempfile
-import bz2
-import gzip
 from StringIO import StringIO
 
 from apache_beam.io.filesystem import CompressedFile, CompressionTypes
 
 
-class _TestCaseWithTempDirCleanUp(unittest.TestCase):
+class TestCompressedFile(unittest.TestCase):
   """Base class for TestCases that deals with TempDir clean-up.
 
   Inherited test cases will call self._new_tempdir() to start a temporary dir
   which will be deleted at the end of the tests (when tearDown() is called).
   """
 
-  def setUp(self):
-    self._tempdirs = []
-
-  def tearDown(self):
-    for path in self._tempdirs:
-      if os.path.exists(path):
-        shutil.rmtree(path)
-    self._tempdirs = []
-
-  def _new_tempdir(self):
-    result = tempfile.mkdtemp()
-    self._tempdirs.append(result)
-    return result
-
-  def _create_temp_file(self, name='', suffix=''):
-    if not name:
-      name = tempfile.template
-    file_name = tempfile.NamedTemporaryFile(
-        delete=False, prefix=name,
-        dir=self._new_tempdir(), suffix=suffix).name
-    return file_name
-
-
-class TestCompressedFile(_TestCaseWithTempDirCleanUp):
   content = """- the BEAM -
 How things really are we would like to know.
 Does
@@ -72,9 +47,21 @@ class TestCompressedFile(_TestCaseWithTempDirCleanUp):
   # in compressed file and not just in the internal buffer
   read_block_size = 4
 
-  def _create_compressed_file(self, compression_type, content,
-                              name='', suffix=''):
-    file_name = self._create_temp_file(name, suffix)
+  def setUp(self):
+    self._tempfiles = []
+
+  def tearDown(self):
+    for path in self._tempfiles:
+      if os.path.exists(path):
+        os.remove(path)
+
+  def _create_temp_file(self):
+    path = tempfile.NamedTemporaryFile(delete=False).name
+    self._tempfiles.append(path)
+    return path
+
+  def _create_compressed_file(self, compression_type, content):
+    file_name = self._create_temp_file()
 
     if compression_type == CompressionTypes.BZIP2:
       compress_factory = bz2.BZ2File
@@ -83,139 +70,144 @@ def _create_compressed_file(self, compression_type, content,
     else:
       assert False, "Invalid compression type: %s" % compression_type
 
-    with compress_factory(file_name, 'w') as f:
+    with compress_factory(file_name, 'wb') as f:
       f.write(content)
 
     return file_name
 
   def test_seekable_enabled_on_read(self):
-    readable = CompressedFile(open(self._create_temp_file(), 'r'))
-    self.assertTrue(readable.seekable)
+    with open(self._create_temp_file(), 'rb') as f:
+      readable = CompressedFile(f)
+      self.assertTrue(readable.seekable)
 
   def test_seekable_disabled_on_write(self):
-    writeable = CompressedFile(open(self._create_temp_file(), 'w'))
-    self.assertFalse(writeable.seekable)
+    with open(self._create_temp_file(), 'wb') as f:
+      writeable = CompressedFile(f)
+      self.assertFalse(writeable.seekable)
 
   def test_seekable_disabled_on_append(self):
-    writeable = CompressedFile(open(self._create_temp_file(), 'a'))
-    self.assertFalse(writeable.seekable)
+    with open(self._create_temp_file(), 'ab') as f:
+      writeable = CompressedFile(f)
+      self.assertFalse(writeable.seekable)
 
   def test_seek_set(self):
     for compression_type in [CompressionTypes.BZIP2, CompressionTypes.GZIP]:
       file_name = self._create_compressed_file(compression_type, self.content)
-
-      compressed_fd = CompressedFile(open(file_name, 'r'), compression_type,
-                                     read_size=self.read_block_size)
-      reference_fd = StringIO(self.content)
-
-      # Note: content (readline) check must come before position (tell) check
-      # because cStringIO's tell() reports out of bound positions (if we seek
-      # beyond the file) up until a real read occurs.
-      # _CompressedFile.tell() always stays within the bounds of the
-      # uncompressed content.
-      for seek_position in (-1, 0, 1,
-                            len(self.content)-1, len(self.content),
-                            len(self.content) + 1):
-        compressed_fd.seek(seek_position, os.SEEK_SET)
-        reference_fd.seek(seek_position, os.SEEK_SET)
-
-        uncompressed_line = compressed_fd.readline()
-        reference_line = reference_fd.readline()
-        self.assertEqual(uncompressed_line, reference_line)
-
-        uncompressed_position = compressed_fd.tell()
-        reference_position = reference_fd.tell()
-        self.assertEqual(uncompressed_position, reference_position)
+      with open(file_name, 'rb') as f:
+        compressed_fd = CompressedFile(f, compression_type,
+                                       read_size=self.read_block_size)
+        reference_fd = StringIO(self.content)
+
+        # Note: content (readline) check must come before position (tell) check
+        # because cStringIO's tell() reports out of bound positions (if we seek
+        # beyond the file) up until a real read occurs.
+        # _CompressedFile.tell() always stays within the bounds of the
+        # uncompressed content.
+        for seek_position in (-1, 0, 1,
+                              len(self.content)-1, len(self.content),
+                              len(self.content) + 1):
+          compressed_fd.seek(seek_position, os.SEEK_SET)
+          reference_fd.seek(seek_position, os.SEEK_SET)
+
+          uncompressed_line = compressed_fd.readline()
+          reference_line = reference_fd.readline()
+          self.assertEqual(uncompressed_line, reference_line)
+
+          uncompressed_position = compressed_fd.tell()
+          reference_position = reference_fd.tell()
+          self.assertEqual(uncompressed_position, reference_position)
 
   def test_seek_cur(self):
     for compression_type in [CompressionTypes.BZIP2, CompressionTypes.GZIP]:
       file_name = self._create_compressed_file(compression_type, self.content)
-
-      compressed_fd = CompressedFile(open(file_name, 'r'), compression_type,
-                                     read_size=self.read_block_size)
-      reference_fd = StringIO(self.content)
-
-      # Test out of bound, inbound seeking in both directions
-      for seek_position in (-1, 0, 1,
-                            len(self.content) / 2,
-                            len(self.content) / 2,
-                            -1 * len(self.content) / 2):
-        compressed_fd.seek(seek_position, os.SEEK_CUR)
-        reference_fd.seek(seek_position, os.SEEK_CUR)
-
-        uncompressed_line = compressed_fd.readline()
-        expected_line = reference_fd.readline()
-        self.assertEqual(uncompressed_line, expected_line)
-
-        reference_position = reference_fd.tell()
-        uncompressed_position = compressed_fd.tell()
-        self.assertEqual(uncompressed_position, reference_position)
+      with open(file_name, 'rb') as f:
+        compressed_fd = CompressedFile(f, compression_type,
+                                       read_size=self.read_block_size)
+        reference_fd = StringIO(self.content)
+
+        # Test out of bound, inbound seeking in both directions
+        for seek_position in (-1, 0, 1,
+                              len(self.content) / 2,
+                              len(self.content) / 2,
+                              -1 * len(self.content) / 2):
+          compressed_fd.seek(seek_position, os.SEEK_CUR)
+          reference_fd.seek(seek_position, os.SEEK_CUR)
+
+          uncompressed_line = compressed_fd.readline()
+          expected_line = reference_fd.readline()
+          self.assertEqual(uncompressed_line, expected_line)
+
+          reference_position = reference_fd.tell()
+          uncompressed_position = compressed_fd.tell()
+          self.assertEqual(uncompressed_position, reference_position)
 
   def test_read_from_end_returns_no_data(self):
     for compression_type in [CompressionTypes.BZIP2, CompressionTypes.GZIP]:
       file_name = self._create_compressed_file(compression_type, self.content)
+      with open(file_name, 'rb') as f:
+        compressed_fd = CompressedFile(f, compression_type,
+                                       read_size=self.read_block_size)
 
-      compressed_fd = CompressedFile(open(file_name, 'r'), compression_type,
-                                     read_size=self.read_block_size)
-
-      seek_position = 0
-      compressed_fd.seek(seek_position, os.SEEK_END)
+        seek_position = 0
+        compressed_fd.seek(seek_position, os.SEEK_END)
 
-      expected_data = ''
-      uncompressed_data = compressed_fd.read(10)
+        expected_data = ''
+        uncompressed_data = compressed_fd.read(10)
 
-      self.assertEqual(uncompressed_data, expected_data)
+        self.assertEqual(uncompressed_data, expected_data)
 
   def test_seek_outside(self):
     for compression_type in [CompressionTypes.BZIP2, CompressionTypes.GZIP]:
       file_name = self._create_compressed_file(compression_type, self.content)
+      with open(file_name, 'rb') as f:
+        compressed_fd = CompressedFile(f, compression_type,
+                                       read_size=self.read_block_size)
 
-      compressed_fd = CompressedFile(open(file_name, 'r'), compression_type,
-                                     read_size=self.read_block_size)
-
-      for whence in (os.SEEK_CUR, os.SEEK_SET, os.SEEK_END):
-        seek_position = -1 * len(self.content) - 10
-        compressed_fd.seek(seek_position, whence)
+        for whence in (os.SEEK_CUR, os.SEEK_SET, os.SEEK_END):
+          seek_position = -1 * len(self.content) - 10
+          compressed_fd.seek(seek_position, whence)
 
-        expected_position = 0
-        uncompressed_position = compressed_fd.tell()
-        self.assertEqual(uncompressed_position, expected_position)
+          expected_position = 0
+          uncompressed_position = compressed_fd.tell()
+          self.assertEqual(uncompressed_position, expected_position)
 
-        seek_position = len(self.content) + 20
-        compressed_fd.seek(seek_position, whence)
+          seek_position = len(self.content) + 20
+          compressed_fd.seek(seek_position, whence)
 
-        expected_position = len(self.content)
-        uncompressed_position = compressed_fd.tell()
-        self.assertEqual(uncompressed_position, expected_position)
+          expected_position = len(self.content)
+          uncompressed_position = compressed_fd.tell()
+          self.assertEqual(uncompressed_position, expected_position)
 
   def test_read_and_seek_back_to_beginning(self):
     for compression_type in [CompressionTypes.BZIP2, CompressionTypes.GZIP]:
       file_name = self._create_compressed_file(compression_type, self.content)
-      compressed_fd = CompressedFile(open(file_name, 'r'), compression_type,
-                                     read_size=self.read_block_size)
+      with open(file_name, 'rb') as f:
+        compressed_fd = CompressedFile(f, compression_type,
+                                       read_size=self.read_block_size)
 
-      first_pass = compressed_fd.readline()
-      compressed_fd.seek(0, os.SEEK_SET)
-      second_pass = compressed_fd.readline()
+        first_pass = compressed_fd.readline()
+        compressed_fd.seek(0, os.SEEK_SET)
+        second_pass = compressed_fd.readline()
 
-      self.assertEqual(first_pass, second_pass)
+        self.assertEqual(first_pass, second_pass)
 
   def test_tell(self):
     lines = ['line%d\n' % i for i in range(10)]
     tmpfile = self._create_temp_file()
-    writeable = CompressedFile(open(tmpfile, 'w'))
-    current_offset = 0
-    for line in lines:
-      writeable.write(line)
-      current_offset += len(line)
-      self.assertEqual(current_offset, writeable.tell())
-
-    writeable.close()
-    readable = CompressedFile(open(tmpfile))
-    current_offset = 0
-    while True:
-      line = readable.readline()
-      current_offset += len(line)
-      self.assertEqual(current_offset, readable.tell())
-      if not line:
-        break
+    with open(tmpfile, 'w') as f:
+      writeable = CompressedFile(f)
+      current_offset = 0
+      for line in lines:
+        writeable.write(line)
+        current_offset += len(line)
+        self.assertEqual(current_offset, writeable.tell())
+
+    with open(tmpfile) as f:
+      readable = CompressedFile(f)
+      current_offset = 0
+      while True:
+        line = readable.readline()
+        current_offset += len(line)
+        self.assertEqual(current_offset, readable.tell())
+        if not line:
+          break

From e05e60117f03fbcb1ba99fd5205b04cd37a917da Mon Sep 17 00:00:00 2001
From: Ahmet Altay 
Date: Fri, 7 Apr 2017 11:21:17 -0700
Subject: [PATCH 04/15] Clean up in textio and tfrecordio

---
 sdks/python/apache_beam/io/textio.py     |  1 -
 sdks/python/apache_beam/io/tfrecordio.py | 12 +++++++-----
 2 files changed, 7 insertions(+), 6 deletions(-)

diff --git a/sdks/python/apache_beam/io/textio.py b/sdks/python/apache_beam/io/textio.py
index 8122fae04b844..9217e740e5d51 100644
--- a/sdks/python/apache_beam/io/textio.py
+++ b/sdks/python/apache_beam/io/textio.py
@@ -385,7 +385,6 @@ def __init__(
     """
 
     super(ReadFromText, self).__init__(**kwargs)
-    self._strip_trailing_newlines = strip_trailing_newlines
     self._source = _TextSource(
         file_pattern, min_bundle_size, compression_type,
         strip_trailing_newlines, coder, validate=validate,
diff --git a/sdks/python/apache_beam/io/tfrecordio.py b/sdks/python/apache_beam/io/tfrecordio.py
index 8b9d9eab8256e..e2b41bffbe772 100644
--- a/sdks/python/apache_beam/io/tfrecordio.py
+++ b/sdks/python/apache_beam/io/tfrecordio.py
@@ -201,10 +201,11 @@ def __init__(self,
       A ReadFromTFRecord transform object.
     """
     super(ReadFromTFRecord, self).__init__(**kwargs)
-    self._args = (file_pattern, coder, compression_type, validate)
+    self._source = _TFRecordSource(file_pattern, coder, compression_type,
+                                   validate)
 
   def expand(self, pvalue):
-    return pvalue.pipeline | Read(_TFRecordSource(*self._args))
+    return pvalue.pipeline | Read(self._source)
 
 
 class _TFRecordSink(fileio.FileSink):
@@ -270,8 +271,9 @@ def __init__(self,
       A WriteToTFRecord transform object.
     """
     super(WriteToTFRecord, self).__init__(**kwargs)
-    self._args = (file_path_prefix, coder, file_name_suffix, num_shards,
-                  shard_name_template, compression_type)
+    self._sink = _TFRecordSink(file_path_prefix, coder, file_name_suffix,
+                               num_shards, shard_name_template,
+                               compression_type)
 
   def expand(self, pcoll):
-    return pcoll | Write(_TFRecordSink(*self._args))
+    return pcoll | Write(self._sink)

From ef344ee4b34abc19f27ec84a2b56f906009db23d Mon Sep 17 00:00:00 2001
From: Mark Liu 
Date: Fri, 7 Apr 2017 14:30:44 -0700
Subject: [PATCH 05/15] [BEAM-1823] Improve ValidatesRunner Test Log

---
 .../runners/dataflow/test_dataflow_runner.py           | 10 +++++++++-
 sdks/python/run_postcommit.sh                          |  6 ++++--
 2 files changed, 13 insertions(+), 3 deletions(-)

diff --git a/sdks/python/apache_beam/runners/dataflow/test_dataflow_runner.py b/sdks/python/apache_beam/runners/dataflow/test_dataflow_runner.py
index aca291ab672d5..046313ad7b3c9 100644
--- a/sdks/python/apache_beam/runners/dataflow/test_dataflow_runner.py
+++ b/sdks/python/apache_beam/runners/dataflow/test_dataflow_runner.py
@@ -18,7 +18,7 @@
 """Wrapper of Beam runners that's built for running and verifying e2e tests."""
 
 from apache_beam.internal import pickler
-from apache_beam.utils.pipeline_options import TestOptions
+from apache_beam.utils.pipeline_options import TestOptions, GoogleCloudOptions
 from apache_beam.runners.dataflow.dataflow_runner import DataflowRunner
 
 
@@ -37,6 +37,14 @@ def run(self, pipeline):
     options.on_success_matcher = None
 
     self.result = super(TestDataflowRunner, self).run(pipeline)
+    if self.result.has_job:
+      project = pipeline.options.view_as(GoogleCloudOptions).project
+      job_id = self.result.job_id()
+      # TODO(markflyhigh)(BEAM-1890): Use print since Nose dosen't show logs
+      # in some cases.
+      print (
+          'Found: https://console.cloud.google.com/dataflow/job/%s?project=%s' %
+          (job_id, project))
     self.result.wait_until_finish()
 
     if on_success_matcher:
diff --git a/sdks/python/run_postcommit.sh b/sdks/python/run_postcommit.sh
index dd3182a723da1..7714d7a254dd8 100755
--- a/sdks/python/run_postcommit.sh
+++ b/sdks/python/run_postcommit.sh
@@ -73,7 +73,8 @@ echo "mock" >> postcommit_requirements.txt
 # Run ValidatesRunner tests on Google Cloud Dataflow service
 echo ">>> RUNNING DATAFLOW RUNNER VALIDATESRUNNER TESTS"
 python setup.py nosetests \
-  -a ValidatesRunner \
+  --attr ValidatesRunner \
+  --nocapture \
   --processes=4 \
   --process-timeout=600 \
   --test-pipeline-options=" \
@@ -89,7 +90,8 @@ python setup.py nosetests \
 # and validate that jobs finish successfully.
 echo ">>> RUNNING TEST DATAFLOW RUNNER it tests"
 python setup.py nosetests \
-  -a IT \
+  --attr IT \
+  --nocapture \
   --processes=4 \
   --process-timeout=600 \
   --test-pipeline-options=" \

From 8179d8c8af9f3afaa785755b13822d2bd508f789 Mon Sep 17 00:00:00 2001
From: Eugene Kirpichov 
Date: Fri, 7 Apr 2017 15:54:03 -0700
Subject: [PATCH 06/15] Cleanup: removes two unused constants

---
 .../sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java     | 2 --
 1 file changed, 2 deletions(-)

diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java
index 8e3a37cde34e2..01552973969d9 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java
@@ -85,8 +85,6 @@ public class ByteBuddyDoFnInvokerFactory implements DoFnInvokerFactory {
   public static final String PROCESS_CONTEXT_PARAMETER_METHOD = "processContext";
   public static final String ON_TIMER_CONTEXT_PARAMETER_METHOD = "onTimerContext";
   public static final String WINDOW_PARAMETER_METHOD = "window";
-  public static final String INPUT_PROVIDER_PARAMETER_METHOD = "inputProvider";
-  public static final String OUTPUT_RECEIVER_PARAMETER_METHOD = "outputReceiver";
   public static final String RESTRICTION_TRACKER_PARAMETER_METHOD = "restrictionTracker";
   public static final String STATE_PARAMETER_METHOD = "state";
   public static final String TIMER_PARAMETER_METHOD = "timer";

From 03a99b4df18ceeb7cd61953c461c4cca377a37ae Mon Sep 17 00:00:00 2001
From: Eugene Kirpichov 
Date: Fri, 7 Apr 2017 12:09:47 -0700
Subject: [PATCH 07/15] [BEAM-65] Adds HasDefaultTracker for RestrictionTracker
 inference

Allows a restriction type to implement HasDefaultTracker,
in that case the splittable DoFn itself does not need to
implement NewTracker - only ProcessElement and GetInitialRestriction.
---
 ...edSplittableProcessElementInvokerTest.java |  5 --
 .../runners/core/SplittableParDoTest.java     | 53 +++++--------
 .../reflect/ByteBuddyDoFnInvokerFactory.java  | 24 +++++-
 .../transforms/reflect/DoFnSignatures.java    | 74 ++++++++++++-------
 .../splittabledofn/HasDefaultTracker.java     | 30 ++++++++
 .../splittabledofn/OffsetRange.java           |  8 +-
 .../sdk/transforms/SplittableDoFnTest.java    | 20 -----
 .../transforms/reflect/DoFnInvokersTest.java  | 54 +++++++++++---
 .../DoFnSignaturesSplittableDoFnTest.java     | 43 +++++++++++
 9 files changed, 215 insertions(+), 96 deletions(-)
 create mode 100644 sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/HasDefaultTracker.java

diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvokerTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvokerTest.java
index d7c98899a6111..b85f4812b31cb 100644
--- a/runners/core-java/src/test/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvokerTest.java
+++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvokerTest.java
@@ -71,11 +71,6 @@ public ProcessContinuation process(ProcessContext context, OffsetRangeTracker tr
     public OffsetRange getInitialRestriction(Integer element) {
       throw new UnsupportedOperationException("Should not be called in this test");
     }
-
-    @NewTracker
-    public OffsetRangeTracker newTracker(OffsetRange range) {
-      throw new UnsupportedOperationException("Should not be called in this test");
-    }
   }
 
   private SplittableProcessElementInvoker.Result
diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/SplittableParDoTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/SplittableParDoTest.java
index ee94ee07e611e..6205777f83ca9 100644
--- a/runners/core-java/src/test/java/org/apache/beam/runners/core/SplittableParDoTest.java
+++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/SplittableParDoTest.java
@@ -45,6 +45,7 @@
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.DoFnTester;
 import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.splittabledofn.HasDefaultTracker;
 import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
@@ -72,10 +73,20 @@ public class SplittableParDoTest {
   private static final Duration MAX_BUNDLE_DURATION = Duration.standardSeconds(5);
 
   // ----------------- Tests for whether the transform sets boundedness correctly --------------
-  private static class SomeRestriction implements Serializable {}
+  private static class SomeRestriction
+      implements Serializable, HasDefaultTracker {
+    @Override
+    public SomeRestrictionTracker newTracker() {
+      return new SomeRestrictionTracker(this);
+    }
+  }
 
   private static class SomeRestrictionTracker implements RestrictionTracker {
-    private final SomeRestriction someRestriction = new SomeRestriction();
+    private final SomeRestriction someRestriction;
+
+    public SomeRestrictionTracker(SomeRestriction someRestriction) {
+      this.someRestriction = someRestriction;
+    }
 
     @Override
     public SomeRestriction currentRestriction() {
@@ -96,11 +107,6 @@ public void processElement(ProcessContext context, SomeRestrictionTracker tracke
     public SomeRestriction getInitialRestriction(Integer element) {
       return null;
     }
-
-    @NewTracker
-    public SomeRestrictionTracker newTracker(SomeRestriction restriction) {
-      return null;
-    }
   }
 
   private static class UnboundedFakeFn extends DoFn {
@@ -114,11 +120,6 @@ public ProcessContinuation processElement(
     public SomeRestriction getInitialRestriction(Integer element) {
       return null;
     }
-
-    @NewTracker
-    public SomeRestrictionTracker newTracker(SomeRestriction restriction) {
-      return null;
-    }
   }
 
   private static PCollection makeUnboundedCollection(Pipeline pipeline) {
@@ -376,11 +377,6 @@ public void process(ProcessContext c, SomeRestrictionTracker tracker) {
     public SomeRestriction getInitialRestriction(Integer elem) {
       return new SomeRestriction();
     }
-
-    @NewTracker
-    public SomeRestrictionTracker newTracker(SomeRestriction restriction) {
-      return new SomeRestrictionTracker();
-    }
   }
 
   @Test
@@ -438,11 +434,6 @@ public ProcessContinuation process(ProcessContext c, SomeRestrictionTracker trac
     public SomeRestriction getInitialRestriction(Integer elem) {
       return new SomeRestriction();
     }
-
-    @NewTracker
-    public SomeRestrictionTracker newTracker(SomeRestriction restriction) {
-      return new SomeRestrictionTracker();
-    }
   }
 
   @Test
@@ -474,12 +465,18 @@ public void testResumeSetsTimer() throws Exception {
     assertThat(tester.takeOutputElements(), contains("42"));
   }
 
-  private static class SomeCheckpoint implements Serializable {
+  private static class SomeCheckpoint
+      implements Serializable, HasDefaultTracker {
     private int firstUnprocessedIndex;
 
     private SomeCheckpoint(int firstUnprocessedIndex) {
       this.firstUnprocessedIndex = firstUnprocessedIndex;
     }
+
+    @Override
+    public SomeCheckpointTracker newTracker() {
+      return new SomeCheckpointTracker(this);
+    }
   }
 
   private static class SomeCheckpointTracker implements RestrictionTracker {
@@ -543,11 +540,6 @@ public ProcessContinuation process(ProcessContext c, SomeCheckpointTracker track
     public SomeCheckpoint getInitialRestriction(Integer elem) {
       throw new UnsupportedOperationException("Expected to be supplied explicitly in this test");
     }
-
-    @NewTracker
-    public SomeCheckpointTracker newTracker(SomeCheckpoint restriction) {
-      return new SomeCheckpointTracker(restriction);
-    }
   }
 
   @Test
@@ -658,11 +650,6 @@ public SomeRestriction getInitialRestriction(Integer element) {
       return new SomeRestriction();
     }
 
-    @NewTracker
-    public SomeRestrictionTracker newTracker(SomeRestriction restriction) {
-      return new SomeRestrictionTracker();
-    }
-
     @Setup
     public void setup() {
       assertEquals(State.BEFORE_SETUP, state);
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java
index 01552973969d9..6746d3a81e137 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java
@@ -74,6 +74,8 @@
 import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.StateParameter;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.TimerParameter;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.WindowParameter;
+import org.apache.beam.sdk.transforms.splittabledofn.HasDefaultTracker;
+import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
 import org.apache.beam.sdk.util.Timer;
 import org.apache.beam.sdk.util.UserCodeException;
 import org.apache.beam.sdk.values.TypeDescriptor;
@@ -257,6 +259,15 @@ public  Coder invokeGetRestrictionCoder(CoderRegistr
     }
   }
 
+  /** Default implementation of {@link DoFn.NewTracker}, for delegation by bytebuddy. */
+  public static class DefaultNewTracker {
+    /** Uses {@link HasDefaultTracker} to produce the tracker. */
+    @SuppressWarnings("unused")
+    public static RestrictionTracker invokeNewTracker(Object restriction) {
+      return ((HasDefaultTracker) restriction).newTracker();
+    }
+  }
+
   /** Generates a {@link DoFnInvoker} class for the given {@link DoFnSignature}. */
   private static Class> generateInvokerClass(DoFnSignature signature) {
     Class> fnClass = signature.fnClass();
@@ -306,7 +317,7 @@ public String subclass(TypeDescription.Generic superClass) {
             .method(ElementMatchers.named("invokeGetRestrictionCoder"))
             .intercept(getRestrictionCoderDelegation(clazzDescription, signature))
             .method(ElementMatchers.named("invokeNewTracker"))
-            .intercept(delegateWithDowncastOrThrow(clazzDescription, signature.newTracker()));
+            .intercept(newTrackerDelegation(clazzDescription, signature.newTracker()));
 
     DynamicType.Unloaded unloaded = builder.make();
 
@@ -346,6 +357,17 @@ private static Implementation splitRestrictionDelegation(
     }
   }
 
+  private static Implementation newTrackerDelegation(
+      TypeDescription doFnType, @Nullable DoFnSignature.NewTrackerMethod signature) {
+    if (signature == null) {
+      // We must have already verified that in this case the restriction type
+      // is a subtype of HasDefaultTracker.
+      return MethodDelegation.to(DefaultNewTracker.class);
+    } else {
+      return delegateWithDowncastOrThrow(doFnType, signature);
+    }
+  }
+
   /** Delegates to the given method if available, or does nothing. */
   private static Implementation delegateOrNoop(
       TypeDescription doFnType, DoFnSignature.DoFnMethod method) {
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java
index 61b91570a6e0c..006d012cbb000 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java
@@ -54,6 +54,7 @@
 import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.WindowParameter;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignature.StateDeclaration;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignature.TimerDeclaration;
+import org.apache.beam.sdk.transforms.splittabledofn.HasDefaultTracker;
 import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.util.Timer;
@@ -368,42 +369,35 @@ private static DoFnSignature parseSignature(Class> fnClass)
               errors.forMethod(DoFn.Teardown.class, teardownMethod), teardownMethod));
     }
 
-    DoFnSignature.GetInitialRestrictionMethod getInitialRestriction = null;
-    ErrorReporter getInitialRestrictionErrors = null;
+    ErrorReporter getInitialRestrictionErrors;
     if (getInitialRestrictionMethod != null) {
       getInitialRestrictionErrors =
           errors.forMethod(DoFn.GetInitialRestriction.class, getInitialRestrictionMethod);
       signatureBuilder.setGetInitialRestriction(
-          getInitialRestriction =
               analyzeGetInitialRestrictionMethod(
                   getInitialRestrictionErrors, fnT, getInitialRestrictionMethod, inputT));
     }
 
-    DoFnSignature.SplitRestrictionMethod splitRestriction = null;
     if (splitRestrictionMethod != null) {
       ErrorReporter splitRestrictionErrors =
           errors.forMethod(DoFn.SplitRestriction.class, splitRestrictionMethod);
       signatureBuilder.setSplitRestriction(
-          splitRestriction =
               analyzeSplitRestrictionMethod(
                   splitRestrictionErrors, fnT, splitRestrictionMethod, inputT));
     }
 
-    DoFnSignature.GetRestrictionCoderMethod getRestrictionCoder = null;
     if (getRestrictionCoderMethod != null) {
       ErrorReporter getRestrictionCoderErrors =
           errors.forMethod(DoFn.GetRestrictionCoder.class, getRestrictionCoderMethod);
       signatureBuilder.setGetRestrictionCoder(
-          getRestrictionCoder =
               analyzeGetRestrictionCoderMethod(
                   getRestrictionCoderErrors, fnT, getRestrictionCoderMethod));
     }
 
-    DoFnSignature.NewTrackerMethod newTracker = null;
     if (newTrackerMethod != null) {
       ErrorReporter newTrackerErrors = errors.forMethod(DoFn.NewTracker.class, newTrackerMethod);
       signatureBuilder.setNewTracker(
-          newTracker = analyzeNewTrackerMethod(newTrackerErrors, fnT, newTrackerMethod));
+          analyzeNewTrackerMethod(newTrackerErrors, fnT, newTrackerMethod));
     }
 
     signatureBuilder.setIsBoundedPerElement(inferBoundedness(fnT, processElement, errors));
@@ -501,38 +495,66 @@ private static void verifySplittableMethods(DoFnSignature signature, ErrorReport
     ErrorReporter processElementErrors =
         errors.forMethod(DoFn.ProcessElement.class, processElement.targetMethod());
 
+    final TypeDescriptor trackerT;
+    final String originOfTrackerT;
+
     List missingRequiredMethods = new ArrayList<>();
     if (getInitialRestriction == null) {
       missingRequiredMethods.add("@" + DoFn.GetInitialRestriction.class.getSimpleName());
     }
     if (newTracker == null) {
-      missingRequiredMethods.add("@" + DoFn.NewTracker.class.getSimpleName());
+      if (getInitialRestriction != null
+          && getInitialRestriction
+              .restrictionT()
+              .isSubtypeOf(TypeDescriptor.of(HasDefaultTracker.class))) {
+        trackerT =
+            getInitialRestriction
+                .restrictionT()
+                .resolveType(HasDefaultTracker.class.getTypeParameters()[1]);
+        originOfTrackerT =
+            String.format(
+                "restriction type %s of @%s method %s",
+                formatType(getInitialRestriction.restrictionT()),
+                DoFn.GetInitialRestriction.class.getSimpleName(),
+                format(getInitialRestriction.targetMethod()));
+      } else {
+        missingRequiredMethods.add("@" + DoFn.NewTracker.class.getSimpleName());
+        trackerT = null;
+        originOfTrackerT = null;
+      }
+    } else {
+      trackerT = newTracker.trackerT();
+      originOfTrackerT =
+          String.format(
+              "%s method %s",
+              DoFn.NewTracker.class.getSimpleName(), format(newTracker.targetMethod()));
+      ErrorReporter getInitialRestrictionErrors =
+          errors.forMethod(DoFn.GetInitialRestriction.class, getInitialRestriction.targetMethod());
+      TypeDescriptor restrictionT = getInitialRestriction.restrictionT();
+      getInitialRestrictionErrors.checkArgument(
+          restrictionT.equals(newTracker.restrictionT()),
+          "Uses restriction type %s, but @%s method %s uses restriction type %s",
+          formatType(restrictionT),
+          DoFn.NewTracker.class.getSimpleName(),
+          format(newTracker.targetMethod()),
+          formatType(newTracker.restrictionT()));
     }
+
     if (!missingRequiredMethods.isEmpty()) {
       processElementErrors.throwIllegalArgument(
           "Splittable, but does not define the following required methods: %s",
           missingRequiredMethods);
     }
 
-    processElementErrors.checkArgument(
-        processElement.trackerT().equals(newTracker.trackerT()),
-        "Has tracker type %s, but @%s method %s uses tracker type %s",
-        formatType(processElement.trackerT()),
-        DoFn.NewTracker.class.getSimpleName(),
-        format(newTracker.targetMethod()),
-        formatType(newTracker.trackerT()));
-
     ErrorReporter getInitialRestrictionErrors =
         errors.forMethod(DoFn.GetInitialRestriction.class, getInitialRestriction.targetMethod());
     TypeDescriptor restrictionT = getInitialRestriction.restrictionT();
-
-    getInitialRestrictionErrors.checkArgument(
-        restrictionT.equals(newTracker.restrictionT()),
-        "Uses restriction type %s, but @%s method %s uses restriction type %s",
-        formatType(restrictionT),
-        DoFn.NewTracker.class.getSimpleName(),
-        format(newTracker.targetMethod()),
-        formatType(newTracker.restrictionT()));
+    processElementErrors.checkArgument(
+        processElement.trackerT().equals(trackerT),
+        "Has tracker type %s, but the DoFn's tracker type was inferred as %s from %s",
+        formatType(processElement.trackerT()),
+        trackerT,
+        originOfTrackerT);
 
     if (getRestrictionCoder != null) {
       getInitialRestrictionErrors.checkArgument(
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/HasDefaultTracker.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/HasDefaultTracker.java
new file mode 100644
index 0000000000000..3366dfecaacdd
--- /dev/null
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/HasDefaultTracker.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.transforms.splittabledofn;
+
+/**
+ * Interface for restrictions for which a default implementation of {@link
+ * org.apache.beam.sdk.transforms.DoFn.NewTracker} is available, depending only on the restriction
+ * itself.
+ */
+public interface HasDefaultTracker<
+    RestrictionT extends HasDefaultTracker,
+    TrackerT extends RestrictionTracker> {
+  /** Creates a new tracker for {@code this}. */
+  TrackerT newTracker();
+}
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/OffsetRange.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/OffsetRange.java
index 67031c432df67..104f5f2564a06 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/OffsetRange.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/OffsetRange.java
@@ -22,7 +22,8 @@
 import java.io.Serializable;
 
 /** A restriction represented by a range of integers [from, to). */
-public class OffsetRange implements Serializable {
+public class OffsetRange
+    implements Serializable, HasDefaultTracker {
   private final long from;
   private final long to;
 
@@ -40,6 +41,11 @@ public long getTo() {
     return to;
   }
 
+  @Override
+  public OffsetRangeTracker newTracker() {
+    return new OffsetRangeTracker(this);
+  }
+
   @Override
   public String toString() {
     return "[" + from + ", " + to + ')';
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SplittableDoFnTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SplittableDoFnTest.java
index d926f6cd96701..154a088347b0b 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SplittableDoFnTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SplittableDoFnTest.java
@@ -88,11 +88,6 @@ public void splitRange(
       receiver.output(new OffsetRange(range.getFrom(), (range.getFrom() + range.getTo()) / 2));
       receiver.output(new OffsetRange((range.getFrom() + range.getTo()) / 2, range.getTo()));
     }
-
-    @NewTracker
-    public OffsetRangeTracker newTracker(OffsetRange range) {
-      return new OffsetRangeTracker(range);
-    }
   }
 
   private static class ReifyTimestampsFn extends DoFn> {
@@ -220,11 +215,6 @@ public ProcessContinuation processElement(ProcessContext c, OffsetRangeTracker t
     public OffsetRange getInitialRange(String element) {
       return new OffsetRange(0, MAX_INDEX);
     }
-
-    @NewTracker
-    public OffsetRangeTracker newTracker(OffsetRange range) {
-      return new OffsetRangeTracker(range);
-    }
   }
 
   @Test
@@ -259,11 +249,6 @@ public void process(ProcessContext c, OffsetRangeTracker tracker) {
     public OffsetRange getInitialRestriction(Integer value) {
       return new OffsetRange(0, 1);
     }
-
-    @NewTracker
-    public OffsetRangeTracker newTracker(OffsetRange range) {
-      return new OffsetRangeTracker(range);
-    }
   }
 
   @Test
@@ -357,11 +342,6 @@ public OffsetRange getInitialRestriction(String value) {
       return new OffsetRange(0, 1);
     }
 
-    @NewTracker
-    public OffsetRangeTracker newTracker(OffsetRange range) {
-      return new OffsetRangeTracker(range);
-    }
-
     @Setup
     public void setUp() {
       assertEquals(State.BEFORE_SETUP, state);
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java
index 5d3746f752adc..425f453639534 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java
@@ -42,6 +42,7 @@
 import org.apache.beam.sdk.transforms.DoFn.ProcessContinuation;
 import org.apache.beam.sdk.transforms.reflect.DoFnInvoker.FakeArgumentProvider;
 import org.apache.beam.sdk.transforms.reflect.testhelper.DoFnInvokersTestHelper;
+import org.apache.beam.sdk.transforms.splittabledofn.HasDefaultTracker;
 import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
@@ -391,19 +392,49 @@ public RestrictionTracker restrictionTracker() {
             }));
   }
 
+  private static class RestrictionWithDefaultTracker
+      implements HasDefaultTracker {
+    @Override
+    public DefaultTracker newTracker() {
+      return new DefaultTracker();
+    }
+  }
+
+  private static class DefaultTracker implements RestrictionTracker {
+    @Override
+    public RestrictionWithDefaultTracker currentRestriction() {
+      throw new UnsupportedOperationException();
+    }
+
+    @Override
+    public RestrictionWithDefaultTracker checkpoint() {
+      throw new UnsupportedOperationException();
+    }
+  }
+
+  private static class CoderForDefaultTracker extends CustomCoder {
+    public static CoderForDefaultTracker of() {
+      return new CoderForDefaultTracker();
+    }
+
+    @Override
+    public void encode(
+        RestrictionWithDefaultTracker value, OutputStream outStream, Context context) {}
+
+    @Override
+    public RestrictionWithDefaultTracker decode(InputStream inStream, Context context) {
+      return null;
+    }
+  }
+
   @Test
   public void testSplittableDoFnDefaultMethods() throws Exception {
     class MockFn extends DoFn {
       @ProcessElement
-      public void processElement(ProcessContext c, SomeRestrictionTracker tracker) {}
+      public void processElement(ProcessContext c, DefaultTracker tracker) {}
 
       @GetInitialRestriction
-      public SomeRestriction getInitialRestriction(String element) {
-        return null;
-      }
-
-      @NewTracker
-      public SomeRestrictionTracker newTracker(SomeRestriction restriction) {
+      public RestrictionWithDefaultTracker getInitialRestriction(String element) {
         return null;
       }
     }
@@ -411,10 +442,10 @@ public SomeRestrictionTracker newTracker(SomeRestriction restriction) {
     DoFnInvoker invoker = DoFnInvokers.invokerFor(fn);
 
     CoderRegistry coderRegistry = new CoderRegistry();
-    coderRegistry.registerCoder(SomeRestriction.class, SomeRestrictionCoder.class);
+    coderRegistry.registerCoder(RestrictionWithDefaultTracker.class, CoderForDefaultTracker.class);
     assertThat(
-        invoker.invokeGetRestrictionCoder(coderRegistry),
-        instanceOf(SomeRestrictionCoder.class));
+        invoker.invokeGetRestrictionCoder(coderRegistry),
+        instanceOf(CoderForDefaultTracker.class));
     invoker.invokeSplitRestriction(
         "blah",
         "foo",
@@ -430,6 +461,9 @@ public void output(String output) {
         });
     assertEquals(
         ProcessContinuation.stop(), invoker.invokeProcessElement(mockArgumentProvider));
+    assertThat(
+        invoker.invokeNewTracker(new RestrictionWithDefaultTracker()),
+        instanceOf(DefaultTracker.class));
   }
 
   // ---------------------------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesSplittableDoFnTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesSplittableDoFnTest.java
index c10d199943183..052feb812b456 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesSplittableDoFnTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesSplittableDoFnTest.java
@@ -32,6 +32,7 @@
 import org.apache.beam.sdk.transforms.DoFn.UnboundedPerElement;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignaturesTestUtils.AnonymousMethod;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignaturesTestUtils.FakeDoFn;
+import org.apache.beam.sdk.transforms.splittabledofn.HasDefaultTracker;
 import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.values.PCollection;
@@ -330,6 +331,48 @@ public void process(ProcessContext context, SomeRestrictionTracker tracker) {}
     DoFnSignatures.getSignature(BadFn.class);
   }
 
+  abstract class SomeDefaultTracker implements RestrictionTracker {}
+  abstract class RestrictionWithDefaultTracker
+      implements HasDefaultTracker {}
+
+  @Test
+  public void testHasDefaultTracker() throws Exception {
+    class Fn extends DoFn {
+      @ProcessElement
+      public void process(ProcessContext c, SomeDefaultTracker tracker) {}
+
+      @GetInitialRestriction
+      public RestrictionWithDefaultTracker getInitialRestriction(Integer element) {
+        return null;
+      }
+    }
+
+    DoFnSignature signature = DoFnSignatures.getSignature(Fn.class);
+    assertEquals(
+        SomeDefaultTracker.class, signature.processElement().trackerT().getRawType());
+  }
+
+  @Test
+  public void testRestrictionHasDefaultTrackerProcessUsesWrongTracker() throws Exception {
+    class Fn extends DoFn {
+      @ProcessElement
+      public void process(ProcessContext c, SomeRestrictionTracker tracker) {}
+
+      @GetInitialRestriction
+      public RestrictionWithDefaultTracker getInitialRestriction(Integer element) {
+        return null;
+      }
+    }
+
+    thrown.expectMessage(
+        "Has tracker type SomeRestrictionTracker, but the DoFn's tracker type was inferred as ");
+    thrown.expectMessage("SomeDefaultTracker");
+    thrown.expectMessage(
+        "from restriction type RestrictionWithDefaultTracker "
+            + "of @GetInitialRestriction method getInitialRestriction(Integer)");
+    DoFnSignatures.getSignature(Fn.class);
+  }
+
   @Test
   public void testNewTrackerReturnsWrongType() throws Exception {
     class BadFn extends DoFn {

From 19f407c9497c911ca3cb61d989aa5a78c84896cf Mon Sep 17 00:00:00 2001
From: Eugene Kirpichov 
Date: Fri, 7 Apr 2017 17:00:39 -0700
Subject: [PATCH 08/15] Clarifies doc of ProcessElement re: HasDefaultTracker

---
 .../src/main/java/org/apache/beam/sdk/transforms/DoFn.java  | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java
index de33612710fd7..5139290dd4975 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java
@@ -40,6 +40,7 @@
 import org.apache.beam.sdk.transforms.Combine.CombineFn;
 import org.apache.beam.sdk.transforms.display.DisplayData;
 import org.apache.beam.sdk.transforms.display.HasDisplayData;
+import org.apache.beam.sdk.transforms.splittabledofn.HasDefaultTracker;
 import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.transforms.windowing.PaneInfo;
@@ -548,10 +549,11 @@ public interface OutputReceiver {
    * 
    *
  • It must define a {@link GetInitialRestriction} method. *
  • It may define a {@link SplitRestriction} method. - *
  • It must define a {@link NewTracker} method returning the same type as the type of + *
  • It may define a {@link NewTracker} method returning the same type as the type of * the {@link RestrictionTracker} argument of {@link ProcessElement}, which in turn must be a * subtype of {@code RestrictionTracker} where {@code R} is the restriction type returned - * by {@link GetInitialRestriction}. + * by {@link GetInitialRestriction}. This method is optional in case the restriction type + * returned by {@link GetInitialRestriction} implements {@link HasDefaultTracker}. *
  • It may define a {@link GetRestrictionCoder} method. *
  • The type of restrictions used by all of these methods must be the same. *
  • Its {@link ProcessElement} method may return a {@link ProcessContinuation} to From dad7ace2b6143850fac0ca5359d9f56f5f0df2c1 Mon Sep 17 00:00:00 2001 From: Eugene Kirpichov Date: Thu, 6 Apr 2017 15:07:05 -0700 Subject: [PATCH 09/15] Fixes SDF issues re: watermarks and stop/resume See detailed discussion in document: https://docs.google.com/document/d/1BGc8pM1GOvZhwR9SARSVte-20XEoBUxrGJ5gTWXdv3c/edit --- .../beam/runners/core/DoFnAdapters.java | 5 + ...oundedSplittableProcessElementInvoker.java | 125 +++++++------ .../beam/runners/core/SimpleDoFnRunner.java | 5 + .../beam/runners/core/SplittableParDo.java | 15 +- .../core/SplittableProcessElementInvoker.java | 22 +-- ...edSplittableProcessElementInvokerTest.java | 16 +- .../runners/core/SplittableParDoTest.java | 171 +++--------------- .../org/apache/beam/sdk/transforms/DoFn.java | 78 ++------ .../beam/sdk/transforms/DoFnTester.java | 5 + .../reflect/ByteBuddyDoFnInvokerFactory.java | 20 +- .../sdk/transforms/reflect/DoFnInvoker.java | 4 +- .../sdk/transforms/reflect/DoFnSignature.java | 10 +- .../transforms/reflect/DoFnSignatures.java | 22 +-- .../splittabledofn/OffsetRangeTracker.java | 33 +++- .../splittabledofn/RestrictionTracker.java | 8 + .../sdk/transforms/SplittableDoFnTest.java | 17 +- .../transforms/reflect/DoFnInvokersTest.java | 99 +++------- .../DoFnSignaturesProcessElementTest.java | 2 +- .../DoFnSignaturesSplittableDoFnTest.java | 74 +------- .../OffsetRangeTrackerTest.java | 49 ++++- 20 files changed, 263 insertions(+), 517 deletions(-) diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/DoFnAdapters.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/DoFnAdapters.java index 693cb2fbf0f13..deb3b7ebf18c0 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/DoFnAdapters.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/DoFnAdapters.java @@ -285,6 +285,11 @@ public PaneInfo pane() { return context.pane(); } + @Override + public void updateWatermark(Instant watermark) { + throw new UnsupportedOperationException("Only splittable DoFn's can use updateWatermark()"); + } + @Override public BoundedWindow window() { return context.window(); diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java index 357094c2e0122..27fd0a3640a0d 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java @@ -97,70 +97,57 @@ public Result invokeProcessElement( final WindowedValue element, final TrackerT tracker) { final ProcessContext processContext = new ProcessContext(element, tracker); - DoFn.ProcessContinuation cont = - invoker.invokeProcessElement( - new DoFnInvoker.ArgumentProvider() { - @Override - public DoFn.ProcessContext processContext( - DoFn doFn) { - return processContext; - } - - @Override - public RestrictionTracker restrictionTracker() { - return tracker; - } - - // Unsupported methods below. - - @Override - public BoundedWindow window() { - throw new UnsupportedOperationException( - "Access to window of the element not supported in Splittable DoFn"); - } - - @Override - public DoFn.Context context(DoFn doFn) { - throw new IllegalStateException( - "Should not access context() from @" - + DoFn.ProcessElement.class.getSimpleName()); - } - - @Override - public DoFn.OnTimerContext onTimerContext( - DoFn doFn) { - throw new UnsupportedOperationException( - "Access to timers not supported in Splittable DoFn"); - } - - @Override - public State state(String stateId) { - throw new UnsupportedOperationException( - "Access to state not supported in Splittable DoFn"); - } - - @Override - public Timer timer(String timerId) { - throw new UnsupportedOperationException( - "Access to timers not supported in Splittable DoFn"); - } - }); - RestrictionT residual; - RestrictionT forcedCheckpoint = processContext.extractCheckpoint(); - if (cont.shouldResume()) { - if (forcedCheckpoint == null) { - // If no checkpoint was forced, the call returned voluntarily (i.e. all tryClaim() calls - // succeeded) - but we still need to have a checkpoint to resume from. - residual = tracker.checkpoint(); - } else { - // A checkpoint was forced - i.e. the call probably (but not guaranteed) returned because of - // a failed tryClaim() call. - residual = forcedCheckpoint; - } - } else { - residual = null; - } - return new Result(residual, cont); + invoker.invokeProcessElement( + new DoFnInvoker.ArgumentProvider() { + @Override + public DoFn.ProcessContext processContext( + DoFn doFn) { + return processContext; + } + + @Override + public RestrictionTracker restrictionTracker() { + return tracker; + } + + // Unsupported methods below. + + @Override + public BoundedWindow window() { + throw new UnsupportedOperationException( + "Access to window of the element not supported in Splittable DoFn"); + } + + @Override + public DoFn.Context context(DoFn doFn) { + throw new IllegalStateException( + "Should not access context() from @" + + DoFn.ProcessElement.class.getSimpleName()); + } + + @Override + public DoFn.OnTimerContext onTimerContext( + DoFn doFn) { + throw new UnsupportedOperationException( + "Access to timers not supported in Splittable DoFn"); + } + + @Override + public State state(String stateId) { + throw new UnsupportedOperationException( + "Access to state not supported in Splittable DoFn"); + } + + @Override + public Timer timer(String timerId) { + throw new UnsupportedOperationException( + "Access to timers not supported in Splittable DoFn"); + } + }); + + tracker.checkDone(); + return new Result( + processContext.extractCheckpoint(), processContext.getLastReportedWatermark()); } private class ProcessContext extends DoFn.ProcessContext { @@ -176,6 +163,7 @@ private class ProcessContext extends DoFn.ProcessContext { private RestrictionT checkpoint; // A handle on the scheduled action to take a checkpoint. private Future scheduledCheckpoint; + private Instant lastReportedWatermark; public ProcessContext(WindowedValue element, TrackerT tracker) { fn.super(); @@ -240,6 +228,15 @@ public PaneInfo pane() { return element.getPane(); } + @Override + public synchronized void updateWatermark(Instant watermark) { + lastReportedWatermark = watermark; + } + + public synchronized Instant getLastReportedWatermark() { + return lastReportedWatermark; + } + @Override public PipelineOptions getPipelineOptions() { return pipelineOptions; diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java index 77286b257a2b5..98d88b6cd1e2b 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java @@ -541,6 +541,11 @@ public PaneInfo pane() { return windowedValue.getPane(); } + @Override + public void updateWatermark(Instant watermark) { + throw new UnsupportedOperationException("Only splittable DoFn's can use updateWatermark()"); + } + @Override public void output(OutputT output) { context.outputWindowedValue(windowedValue.withValue(output)); diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDo.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDo.java index 0b311c7751100..c16bf44a6c60f 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDo.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDo.java @@ -324,14 +324,12 @@ public static class ProcessFn< /** * The state cell containing a watermark hold for the output of this {@link DoFn}. The hold is * acquired during the first {@link DoFn.ProcessElement} call for each element and restriction, - * and is released when the {@link DoFn.ProcessElement} call returns {@link - * DoFn.ProcessContinuation#stop}. + * and is released when the {@link DoFn.ProcessElement} call returns and there is no residual + * restriction captured by the {@link SplittableProcessElementInvoker}. * *

    A hold is needed to avoid letting the output watermark immediately progress together with * the input watermark when the first {@link DoFn.ProcessElement} call for this element * completes. - * - *

    The hold is updated with the future output watermark reported by ProcessContinuation. */ private static final StateTag> watermarkHoldTag = StateTags.makeSystemTagInternal( @@ -461,7 +459,7 @@ public void processElement(final ProcessContext c) { invoker, elementAndRestriction.element(), tracker); // Save state for resuming. - if (!result.getContinuation().shouldResume()) { + if (result.getResidualRestriction() == null) { // All work for this element/restriction is completed. Clear state and release hold. elementState.clear(); restrictionState.clear(); @@ -469,16 +467,15 @@ public void processElement(final ProcessContext c) { return; } restrictionState.write(result.getResidualRestriction()); - Instant futureOutputWatermark = result.getContinuation().getWatermark(); + Instant futureOutputWatermark = result.getFutureOutputWatermark(); if (futureOutputWatermark == null) { futureOutputWatermark = elementAndRestriction.element().getTimestamp(); } - Instant wakeupTime = - timerInternals.currentProcessingTime().plus(result.getContinuation().resumeDelay()); holdState.add(futureOutputWatermark); // Set a timer to continue processing this element. timerInternals.setTimer( - TimerInternals.TimerData.of(stateNamespace, wakeupTime, TimeDomain.PROCESSING_TIME)); + TimerInternals.TimerData.of( + stateNamespace, timerInternals.currentProcessingTime(), TimeDomain.PROCESSING_TIME)); } /** diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableProcessElementInvoker.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableProcessElementInvoker.java index cfc39e7bdb953..ced6c015039f8 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableProcessElementInvoker.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableProcessElementInvoker.java @@ -22,6 +22,7 @@ import org.apache.beam.sdk.transforms.reflect.DoFnInvoker; import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; import org.apache.beam.sdk.util.WindowedValue; +import org.joda.time.Instant; /** * A runner-specific hook for invoking a {@link DoFn.ProcessElement} method for a splittable {@link @@ -31,25 +32,24 @@ public abstract class SplittableProcessElementInvoker< InputT, OutputT, RestrictionT, TrackerT extends RestrictionTracker> { /** Specifies how to resume a splittable {@link DoFn.ProcessElement} call. */ public class Result { - @Nullable private final RestrictionT residualRestriction; - private final DoFn.ProcessContinuation continuation; + @Nullable + private final RestrictionT residualRestriction; + private final Instant futureOutputWatermark; public Result( - @Nullable RestrictionT residualRestriction, DoFn.ProcessContinuation continuation) { + @Nullable RestrictionT residualRestriction, Instant futureOutputWatermark) { this.residualRestriction = residualRestriction; - this.continuation = continuation; + this.futureOutputWatermark = futureOutputWatermark; } - /** - * Can be {@code null} only if {@link #getContinuation} specifies the call should not resume. - */ + /** If {@code null}, means the call should not resume. */ @Nullable public RestrictionT getResidualRestriction() { return residualRestriction; } - public DoFn.ProcessContinuation getContinuation() { - return continuation; + public Instant getFutureOutputWatermark() { + return futureOutputWatermark; } } @@ -57,8 +57,8 @@ public DoFn.ProcessContinuation getContinuation() { * Invokes the {@link DoFn.ProcessElement} method using the given {@link DoFnInvoker} for the * original {@link DoFn}, on the given element and with the given {@link RestrictionTracker}. * - * @return Information on how to resume the call: residual restriction and a {@link - * DoFn.ProcessContinuation}. + * @return Information on how to resume the call: residual restriction and a + * future output watermark. */ public abstract Result invokeProcessElement( DoFnInvoker invoker, WindowedValue element, TrackerT tracker); diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvokerTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvokerTest.java index b85f4812b31cb..965380bc5d08d 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvokerTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvokerTest.java @@ -17,15 +17,11 @@ */ package org.apache.beam.runners.core; -import static org.apache.beam.sdk.transforms.DoFn.ProcessContinuation.resume; -import static org.apache.beam.sdk.transforms.DoFn.ProcessContinuation.stop; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.lessThan; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThat; -import static org.junit.Assert.assertTrue; import java.util.Collection; import java.util.concurrent.Executors; @@ -54,17 +50,12 @@ private SomeFn(Duration sleepBeforeEachOutput) { } @ProcessElement - public ProcessContinuation process(ProcessContext context, OffsetRangeTracker tracker) + public void process(ProcessContext context, OffsetRangeTracker tracker) throws Exception { - OffsetRange range = tracker.currentRestriction(); - for (int i = (int) range.getFrom(); i < range.getTo(); ++i) { - if (!tracker.tryClaim(i)) { - return resume(); - } + for (long i = tracker.currentRestriction().getFrom(); tracker.tryClaim(i); ++i) { Thread.sleep(sleepBeforeEachOutput.getMillis()); context.output("" + i); } - return stop(); } @GetInitialRestriction @@ -111,7 +102,6 @@ public void sideOutputWindowedValue( public void testInvokeProcessElementOutputBounded() throws Exception { SplittableProcessElementInvoker.Result res = runTest(10000, Duration.ZERO); - assertTrue(res.getContinuation().shouldResume()); OffsetRange residualRange = res.getResidualRestriction(); // Should process the first 100 elements. assertEquals(1000, residualRange.getFrom()); @@ -122,7 +112,6 @@ public void testInvokeProcessElementOutputBounded() throws Exception { public void testInvokeProcessElementTimeBounded() throws Exception { SplittableProcessElementInvoker.Result res = runTest(10000, Duration.millis(100)); - assertTrue(res.getContinuation().shouldResume()); OffsetRange residualRange = res.getResidualRestriction(); // Should process ideally around 30 elements - but due to timing flakiness, we can't enforce // that precisely. Just test that it's not egregiously off. @@ -135,7 +124,6 @@ public void testInvokeProcessElementTimeBounded() throws Exception { public void testInvokeProcessElementVoluntaryReturn() throws Exception { SplittableProcessElementInvoker.Result res = runTest(5, Duration.millis(100)); - assertFalse(res.getContinuation().shouldResume()); assertNull(res.getResidualRestriction()); } } diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/SplittableParDoTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/SplittableParDoTest.java index 6205777f83ca9..f8d60959beb2e 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/SplittableParDoTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/SplittableParDoTest.java @@ -17,9 +17,6 @@ */ package org.apache.beam.runners.core; -import static org.apache.beam.sdk.transforms.DoFn.ProcessContinuation.resume; -import static org.apache.beam.sdk.transforms.DoFn.ProcessContinuation.stop; -import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.hasItem; import static org.hamcrest.Matchers.not; @@ -43,9 +40,13 @@ import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFn.BoundedPerElement; +import org.apache.beam.sdk.transforms.DoFn.UnboundedPerElement; import org.apache.beam.sdk.transforms.DoFnTester; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.splittabledofn.HasDefaultTracker; +import org.apache.beam.sdk.transforms.splittabledofn.OffsetRange; +import org.apache.beam.sdk.transforms.splittabledofn.OffsetRangeTracker; import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; @@ -97,8 +98,12 @@ public SomeRestriction currentRestriction() { public SomeRestriction checkpoint() { return someRestriction; } + + @Override + public void checkDone() {} } + @BoundedPerElement private static class BoundedFakeFn extends DoFn { @ProcessElement public void processElement(ProcessContext context, SomeRestrictionTracker tracker) {} @@ -109,12 +114,10 @@ public SomeRestriction getInitialRestriction(Integer element) { } } + @UnboundedPerElement private static class UnboundedFakeFn extends DoFn { @ProcessElement - public ProcessContinuation processElement( - ProcessContext context, SomeRestrictionTracker tracker) { - return stop(); - } + public void processElement(ProcessContext context, SomeRestrictionTracker tracker) {} @GetInitialRestriction public SomeRestriction getInitialRestriction(Integer element) { @@ -422,164 +425,40 @@ public void testTrivialProcessFnPropagatesOutputsWindowsAndTimestamp() throws Ex } } - /** A simple splittable {@link DoFn} that outputs the given element every 5 seconds forever. */ - private static class SelfInitiatedResumeFn extends DoFn { - @ProcessElement - public ProcessContinuation process(ProcessContext c, SomeRestrictionTracker tracker) { - c.output(c.element().toString()); - return resume().withResumeDelay(Duration.standardSeconds(5)).withWatermark(c.timestamp()); - } - - @GetInitialRestriction - public SomeRestriction getInitialRestriction(Integer elem) { - return new SomeRestriction(); - } - } - - @Test - public void testResumeSetsTimer() throws Exception { - DoFn fn = new SelfInitiatedResumeFn(); - Instant base = Instant.now(); - ProcessFnTester tester = - new ProcessFnTester<>( - base, fn, BigEndianIntegerCoder.of(), SerializableCoder.of(SomeRestriction.class), - MAX_OUTPUTS_PER_BUNDLE, MAX_BUNDLE_DURATION); - - tester.startElement(42, new SomeRestriction()); - assertThat(tester.takeOutputElements(), contains("42")); - - // Should resume after 5 seconds: advancing by 3 seconds should have no effect. - assertFalse(tester.advanceProcessingTimeBy(Duration.standardSeconds(3))); - assertTrue(tester.takeOutputElements().isEmpty()); - - // 6 seconds should be enough - should invoke the fn again. - assertTrue(tester.advanceProcessingTimeBy(Duration.standardSeconds(3))); - assertThat(tester.takeOutputElements(), contains("42")); - - // Should again resume after 5 seconds: advancing by 3 seconds should again have no effect. - assertFalse(tester.advanceProcessingTimeBy(Duration.standardSeconds(3))); - assertTrue(tester.takeOutputElements().isEmpty()); - - // 6 seconds should again be enough. - assertTrue(tester.advanceProcessingTimeBy(Duration.standardSeconds(3))); - assertThat(tester.takeOutputElements(), contains("42")); - } - - private static class SomeCheckpoint - implements Serializable, HasDefaultTracker { - private int firstUnprocessedIndex; - - private SomeCheckpoint(int firstUnprocessedIndex) { - this.firstUnprocessedIndex = firstUnprocessedIndex; - } - - @Override - public SomeCheckpointTracker newTracker() { - return new SomeCheckpointTracker(this); - } - } - - private static class SomeCheckpointTracker implements RestrictionTracker { - private SomeCheckpoint current; - private boolean isActive = true; - - private SomeCheckpointTracker(SomeCheckpoint current) { - this.current = current; - } - - @Override - public SomeCheckpoint currentRestriction() { - return current; - } - - public boolean tryUpdateCheckpoint(int firstUnprocessedIndex) { - if (!isActive) { - return false; - } - current = new SomeCheckpoint(firstUnprocessedIndex); - return true; - } - - @Override - public SomeCheckpoint checkpoint() { - isActive = false; - return current; - } - } - /** - * A splittable {@link DoFn} that generates the sequence [init, init + total) in batches of given - * size. + * A splittable {@link DoFn} that generates the sequence [init, init + total). */ private static class CounterFn extends DoFn { - private final int numTotalOutputs; - private final int numOutputsPerCall; - - private CounterFn(int numTotalOutputs, int numOutputsPerCall) { - this.numTotalOutputs = numTotalOutputs; - this.numOutputsPerCall = numOutputsPerCall; - } - @ProcessElement - public ProcessContinuation process(ProcessContext c, SomeCheckpointTracker tracker) { - int start = tracker.currentRestriction().firstUnprocessedIndex; - for (int i = 0; i < numOutputsPerCall; ++i) { - int index = start + i; - if (!tracker.tryUpdateCheckpoint(index + 1)) { - return resume(); - } - if (index >= numTotalOutputs) { - return stop(); - } - c.output(String.valueOf(c.element() + index)); + public void process(ProcessContext c, OffsetRangeTracker tracker) { + for (long i = tracker.currentRestriction().getFrom(); + tracker.tryClaim(i); ++i) { + c.output(String.valueOf(c.element() + i)); } - return resume(); } @GetInitialRestriction - public SomeCheckpoint getInitialRestriction(Integer elem) { + public OffsetRange getInitialRestriction(Integer elem) { throw new UnsupportedOperationException("Expected to be supplied explicitly in this test"); } } - @Test - public void testResumeCarriesOverState() throws Exception { - DoFn fn = new CounterFn(3, 1); - Instant base = Instant.now(); - ProcessFnTester tester = - new ProcessFnTester<>( - base, fn, BigEndianIntegerCoder.of(), SerializableCoder.of(SomeCheckpoint.class), - MAX_OUTPUTS_PER_BUNDLE, MAX_BUNDLE_DURATION); - - tester.startElement(42, new SomeCheckpoint(0)); - assertThat(tester.takeOutputElements(), contains("42")); - assertTrue(tester.advanceProcessingTimeBy(Duration.standardSeconds(1))); - assertThat(tester.takeOutputElements(), contains("43")); - assertTrue(tester.advanceProcessingTimeBy(Duration.standardSeconds(1))); - assertThat(tester.takeOutputElements(), contains("44")); - assertTrue(tester.advanceProcessingTimeBy(Duration.standardSeconds(1))); - // After outputting all 3 items, should not output anything more. - assertEquals(0, tester.takeOutputElements().size()); - // Should also not ask to resume. - assertFalse(tester.advanceProcessingTimeBy(Duration.standardSeconds(1))); - } - @Test public void testCheckpointsAfterNumOutputs() throws Exception { int max = 100; - // Create an fn that attempts to 2x output more than checkpointing allows. - DoFn fn = new CounterFn(2 * max + max / 2, 2 * max); + DoFn fn = new CounterFn(); Instant base = Instant.now(); int baseIndex = 42; - ProcessFnTester tester = + ProcessFnTester tester = new ProcessFnTester<>( - base, fn, BigEndianIntegerCoder.of(), SerializableCoder.of(SomeCheckpoint.class), + base, fn, BigEndianIntegerCoder.of(), SerializableCoder.of(OffsetRange.class), max, MAX_BUNDLE_DURATION); List elements; - tester.startElement(baseIndex, new SomeCheckpoint(0)); + // Create an fn that attempts to 2x output more than checkpointing allows. + tester.startElement(baseIndex, new OffsetRange(0, 2 * max + max / 2)); elements = tester.takeOutputElements(); assertEquals(max, elements.size()); // Should output the range [0, max) @@ -609,18 +488,18 @@ public void testCheckpointsAfterDuration() throws Exception { // But bound bundle duration - the bundle should terminate. Duration maxBundleDuration = Duration.standardSeconds(1); // Create an fn that attempts to 2x output more than checkpointing allows. - DoFn fn = new CounterFn(max, max); + DoFn fn = new CounterFn(); Instant base = Instant.now(); int baseIndex = 42; - ProcessFnTester tester = + ProcessFnTester tester = new ProcessFnTester<>( - base, fn, BigEndianIntegerCoder.of(), SerializableCoder.of(SomeCheckpoint.class), + base, fn, BigEndianIntegerCoder.of(), SerializableCoder.of(OffsetRange.class), max, maxBundleDuration); List elements; - tester.startElement(baseIndex, new SomeCheckpoint(0)); + tester.startElement(baseIndex, new OffsetRange(0, Long.MAX_VALUE)); // Bundle should terminate, and should do at least some processing. elements = tester.takeOutputElements(); assertFalse(elements.isEmpty()); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java index 5139290dd4975..e35457cd4576c 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java @@ -21,7 +21,6 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; -import com.google.auto.value.AutoValue; import java.io.Serializable; import java.lang.annotation.Documented; import java.lang.annotation.ElementType; @@ -32,7 +31,6 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; -import javax.annotation.Nullable; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.annotations.Experimental.Kind; @@ -293,6 +291,18 @@ public abstract class ProcessContext extends Context { * See {@link Window} for more information. */ public abstract PaneInfo pane(); + + /** + * Gives the runner a (best-effort) lower bound about the timestamps of future output associated + * with the current element. + * + *

    If the {@link DoFn} has multiple outputs, the watermark applies to all of them. + * + *

    Only splittable {@link DoFn DoFns} are allowed to call this method. It is safe to call + * this method from a different thread than the one running {@link ProcessElement}, but + * all calls must finish before {@link ProcessElement} returns. + */ + public abstract void updateWatermark(Instant watermark); } /** @@ -556,15 +566,11 @@ public interface OutputReceiver { * returned by {@link GetInitialRestriction} implements {@link HasDefaultTracker}. *

  • It may define a {@link GetRestrictionCoder} method. *
  • The type of restrictions used by all of these methods must be the same. - *
  • Its {@link ProcessElement} method may return a {@link ProcessContinuation} to - * indicate whether there is more work to be done for the current element. *
  • Its {@link ProcessElement} method must not use any extra context parameters, such as * {@link BoundedWindow}. *
  • The {@link DoFn} itself may be annotated with {@link BoundedPerElement} or * {@link UnboundedPerElement}, but not both at the same time. If it's not annotated with - * either of these, it's assumed to be {@link BoundedPerElement} if its {@link - * ProcessElement} method returns {@code void} and {@link UnboundedPerElement} if it - * returns a {@link ProcessContinuation}. + * either of these, it's assumed to be {@link BoundedPerElement}. *
* *

A non-splittable {@link DoFn} must not define any of these methods. @@ -692,61 +698,9 @@ public interface OutputReceiver { @Experimental(Kind.SPLITTABLE_DO_FN) public @interface UnboundedPerElement {} - // This can't be put into ProcessContinuation itself due to the following problem: - // http://ternarysearch.blogspot.com/2013/07/static-initialization-deadlock.html - private static final ProcessContinuation PROCESS_CONTINUATION_STOP = - new AutoValue_DoFn_ProcessContinuation(false, Duration.ZERO, null); - - /** - * When used as a return value of {@link ProcessElement}, indicates whether there is more work to - * be done for the current element. - */ - @Experimental(Kind.SPLITTABLE_DO_FN) - @AutoValue - public abstract static class ProcessContinuation { - /** Indicates that there is no more work to be done for the current element. */ - public static ProcessContinuation stop() { - return PROCESS_CONTINUATION_STOP; - } - - /** Indicates that there is more work to be done for the current element. */ - public static ProcessContinuation resume() { - return new AutoValue_DoFn_ProcessContinuation(true, Duration.ZERO, null); - } - - /** - * If false, the {@link DoFn} promises that there is no more work remaining for the current - * element, so the runner should not resume the {@link ProcessElement} call. - */ - public abstract boolean shouldResume(); - - /** - * A minimum duration that should elapse between the end of this {@link ProcessElement} call and - * the {@link ProcessElement} call continuing processing of the same element. By default, zero. - */ - public abstract Duration resumeDelay(); - - /** - * A lower bound provided by the {@link DoFn} on timestamps of the output that will be emitted - * by future {@link ProcessElement} calls continuing processing of the current element. - * - *

A runner should treat an absent value as equivalent to the timestamp of the input element. - */ - @Nullable - public abstract Instant getWatermark(); - - /** Builder method to set the value of {@link #resumeDelay()}. */ - public ProcessContinuation withResumeDelay(Duration resumeDelay) { - return new AutoValue_DoFn_ProcessContinuation( - shouldResume(), resumeDelay, getWatermark()); - } - - /** Builder method to set the value of {@link #getWatermark()}. */ - public ProcessContinuation withWatermark(Instant watermark) { - return new AutoValue_DoFn_ProcessContinuation( - shouldResume(), resumeDelay(), watermark); - } - } + /** Do not use. See https://issues.apache.org/jira/browse/BEAM-1904 */ + @Deprecated + public class ProcessContinuation {} /** * Returns an {@link Aggregator} with aggregation logic specified by the {@link CombineFn} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java index 01f0291e527b6..88f40356bcbe8 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java @@ -671,6 +671,11 @@ public PaneInfo pane() { return element.getPane(); } + @Override + public void updateWatermark(Instant watermark) { + throw new UnsupportedOperationException(); + } + @Override public PipelineOptions getPipelineOptions() { return context.getPipelineOptions(); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java index 6746d3a81e137..4b0cbf74cf65d 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java @@ -50,6 +50,7 @@ import net.bytebuddy.implementation.bytecode.assign.Assigner; import net.bytebuddy.implementation.bytecode.assign.Assigner.Typing; import net.bytebuddy.implementation.bytecode.assign.TypeCasting; +import net.bytebuddy.implementation.bytecode.constant.NullConstant; import net.bytebuddy.implementation.bytecode.constant.TextConstant; import net.bytebuddy.implementation.bytecode.member.FieldAccess; import net.bytebuddy.implementation.bytecode.member.MethodInvocation; @@ -625,17 +626,6 @@ public StackManipulation dispatch(TimerParameter p) { * {@link ProcessElement} method. */ private static final class ProcessElementDelegation extends DoFnMethodDelegation { - private static final MethodDescription PROCESS_CONTINUATION_STOP_METHOD; - - static { - try { - PROCESS_CONTINUATION_STOP_METHOD = - new MethodDescription.ForLoadedMethod(DoFn.ProcessContinuation.class.getMethod("stop")); - } catch (NoSuchMethodException e) { - throw new RuntimeException("Failed to locate ProcessContinuation.stop()"); - } - } - private final DoFnSignature.ProcessElementMethod signature; /** Implementation of {@link MethodDelegation} for the {@link ProcessElement} method. */ @@ -672,12 +662,8 @@ protected StackManipulation beforeDelegation(MethodDescription instrumentedMetho @Override protected StackManipulation afterDelegation(MethodDescription instrumentedMethod) { - if (TypeDescription.VOID.equals(targetMethod.getReturnType().asErasure())) { - return new StackManipulation.Compound( - MethodInvocation.invoke(PROCESS_CONTINUATION_STOP_METHOD), MethodReturn.REFERENCE); - } else { - return MethodReturn.of(targetMethod.getReturnType().asErasure()); - } + return new StackManipulation.Compound( + NullConstant.INSTANCE, MethodReturn.REFERENCE); } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java index 85831a7c45f95..cc06e70d070c0 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java @@ -53,8 +53,8 @@ public interface DoFnInvoker { * Invoke the {@link DoFn.ProcessElement} method on the bound {@link DoFn}. * * @param extra Factory for producing extra parameter objects (such as window), if necessary. - * @return The {@link DoFn.ProcessContinuation} returned by the underlying method, or {@link - * DoFn.ProcessContinuation#stop()} if it returns {@code void}. + * @return {@code null} - see JIRA + * tracking the complete removal of {@link DoFn.ProcessContinuation}. */ DoFn.ProcessContinuation invokeProcessElement(ArgumentProvider extra); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java index 007d8be6b00ff..1be741f2db39f 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java @@ -28,7 +28,6 @@ import javax.annotation.Nullable; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.DoFn.ProcessContinuation; import org.apache.beam.sdk.transforms.DoFn.StateId; import org.apache.beam.sdk.transforms.DoFn.TimerId; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.RestrictionTrackerParameter; @@ -397,21 +396,16 @@ public abstract static class ProcessElementMethod implements MethodWithExtraPara @Nullable public abstract TypeDescriptor windowT(); - /** Whether this {@link DoFn} returns a {@link ProcessContinuation} or void. */ - public abstract boolean hasReturnValue(); - static ProcessElementMethod create( Method targetMethod, List extraParameters, TypeDescriptor trackerT, - @Nullable TypeDescriptor windowT, - boolean hasReturnValue) { + @Nullable TypeDescriptor windowT) { return new AutoValue_DoFnSignature_ProcessElementMethod( targetMethod, Collections.unmodifiableList(extraParameters), trackerT, - windowT, - hasReturnValue); + windowT); } /** diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java index 006d012cbb000..80dbe1062cbbc 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java @@ -17,8 +17,6 @@ */ package org.apache.beam.sdk.transforms.reflect; -import static com.google.common.base.Preconditions.checkState; - import com.google.auto.value.AutoValue; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Predicates; @@ -428,8 +426,6 @@ private static DoFnSignature parseSignature(Class> fnClass) *

  • If the {@link DoFn} (or any of its supertypes) is annotated as {@link * DoFn.BoundedPerElement} or {@link DoFn.UnboundedPerElement}, use that. Only one of * these must be specified. - *
  • If {@link DoFn.ProcessElement} returns {@link DoFn.ProcessContinuation}, assume it is - * unbounded. Otherwise (if it returns {@code void}), assume it is bounded. *
  • If {@link DoFn.ProcessElement} returns {@code void}, but the {@link DoFn} is annotated * {@link DoFn.UnboundedPerElement}, this is an error. * @@ -455,10 +451,7 @@ private static PCollection.IsBounded inferBoundedness( } if (processElement.isSplittable()) { if (isBounded == null) { - isBounded = - processElement.hasReturnValue() - ? PCollection.IsBounded.UNBOUNDED - : PCollection.IsBounded.BOUNDED; + isBounded = PCollection.IsBounded.BOUNDED; } } else { errors.checkArgument( @@ -467,7 +460,6 @@ private static PCollection.IsBounded inferBoundedness( + ((isBounded == PCollection.IsBounded.BOUNDED) ? DoFn.BoundedPerElement.class.getSimpleName() : DoFn.UnboundedPerElement.class.getSimpleName())); - checkState(!processElement.hasReturnValue(), "Should have been inferred splittable"); isBounded = PCollection.IsBounded.BOUNDED; } return isBounded; @@ -691,10 +683,8 @@ static DoFnSignature.ProcessElementMethod analyzeProcessElementMethod( TypeDescriptor outputT, FnAnalysisContext fnContext) { errors.checkArgument( - void.class.equals(m.getReturnType()) - || DoFn.ProcessContinuation.class.equals(m.getReturnType()), - "Must return void or %s", - DoFn.ProcessContinuation.class.getSimpleName()); + void.class.equals(m.getReturnType()), + "Must return void"); MethodAnalysisContext methodContext = MethodAnalysisContext.create(); @@ -734,11 +724,7 @@ static DoFnSignature.ProcessElementMethod analyzeProcessElementMethod( } return DoFnSignature.ProcessElementMethod.create( - m, - methodContext.getExtraParameters(), - trackerT, - windowT, - DoFn.ProcessContinuation.class.equals(m.getReturnType())); + m, methodContext.getExtraParameters(), trackerT, windowT); } private static void checkParameterOneOf( diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/OffsetRangeTracker.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/OffsetRangeTracker.java index 87c7bfdb2b3c9..0271a0d1f3f1e 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/OffsetRangeTracker.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/OffsetRangeTracker.java @@ -19,6 +19,9 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; + +import org.apache.beam.sdk.transforms.DoFn; /** * A {@link RestrictionTracker} for claiming offsets in an {@link OffsetRange} in a monotonically @@ -27,6 +30,7 @@ public class OffsetRangeTracker implements RestrictionTracker { private OffsetRange range; private Long lastClaimedOffset = null; + private Long lastAttemptedOffset = null; public OffsetRangeTracker(OffsetRange range) { this.range = checkNotNull(range); @@ -59,12 +63,13 @@ public synchronized OffsetRange checkpoint() { */ public synchronized boolean tryClaim(long i) { checkArgument( - lastClaimedOffset == null || i > lastClaimedOffset, - "Trying to claim offset %s while last claimed was %s", + lastAttemptedOffset == null || i > lastAttemptedOffset, + "Trying to claim offset %s while last attempted was %s", i, - lastClaimedOffset); + lastAttemptedOffset); checkArgument( i >= range.getFrom(), "Trying to claim offset %s before start of the range %s", i, range); + lastAttemptedOffset = i; // No respective checkArgument for i < range.to() - it's ok to try claiming offsets beyond it. if (i >= range.getTo()) { return false; @@ -72,4 +77,26 @@ public synchronized boolean tryClaim(long i) { lastClaimedOffset = i; return true; } + + /** + * Marks that there are no more offsets to be claimed in the range. + * + *

    E.g., a {@link DoFn} reading a file and claiming the offset of each record in the file might + * call this if it hits EOF - even though the last attempted claim was before the end of the + * range, there are no more offsets to claim. + */ + public synchronized void markDone() { + lastAttemptedOffset = Long.MAX_VALUE; + } + + @Override + public synchronized void checkDone() throws IllegalStateException { + checkState( + lastAttemptedOffset >= range.getTo() - 1, + "Last attempted offset was %s in range %s, claiming work in [%s, %s) was not attempted", + lastAttemptedOffset, + range, + lastAttemptedOffset + 1, + range.getTo()); + } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/RestrictionTracker.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/RestrictionTracker.java index e9b718e269591..27ef68f4a980c 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/RestrictionTracker.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/splittabledofn/RestrictionTracker.java @@ -38,5 +38,13 @@ public interface RestrictionTracker { */ RestrictionT checkpoint(); + /** + * Called by the runner after {@link DoFn.ProcessElement} returns. + * + *

    Must throw an exception with an informative error message, if there is still any unclaimed + * work remaining in the restriction. + */ + void checkDone() throws IllegalStateException; + // TODO: Add the more general splitRemainderAfterFraction() and other methods. } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SplittableDoFnTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SplittableDoFnTest.java index 154a088347b0b..a122f673c5a56 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SplittableDoFnTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/SplittableDoFnTest.java @@ -18,8 +18,6 @@ package org.apache.beam.sdk.transforms; import static com.google.common.base.Preconditions.checkState; -import static org.apache.beam.sdk.transforms.DoFn.ProcessContinuation.resume; -import static org.apache.beam.sdk.transforms.DoFn.ProcessContinuation.stop; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; @@ -67,14 +65,10 @@ public class SplittableDoFnTest { static class PairStringWithIndexToLength extends DoFn> { @ProcessElement - public ProcessContinuation process(ProcessContext c, OffsetRangeTracker tracker) { + public void process(ProcessContext c, OffsetRangeTracker tracker) { for (long i = tracker.currentRestriction().getFrom(); tracker.tryClaim(i); ++i) { c.output(KV.of(c.element(), (int) i)); - if (i % 3 == 0) { - return resume(); - } } - return stop(); } @GetInitialRestriction @@ -196,19 +190,14 @@ private static int snapToNextBlock(int index, int[] blockStarts) { } @ProcessElement - public ProcessContinuation processElement(ProcessContext c, OffsetRangeTracker tracker) { + public void processElement(ProcessContext c, OffsetRangeTracker tracker) { int[] blockStarts = {-1, 0, 12, 123, 1234, 12345, 34567, MAX_INDEX}; int trueStart = snapToNextBlock((int) tracker.currentRestriction().getFrom(), blockStarts); - int trueEnd = snapToNextBlock((int) tracker.currentRestriction().getTo(), blockStarts); - for (int i = trueStart; i < trueEnd; ++i) { - if (!tracker.tryClaim(blockStarts[i])) { - return resume(); - } + for (int i = trueStart; tracker.tryClaim(blockStarts[i]); ++i) { for (int index = blockStarts[i]; index < blockStarts[i + 1]; ++index) { c.output(index); } } - return stop(); } @GetInitialRestriction diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java index 425f453639534..8b4df4c3e8901 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java @@ -39,7 +39,6 @@ import org.apache.beam.sdk.coders.CustomCoder; import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.DoFn.ProcessContinuation; import org.apache.beam.sdk.transforms.reflect.DoFnInvoker.FakeArgumentProvider; import org.apache.beam.sdk.transforms.reflect.testhelper.DoFnInvokersTestHelper; import org.apache.beam.sdk.transforms.splittabledofn.HasDefaultTracker; @@ -83,8 +82,8 @@ public void setUp() { when(mockArgumentProvider.processContext(Matchers.any())).thenReturn(mockProcessContext); } - private ProcessContinuation invokeProcessElement(DoFn fn) { - return DoFnInvokers.invokerFor(fn).invokeProcessElement(mockArgumentProvider); + private void invokeProcessElement(DoFn fn) { + DoFnInvokers.invokerFor(fn).invokeProcessElement(mockArgumentProvider); } private void invokeOnTimer(String timerId, DoFn fn) { @@ -113,7 +112,7 @@ class MockFn extends DoFn { public void processElement(ProcessContext c) throws Exception {} } MockFn mockFn = mock(MockFn.class); - assertEquals(ProcessContinuation.stop(), invokeProcessElement(mockFn)); + invokeProcessElement(mockFn); verify(mockFn).processElement(mockProcessContext); } @@ -134,7 +133,7 @@ public void processElement(DoFn.ProcessContext c) {} public void testDoFnWithProcessElementInterface() throws Exception { IdentityUsingInterfaceWithProcessElement fn = mock(IdentityUsingInterfaceWithProcessElement.class); - assertEquals(ProcessContinuation.stop(), invokeProcessElement(fn)); + invokeProcessElement(fn); verify(fn).processElement(mockProcessContext); } @@ -155,14 +154,14 @@ public void process(DoFn.ProcessContext c) { @Test public void testDoFnWithMethodInSuperclass() throws Exception { IdentityChildWithoutOverride fn = mock(IdentityChildWithoutOverride.class); - assertEquals(ProcessContinuation.stop(), invokeProcessElement(fn)); + invokeProcessElement(fn); verify(fn).process(mockProcessContext); } @Test public void testDoFnWithMethodInSubclass() throws Exception { IdentityChildWithOverride fn = mock(IdentityChildWithOverride.class); - assertEquals(ProcessContinuation.stop(), invokeProcessElement(fn)); + invokeProcessElement(fn); verify(fn).process(mockProcessContext); } @@ -173,7 +172,7 @@ class MockFn extends DoFn { public void processElement(ProcessContext c, IntervalWindow w) throws Exception {} } MockFn fn = mock(MockFn.class); - assertEquals(ProcessContinuation.stop(), invokeProcessElement(fn)); + invokeProcessElement(fn); verify(fn).processElement(mockProcessContext, mockWindow); } @@ -197,7 +196,7 @@ public void processElement(ProcessContext c, @StateId(stateId) ValueState { - @DoFn.ProcessElement - public ProcessContinuation processElement(ProcessContext c, SomeRestrictionTracker tracker) - throws Exception { - return null; - } - - @GetInitialRestriction - public SomeRestriction getInitialRestriction(String element) { - return null; - } - - @NewTracker - public SomeRestrictionTracker newTracker(SomeRestriction restriction) { - return null; - } - } - MockFn fn = mock(MockFn.class); - when(fn.processElement(mockProcessContext, null)).thenReturn(ProcessContinuation.resume()); - assertEquals(ProcessContinuation.resume(), invokeProcessElement(fn)); - } - @Test public void testDoFnWithStartBundleSetupTeardown() throws Exception { class MockFn extends DoFn { @@ -306,9 +281,7 @@ public SomeRestriction decode(InputStream inStream, Context context) { /** Public so Mockito can do "delegatesTo()" in the test below. */ public static class MockFn extends DoFn { @ProcessElement - public ProcessContinuation processElement(ProcessContext c, SomeRestrictionTracker tracker) { - return null; - } + public void processElement(ProcessContext c, SomeRestrictionTracker tracker) {} @GetInitialRestriction public SomeRestriction getInitialRestriction(String element) { @@ -360,7 +333,7 @@ public void splitRestriction( .splitRestriction( eq("blah"), same(restriction), Mockito.>any()); when(fn.newTracker(restriction)).thenReturn(tracker); - when(fn.processElement(mockProcessContext, tracker)).thenReturn(ProcessContinuation.resume()); + fn.processElement(mockProcessContext, tracker); assertEquals(coder, invoker.invokeGetRestrictionCoder(new CoderRegistry())); assertEquals(restriction, invoker.invokeGetInitialRestriction("blah")); @@ -376,8 +349,6 @@ public void output(SomeRestriction output) { }); assertEquals(Arrays.asList(part1, part2, part3), outputs); assertEquals(tracker, invoker.invokeNewTracker(restriction)); - assertEquals( - ProcessContinuation.resume(), invoker.invokeProcessElement( new FakeArgumentProvider() { @Override @@ -389,7 +360,7 @@ public DoFn.ProcessContext processContext(DoFn f public RestrictionTracker restrictionTracker() { return tracker; } - })); + }); } private static class RestrictionWithDefaultTracker @@ -410,6 +381,9 @@ public RestrictionWithDefaultTracker currentRestriction() { public RestrictionWithDefaultTracker checkpoint() { throw new UnsupportedOperationException(); } + + @Override + public void checkDone() throws IllegalStateException {} } private static class CoderForDefaultTracker extends CustomCoder { @@ -459,8 +433,7 @@ public void output(String output) { assertEquals("foo", output); } }); - assertEquals( - ProcessContinuation.stop(), invoker.invokeProcessElement(mockArgumentProvider)); + invoker.invokeProcessElement(mockArgumentProvider); assertThat( invoker.invokeNewTracker(new RestrictionWithDefaultTracker()), instanceOf(DefaultTracker.class)); @@ -550,14 +523,14 @@ public void processThis(ProcessContext c) {} @Test public void testLocalPrivateDoFnClass() throws Exception { PrivateDoFnClass fn = mock(PrivateDoFnClass.class); - assertEquals(ProcessContinuation.stop(), invokeProcessElement(fn)); + invokeProcessElement(fn); verify(fn).processThis(mockProcessContext); } @Test public void testStaticPackagePrivateDoFnClass() throws Exception { DoFn fn = mock(DoFnInvokersTestHelper.newStaticPackagePrivateDoFn().getClass()); - assertEquals(ProcessContinuation.stop(), invokeProcessElement(fn)); + invokeProcessElement(fn); DoFnInvokersTestHelper.verifyStaticPackagePrivateDoFn(fn, mockProcessContext); } @@ -565,28 +538,28 @@ public void testStaticPackagePrivateDoFnClass() throws Exception { public void testInnerPackagePrivateDoFnClass() throws Exception { DoFn fn = mock(new DoFnInvokersTestHelper().newInnerPackagePrivateDoFn().getClass()); - assertEquals(ProcessContinuation.stop(), invokeProcessElement(fn)); + invokeProcessElement(fn); DoFnInvokersTestHelper.verifyInnerPackagePrivateDoFn(fn, mockProcessContext); } @Test public void testStaticPrivateDoFnClass() throws Exception { DoFn fn = mock(DoFnInvokersTestHelper.newStaticPrivateDoFn().getClass()); - assertEquals(ProcessContinuation.stop(), invokeProcessElement(fn)); + invokeProcessElement(fn); DoFnInvokersTestHelper.verifyStaticPrivateDoFn(fn, mockProcessContext); } @Test public void testInnerPrivateDoFnClass() throws Exception { DoFn fn = mock(new DoFnInvokersTestHelper().newInnerPrivateDoFn().getClass()); - assertEquals(ProcessContinuation.stop(), invokeProcessElement(fn)); + invokeProcessElement(fn); DoFnInvokersTestHelper.verifyInnerPrivateDoFn(fn, mockProcessContext); } @Test public void testAnonymousInnerDoFn() throws Exception { DoFn fn = mock(new DoFnInvokersTestHelper().newInnerAnonymousDoFn().getClass()); - assertEquals(ProcessContinuation.stop(), invokeProcessElement(fn)); + invokeProcessElement(fn); DoFnInvokersTestHelper.verifyInnerAnonymousDoFn(fn, mockProcessContext); } @@ -594,7 +567,7 @@ public void testAnonymousInnerDoFn() throws Exception { public void testStaticAnonymousDoFnInOtherPackage() throws Exception { // Can't use mockito for this one - the anonymous class is final and can't be mocked. DoFn fn = DoFnInvokersTestHelper.newStaticAnonymousDoFn(); - assertEquals(ProcessContinuation.stop(), invokeProcessElement(fn)); + invokeProcessElement(fn); DoFnInvokersTestHelper.verifyStaticAnonymousDoFnInvoked(fn, mockProcessContext); } @@ -622,32 +595,6 @@ public DoFn.ProcessContext processContext(DoFn() { - @ProcessElement - public ProcessContinuation processElement( - @SuppressWarnings("unused") ProcessContext c, SomeRestrictionTracker tracker) { - throw new IllegalArgumentException("bogus"); - } - - @GetInitialRestriction - public SomeRestriction getInitialRestriction(Integer element) { - return null; - } - - @NewTracker - public SomeRestrictionTracker newTracker(SomeRestriction restriction) { - return null; - } - }) - .invokeProcessElement(new FakeArgumentProvider()); - } - @Test public void testStartBundleException() throws Exception { DoFnInvoker invoker = diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesProcessElementTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesProcessElementTest.java index 44ae5c4f2425a..d321f54d68bd7 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesProcessElementTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesProcessElementTest.java @@ -50,7 +50,7 @@ private void method(DoFn.ProcessContext c, Integer n) {} @Test public void testBadReturnType() throws Exception { thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Must return void or ProcessContinuation"); + thrown.expectMessage("Must return void"); analyzeProcessElementMethod( new AnonymousMethod() { diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesSplittableDoFnTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesSplittableDoFnTest.java index 052feb812b456..b937e84c5c961 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesSplittableDoFnTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesSplittableDoFnTest.java @@ -59,20 +59,6 @@ private abstract static class SomeRestrictionTracker private abstract static class SomeRestrictionCoder implements Coder {} - @Test - public void testReturnsProcessContinuation() throws Exception { - DoFnSignature.ProcessElementMethod signature = - analyzeProcessElementMethod( - new AnonymousMethod() { - private DoFn.ProcessContinuation method( - DoFn.ProcessContext context) { - return null; - } - }); - - assertTrue(signature.hasReturnValue()); - } - @Test public void testHasRestrictionTracker() throws Exception { DoFnSignature.ProcessElementMethod signature = @@ -157,54 +143,6 @@ public void process(ProcessContext context) {} .isBoundedPerElement()); } - private static class BaseFnWithContinuation extends DoFn { - @ProcessElement - public ProcessContinuation processElement( - ProcessContext context, SomeRestrictionTracker tracker) { - return null; - } - - @GetInitialRestriction - public SomeRestriction getInitialRestriction(Integer element) { - return null; - } - - @NewTracker - public SomeRestrictionTracker newTracker(SomeRestriction restriction) { - return null; - } - } - - @Test - public void testSplittableIsBoundedByDefault() throws Exception { - assertEquals( - PCollection.IsBounded.UNBOUNDED, - DoFnSignatures - .getSignature(BaseFnWithContinuation.class) - .isBoundedPerElement()); - } - - @Test - public void testSplittableRespectsBoundednessAnnotation() throws Exception { - @BoundedPerElement - class BoundedFnWithContinuation extends BaseFnWithContinuation {} - - assertEquals( - PCollection.IsBounded.BOUNDED, - DoFnSignatures - .getSignature(BoundedFnWithContinuation.class) - .isBoundedPerElement()); - - @UnboundedPerElement - class UnboundedFnWithContinuation extends BaseFnWithContinuation {} - - assertEquals( - PCollection.IsBounded.UNBOUNDED, - DoFnSignatures - .getSignature(UnboundedFnWithContinuation.class) - .isBoundedPerElement()); - } - @Test public void testUnsplittableButDeclaresBounded() throws Exception { @BoundedPerElement @@ -234,10 +172,8 @@ public void process(ProcessContext context) {} public void testSplittableWithAllFunctions() throws Exception { class GoodSplittableDoFn extends DoFn { @ProcessElement - public ProcessContinuation processElement( - ProcessContext context, SomeRestrictionTracker tracker) { - return null; - } + public void processElement( + ProcessContext context, SomeRestrictionTracker tracker) {} @GetInitialRestriction public SomeRestriction getInitialRestriction(Integer element) { @@ -262,7 +198,6 @@ public SomeRestrictionCoder getRestrictionCoder() { DoFnSignature signature = DoFnSignatures.getSignature(GoodSplittableDoFn.class); assertEquals(SomeRestrictionTracker.class, signature.processElement().trackerT().getRawType()); assertTrue(signature.processElement().isSplittable()); - assertTrue(signature.processElement().hasReturnValue()); assertEquals( SomeRestriction.class, signature.getInitialRestriction().restrictionT().getRawType()); assertEquals(SomeRestriction.class, signature.splitRestriction().restrictionT().getRawType()); @@ -279,9 +214,7 @@ public SomeRestrictionCoder getRestrictionCoder() { public void testSplittableWithAllFunctionsGeneric() throws Exception { class GoodGenericSplittableDoFn extends DoFn { @ProcessElement - public ProcessContinuation processElement(ProcessContext context, TrackerT tracker) { - return null; - } + public void processElement(ProcessContext context, TrackerT tracker) {} @GetInitialRestriction public RestrictionT getInitialRestriction(Integer element) { @@ -309,7 +242,6 @@ public CoderT getRestrictionCoder() { SomeRestriction, SomeRestrictionTracker, SomeRestrictionCoder>() {}.getClass()); assertEquals(SomeRestrictionTracker.class, signature.processElement().trackerT().getRawType()); assertTrue(signature.processElement().isSplittable()); - assertTrue(signature.processElement().hasReturnValue()); assertEquals( SomeRestriction.class, signature.getInitialRestriction().restrictionT().getRawType()); assertEquals(SomeRestriction.class, signature.splitRestriction().restrictionT().getRawType()); diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/splittabledofn/OffsetRangeTrackerTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/splittabledofn/OffsetRangeTrackerTest.java index c8a530c655f09..831894ca96929 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/splittabledofn/OffsetRangeTrackerTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/splittabledofn/OffsetRangeTrackerTest.java @@ -95,7 +95,7 @@ public void testCheckpointAfterFailedClaim() throws Exception { @Test public void testNonMonotonicClaim() throws Exception { - expected.expectMessage("Trying to claim offset 103 while last claimed was 110"); + expected.expectMessage("Trying to claim offset 103 while last attempted was 110"); OffsetRangeTracker tracker = new OffsetRangeTracker(new OffsetRange(100, 200)); assertTrue(tracker.tryClaim(105)); assertTrue(tracker.tryClaim(110)); @@ -108,4 +108,51 @@ public void testClaimBeforeStartOfRange() throws Exception { OffsetRangeTracker tracker = new OffsetRangeTracker(new OffsetRange(100, 200)); tracker.tryClaim(90); } + + @Test + public void testCheckDoneAfterTryClaimPastEndOfRange() { + OffsetRangeTracker tracker = new OffsetRangeTracker(new OffsetRange(100, 200)); + assertTrue(tracker.tryClaim(150)); + assertTrue(tracker.tryClaim(175)); + assertFalse(tracker.tryClaim(220)); + tracker.checkDone(); + } + + @Test + public void testCheckDoneAfterTryClaimAtEndOfRange() { + OffsetRangeTracker tracker = new OffsetRangeTracker(new OffsetRange(100, 200)); + assertTrue(tracker.tryClaim(150)); + assertTrue(tracker.tryClaim(175)); + assertFalse(tracker.tryClaim(200)); + tracker.checkDone(); + } + + @Test + public void testCheckDoneAfterTryClaimRightBeforeEndOfRange() { + OffsetRangeTracker tracker = new OffsetRangeTracker(new OffsetRange(100, 200)); + assertTrue(tracker.tryClaim(150)); + assertTrue(tracker.tryClaim(175)); + assertTrue(tracker.tryClaim(199)); + tracker.checkDone(); + } + + @Test + public void testCheckDoneWhenNotDone() { + OffsetRangeTracker tracker = new OffsetRangeTracker(new OffsetRange(100, 200)); + assertTrue(tracker.tryClaim(150)); + assertTrue(tracker.tryClaim(175)); + expected.expectMessage( + "Last attempted offset was 175 in range [100, 200), " + + "claiming work in [176, 200) was not attempted"); + tracker.checkDone(); + } + + @Test + public void testCheckDoneWhenExplicitlyMarkedDone() { + OffsetRangeTracker tracker = new OffsetRangeTracker(new OffsetRange(100, 200)); + assertTrue(tracker.tryClaim(150)); + assertTrue(tracker.tryClaim(175)); + tracker.markDone(); + tracker.checkDone(); + } } From 29c280211c2431f29c5552c35bd3435c65e4975b Mon Sep 17 00:00:00 2001 From: Eugene Kirpichov Date: Fri, 7 Apr 2017 14:00:05 -0700 Subject: [PATCH 10/15] Adds tests for the watermark hold (previously untested) --- .../runners/core/SplittableParDoTest.java | 56 ++++++++++++++++++- 1 file changed, 54 insertions(+), 2 deletions(-) diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/SplittableParDoTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/SplittableParDoTest.java index f8d60959beb2e..d30111326363b 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/SplittableParDoTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/SplittableParDoTest.java @@ -19,6 +19,7 @@ import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.hasItem; +import static org.hamcrest.Matchers.hasItems; import static org.hamcrest.Matchers.not; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -36,6 +37,7 @@ import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.BigEndianIntegerCoder; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.InstantCoder; import org.apache.beam.sdk.coders.SerializableCoder; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; @@ -208,7 +210,7 @@ private static class ProcessFnTester< private Instant currentProcessingTime; private InMemoryTimerInternals timerInternals; - private InMemoryStateInternals stateInternals; + private TestInMemoryStateInternals stateInternals; ProcessFnTester( Instant currentProcessingTime, @@ -223,7 +225,7 @@ private static class ProcessFnTester< fn, inputCoder, restrictionCoder, IntervalWindow.getCoder()); this.tester = DoFnTester.of(processFn); this.timerInternals = new InMemoryTimerInternals(); - this.stateInternals = InMemoryStateInternals.forKey("dummy"); + this.stateInternals = new TestInMemoryStateInternals<>("dummy"); processFn.setStateInternalsFactory( new StateInternalsFactory() { @Override @@ -335,6 +337,9 @@ List takeOutputElements() { return tester.takeOutputElements(); } + public Instant getWatermarkHold() { + return stateInternals.earliestWatermarkHold(); + } } private static class OutputWindowedValueToDoFnTester @@ -425,6 +430,53 @@ public void testTrivialProcessFnPropagatesOutputsWindowsAndTimestamp() throws Ex } } + private static class WatermarkUpdateFn extends DoFn { + @ProcessElement + public void process(ProcessContext c, OffsetRangeTracker tracker) { + for (long i = tracker.currentRestriction().getFrom(); tracker.tryClaim(i); ++i) { + c.updateWatermark(c.element().plus(Duration.standardSeconds(i))); + c.output(String.valueOf(i)); + } + } + + @GetInitialRestriction + public OffsetRange getInitialRestriction(Instant elem) { + throw new IllegalStateException("Expected to be supplied explicitly in this test"); + } + + @NewTracker + public OffsetRangeTracker newTracker(OffsetRange range) { + return new OffsetRangeTracker(range); + } + } + + @Test + public void testUpdatesWatermark() throws Exception { + DoFn fn = new WatermarkUpdateFn(); + Instant base = Instant.now(); + + ProcessFnTester tester = + new ProcessFnTester<>( + base, + fn, + InstantCoder.of(), + SerializableCoder.of(OffsetRange.class), + 3, + MAX_BUNDLE_DURATION); + + tester.startElement(base, new OffsetRange(0, 8)); + assertThat(tester.takeOutputElements(), hasItems("0", "1", "2")); + assertEquals(base.plus(Duration.standardSeconds(2)), tester.getWatermarkHold()); + + assertTrue(tester.advanceProcessingTimeBy(Duration.standardSeconds(1))); + assertThat(tester.takeOutputElements(), hasItems("3", "4", "5")); + assertEquals(base.plus(Duration.standardSeconds(5)), tester.getWatermarkHold()); + + assertTrue(tester.advanceProcessingTimeBy(Duration.standardSeconds(1))); + assertThat(tester.takeOutputElements(), hasItems("6", "7")); + assertEquals(null, tester.getWatermarkHold()); + } + /** * A splittable {@link DoFn} that generates the sequence [init, init + total). */ From 8484ef9168940cf929648fd53982acc6740c5107 Mon Sep 17 00:00:00 2001 From: Thomas Weise Date: Tue, 4 Apr 2017 00:33:35 -0700 Subject: [PATCH 11/15] BEAM-1887 Switch Apex ParDo to new DoFn. --- .../apache/beam/runners/apex/ApexRunner.java | 2 - .../translation/GroupByKeyTranslator.java | 4 +- .../apex/translation/ParDoTranslator.java | 15 +- .../apex/translation/TranslationContext.java | 16 +- .../translation/WindowAssignTranslator.java | 58 ++---- .../operators/ApexGroupByKeyOperator.java | 21 +- .../operators/ApexParDoOperator.java | 187 ++++++++++++++---- .../operators/ApexProcessFnOperator.java | 184 +++++++++++++++++ .../translation/utils/ApexStateInternals.java | 73 +++++-- .../utils/StateInternalsProxy.java | 67 +++++++ .../ApexGroupByKeyOperatorTest.java | 2 +- .../apex/translation/ParDoTranslatorTest.java | 2 +- .../utils/ApexStateInternalsTest.java | 25 ++- .../SingletonKeyedWorkItemCoder.java | 2 +- 14 files changed, 512 insertions(+), 146 deletions(-) create mode 100644 runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexProcessFnOperator.java create mode 100644 runners/apex/src/main/java/org/apache/beam/runners/apex/translation/utils/StateInternalsProxy.java diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexRunner.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexRunner.java index d23fc14030a25..1c99f8dacd86f 100644 --- a/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexRunner.java +++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexRunner.java @@ -67,8 +67,6 @@ * A {@link PipelineRunner} that translates the * pipeline to an Apex DAG and executes it on an Apex cluster. * - *

    Currently execution is always in embedded mode, - * launch on Hadoop cluster will be added in subsequent iteration. */ @SuppressWarnings({"rawtypes", "unchecked"}) public class ApexRunner extends PipelineRunner { diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/GroupByKeyTranslator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/GroupByKeyTranslator.java index b46e3eb3ea16d..2e0bae70186a4 100644 --- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/GroupByKeyTranslator.java +++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/GroupByKeyTranslator.java @@ -31,9 +31,9 @@ class GroupByKeyTranslator implements TransformTranslator @Override public void translate(GroupByKey transform, TranslationContext context) { - PCollection> input = (PCollection>) context.getInput(); + PCollection> input = context.getInput(); ApexGroupByKeyOperator group = new ApexGroupByKeyOperator<>(context.getPipelineOptions(), - input, context.stateInternalsFactory() + input, context.getStateBackend() ); context.addOperator(group, group.output); context.addStream(input, group.input); diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoTranslator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoTranslator.java index 75722c793ee3e..fa9d21d4406a4 100644 --- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoTranslator.java +++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoTranslator.java @@ -82,23 +82,22 @@ public void translate(ParDo.MultiOutput transform, TranslationC } List outputs = context.getOutputs(); - PCollection input = (PCollection) context.getInput(); + PCollection input = context.getInput(); List> sideInputs = transform.getSideInputs(); Coder inputCoder = input.getCoder(); WindowedValueCoder wvInputCoder = FullWindowedValueCoder.of( inputCoder, input.getWindowingStrategy().getWindowFn().windowCoder()); - ApexParDoOperator operator = - new ApexParDoOperator<>( + ApexParDoOperator operator = new ApexParDoOperator<>( context.getPipelineOptions(), doFn, transform.getMainOutputTag(), transform.getSideOutputTags().getAll(), - ((PCollection) context.getInput()).getWindowingStrategy(), + input.getWindowingStrategy(), sideInputs, wvInputCoder, - context.stateInternalsFactory()); + context.getStateBackend()); Map, OutputPort> ports = Maps.newHashMapWithExpectedSize(outputs.size()); for (TaggedPValue output : outputs) { @@ -126,15 +125,15 @@ public void translate(ParDo.MultiOutput transform, TranslationC context.addOperator(operator, ports); context.addStream(context.getInput(), operator.input); if (!sideInputs.isEmpty()) { - addSideInputs(operator, sideInputs, context); + addSideInputs(operator.sideInput1, sideInputs, context); } } static void addSideInputs( - ApexParDoOperator operator, + Operator.InputPort sideInputPort, List> sideInputs, TranslationContext context) { - Operator.InputPort[] sideInputPorts = {operator.sideInput1}; + Operator.InputPort[] sideInputPorts = {sideInputPort}; if (sideInputs.size() > sideInputPorts.length) { PCollection unionCollection = unionSideInputs(sideInputs, context); context.addStream(unionCollection, sideInputPorts[0]); diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/TranslationContext.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/TranslationContext.java index fc49fc7b80822..81507ef02fcb5 100644 --- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/TranslationContext.java +++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/TranslationContext.java @@ -31,9 +31,9 @@ import java.util.Map; import org.apache.beam.runners.apex.ApexPipelineOptions; import org.apache.beam.runners.apex.translation.utils.ApexStateInternals; +import org.apache.beam.runners.apex.translation.utils.ApexStateInternals.ApexStateBackend; import org.apache.beam.runners.apex.translation.utils.ApexStreamTuple; import org.apache.beam.runners.apex.translation.utils.CoderAdapterStreamCodec; -import org.apache.beam.runners.core.StateInternalsFactory; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.runners.TransformHierarchy; import org.apache.beam.sdk.transforms.AppliedPTransform; @@ -89,16 +89,16 @@ public List getInputs() { return getCurrentTransform().getInputs(); } - public PValue getInput() { - return Iterables.getOnlyElement(getCurrentTransform().getInputs()).getValue(); + public InputT getInput() { + return (InputT) Iterables.getOnlyElement(getCurrentTransform().getInputs()).getValue(); } public List getOutputs() { return getCurrentTransform().getOutputs(); } - public PValue getOutput() { - return Iterables.getOnlyElement(getCurrentTransform().getOutputs()).getValue(); + public OutputT getOutput() { + return (OutputT) Iterables.getOnlyElement(getCurrentTransform().getOutputs()).getValue(); } private AppliedPTransform getCurrentTransform() { @@ -192,10 +192,10 @@ public void populateDAG(DAG dag) { } /** - * Return the {@link StateInternalsFactory} for the pipeline translation. + * Return the state backend for the pipeline translation. * @return */ - public StateInternalsFactory stateInternalsFactory() { - return new ApexStateInternals.ApexStateInternalsFactory(); + public ApexStateBackend getStateBackend() { + return new ApexStateInternals.ApexStateBackend(); } } diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/WindowAssignTranslator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/WindowAssignTranslator.java index 6106f75c32e40..f34f9eecd9341 100644 --- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/WindowAssignTranslator.java +++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/WindowAssignTranslator.java @@ -18,61 +18,35 @@ package org.apache.beam.runners.apex.translation; -import java.util.Collections; -import org.apache.beam.runners.apex.ApexPipelineOptions; -import org.apache.beam.runners.apex.translation.operators.ApexParDoOperator; -import org.apache.beam.runners.core.AssignWindowsDoFn; -import org.apache.beam.runners.core.DoFnAdapters; -import org.apache.beam.runners.core.OldDoFn; -import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.runners.apex.translation.operators.ApexProcessFnOperator; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.Window; -import org.apache.beam.sdk.util.WindowedValue; -import org.apache.beam.sdk.util.WindowingStrategy; +import org.apache.beam.sdk.transforms.windowing.WindowFn; import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.PCollectionView; -import org.apache.beam.sdk.values.TupleTag; -import org.apache.beam.sdk.values.TupleTagList; /** - * {@link Window} is translated to {link ApexParDoOperator} that wraps an {@link - * AssignWindowsDoFn}. + * {@link Window} is translated to {@link ApexProcessFnOperator#assignWindows}. */ class WindowAssignTranslator implements TransformTranslator> { private static final long serialVersionUID = 1L; @Override public void translate(Window.Assign transform, TranslationContext context) { - PCollection output = (PCollection) context.getOutput(); - PCollection input = (PCollection) context.getInput(); - @SuppressWarnings("unchecked") - WindowingStrategy windowingStrategy = - (WindowingStrategy) output.getWindowingStrategy(); + PCollection output = context.getOutput(); + PCollection input = context.getInput(); - OldDoFn fn = - (transform.getWindowFn() == null) - ? DoFnAdapters.toOldDoFn(new IdentityFn()) - : new AssignWindowsDoFn<>(transform.getWindowFn()); + if (transform.getWindowFn() == null) { + // no work to do + context.addAlias(output, input); + } else { + @SuppressWarnings("unchecked") + WindowFn windowFn = (WindowFn) transform.getWindowFn(); + ApexProcessFnOperator operator = ApexProcessFnOperator.assignWindows(windowFn, + context.getPipelineOptions()); + context.addOperator(operator, operator.outputPort); + context.addStream(context.getInput(), operator.inputPort); + } - ApexParDoOperator operator = - new ApexParDoOperator( - context.getPipelineOptions().as(ApexPipelineOptions.class), - fn, - new TupleTag(), - TupleTagList.empty().getAll(), - windowingStrategy, - Collections.>emptyList(), - WindowedValue.getFullCoder( - input.getCoder(), windowingStrategy.getWindowFn().windowCoder()), - context.stateInternalsFactory()); - context.addOperator(operator, operator.output); - context.addStream(context.getInput(), operator.input); } - private static class IdentityFn extends DoFn { - @ProcessElement - public void processElement(ProcessContext c) { - c.output(c.element()); - } - } } diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexGroupByKeyOperator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexGroupByKeyOperator.java index 3508c3ebed05b..4551c9c639501 100644 --- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexGroupByKeyOperator.java +++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexGroupByKeyOperator.java @@ -39,6 +39,7 @@ import java.util.Map; import java.util.Set; import org.apache.beam.runners.apex.ApexPipelineOptions; +import org.apache.beam.runners.apex.translation.utils.ApexStateInternals.ApexStateBackend; import org.apache.beam.runners.apex.translation.utils.ApexStreamTuple; import org.apache.beam.runners.apex.translation.utils.SerializablePipelineOptions; import org.apache.beam.runners.core.GroupAlsoByWindowViaWindowSetDoFn; @@ -95,7 +96,6 @@ public class ApexGroupByKeyOperator implements Operator { private final SerializablePipelineOptions serializedOptions; @Bind(JavaSerializer.class) private final StateInternalsFactory stateInternalsFactory; - private Map> perKeyStateInternals = new HashMap<>(); private Map> activeTimers = new HashMap<>(); private transient ProcessContext context; @@ -135,13 +135,13 @@ public void process(ApexStreamTuple>> t) { @SuppressWarnings("unchecked") public ApexGroupByKeyOperator(ApexPipelineOptions pipelineOptions, PCollection> input, - StateInternalsFactory stateInternalsFactory) { + ApexStateBackend stateBackend) { checkNotNull(pipelineOptions); this.serializedOptions = new SerializablePipelineOptions(pipelineOptions); this.windowingStrategy = (WindowingStrategy) input.getWindowingStrategy(); this.keyCoder = ((KvCoder) input.getCoder()).getKeyCoder(); this.valueCoder = ((KvCoder) input.getCoder()).getValueCoder(); - this.stateInternalsFactory = stateInternalsFactory; + this.stateInternalsFactory = stateBackend.newStateInternalsFactory(keyCoder); } @SuppressWarnings("unused") // for Kryo @@ -222,18 +222,7 @@ private void processElement(WindowedValue> windowedValue) throws Except } private StateInternals getStateInternalsForKey(K key) { - final ByteBuffer keyBytes; - try { - keyBytes = ByteBuffer.wrap(CoderUtils.encodeToByteArray(keyCoder, key)); - } catch (CoderException e) { - throw new RuntimeException(e); - } - StateInternals stateInternals = perKeyStateInternals.get(keyBytes); - if (stateInternals == null) { - stateInternals = stateInternalsFactory.stateInternalsForKey(key); - perKeyStateInternals.put(keyBytes, stateInternals); - } - return stateInternals; + return stateInternalsFactory.stateInternalsForKey(key); } private void registerActiveTimer(K key, TimerInternals.TimerData timer) { @@ -423,7 +412,7 @@ public Aggregator createAggregato * An implementation of Beam's {@link TimerInternals}. * */ - public class ApexTimerInternals implements TimerInternals { + private class ApexTimerInternals implements TimerInternals { @Deprecated @Override diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexParDoOperator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexParDoOperator.java index 7f2512a455d23..1fc91c80523c9 100644 --- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexParDoOperator.java +++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexParDoOperator.java @@ -25,41 +25,52 @@ import com.datatorrent.common.util.BaseOperator; import com.esotericsoftware.kryo.serializers.FieldSerializer.Bind; import com.esotericsoftware.kryo.serializers.JavaSerializer; -import com.google.common.base.Throwables; import com.google.common.collect.Iterables; import com.google.common.collect.Maps; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; import org.apache.beam.runners.apex.ApexPipelineOptions; import org.apache.beam.runners.apex.ApexRunner; +import org.apache.beam.runners.apex.translation.utils.ApexStateInternals.ApexStateBackend; import org.apache.beam.runners.apex.translation.utils.ApexStreamTuple; import org.apache.beam.runners.apex.translation.utils.NoOpStepContext; import org.apache.beam.runners.apex.translation.utils.SerializablePipelineOptions; +import org.apache.beam.runners.apex.translation.utils.StateInternalsProxy; import org.apache.beam.runners.apex.translation.utils.ValueAndCoderKryoSerializable; import org.apache.beam.runners.core.AggregatorFactory; -import org.apache.beam.runners.core.DoFnAdapters; import org.apache.beam.runners.core.DoFnRunner; import org.apache.beam.runners.core.DoFnRunners; import org.apache.beam.runners.core.DoFnRunners.OutputManager; import org.apache.beam.runners.core.ExecutionContext; -import org.apache.beam.runners.core.OldDoFn; +import org.apache.beam.runners.core.InMemoryTimerInternals; +import org.apache.beam.runners.core.KeyedWorkItem; import org.apache.beam.runners.core.PushbackSideInputDoFnRunner; import org.apache.beam.runners.core.SideInputHandler; import org.apache.beam.runners.core.StateInternals; -import org.apache.beam.runners.core.StateInternalsFactory; +import org.apache.beam.runners.core.StateNamespace; +import org.apache.beam.runners.core.StatefulDoFnRunner; +import org.apache.beam.runners.core.TimerInternals; +import org.apache.beam.runners.core.TimerInternalsFactory; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.ListCoder; +import org.apache.beam.sdk.coders.VoidCoder; import org.apache.beam.sdk.transforms.Aggregator; import org.apache.beam.sdk.transforms.Combine.CombineFn; import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.reflect.DoFnInvoker; +import org.apache.beam.sdk.transforms.reflect.DoFnInvokers; import org.apache.beam.sdk.util.NullSideInputReader; import org.apache.beam.sdk.util.SideInputReader; +import org.apache.beam.sdk.util.TimeDomain; import org.apache.beam.sdk.util.UserCodeException; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.util.WindowingStrategy; +import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TupleTag; +import org.joda.time.Instant; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -73,7 +84,7 @@ public class ApexParDoOperator extends BaseOperator implements @Bind(JavaSerializer.class) private final SerializablePipelineOptions pipelineOptions; @Bind(JavaSerializer.class) - private final OldDoFn doFn; + private final DoFn doFn; @Bind(JavaSerializer.class) private final TupleTag mainOutputTag; @Bind(JavaSerializer.class) @@ -83,6 +94,12 @@ public class ApexParDoOperator extends BaseOperator implements @Bind(JavaSerializer.class) private final List> sideInputs; + private StateInternalsProxy currentKeyStateInternals; + // TODO: if the operator gets restored to checkpointed state due to a failure, + // the timer state is lost. + private final transient CurrentKeyTimerInternals currentKeyTimerInternals = + new CurrentKeyTimerInternals<>(); + private final StateInternals sideInputStateInternals; private final ValueAndCoderKryoSerializable>> pushedBack; private LongMin pushedBackWatermark = new LongMin(); @@ -93,17 +110,17 @@ public class ApexParDoOperator extends BaseOperator implements private transient SideInputHandler sideInputHandler; private transient Map, DefaultOutputPort>> sideOutputPortMapping = Maps.newHashMapWithExpectedSize(5); + private transient DoFnInvoker doFnInvoker; - @Deprecated public ApexParDoOperator( ApexPipelineOptions pipelineOptions, - OldDoFn doFn, + DoFn doFn, TupleTag mainOutputTag, List> sideOutputTags, WindowingStrategy windowingStrategy, List> sideInputs, Coder> inputCoder, - StateInternalsFactory stateInternalsFactory + ApexStateBackend stateBackend ) { this.pipelineOptions = new SerializablePipelineOptions(pipelineOptions); this.doFn = doFn; @@ -111,7 +128,8 @@ public ApexParDoOperator( this.sideOutputTags = sideOutputTags; this.windowingStrategy = windowingStrategy; this.sideInputs = sideInputs; - this.sideInputStateInternals = stateInternalsFactory.stateInternalsForKey(null); + this.sideInputStateInternals = new StateInternalsProxy<>( + stateBackend.newStateInternalsFactory(VoidCoder.of())); if (sideOutputTags.size() > sideOutputPorts.length) { String msg = String.format("Too many side outputs (currently only supporting %s).", @@ -125,27 +143,6 @@ public ApexParDoOperator( } - public ApexParDoOperator( - ApexPipelineOptions pipelineOptions, - DoFn doFn, - TupleTag mainOutputTag, - List> sideOutputTags, - WindowingStrategy windowingStrategy, - List> sideInputs, - Coder> inputCoder, - StateInternalsFactory stateInternalsFactory - ) { - this( - pipelineOptions, - DoFnAdapters.toOldDoFn(doFn), - mainOutputTag, - sideOutputTags, - windowingStrategy, - sideInputs, - inputCoder, - stateInternalsFactory); - } - @SuppressWarnings("unused") // for Kryo private ApexParDoOperator() { this.pipelineOptions = null; @@ -255,6 +252,17 @@ public void output(TupleTag tag, WindowedValue tuple) { private Iterable> processElementInReadyWindows(WindowedValue elem) { try { pushbackDoFnRunner.startBundle(); + if (currentKeyStateInternals != null) { + InputT value = elem.getValue(); + Object key; + if (value instanceof KeyedWorkItem) { + key = ((KeyedWorkItem) value).key(); + } else { + key = ((KV) value).getKey(); + } + ((StateInternalsProxy) currentKeyStateInternals).setKey(key); + currentKeyTimerInternals.currentKey = key; + } Iterable> pushedBack = pushbackDoFnRunner .processElementInReadyWindows(elem); pushbackDoFnRunner.finishBundle(); @@ -305,6 +313,19 @@ public void setup(OperatorContext context) { sideOutputPortMapping.put(sideOutputTags.get(i), port); } + NoOpStepContext stepContext = new NoOpStepContext() { + + @Override + public StateInternals stateInternals() { + return currentKeyStateInternals; + } + + @Override + public TimerInternals timerInternals() { + return currentKeyTimerInternals; + } + + }; DoFnRunner doFnRunner = DoFnRunners.simpleRunner( pipelineOptions.get(), doFn, @@ -312,21 +333,47 @@ public void setup(OperatorContext context) { this, mainOutputTag, sideOutputTags, - new NoOpStepContext(), + stepContext, new NoOpAggregatorFactory(), windowingStrategy ); + doFnInvoker = DoFnInvokers.invokerFor(doFn); + doFnInvoker.invokeSetup(); + + if (this.currentKeyStateInternals != null) { + + StatefulDoFnRunner.CleanupTimer cleanupTimer = + new StatefulDoFnRunner.TimeInternalsCleanupTimer( + stepContext.timerInternals(), windowingStrategy); + + @SuppressWarnings({"rawtypes"}) + Coder windowCoder = windowingStrategy.getWindowFn().windowCoder(); + + @SuppressWarnings({"unchecked"}) + StatefulDoFnRunner.StateCleaner stateCleaner = + new StatefulDoFnRunner.StateInternalsStateCleaner<>( + doFn, stepContext.stateInternals(), windowCoder); + + doFnRunner = DoFnRunners.defaultStatefulDoFnRunner( + doFn, + doFnRunner, + stepContext, + new NoOpAggregatorFactory(), + windowingStrategy, + cleanupTimer, + stateCleaner); + } + pushbackDoFnRunner = PushbackSideInputDoFnRunner.create(doFnRunner, sideInputs, sideInputHandler); - try { - doFn.setup(); - } catch (Exception e) { - Throwables.propagateIfPossible(e); - throw new RuntimeException(e); - } + } + @Override + public void teardown() { + doFnInvoker.invokeTeardown(); + super.teardown(); } @Override @@ -393,4 +440,70 @@ public void clear() { } + private class CurrentKeyTimerInternals implements TimerInternals { + + private TimerInternalsFactory factory = new TimerInternalsFactory() { + @Override + public TimerInternals timerInternalsForKey(K key) { + InMemoryTimerInternals timerInternals = perKeyTimerInternals.get(key); + if (timerInternals == null) { + perKeyTimerInternals.put(key, timerInternals = new InMemoryTimerInternals()); + } + return timerInternals; + } + }; + + // TODO: durable state store + final Map perKeyTimerInternals = new HashMap<>(); + private K currentKey; + + @Override + public void setTimer(StateNamespace namespace, String timerId, Instant target, + TimeDomain timeDomain) { + factory.timerInternalsForKey(currentKey).setTimer( + namespace, timerId, target, timeDomain); + } + + @Override + public void setTimer(TimerData timerData) { + throw new UnsupportedOperationException(); + } + + @Override + public void deleteTimer(StateNamespace namespace, String timerId, TimeDomain timeDomain) { + throw new UnsupportedOperationException(); + } + + @Override + public void deleteTimer(StateNamespace namespace, String timerId) { + throw new UnsupportedOperationException(); + } + + @Override + public void deleteTimer(TimerData timerKey) { + throw new UnsupportedOperationException(); + } + + @Override + public Instant currentProcessingTime() { + throw new UnsupportedOperationException(); + } + + @Override + public Instant currentSynchronizedProcessingTime() { + throw new UnsupportedOperationException(); + } + + @Override + public Instant currentInputWatermarkTime() { + return new Instant(currentInputWatermark); + } + + @Override + public Instant currentOutputWatermarkTime() { + throw new UnsupportedOperationException(); + } + + } + } diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexProcessFnOperator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexProcessFnOperator.java new file mode 100644 index 0000000000000..835c9e0cdec8f --- /dev/null +++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexProcessFnOperator.java @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.apex.translation.operators; + +import com.datatorrent.api.DefaultInputPort; +import com.datatorrent.api.DefaultOutputPort; +import com.datatorrent.api.annotation.OutputPortFieldAnnotation; +import com.datatorrent.common.util.BaseOperator; +import com.esotericsoftware.kryo.serializers.FieldSerializer.Bind; +import com.esotericsoftware.kryo.serializers.JavaSerializer; +import com.google.common.base.Throwables; +import com.google.common.collect.Iterables; +import java.io.Serializable; +import java.util.Collection; +import java.util.Collections; +import org.apache.beam.runners.apex.ApexPipelineOptions; +import org.apache.beam.runners.apex.translation.utils.ApexStreamTuple; +import org.apache.beam.runners.core.KeyedWorkItem; +import org.apache.beam.runners.core.KeyedWorkItems; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.WindowFn; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.KV; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Apex operator for simple native map operations. + */ +public class ApexProcessFnOperator extends BaseOperator { + + private static final Logger LOG = LoggerFactory.getLogger(ApexProcessFnOperator.class); + private boolean traceTuples = false; + @Bind(JavaSerializer.class) + private final ApexOperatorFn fn; + + public ApexProcessFnOperator(ApexOperatorFn fn, boolean traceTuples) { + super(); + this.traceTuples = traceTuples; + this.fn = fn; + } + + @SuppressWarnings("unused") + private ApexProcessFnOperator() { + // for Kryo + fn = null; + } + + private final transient OutputEmitter>> outputEmitter = + new OutputEmitter>>() { + @Override + public void emit(ApexStreamTuple> tuple) { + if (traceTuples) { + LOG.debug("\nemitting {}\n", tuple); + } + outputPort.emit(tuple); + } + }; + + /** + * Something that emits results. + */ + public interface OutputEmitter { + void emit(T tuple); + }; + + /** + * The processing logic for this operator. + */ + public interface ApexOperatorFn extends Serializable { + void process(ApexStreamTuple> input, + OutputEmitter>> outputEmitter) throws Exception; + } + + /** + * Convert {@link KV} into {@link KeyedWorkItem}s. + */ + public static class ToKeyedWorkItems implements ApexOperatorFn> { + @Override + public final void process(ApexStreamTuple>> tuple, + OutputEmitter>> outputEmitter) { + + if (tuple instanceof ApexStreamTuple.WatermarkTuple) { + outputEmitter.emit(tuple); + } else { + for (WindowedValue> in : tuple.getValue().explodeWindows()) { + KeyedWorkItem kwi = KeyedWorkItems.elementsWorkItem(in.getValue().getKey(), + Collections.singletonList(in.withValue(in.getValue().getValue()))); + outputEmitter.emit(ApexStreamTuple.DataTuple.of(in.withValue(kwi))); + } + } + } + } + + public static ApexProcessFnOperator assignWindows( + WindowFn windowFn, ApexPipelineOptions options) { + ApexOperatorFn fn = new AssignWindows(windowFn); + return new ApexProcessFnOperator(fn, options.isTupleTracingEnabled()); + } + + /** + * Function for implementing {@link org.apache.beam.sdk.transforms.windowing.Window.Assign}. + */ + private static class AssignWindows implements ApexOperatorFn { + private final WindowFn windowFn; + + private AssignWindows(WindowFn windowFn) { + this.windowFn = windowFn; + } + + @Override + public final void process(ApexStreamTuple> tuple, + OutputEmitter>> outputEmitter) throws Exception { + if (tuple instanceof ApexStreamTuple.WatermarkTuple) { + outputEmitter.emit(tuple); + } else { + final WindowedValue input = tuple.getValue(); + Collection windows = + (windowFn).assignWindows( + (windowFn).new AssignContext() { + @Override + public T element() { + return input.getValue(); + } + + @Override + public Instant timestamp() { + return input.getTimestamp(); + } + + @Override + public BoundedWindow window() { + return Iterables.getOnlyElement(input.getWindows()); + } + }); + for (W w: windows) { + WindowedValue wv = WindowedValue.of(input.getValue(), input.getTimestamp(), + w, input.getPane()); + outputEmitter.emit(ApexStreamTuple.DataTuple.of(wv)); + } + } + } + } + + /** + * Input port. + */ + public final transient DefaultInputPort>> inputPort = + new DefaultInputPort>>() { + @Override + public void process(ApexStreamTuple> tuple) { + try { + fn.process(tuple, outputEmitter); + } catch (Exception e) { + Throwables.throwIfUnchecked(e); + throw new RuntimeException(e); + } + } + }; + + /** + * Output port. + */ + @OutputPortFieldAnnotation(optional = true) + public final transient DefaultOutputPort>> + outputPort = new DefaultOutputPort<>(); + +} diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/utils/ApexStateInternals.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/utils/ApexStateInternals.java index c59afc5961ca8..cfc57cd575365 100644 --- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/utils/ApexStateInternals.java +++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/utils/ApexStateInternals.java @@ -17,6 +17,7 @@ */ package org.apache.beam.runners.apex.translation.utils; +import com.datatorrent.netlet.util.Slice; import com.esotericsoftware.kryo.DefaultSerializer; import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.serializers.JavaSerializer; @@ -27,7 +28,9 @@ import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.apache.beam.runners.core.StateInternals; import org.apache.beam.runners.core.StateInternalsFactory; import org.apache.beam.runners.core.StateNamespace; @@ -35,6 +38,7 @@ import org.apache.beam.runners.core.StateTag.StateBinder; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.Coder.Context; +import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.coders.InstantCoder; import org.apache.beam.sdk.coders.ListCoder; import org.apache.beam.sdk.transforms.Combine.CombineFn; @@ -42,6 +46,7 @@ import org.apache.beam.sdk.transforms.CombineWithContext.KeyedCombineFnWithContext; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.OutputTimeFn; +import org.apache.beam.sdk.util.CoderUtils; import org.apache.beam.sdk.util.CombineFnUtil; import org.apache.beam.sdk.util.state.BagState; import org.apache.beam.sdk.util.state.CombiningState; @@ -56,22 +61,18 @@ import org.joda.time.Instant; /** - * Implementation of {@link StateInternals} that can be serialized and - * checkpointed with the operator. Suitable for small states, in the future this - * should be based on the incremental state saving components in the Apex - * library. + * Implementation of {@link StateInternals} for transient use. + * + *

    For fields that need to be serialized, use {@link ApexStateInternalsFactory} + * or {@link StateInternalsProxy} */ -@DefaultSerializer(JavaSerializer.class) -public class ApexStateInternals implements StateInternals, Serializable { - private static final long serialVersionUID = 1L; - public static ApexStateInternals forKey(K key) { - return new ApexStateInternals<>(key); - } - +public class ApexStateInternals implements StateInternals { private final K key; + private final Table stateTable; - protected ApexStateInternals(K key) { + protected ApexStateInternals(K key, Table stateTable) { this.key = key; + this.stateTable = stateTable; } @Override @@ -79,11 +80,6 @@ public K getKey() { return key; } - /** - * Serializable state for internals (namespace to state tag to coded value). - */ - private final Table stateTable = HashBasedTable.create(); - @Override public T state(StateNamespace namespace, StateTag address) { return state(namespace, address, StateContexts.nullContext()); @@ -437,17 +433,54 @@ public Boolean read() { } /** - * Factory for {@link ApexStateInternals}. + * Implementation of {@link StateInternals} that can be serialized and + * checkpointed with the operator. Suitable for small states, in the future this + * should be based on the incremental state saving components in the Apex + * library. * * @param key type */ + @DefaultSerializer(JavaSerializer.class) public static class ApexStateInternalsFactory implements StateInternalsFactory, Serializable { private static final long serialVersionUID = 1L; + /** + * Serializable state for internals (namespace to state tag to coded value). + */ + private Map> perKeyState = new HashMap<>(); + private final Coder keyCoder; + + private ApexStateInternalsFactory(Coder keyCoder) { + this.keyCoder = keyCoder; + } @Override - public StateInternals stateInternalsForKey(K key) { - return ApexStateInternals.forKey(key); + public ApexStateInternals stateInternalsForKey(K key) { + final Slice keyBytes; + try { + keyBytes = (key != null) ? new Slice(CoderUtils.encodeToByteArray(keyCoder, key)) : + new Slice(null); + } catch (CoderException e) { + throw new RuntimeException(e); + } + Table stateTable = perKeyState.get(keyBytes); + if (stateTable == null) { + stateTable = HashBasedTable.create(); + perKeyState.put(keyBytes, stateTable); + } + return new ApexStateInternals<>(key, stateTable); + } + + } + + /** + * Factory to create the state internals. + */ + public static class ApexStateBackend implements Serializable { + private static final long serialVersionUID = 1L; + + public ApexStateInternalsFactory newStateInternalsFactory(Coder keyCoder) { + return new ApexStateInternalsFactory(keyCoder); } } diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/utils/StateInternalsProxy.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/utils/StateInternalsProxy.java new file mode 100644 index 0000000000000..1f28364269215 --- /dev/null +++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/utils/StateInternalsProxy.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.apex.translation.utils; + +import com.esotericsoftware.kryo.DefaultSerializer; +import com.esotericsoftware.kryo.serializers.JavaSerializer; +import java.io.Serializable; +import org.apache.beam.runners.core.StateInternals; +import org.apache.beam.runners.core.StateInternalsFactory; +import org.apache.beam.runners.core.StateNamespace; +import org.apache.beam.runners.core.StateTag; +import org.apache.beam.sdk.util.state.State; +import org.apache.beam.sdk.util.state.StateContext; + +/** + * State internals for reusable processing context. + * @param + */ +@DefaultSerializer(JavaSerializer.class) +public class StateInternalsProxy implements StateInternals, Serializable { + + private final StateInternalsFactory factory; + private transient K currentKey; + + public StateInternalsProxy(ApexStateInternals.ApexStateInternalsFactory factory) { + this.factory = factory; + } + + public StateInternalsFactory getFactory() { + return this.factory; + } + + public void setKey(K key) { + currentKey = key; + } + + @Override + public K getKey() { + return currentKey; + } + + @Override + public T state(StateNamespace namespace, StateTag address) { + return factory.stateInternalsForKey(currentKey).state(namespace, address); + } + + @Override + public T state(StateNamespace namespace, StateTag address, + StateContext c) { + return factory.stateInternalsForKey(currentKey).state(namespace, address, c); + } +} diff --git a/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/ApexGroupByKeyOperatorTest.java b/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/ApexGroupByKeyOperatorTest.java index fb80d0c33d8c8..4b73114da4fee 100644 --- a/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/ApexGroupByKeyOperatorTest.java +++ b/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/ApexGroupByKeyOperatorTest.java @@ -66,7 +66,7 @@ public void testGlobalWindowMinTimestamp() throws Exception { input.setCoder(KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of())); ApexGroupByKeyOperator operator = new ApexGroupByKeyOperator<>(options, - input, new ApexStateInternals.ApexStateInternalsFactory() + input, new ApexStateInternals.ApexStateBackend() ); operator.setup(null); diff --git a/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/ParDoTranslatorTest.java b/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/ParDoTranslatorTest.java index 3bcba00e685d5..2760d06fa518e 100644 --- a/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/ParDoTranslatorTest.java +++ b/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/ParDoTranslatorTest.java @@ -216,7 +216,7 @@ public void testSerialization() throws Exception { WindowingStrategy.globalDefault(), Collections.>singletonList(singletonView), coder, - new ApexStateInternals.ApexStateInternalsFactory()); + new ApexStateInternals.ApexStateBackend()); operator.setup(null); operator.beginWindow(0); WindowedValue wv1 = WindowedValue.valueInGlobalWindow(1); diff --git a/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/utils/ApexStateInternalsTest.java b/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/utils/ApexStateInternalsTest.java index 4f4ecfb3fc81c..7160e4544ab1e 100644 --- a/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/utils/ApexStateInternalsTest.java +++ b/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/utils/ApexStateInternalsTest.java @@ -24,6 +24,8 @@ import com.datatorrent.lib.util.KryoCloneUtils; import java.util.Arrays; +import org.apache.beam.runners.apex.translation.utils.ApexStateInternals.ApexStateBackend; +import org.apache.beam.runners.apex.translation.utils.ApexStateInternals.ApexStateInternalsFactory; import org.apache.beam.runners.core.StateMerging; import org.apache.beam.runners.core.StateNamespace; import org.apache.beam.runners.core.StateNamespaceForTest; @@ -76,7 +78,9 @@ public class ApexStateInternalsTest { @Before public void initStateInternals() { - underTest = new ApexStateInternals<>(null); + underTest = new ApexStateInternals.ApexStateBackend() + .newStateInternalsFactory(StringUtf8Coder.of()) + .stateInternalsForKey((String) null); } @Test @@ -344,16 +348,21 @@ public void testMergeLatestWatermarkIntoSource() throws Exception { @Test public void testSerialization() throws Exception { - ApexStateInternals original = new ApexStateInternals(null); - ValueState value = original.state(NAMESPACE_1, STRING_VALUE_ADDR); - assertEquals(original.state(NAMESPACE_1, STRING_VALUE_ADDR), value); + ApexStateInternalsFactory sif = new ApexStateBackend(). + newStateInternalsFactory(StringUtf8Coder.of()); + ApexStateInternals keyAndState = sif.stateInternalsForKey("dummy"); + + ValueState value = keyAndState.state(NAMESPACE_1, STRING_VALUE_ADDR); + assertEquals(keyAndState.state(NAMESPACE_1, STRING_VALUE_ADDR), value); value.write("hello"); - ApexStateInternals cloned; - assertNotNull("Serialization", cloned = KryoCloneUtils.cloneObject(original)); - ValueState clonedValue = cloned.state(NAMESPACE_1, STRING_VALUE_ADDR); + ApexStateInternalsFactory cloned; + assertNotNull("Serialization", cloned = KryoCloneUtils.cloneObject(sif)); + ApexStateInternals clonedKeyAndState = cloned.stateInternalsForKey("dummy"); + + ValueState clonedValue = clonedKeyAndState.state(NAMESPACE_1, STRING_VALUE_ADDR); assertThat(clonedValue.read(), Matchers.equalTo("hello")); - assertEquals(cloned.state(NAMESPACE_1, STRING_VALUE_ADDR), value); + assertEquals(clonedKeyAndState.state(NAMESPACE_1, STRING_VALUE_ADDR), value); } } diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SingletonKeyedWorkItemCoder.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SingletonKeyedWorkItemCoder.java index d95ed7c624439..fe96eb1e2d8ba 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SingletonKeyedWorkItemCoder.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SingletonKeyedWorkItemCoder.java @@ -36,7 +36,7 @@ import org.apache.beam.sdk.util.WindowedValue; /** - * Singleton keyed word iteam coder. + * Singleton keyed work item coder. * @param * @param */ From 042b28a681d1ee0c8942a0ec61df014c23a0fdaf Mon Sep 17 00:00:00 2001 From: Dan Halperin Date: Fri, 7 Apr 2017 14:56:09 -0700 Subject: [PATCH 12/15] TestDataflowRunner: better error handling 1. There was a race in which pipelines without PAsserts might erroneously pass because they would be canceled, which would in turn cause the watermarks to reach max infinity, which would in turn (because there are no PAsserts) cause the main streaming poll loop think the pipeline succeeded. Fix this by making the error presence available to the main polling loop, and only canceling from there. 2. The fact we were canceling from two places meant we could get double-cancelations that led to test failures. Fix both these issues (I hope). --- .../dataflow/testing/TestDataflowRunner.java | 34 +++++++++++-------- .../testing/TestDataflowRunnerTest.java | 3 +- 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/testing/TestDataflowRunner.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/testing/TestDataflowRunner.java index d220bb013752b..dc32466ceb5ed 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/testing/TestDataflowRunner.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/testing/TestDataflowRunner.java @@ -111,8 +111,8 @@ DataflowPipelineJob run(Pipeline pipeline, DataflowRunner runner) { assertThat(job, testPipelineOptions.getOnCreateMatcher()); - CancelWorkflowOnError messageHandler = new CancelWorkflowOnError( - job, new MonitoringUtil.LoggingHandler()); + final ErrorMonitorMessagesHandler messageHandler = + new ErrorMonitorMessagesHandler(job, new MonitoringUtil.LoggingHandler()); try { final Optional success; @@ -126,6 +126,10 @@ public Optional call() throws Exception { for (;;) { JobMetrics metrics = getJobMetrics(job); Optional success = checkForPAssertSuccess(job, metrics); + if (messageHandler.hasSeenError()) { + return Optional.of(false); + } + if (success.isPresent() && (!success.get() || atMaxWatermark(job, metrics))) { // It's possible that the streaming pipeline doesn't use PAssert. // So checkForSuccess() will return true before job is finished. @@ -312,18 +316,22 @@ public String toString() { } /** - * Cancels the workflow on the first error message it sees. + * Monitors job log output messages for errors. * *

    Creates an error message representing the concatenation of all error messages seen. */ - private static class CancelWorkflowOnError implements JobMessagesHandler { + private static class ErrorMonitorMessagesHandler implements JobMessagesHandler { private final DataflowPipelineJob job; private final JobMessagesHandler messageHandler; private final StringBuffer errorMessage; - private CancelWorkflowOnError(DataflowPipelineJob job, JobMessagesHandler messageHandler) { + private volatile boolean hasSeenError; + + private ErrorMonitorMessagesHandler( + DataflowPipelineJob job, JobMessagesHandler messageHandler) { this.job = job; this.messageHandler = messageHandler; this.errorMessage = new StringBuffer(); + this.hasSeenError = false; } @Override @@ -335,20 +343,16 @@ public void process(List messages) { LOG.info("Dataflow job {} threw exception. Failure message was: {}", job.getJobId(), message.getMessageText()); errorMessage.append(message.getMessageText()); + hasSeenError = true; } } - if (errorMessage.length() > 0) { - LOG.info("Cancelling Dataflow job {}", job.getJobId()); - try { - job.cancel(); - } catch (Exception ignore) { - // The TestDataflowRunner will thrown an AssertionError with the job failure - // messages. - } - } } - private String getErrorMessage() { + boolean hasSeenError() { + return hasSeenError; + } + + String getErrorMessage() { return errorMessage.toString(); } } diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/testing/TestDataflowRunnerTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/testing/TestDataflowRunnerTest.java index d3eccbb4df19d..e4fa788e4fbcd 100644 --- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/testing/TestDataflowRunnerTest.java +++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/testing/TestDataflowRunnerTest.java @@ -32,6 +32,7 @@ import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.doCallRealMethod; import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -207,7 +208,7 @@ public State answer(InvocationOnMock invocation) { runner.run(p, mockRunner); } catch (AssertionError expected) { assertThat(expected.getMessage(), containsString("FooException")); - verify(mockJob, atLeastOnce()).cancel(); + verify(mockJob, never()).cancel(); return; } // Note that fail throws an AssertionError which is why it is placed out here From 42f63342be667649eff5ea65dd9714c8b2f75cff Mon Sep 17 00:00:00 2001 From: rjoshi2 Date: Fri, 7 Apr 2017 18:12:09 -0700 Subject: [PATCH 13/15] Fix for potentially unclosed streams in ApexYarnLauncher --- .../beam/runners/apex/ApexYarnLauncher.java | 109 +++++++++--------- 1 file changed, 55 insertions(+), 54 deletions(-) diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexYarnLauncher.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexYarnLauncher.java index 6bc42f087231b..198b9bff52f25 100644 --- a/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexYarnLauncher.java +++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexYarnLauncher.java @@ -173,56 +173,57 @@ public void shutdown(ShutdownMode arg0) throws LauncherException { * @throws IOException when dependency information cannot be read */ public static List getYarnDeployDependencies() throws IOException { - InputStream dependencyTree = ApexRunner.class.getResourceAsStream("dependency-tree"); - BufferedReader br = new BufferedReader(new InputStreamReader(dependencyTree)); - String line = null; - List excludes = new ArrayList<>(); - int excludeLevel = Integer.MAX_VALUE; - while ((line = br.readLine()) != null) { - for (int i = 0; i < line.length(); i++) { - char c = line.charAt(i); - if (Character.isLetter(c)) { - if (i > excludeLevel) { - excludes.add(line.substring(i)); - } else { - if (line.substring(i).startsWith("org.apache.hadoop")) { - excludeLevel = i; - excludes.add(line.substring(i)); - } else { - excludeLevel = Integer.MAX_VALUE; + try (InputStream dependencyTree = ApexRunner.class.getResourceAsStream("dependency-tree")) { + try (BufferedReader br = new BufferedReader(new InputStreamReader(dependencyTree))) { + String line; + List excludes = new ArrayList<>(); + int excludeLevel = Integer.MAX_VALUE; + while ((line = br.readLine()) != null) { + for (int i = 0; i < line.length(); i++) { + char c = line.charAt(i); + if (Character.isLetter(c)) { + if (i > excludeLevel) { + excludes.add(line.substring(i)); + } else { + if (line.substring(i).startsWith("org.apache.hadoop")) { + excludeLevel = i; + excludes.add(line.substring(i)); + } else { + excludeLevel = Integer.MAX_VALUE; + } + } + break; } } - break; } - } - } - br.close(); - - Set excludeJarFileNames = Sets.newHashSet(); - for (String exclude : excludes) { - String[] mvnc = exclude.split(":"); - String fileName = mvnc[1] + "-"; - if (mvnc.length == 6) { - fileName += mvnc[4] + "-" + mvnc[3]; // with classifier - } else { - fileName += mvnc[3]; - } - fileName += ".jar"; - excludeJarFileNames.add(fileName); - } - ClassLoader classLoader = ApexYarnLauncher.class.getClassLoader(); - URL[] urls = ((URLClassLoader) classLoader).getURLs(); - List dependencyJars = new ArrayList<>(); - for (int i = 0; i < urls.length; i++) { - File f = new File(urls[i].getFile()); - // dependencies can also be directories in the build reactor, - // the Apex client will automatically create jar files for those. - if (f.exists() && !excludeJarFileNames.contains(f.getName())) { - dependencyJars.add(f); + Set excludeJarFileNames = Sets.newHashSet(); + for (String exclude : excludes) { + String[] mvnc = exclude.split(":"); + String fileName = mvnc[1] + "-"; + if (mvnc.length == 6) { + fileName += mvnc[4] + "-" + mvnc[3]; // with classifier + } else { + fileName += mvnc[3]; + } + fileName += ".jar"; + excludeJarFileNames.add(fileName); + } + + ClassLoader classLoader = ApexYarnLauncher.class.getClassLoader(); + URL[] urls = ((URLClassLoader) classLoader).getURLs(); + List dependencyJars = new ArrayList<>(); + for (int i = 0; i < urls.length; i++) { + File f = new File(urls[i].getFile()); + // dependencies can also be directories in the build reactor, + // the Apex client will automatically create jar files for those. + if (f.exists() && !excludeJarFileNames.contains(f.getName())) { + dependencyJars.add(f); + } + } + return dependencyJars; } } - return dependencyJars; } /** @@ -238,17 +239,17 @@ public static void createJar(File dir, File jarFile) throws IOException { throw new RuntimeException("Failed to remove " + jarFile); } URI uri = URI.create("jar:" + jarFile.toURI()); - try (final FileSystem zipfs = FileSystems.newFileSystem(uri, env);) { + try (final FileSystem zipfs = FileSystems.newFileSystem(uri, env)) { File manifestFile = new File(dir, JarFile.MANIFEST_NAME); Files.createDirectory(zipfs.getPath("META-INF")); - final OutputStream out = Files.newOutputStream(zipfs.getPath(JarFile.MANIFEST_NAME)); - if (!manifestFile.exists()) { - new Manifest().write(out); - } else { - FileUtils.copyFile(manifestFile, out); + try (final OutputStream out = Files.newOutputStream(zipfs.getPath(JarFile.MANIFEST_NAME))) { + if (!manifestFile.exists()) { + new Manifest().write(out); + } else { + FileUtils.copyFile(manifestFile, out); + } } - out.close(); final java.nio.file.Path root = dir.toPath(); Files.walkFileTree(root, new java.nio.file.SimpleFileVisitor() { @@ -274,9 +275,9 @@ public FileVisitResult preVisitDirectory(Path dir, BasicFileAttributes attrs) public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) throws IOException { String name = relativePath + file.getFileName(); if (!JarFile.MANIFEST_NAME.equals(name)) { - final OutputStream out = Files.newOutputStream(zipfs.getPath(name)); - FileUtils.copyFile(file.toFile(), out); - out.close(); + try (final OutputStream out = Files.newOutputStream(zipfs.getPath(name))) { + FileUtils.copyFile(file.toFile(), out); + } } return super.visitFile(file, attrs); } From 9e294dc05c1f1ac2aa2ac15f364394698a2767a7 Mon Sep 17 00:00:00 2001 From: Amit Sela Date: Tue, 4 Apr 2017 11:46:38 +0300 Subject: [PATCH 14/15] [BEAM-1737] Implement a Single-output ParDo as a Multi-output ParDo with a single output. remove use of EvaluationContext in DStream lambda, it is not serializable and also redundant in this case. implement pardo as multido. cache only if this is an actual MultiDo. --- .../spark/translation/DoFnFunction.java | 130 ------------------ .../translation/TransformTranslator.java | 74 ++++------ .../StreamingTransformTranslator.java | 78 +++-------- 3 files changed, 47 insertions(+), 235 deletions(-) delete mode 100644 runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java deleted file mode 100644 index 11761b6b161d2..0000000000000 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java +++ /dev/null @@ -1,130 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.beam.runners.spark.translation; - -import java.util.Collections; -import java.util.Iterator; -import java.util.LinkedList; -import java.util.List; -import java.util.Map; -import org.apache.beam.runners.core.DoFnRunner; -import org.apache.beam.runners.core.DoFnRunners; -import org.apache.beam.runners.spark.aggregators.NamedAggregators; -import org.apache.beam.runners.spark.aggregators.SparkAggregators; -import org.apache.beam.runners.spark.metrics.SparkMetricsContainer; -import org.apache.beam.runners.spark.util.SideInputBroadcast; -import org.apache.beam.runners.spark.util.SparkSideInputReader; -import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.util.WindowedValue; -import org.apache.beam.sdk.util.WindowingStrategy; -import org.apache.beam.sdk.values.KV; -import org.apache.beam.sdk.values.TupleTag; -import org.apache.spark.Accumulator; -import org.apache.spark.api.java.function.FlatMapFunction; - - -/** - * Beam's Do functions correspond to Spark's FlatMap functions. - * - * @param Input element type. - * @param Output element type. - */ -public class DoFnFunction - implements FlatMapFunction>, WindowedValue> { - - private final Accumulator aggregatorsAccum; - private final Accumulator metricsAccum; - private final String stepName; - private final DoFn doFn; - private final SparkRuntimeContext runtimeContext; - private final Map, KV, SideInputBroadcast>> sideInputs; - private final WindowingStrategy windowingStrategy; - - /** - * @param aggregatorsAccum The Spark {@link Accumulator} that backs the Beam Aggregators. - * @param doFn The {@link DoFn} to be wrapped. - * @param runtimeContext The {@link SparkRuntimeContext}. - * @param sideInputs Side inputs used in this {@link DoFn}. - * @param windowingStrategy Input {@link WindowingStrategy}. - */ - public DoFnFunction( - Accumulator aggregatorsAccum, - Accumulator metricsAccum, - String stepName, - DoFn doFn, - SparkRuntimeContext runtimeContext, - Map, KV, SideInputBroadcast>> sideInputs, - WindowingStrategy windowingStrategy) { - this.aggregatorsAccum = aggregatorsAccum; - this.metricsAccum = metricsAccum; - this.stepName = stepName; - this.doFn = doFn; - this.runtimeContext = runtimeContext; - this.sideInputs = sideInputs; - this.windowingStrategy = windowingStrategy; - } - - @Override - public Iterable> call( - Iterator> iter) throws Exception { - DoFnOutputManager outputManager = new DoFnOutputManager(); - - DoFnRunner doFnRunner = - DoFnRunners.simpleRunner( - runtimeContext.getPipelineOptions(), - doFn, - new SparkSideInputReader(sideInputs), - outputManager, - new TupleTag() { - }, - Collections.>emptyList(), - new SparkProcessContext.NoOpStepContext(), - new SparkAggregators.Factory(runtimeContext, aggregatorsAccum), - windowingStrategy); - - DoFnRunner doFnRunnerWithMetrics = - new DoFnRunnerWithMetrics<>(stepName, doFnRunner, metricsAccum); - - return new SparkProcessContext<>(doFn, doFnRunnerWithMetrics, outputManager) - .processPartition(iter); - } - - private class DoFnOutputManager - implements SparkProcessContext.SparkOutputManager> { - - private final List> outputs = new LinkedList<>(); - - @Override - public void clear() { - outputs.clear(); - } - - @Override - public Iterator> iterator() { - return outputs.iterator(); - } - - @Override - @SuppressWarnings("unchecked") - public synchronized void output(TupleTag tag, WindowedValue output) { - outputs.add((WindowedValue) output); - } - } - -} diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java index 6290bbaf85ea0..7894c4eac1da2 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java @@ -24,7 +24,6 @@ import static org.apache.beam.runners.spark.translation.TranslationUtils.rejectStateAndTimers; import com.google.common.base.Optional; -import com.google.common.collect.Iterables; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import java.util.Collections; @@ -334,8 +333,7 @@ public String toNativeString() { }; } - private static TransformEvaluator> - parDo() { + private static TransformEvaluator> parDo() { return new TransformEvaluator>() { @Override public void evaluate( @@ -351,51 +349,31 @@ public void evaluate( context.getInput(transform).getWindowingStrategy(); Accumulator aggAccum = AggregatorsAccumulator.getInstance(); Accumulator metricsAccum = MetricsAccumulator.getInstance(); - Map, KV, SideInputBroadcast>> sideInputs = - TranslationUtils.getSideInputs(transform.getSideInputs(), context); - if (transform.getSideOutputTags().size() == 0) { - // Don't tag with the output and filter for a single-output ParDo, as it's additional - // identity transforms. - // Also see BEAM-1737 for failures when the two versions are condensed. - PCollection output = - (PCollection) - Iterables.getOnlyElement(context.getOutputs(transform)).getValue(); - context.putDataset( - output, - new BoundedDataset<>( - inRDD.mapPartitions( - new DoFnFunction<>( - aggAccum, - metricsAccum, - stepName, - doFn, - context.getRuntimeContext(), - sideInputs, - windowingStrategy)))); - } else { - JavaPairRDD, WindowedValue> all = - inRDD - .mapPartitionsToPair( - new MultiDoFnFunction<>( - aggAccum, - metricsAccum, - stepName, - doFn, - context.getRuntimeContext(), - transform.getMainOutputTag(), - TranslationUtils.getSideInputs(transform.getSideInputs(), context), - windowingStrategy)) - .cache(); - for (TaggedPValue output : context.getOutputs(transform)) { - @SuppressWarnings("unchecked") - JavaPairRDD, WindowedValue> filtered = - all.filter(new TranslationUtils.TupleTagFilter(output.getTag())); - @SuppressWarnings("unchecked") - // Object is the best we can do since different outputs can have different tags - JavaRDD> values = - (JavaRDD>) (JavaRDD) filtered.values(); - context.putDataset(output.getValue(), new BoundedDataset<>(values)); - } + JavaPairRDD, WindowedValue> all = + inRDD.mapPartitionsToPair( + new MultiDoFnFunction<>( + aggAccum, + metricsAccum, + stepName, + doFn, + context.getRuntimeContext(), + transform.getMainOutputTag(), + TranslationUtils.getSideInputs(transform.getSideInputs(), context), + windowingStrategy)); + List outputs = context.getOutputs(transform); + if (outputs.size() > 1) { + // cache the RDD if we're going to filter it more than once. + all.cache(); + } + for (TaggedPValue output : outputs) { + @SuppressWarnings("unchecked") + JavaPairRDD, WindowedValue> filtered = + all.filter(new TranslationUtils.TupleTagFilter(output.getTag())); + @SuppressWarnings("unchecked") + // Object is the best we can do since different outputs can have different tags + JavaRDD> values = + (JavaRDD>) (JavaRDD) filtered.values(); + context.putDataset(output.getValue(), new BoundedDataset<>(values)); } } diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java index 2d2854f0635b9..d4c6c9de3bea4 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java @@ -43,7 +43,6 @@ import org.apache.beam.runners.spark.stateful.SparkGroupAlsoByWindowViaWindowSet; import org.apache.beam.runners.spark.translation.BoundedDataset; import org.apache.beam.runners.spark.translation.Dataset; -import org.apache.beam.runners.spark.translation.DoFnFunction; import org.apache.beam.runners.spark.translation.EvaluationContext; import org.apache.beam.runners.spark.translation.GroupCombineFunctions; import org.apache.beam.runners.spark.translation.MultiDoFnFunction; @@ -368,8 +367,7 @@ public String toNativeString() { }; } - private static TransformEvaluator> - multiDo() { + private static TransformEvaluator> parDo() { return new TransformEvaluator>() { public void evaluate( final ParDo.MultiOutput transform, final EvaluationContext context) { @@ -388,16 +386,14 @@ public void evaluate( final String stepName = context.getCurrentTransform().getFullName(); if (transform.getSideOutputTags().size() == 0) { - // Don't tag with the output and filter for a single-output ParDo, as it's additional - // identity transforms. - // Also see BEAM-1737 for failures when the two versions are condensed. - JavaDStream> outStream = - dStream.transform( - new Function>, JavaRDD>>() { + JavaPairDStream, WindowedValue> all = + dStream.transformToPair( + new Function< + JavaRDD>, + JavaPairRDD, WindowedValue>>() { @Override - public JavaRDD> call(JavaRDD> rdd) - throws Exception { - final JavaSparkContext jsc = new JavaSparkContext(rdd.context()); + public JavaPairRDD, WindowedValue> call( + JavaRDD> rdd) throws Exception { final Accumulator aggAccum = AggregatorsAccumulator.getInstance(); final Accumulator metricsAccum = @@ -405,59 +401,27 @@ public JavaRDD> call(JavaRDD> rdd) final Map, KV, SideInputBroadcast>> sideInputs = TranslationUtils.getSideInputs( - transform.getSideInputs(), jsc, pviews); - return rdd.mapPartitions( - new DoFnFunction<>( + transform.getSideInputs(), + JavaSparkContext.fromSparkContext(rdd.context()), + pviews); + return rdd.mapPartitionsToPair( + new MultiDoFnFunction<>( aggAccum, metricsAccum, stepName, doFn, runtimeContext, + transform.getMainOutputTag(), sideInputs, windowingStrategy)); } }); - - PCollection output = - (PCollection) - Iterables.getOnlyElement(context.getOutputs(transform)).getValue(); - context.putDataset( - output, new UnboundedDataset<>(outStream, unboundedDataset.getStreamSources())); - } else { - JavaPairDStream, WindowedValue> all = - dStream - .transformToPair( - new Function< - JavaRDD>, - JavaPairRDD, WindowedValue>>() { - @Override - public JavaPairRDD, WindowedValue> call( - JavaRDD> rdd) throws Exception { - String stepName = context.getCurrentTransform().getFullName(); - final Accumulator aggAccum = - AggregatorsAccumulator.getInstance(); - final Accumulator metricsAccum = - MetricsAccumulator.getInstance(); - final Map, KV, SideInputBroadcast>> - sideInputs = - TranslationUtils.getSideInputs( - transform.getSideInputs(), - JavaSparkContext.fromSparkContext(rdd.context()), - pviews); - return rdd.mapPartitionsToPair( - new MultiDoFnFunction<>( - aggAccum, - metricsAccum, - stepName, - doFn, - runtimeContext, - transform.getMainOutputTag(), - sideInputs, - windowingStrategy)); - } - }) - .cache(); - for (TaggedPValue output : context.getOutputs(transform)) { + List outputs = context.getOutputs(transform); + if (outputs.size() > 1) { + // cache the DStream if we're going to filter it more than once. + all.cache(); + } + for (TaggedPValue output : outputs) { @SuppressWarnings("unchecked") JavaPairDStream, WindowedValue> filtered = all.filter(new TranslationUtils.TupleTagFilter(output.getTag())); @@ -525,7 +489,7 @@ public JavaRDD>> call( EVALUATORS.put(Read.Unbounded.class, readUnbounded()); EVALUATORS.put(GroupByKey.class, groupByKey()); EVALUATORS.put(Combine.GroupedValues.class, combineGrouped()); - EVALUATORS.put(ParDo.MultiOutput.class, multiDo()); + EVALUATORS.put(ParDo.MultiOutput.class, parDo()); EVALUATORS.put(ConsoleIO.Write.Unbound.class, print()); EVALUATORS.put(CreateStream.class, createFromQueue()); EVALUATORS.put(Window.Assign.class, window()); From 870e42dec6870d0170724a7eff849e95f9ff266b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jean-Baptiste=20Onofr=C3=A9?= Date: Tue, 9 Aug 2016 11:17:32 +0200 Subject: [PATCH 15/15] Add DSLs module --- dsls/pom.xml | 56 ++++++++++++++++++++++++++++++++++++++++++++++++++++ pom.xml | 1 + 2 files changed, 57 insertions(+) create mode 100644 dsls/pom.xml diff --git a/dsls/pom.xml b/dsls/pom.xml new file mode 100644 index 0000000000000..6e0017115a41b --- /dev/null +++ b/dsls/pom.xml @@ -0,0 +1,56 @@ + + + + + 4.0.0 + + + org.apache.beam + beam-parent + 0.7.0-SNAPSHOT + ../pom.xml + + + beam-dsls-parent + Apache Beam :: DSLs + + + + + + + + + + + org.apache.maven.plugins + maven-jar-plugin + + + default-test-jar + + test-jar + + + + + + + + + \ No newline at end of file diff --git a/pom.xml b/pom.xml index 09f39859b0416..ef312c19373f7 100644 --- a/pom.xml +++ b/pom.xml @@ -152,6 +152,7 @@ sdks/java/build-tools sdks runners + dsls examples sdks/java/javadoc