From d052c077360ee35af1dc7f395b1ccddae67ab5ef Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 27 Mar 2014 20:58:08 -0700 Subject: [PATCH] Python tests now pass with iterator pandas --- python/pyspark/join.py | 3 ++- python/pyspark/rdd.py | 10 +++++----- python/pyspark/resultitr.py | 39 +++++++++++++++++++++++++++++++++++++ 3 files changed, 46 insertions(+), 6 deletions(-) create mode 100644 python/pyspark/resultitr.py diff --git a/python/pyspark/join.py b/python/pyspark/join.py index 9feb4362dc469..febc223a645a4 100644 --- a/python/pyspark/join.py +++ b/python/pyspark/join.py @@ -31,6 +31,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """ +from pyspark.resultitr import ResultItr def _do_python_join(rdd, other, numPartitions, dispatch): vs = rdd.map(lambda (k, v): (k, (1, v))) @@ -88,5 +89,5 @@ def dispatch(seq): vbuf.append(v) elif n == 2: wbuf.append(v) - return (iter(vbuf), iter(wbuf)) + return (ResultItr(vbuf), ResultItr(wbuf)) return vs.union(ws).groupByKey(numPartitions).mapValues(dispatch) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 99693a4bff75f..feb00731af8e9 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -38,6 +38,7 @@ from pyspark.statcounter import StatCounter from pyspark.rddsampler import RDDSampler from pyspark.storagelevel import StorageLevel +from pyspark.resultitr import ResultItr from py4j.java_collections import ListConverter, MapConverter @@ -1118,7 +1119,7 @@ def groupByKey(self, numPartitions=None): Hash-partitions the resulting RDD with into numPartitions partitions. >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) - >>> sorted(x.groupByKey().collect()) + >>> map((lambda (x,y): (x, list(y))), sorted(x.groupByKey().collect())) [('a', [1, 1]), ('b', [1])] """ @@ -1133,7 +1134,7 @@ def mergeCombiners(a, b): return a + b return self.combineByKey(createCombiner, mergeValue, mergeCombiners, - numPartitions).mapValues(lambda x: iter(x)) + numPartitions).mapValues(lambda x: ResultItr(x)) # TODO: add tests def flatMapValues(self, f): @@ -1180,7 +1181,7 @@ def cogroup(self, other, numPartitions=None): >>> x = sc.parallelize([("a", 1), ("b", 4)]) >>> y = sc.parallelize([("a", 2)]) - >>> sorted(list(x.cogroup(y).collect())) + >>> map((lambda (x,y): (x, (list(y[0]), list(y[1])))), sorted(list(x.cogroup(y).collect()))) [('a', ([1], [2])), ('b', ([4], []))] """ return python_cogroup(self, other, numPartitions) @@ -1217,7 +1218,7 @@ def keyBy(self, f): >>> x = sc.parallelize(range(0,3)).keyBy(lambda x: x*x) >>> y = sc.parallelize(zip(range(0,5), range(0,5))) - >>> sorted(x.cogroup(y).collect()) + >>> map((lambda (x,y): (x, (list(y[0]), (list(y[1]))))), sorted(x.cogroup(y).collect())) [(0, ([0], [0])), (1, ([1], [1])), (2, ([], [2])), (3, ([], [3])), (4, ([2], [4]))] """ return self.map(lambda x: (f(x), x)) @@ -1317,7 +1318,6 @@ def getStorageLevel(self): # keys in the pairs. This could be an expensive operation, since those # hashes aren't retained. - class PipelinedRDD(RDD): """ Pipelined maps: diff --git a/python/pyspark/resultitr.py b/python/pyspark/resultitr.py new file mode 100644 index 0000000000000..c0a07f64ecb78 --- /dev/null +++ b/python/pyspark/resultitr.py @@ -0,0 +1,39 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +__all__ = ["ResultItr"] + +import collections + +class ResultItr(collections.Iterator): + """ + A special result iterator. This is used because the standard iterator can not be pickled + """ + def __init__(self, data): + self.data = data + self.index = 0 + self.maxindex = len(data) + def next(self): + if index == maxindex: + raise StopIteration + v = self.data[0] + self.data = data[1:] + return v + def __iter__(self): + return iter(self.data) + def __len__(self): + return len(self.data)