Skip to content

Commit

Permalink
fix py style
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Nov 3, 2014
1 parent 4e84fce commit e98d9d0
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
8 changes: 6 additions & 2 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,7 @@ def _infer_schema(row):
fields = [StructField(k, _infer_type(v), True) for k, v in items]
return StructType(fields)


def _need_python_to_sql_conversion(dataType):
"""
Checks whether we need python to sql conversion for the given type.
Expand Down Expand Up @@ -665,12 +666,13 @@ def _need_python_to_sql_conversion(dataType):
return _need_python_to_sql_conversion(dataType.elementType)
elif isinstance(dataType, MapType):
return _need_python_to_sql_conversion(dataType.keyType) or \
_need_python_to_sql_conversion(dataType.valueType)
_need_python_to_sql_conversion(dataType.valueType)
elif isinstance(dataType, UserDefinedType):
return True
else:
return False


def _python_to_sql_converter(dataType):
"""
Returns a converter that converts a Python object into a SQL datum for the given type.
Expand All @@ -697,13 +699,14 @@ def _python_to_sql_converter(dataType):
if isinstance(dataType, StructType):
names, types = zip(*[(f.name, f.dataType) for f in dataType.fields])
converters = map(_python_to_sql_converter, types)

def converter(obj):
if isinstance(obj, dict):
return tuple(c(obj.get(n)) for n, c in zip(names, converters))
elif isinstance(obj, tuple):
if hasattr(obj, "_fields") or hasattr(obj, "__FIELDS__"):
return tuple(c(v) for c, v in zip(converters, obj))
elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs
elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs
d = dict(obj)
return tuple(c(d.get(n)) for n, c in zip(names, converters))
else:
Expand All @@ -723,6 +726,7 @@ def converter(obj):
else:
raise ValueError("Unexpected type %r" % dataType)


def _create_converter(obj, dataType):
"""Create an converter to drop the names of fields in obj"""
if isinstance(dataType, ArrayType):
Expand Down
4 changes: 3 additions & 1 deletion python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,7 @@ def heavy_foo(x):
self.sc.dump_profiles(d)
self.assertTrue("rdd_%d.pstats" % id in os.listdir(d))


class ExamplePointUDT(UserDefinedType):
"""
User-defined type (UDT) for ExamplePoint.
Expand All @@ -702,6 +703,7 @@ def module(cls):
def scalaUDT(cls):
return 'org.apache.spark.sql.test.ExamplePointUDT'


class ExamplePoint:
"""
An example class to demonstrate UDT in Scala, Java, and Python.
Expand All @@ -721,7 +723,7 @@ def __str__(self):

def __eq__(self, other):
return isinstance(other, ExamplePoint) and \
other.x == self.x and other.y == self.y
other.x == self.x and other.y == self.y


class SQLTests(ReusedPySparkTestCase):
Expand Down

0 comments on commit e98d9d0

Please sign in to comment.