Skip to content

Commit

Permalink
compress vectors in VectorAssembler
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed May 7, 2015
1 parent 1712a7c commit 6d90d45
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,6 @@ object VectorAssembler {
case o =>
throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.")
}
Vectors.sparse(cur, indices.result(), values.result())
Vectors.sparse(cur, indices.result(), values.result()).compressed
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature
import org.scalatest.FunSuite

import org.apache.spark.SparkException
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{Row, SQLContext}

Expand Down Expand Up @@ -48,6 +48,14 @@ class VectorAssemblerSuite extends FunSuite with MLlibTestSparkContext {
}
}

test("assemble should compress vectors") {
import org.apache.spark.ml.feature.VectorAssembler.assemble
val v1 = assemble(0.0, 0.0, 0.0, Vectors.dense(4.0))
assert(v1.isInstanceOf[SparseVector])
val v2 = assemble(1.0, 2.0, 3.0, Vectors.sparse(1, Array(0), Array(4.0)))
assert(v2.isInstanceOf[DenseVector])
}

test("VectorAssembler") {
val df = sqlContext.createDataFrame(Seq(
(0, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, Array(1), Array(3.0)), 10L)
Expand Down

0 comments on commit 6d90d45

Please sign in to comment.