Skip to content

Commit

Permalink
Fix.
Browse files Browse the repository at this point in the history
  • Loading branch information
ueshin committed Jul 26, 2023
1 parent b7cc3af commit d60e289
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
4 changes: 3 additions & 1 deletion python/pyspark/sql/connect/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,14 +170,16 @@ def pyspark_types_to_proto_types(data_type: DataType) -> pb2.DataType:
ret.year_month_interval.start_field = data_type.startField
ret.year_month_interval.end_field = data_type.endField
elif isinstance(data_type, StructType):
struct = pb2.DataType.Struct()
for field in data_type.fields:
struct_field = pb2.DataType.StructField()
struct_field.name = field.name
struct_field.data_type.CopyFrom(pyspark_types_to_proto_types(field.dataType))
struct_field.nullable = field.nullable
if field.metadata is not None and len(field.metadata) > 0:
struct_field.metadata = json.dumps(field.metadata)
ret.struct.fields.append(struct_field)
struct.fields.append(struct_field)
ret.struct.CopyFrom(struct)
elif isinstance(data_type, MapType):
ret.map.key_type.CopyFrom(pyspark_types_to_proto_types(data_type.keyType))
ret.map.value_type.CopyFrom(pyspark_types_to_proto_types(data_type.valueType))
Expand Down
16 changes: 8 additions & 8 deletions python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
pandas_requirement_message,
pyarrow_requirement_message,
)
from pyspark.testing.utils import QuietTest
from pyspark.testing.utils import QuietTest, assertDataFrameEqual

if have_pandas:
import pandas as pd
Expand Down Expand Up @@ -529,7 +529,7 @@ def iter_f(it):
self.assertListEqual([i, i + 1], f[1])

def test_vectorized_udf_struct_empty(self):
df = self.spark.range(10)
df = self.spark.range(3)
return_type = StructType()

def _scalar_f(id):
Expand All @@ -542,12 +542,12 @@ def iter_f(it):
for id in it:
yield _scalar_f(id)

for f, udf_type in [(scalar_f, PandasUDFType.SCALAR), (iter_f, PandasUDFType.SCALAR_ITER)]:
actual = df.withColumn("f", f(col("id"))).collect()
for i, row in enumerate(actual):
id, f = row
self.assertEqual(i, id)
self.assertEqual(Row(), f)
for f, udf_type in [(scalar_f, "SCALAR"), (iter_f, "SCALAR_ITER")]:
with self.subTest(udf_type=udf_type):
assertDataFrameEqual(
df.withColumn("f", f(col("id"))),
[Row(id=0, f=Row()), Row(id=1, f=Row()), Row(id=2, f=Row())],
)

def test_vectorized_udf_nested_struct(self):
with QuietTest(self.sc):
Expand Down

0 comments on commit d60e289

Please sign in to comment.