Skip to content

Commit

Permalink
speed up dense vector
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Nov 23, 2014
1 parent 90a6a46 commit ef6ce70
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.mllib.api.python

import java.nio.{ByteBuffer, ByteOrder}
import java.io.OutputStream
import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}

Expand Down Expand Up @@ -741,15 +742,31 @@ private[spark] object SerDe extends Serializable {

def saveState(obj: Object, out: OutputStream, pickler: Pickler) = {
val vector: DenseVector = obj.asInstanceOf[DenseVector]
saveObjects(out, pickler, vector.toArray)
val bytes = new Array[Byte](8 * vector.size)
val bb = ByteBuffer.wrap(bytes)
bb.order(ByteOrder.nativeOrder())
val db = bb.asDoubleBuffer()
db.put(vector.toArray)

out.write(Opcodes.BINSTRING)
out.write(PickleUtils.integer_to_bytes(bytes.length))
out.write(bytes)

out.write(Opcodes.TUPLE1)
}

def construct(args: Array[Object]): Object = {
require(args.length == 1)
if (args.length != 1) {
throw new PickleException("should be 1")
}
new DenseVector(args(0).asInstanceOf[Array[Double]])
val bytes = args(0).asInstanceOf[String].getBytes("ISO-8859-1")
val bb = ByteBuffer.wrap(bytes, 0, bytes.length)
bb.order(ByteOrder.nativeOrder())
val db = bb.asDoubleBuffer()
val ans = new Array[Double](bytes.length/8)
db.get(ans)
Vectors.dense(ans)
}
}

Expand Down
21 changes: 13 additions & 8 deletions python/pyspark/mllib/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,12 +173,16 @@ class DenseVector(Vector):
A dense vector represented by a value array.
"""
def __init__(self, ar):
if not isinstance(ar, array.array):
ar = array.array('d', ar)
if isinstance(ar, (bytearray, basestring)):
ar = np.frombuffer(ar, dtype=np.float64)
elif not isinstance(ar, np.ndarray):
ar = np.array(ar, dtype=np.float64)
if ar.dtype != np.float64:
ar.astype(np.float64)
self.array = ar

def __reduce__(self):
return DenseVector, (self.array,)
return DenseVector, (self.array.tostring(),)

def dot(self, other):
"""
Expand Down Expand Up @@ -207,9 +211,10 @@ def dot(self, other):
...
AssertionError: dimension mismatch
"""
if type(other) == np.ndarray and other.ndim > 1:
assert len(self) == other.shape[0], "dimension mismatch"
return np.dot(self.toArray(), other)
if type(other) == np.ndarray:
if other.ndim > 1:
assert len(self) == other.shape[0], "dimension mismatch"
return np.dot(self.array, other)
elif _have_scipy and scipy.sparse.issparse(other):
assert len(self) == other.shape[0], "dimension mismatch"
return other.transpose().dot(self.toArray())
Expand Down Expand Up @@ -261,7 +266,7 @@ def squared_distance(self, other):
return np.dot(diff, diff)

def toArray(self):
return np.array(self.array)
return self.array

def __getitem__(self, item):
return self.array[item]
Expand All @@ -276,7 +281,7 @@ def __repr__(self):
return "DenseVector([%s])" % (', '.join(_format_float(i) for i in self.array))

def __eq__(self, other):
return isinstance(other, DenseVector) and self.array == other.array
return isinstance(other, DenseVector) and all(self.array == other.array)

def __ne__(self, other):
return not self == other
Expand Down

0 comments on commit ef6ce70

Please sign in to comment.