diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index be0f904dc14d9..3059c96e893ee 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -1973,7 +1973,7 @@ def collect(self): [Row(field1=1, field2=u'row1'), ..., Row(field1=3, field2=u'row3')] """ with SCCallSiteSync(self._sc) as css: - bytesInJava = self._jdf.collectToPython().iterator() + bytesInJava = self._jdf.javaToPython().collect().iterator() cls = _create_cls(self.schema()) tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir) tempFile.close() @@ -1997,14 +1997,14 @@ def take(self, num): return self.limit(num).collect() def map(self, f): + """ Return a new RDD by applying a function to each Row, it's a + shorthand for df.rdd.map() + """ return self.rdd.map(f) - # Convert each object in the RDD to a Row with the right class - # for this DataFrame, so that fields can be accessed as attributes. def mapPartitions(self, f, preservesPartitioning=False): """ - Return a new RDD by applying a function to each partition of this RDD, - while tracking the index of the original partition. + Return a new RDD by applying a function to each partition. >>> rdd = sc.parallelize([1, 2, 3, 4], 4) >>> def f(iterator): yield 1 @@ -2013,21 +2013,28 @@ def mapPartitions(self, f, preservesPartitioning=False): """ return self.rdd.mapPartitions(f, preservesPartitioning) - # We override the default cache/persist/checkpoint behavior - # as we want to cache the underlying DataFrame object in the JVM, - # not the PythonRDD checkpointed by the super class def cache(self): + """ Persist with the default storage level (C{MEMORY_ONLY_SER}). + """ self.is_cached = True self._jdf.cache() return self def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER): + """ Set the storage level to persist its values across operations + after the first time it is computed. This can only be used to assign + a new storage level if the RDD does not have a storage level set yet. + If no storage level is specified defaults to (C{MEMORY_ONLY_SER}). + """ self.is_cached = True javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel) self._jdf.persist(javaStorageLevel) return self def unpersist(self, blocking=True): + """ Mark it as non-persistent, and remove all blocks for it from + memory and disk. + """ self.is_cached = False self._jdf.unpersist(blocking) return self @@ -2359,11 +2366,11 @@ def _scalaMethod(name): """ Translate operators into methodName in Scala For example: - >>> scalaMethod('+') + >>> _scalaMethod('+') '$plus' - >>> scalaMethod('>=') + >>> _scalaMethod('>=') '$greater$eq' - >>> scalaMethod('cast') + >>> _scalaMethod('cast') 'cast' """ return ''.join(SCALA_METHOD_MAPPINGS.get(c, c) for c in name) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index eb48102229837..e8e207af462de 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -946,8 +946,7 @@ def test_apply_schema_with_udt(self): schema = StructType([StructField("label", DoubleType(), False), StructField("point", ExamplePointUDT(), False)]) df = self.sqlCtx.applySchema(rdd, schema) - # TODO: test collect with UDT - point = df.rdd.first().point + point = df.head().point self.assertEquals(point, ExamplePoint(1.0, 2.0)) def test_parquet_with_udt(self): @@ -984,11 +983,12 @@ def test_column_select(self): self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect()) def test_aggregator(self): - from pyspark.sql import Aggregator as Agg df = self.df g = df.groupBy() self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0])) self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect()) + # TODO(davies): fix aggregators + from pyspark.sql import Aggregator as Agg # self.assertEqual((0, '100'), tuple(g.agg(Agg.first(df.key), Agg.last(df.value)).first())) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index dad36864b923e..d0bb3640f8c1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -590,17 +590,7 @@ class DataFrame protected[sql]( */ protected[sql] def javaToPython: JavaRDD[Array[Byte]] = { val fieldTypes = schema.fields.map(_.dataType) - val jrdd = this.rdd.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD() + val jrdd = rdd.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD() SerDeUtil.javaToPython(jrdd) } - /** - * Serializes the Array[Row] returned by collect(), using the same format as javaToPython. - */ - protected[sql] def collectToPython: JList[Array[Byte]] = { - val fieldTypes = schema.fields.map(_.dataType) - val pickle = new Pickler - new ArrayList[Array[Byte]](collect().map { row => - EvaluatePython.rowToArray(row, fieldTypes) - }.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable) - } }