Skip to content

Commit

Permalink
speedup SparseVector
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Nov 23, 2014
1 parent ef6ce70 commit f0d3c40
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,6 @@ private[spark] object SerDe extends Serializable {
out.write(Opcodes.BINSTRING)
out.write(PickleUtils.integer_to_bytes(bytes.length))
out.write(bytes)

out.write(Opcodes.TUPLE1)
}

Expand Down Expand Up @@ -792,15 +791,32 @@ private[spark] object SerDe extends Serializable {

def saveState(obj: Object, out: OutputStream, pickler: Pickler) = {
val v: SparseVector = obj.asInstanceOf[SparseVector]
saveObjects(out, pickler, v.size, v.indices, v.values)
val n = v.indices.size
val bytes = new Array[Byte](12 * n)
val order = ByteOrder.nativeOrder()
ByteBuffer.wrap(bytes).order(order).asIntBuffer().put(v.indices)
ByteBuffer.wrap(bytes, 4 * n, 8 * n).order(order).asDoubleBuffer().put(v.values)

out.write(Opcodes.BININT)
out.write(PickleUtils.integer_to_bytes(v.size))
out.write(Opcodes.BINSTRING)
out.write(PickleUtils.integer_to_bytes(bytes.length))
out.write(bytes)
out.write(Opcodes.TUPLE2)
}

def construct(args: Array[Object]): Object = {
if (args.length != 3) {
throw new PickleException("should be 3")
if (args.length != 2) {
throw new PickleException("should be 2")
}
new SparseVector(args(0).asInstanceOf[Int], args(1).asInstanceOf[Array[Int]],
args(2).asInstanceOf[Array[Double]])
val bytes = args(1).asInstanceOf[String].getBytes("ISO-8859-1")
val n = bytes.length / 12
val indices = new Array[Int](n)
val values = new Array[Double](n)
val order = ByteOrder.nativeOrder()
ByteBuffer.wrap(bytes, 0, n * 4).order(order).asIntBuffer().get(indices)
ByteBuffer.wrap(bytes, n * 4, n * 8).order(order).asDoubleBuffer().get(values)
new SparseVector(args(0).asInstanceOf[Int], indices, values)
}
}

Expand Down
29 changes: 17 additions & 12 deletions python/pyspark/mllib/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,21 +316,26 @@ def __init__(self, size, *args):
assert 1 <= len(args) <= 2, "must pass either 2 or 3 arguments"
if len(args) == 1:
pairs = args[0]
if type(pairs) == dict:
pairs = pairs.items()
pairs = sorted(pairs)
self.indices = array.array('i', [p[0] for p in pairs])
self.values = array.array('d', [p[1] for p in pairs])
if isinstance(pairs, basestring):
l = len(pairs) / (4 + 8)
self.indices = np.frombuffer(pairs, np.uint32, count=l)
self.values = np.frombuffer(pairs, np.float64, l, offset=l * 4)
else:
if type(pairs) == dict:
pairs = pairs.items()
pairs = sorted(pairs)
self.indices = np.array([p[0] for p in pairs], dtype=np.uint32)
self.values = np.array([p[1] for p in pairs], dtype=np.float64)
else:
assert len(args[0]) == len(args[1]), "index and value arrays not same length"
self.indices = array.array('i', args[0])
self.values = array.array('d', args[1])
self.indices = np.array(args[0], dtype=np.uint32)
self.values = np.array(args[1], dtype=np.float64)
for i in xrange(len(self.indices) - 1):
if self.indices[i] >= self.indices[i + 1]:
raise TypeError("indices array must be sorted")

def __reduce__(self):
return (SparseVector, (self.size, self.indices, self.values))
return (SparseVector, (self.size, self.indices.tostring() + self.values.tostring()))

def dot(self, other):
"""
Expand Down Expand Up @@ -466,8 +471,8 @@ def toArray(self):
Returns a copy of this SparseVector as a 1-dimensional NumPy array.
"""
arr = np.zeros((self.size,), dtype=np.float64)
for i in xrange(len(self.indices)):
arr[self.indices[i]] = self.values[i]
for i, v in zip(self.indices, self.values):
arr[i] = v
return arr

def __len__(self):
Expand Down Expand Up @@ -498,8 +503,8 @@ def __eq__(self, other):
"""
return (isinstance(other, self.__class__)
and other.size == self.size
and other.indices == self.indices
and other.values == self.values)
and all(other.indices == self.indices)
and all(other.values == self.values))

def __ne__(self, other):
return not self.__eq__(other)
Expand Down
1 change: 1 addition & 0 deletions python/pyspark/mllib/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def _squared_distance(a, b):
class VectorTests(PySparkTestCase):

def _test_serialize(self, v):
self.assertEqual(v, ser.loads(ser.dumps(v)))
jvec = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(v)))
nv = ser.loads(str(self.sc._jvm.SerDe.dumps(jvec)))
self.assertEqual(v, nv)
Expand Down

0 comments on commit f0d3c40

Please sign in to comment.