Skip to content

Commit

Permalink
SPARK-1428: MLlib should convert non-float64 NumPy arrays to float64 …
Browse files Browse the repository at this point in the history
…instead of complaining
  • Loading branch information
techaddict committed Apr 10, 2014
1 parent 0d0493f commit 3bdf5f6
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions python/pyspark/mllib/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
# limitations under the License.
#

from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot, shape
from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot, shape, complex, issubdtype
from pyspark import SparkContext, RDD
import numpy as np

from pyspark.serializers import Serializer
import struct
Expand Down Expand Up @@ -47,13 +48,22 @@ def _deserialize_byte_array(shape, ba, offset):
return ar.copy()

def _serialize_double_vector(v):
"""Serialize a double vector into a mutually understood format."""
"""Serialize a double vector into a mutually understood format.
>>> x = array([1,2,3])
>>> y = _deserialize_double_vector(_serialize_double_vector(x))
>>> array_equal(y, array([1.0, 2.0, 3.0]))
True
"""
if type(v) != ndarray:
raise TypeError("_serialize_double_vector called on a %s; "
"wanted ndarray" % type(v))
"""complex is only datatype that can't be converted to float64"""
if issubdtype(v.dtype, complex):
raise TypeError("_serialize_double_vector called on a %s; "
"wanted ndarray" % type(v))
if v.dtype != float64:
raise TypeError("_serialize_double_vector called on an ndarray of %s; "
"wanted ndarray of float64" % v.dtype)
v = v.astype(float64)
if v.ndim != 1:
raise TypeError("_serialize_double_vector called on a %ddarray; "
"wanted a 1darray" % v.ndim)
Expand Down

0 comments on commit 3bdf5f6

Please sign in to comment.