Skip to content

Commit

Permalink
Python tests now pass with iterator pandas
Browse files Browse the repository at this point in the history
  • Loading branch information
holdenk committed Apr 8, 2014
1 parent 3bcd81d commit d052c07
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 6 deletions.
3 changes: 2 additions & 1 deletion python/pyspark/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""

from pyspark.resultitr import ResultItr

def _do_python_join(rdd, other, numPartitions, dispatch):
vs = rdd.map(lambda (k, v): (k, (1, v)))
Expand Down Expand Up @@ -88,5 +89,5 @@ def dispatch(seq):
vbuf.append(v)
elif n == 2:
wbuf.append(v)
return (iter(vbuf), iter(wbuf))
return (ResultItr(vbuf), ResultItr(wbuf))
return vs.union(ws).groupByKey(numPartitions).mapValues(dispatch)
10 changes: 5 additions & 5 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from pyspark.statcounter import StatCounter
from pyspark.rddsampler import RDDSampler
from pyspark.storagelevel import StorageLevel
from pyspark.resultitr import ResultItr

from py4j.java_collections import ListConverter, MapConverter

Expand Down Expand Up @@ -1118,7 +1119,7 @@ def groupByKey(self, numPartitions=None):
Hash-partitions the resulting RDD with into numPartitions partitions.
>>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
>>> sorted(x.groupByKey().collect())
>>> map((lambda (x,y): (x, list(y))), sorted(x.groupByKey().collect()))
[('a', [1, 1]), ('b', [1])]
"""

Expand All @@ -1133,7 +1134,7 @@ def mergeCombiners(a, b):
return a + b

return self.combineByKey(createCombiner, mergeValue, mergeCombiners,
numPartitions).mapValues(lambda x: iter(x))
numPartitions).mapValues(lambda x: ResultItr(x))

# TODO: add tests
def flatMapValues(self, f):
Expand Down Expand Up @@ -1180,7 +1181,7 @@ def cogroup(self, other, numPartitions=None):
>>> x = sc.parallelize([("a", 1), ("b", 4)])
>>> y = sc.parallelize([("a", 2)])
>>> sorted(list(x.cogroup(y).collect()))
>>> map((lambda (x,y): (x, (list(y[0]), list(y[1])))), sorted(list(x.cogroup(y).collect())))
[('a', ([1], [2])), ('b', ([4], []))]
"""
return python_cogroup(self, other, numPartitions)
Expand Down Expand Up @@ -1217,7 +1218,7 @@ def keyBy(self, f):
>>> x = sc.parallelize(range(0,3)).keyBy(lambda x: x*x)
>>> y = sc.parallelize(zip(range(0,5), range(0,5)))
>>> sorted(x.cogroup(y).collect())
>>> map((lambda (x,y): (x, (list(y[0]), (list(y[1]))))), sorted(x.cogroup(y).collect()))
[(0, ([0], [0])), (1, ([1], [1])), (2, ([], [2])), (3, ([], [3])), (4, ([2], [4]))]
"""
return self.map(lambda x: (f(x), x))
Expand Down Expand Up @@ -1317,7 +1318,6 @@ def getStorageLevel(self):
# keys in the pairs. This could be an expensive operation, since those
# hashes aren't retained.


class PipelinedRDD(RDD):
"""
Pipelined maps:
Expand Down
39 changes: 39 additions & 0 deletions python/pyspark/resultitr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

__all__ = ["ResultItr"]

import collections

class ResultItr(collections.Iterator):
"""
A special result iterator. This is used because the standard iterator can not be pickled
"""
def __init__(self, data):
self.data = data
self.index = 0
self.maxindex = len(data)
def next(self):
if index == maxindex:
raise StopIteration
v = self.data[0]
self.data = data[1:]
return v
def __iter__(self):
return iter(self.data)
def __len__(self):
return len(self.data)

0 comments on commit d052c07

Please sign in to comment.