Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Nov 3, 2014
1 parent 75223db commit 7c4a6a9
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 33 deletions.
51 changes: 25 additions & 26 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
from array import array
from operator import itemgetter
from itertools import imap
import importlib

from py4j.protocol import Py4JError
from py4j.java_collections import ListConverter, MapConverter
Expand Down Expand Up @@ -416,25 +415,15 @@ class UserDefinedType(DataType):
"""

@classmethod
def sqlType(self):
"""
Underlying SQL storage type for this UDT.
"""
raise NotImplementedError("UDT must implement sqlType().")

@classmethod
def serialize(self, obj):
"""
Converts the a user-type object into a SQL datum.
"""
raise NotImplementedError("UDT must implement serialize().")
def typeName(cls):
return cls.__name__.lower()

@classmethod
def deserialize(self, datum):
def sqlType(cls):
"""
Converts a SQL datum into a user-type object.
Underlying SQL storage type for this UDT.
"""
raise NotImplementedError("UDT must implement deserialize().")
raise NotImplementedError("UDT must implement sqlType().")

@classmethod
def module(cls):
Expand All @@ -450,25 +439,35 @@ def scalaUDT(cls):
"""
raise NotImplementedError("UDT must have a paired Scala UDT.")

@classmethod
def json(cls):
return json.dumps(cls.jsonValue(), separators=(',', ':'), sort_keys=True)
def serialize(self, obj):
"""
Converts the a user-type object into a SQL datum.
"""
raise NotImplementedError("UDT must implement serialize().")

@classmethod
def jsonValue(cls):
def deserialize(self, datum):
"""
Converts a SQL datum into a user-type object.
"""
raise NotImplementedError("UDT must implement deserialize().")

def json(self):
return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True)

def jsonValue(self):
schema = {
"type": "udt",
"pyModule": cls.module(),
"pyClass": cls.__name__}
if cls.scalaUDT() is not None:
schema['class'] = cls.scalaUDT()
"pyModule": self.module(),
"pyClass": type(self).__name__,
"class": self.scalaUDT()
}
return schema

@classmethod
def fromJson(cls, json):
pyModule = json['pyModule']
pyClass = json['pyClass']
m = importlib.import_module(pyModule)
m = __import__(pyModule, globals(), locals(), [pyClass], -1)
UDT = getattr(m, pyClass)
return UDT()

Expand Down
12 changes: 6 additions & 6 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,12 +689,6 @@ class ExamplePointUDT(UserDefinedType):
def sqlType(self):
return ArrayType(DoubleType(), False)

def serialize(self, obj):
return [obj.x, obj.y]

def deserialize(self, datum):
return ExamplePoint(datum[0], datum[1])

@classmethod
def module(cls):
return 'pyspark.tests'
Expand All @@ -703,6 +697,12 @@ def module(cls):
def scalaUDT(cls):
return 'org.apache.spark.sql.test.ExamplePointUDT'

def serialize(self, obj):
return [obj.x, obj.y]

def deserialize(self, datum):
return ExamplePoint(datum[0], datum[1])


class ExamplePoint:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.{Optimizer, DefaultOptimizer}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.types.UserDefinedType
import org.apache.spark.sql.execution.{SparkStrategies, _}
import org.apache.spark.sql.json._
import org.apache.spark.sql.parquet.ParquetRelation
Expand Down Expand Up @@ -483,7 +484,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
case ArrayType(_, _) => true
case MapType(_, _, _) => true
case StructType(_) => true
case udt: UserDefinedType[_] => true
case _: UserDefinedType[_] => true
case other => false
}

Expand Down

0 comments on commit 7c4a6a9

Please sign in to comment.