diff --git a/python/pyspark/streaming_tests.py b/python/pyspark/streaming_tests.py index ef308fdd6aa59..c35d352c66ca5 100644 --- a/python/pyspark/streaming_tests.py +++ b/python/pyspark/streaming_tests.py @@ -275,7 +275,7 @@ def test_func(dstream): self.assertEqual(expected_output, output) def test_mapPartitions_batch(self): - """Basic operation test for DStream.mapPartitions with batch deserializer""" + """Basic operation test for DStream.mapPartitions with batch deserializer.""" test_input = [range(1, 5), range(5, 9), range(9, 13)] numSlices = 2 @@ -288,7 +288,7 @@ def f(iterator): self.assertEqual(expected_output, output) def test_mapPartitions_unbatch(self): - """Basic operation test for DStream.mapPartitions with unbatch deserializer""" + """Basic operation test for DStream.mapPartitions with unbatch deserializer.""" test_input = [range(1, 4), range(4, 7), range(7, 10)] numSlices = 2 @@ -301,8 +301,8 @@ def f(iterator): self.assertEqual(expected_output, output) def test_countByValue_batch(self): - """Basic operation test for DStream.countByValue with batch deserializer""" - test_input = [range(1, 5) + range(1,5), range(5, 7) + range(5, 9), ["a"] * 2 + ["b"] + [""] ] + """Basic operation test for DStream.countByValue with batch deserializer.""" + test_input = [range(1, 5) + range(1,5), range(5, 7) + range(5, 9), ["a", "a", "b", ""]] def test_func(dstream): return dstream.countByValue() @@ -315,7 +315,7 @@ def test_func(dstream): self.assertEqual(expected_output, output) def test_countByValue_unbatch(self): - """Basic operation test for DStream.countByValue with unbatch deserializer""" + """Basic operation test for DStream.countByValue with unbatch deserializer.""" test_input = [range(1, 4), [1, 1, ""], ["a", "a", "b"]] def test_func(dstream): @@ -328,30 +328,72 @@ def test_func(dstream): self._sort_result_based_on_key(result) self.assertEqual(expected_output, output) + def test_groupByKey_batch(self): + """Basic operation test for DStream.groupByKey with batch deserializer.""" + test_input = [range(1, 5), [1, 1, 1, 2, 2, 3], ["a", "a", "b", "", "", ""]] + def test_func(dstream): + return dstream.map(lambda x: (x,1)).groupByKey() + expected_output = [[(1, [1]), (2, [1]), (3, [1]), (4, [1])], + [(1, [1, 1, 1]), (2, [1, 1]), (3, [1])], + [("a", [1, 1]), ("b", [1]), ("", [1, 1, 1])]] + scattered_output = self._run_stream(test_input, test_func, expected_output) + output = self._convert_iter_value_to_list(scattered_output) + for result in (output, expected_output): + self._sort_result_based_on_key(result) + self.assertEqual(expected_output, output) + + def test_groupByKey_unbatch(self): + """Basic operation test for DStream.groupByKey with unbatch deserializer.""" + test_input = [range(1, 4), [1, 1, ""], ["a", "a", "b"]] + def test_func(dstream): + return dstream.map(lambda x: (x,1)).groupByKey() + expected_output = [[(1, [1]), (2, [1]), (3, [1])], + [(1, [1, 1]), ("", [1])], + [("a", [1, 1]), ("b", [1])]] + scattered_output = self._run_stream(test_input, test_func, expected_output) + output = self._convert_iter_value_to_list(scattered_output) + for result in (output, expected_output): + self._sort_result_based_on_key(result) + self.assertEqual(expected_output, output) + + def _convert_iter_value_to_list(self, outputs): + """Return key value pair list. Value is converted to iterator to list.""" + result = list() + for output in outputs: + result.append(map(lambda (x, y): (x, list(y)), output)) + return result + def _sort_result_based_on_key(self, outputs): + """Sort the list base onf first value.""" for output in outputs: output.sort(key=lambda x: x[0]) def _run_stream(self, test_input, test_func, expected_output, numSlices=None): - """Start stream and return the output""" - # Generate input stream with user-defined input + """ + Start stream and return the output. + @param test_input: dataset for the test. This should be list of lists. + @param test_func: wrapped test_function. This function should return PythonDstream object. + @param expexted_output: expected output for this testcase. + @param numSlices: the number of slices in the rdd in the dstream. + """ + # Generate input stream with user-defined input. numSlices = numSlices or self.numInputPartitions test_input_stream = self.ssc._testInputStream(test_input, numSlices) - # Apply test function to stream + # Apply test function to stream. test_stream = test_func(test_input_stream) - # Add job to get output from stream + # Add job to get output from stream. test_stream._test_output(self.result) self.ssc.start() start_time = time.time() - # loop until get the result from stream + # Loop until get the expected the number of the result from the stream. while True: current_time = time.time() - # check time out + # Check time out. if (current_time - start_time) > self.timeout: break self.ssc.awaitTermination(50) - # check if the output is the same length of expexted output + # Check if the output is the same length of expexted output. if len(expected_output) == len(self.result): break @@ -372,9 +414,5 @@ def tearDownClass(cls): PySparkStreamingTestCase.tearDownClass() - - - - if __name__ == "__main__": unittest.main()