From d60e289957e75f18e44bd4f3d7fbe28f03340035 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Wed, 26 Jul 2023 12:18:44 -0700 Subject: [PATCH] Fix. --- python/pyspark/sql/connect/types.py | 4 +++- .../sql/tests/pandas/test_pandas_udf_scalar.py | 16 ++++++++-------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/python/pyspark/sql/connect/types.py b/python/pyspark/sql/connect/types.py index 2a21cdf067513..0db2833d2c1aa 100644 --- a/python/pyspark/sql/connect/types.py +++ b/python/pyspark/sql/connect/types.py @@ -170,6 +170,7 @@ def pyspark_types_to_proto_types(data_type: DataType) -> pb2.DataType: ret.year_month_interval.start_field = data_type.startField ret.year_month_interval.end_field = data_type.endField elif isinstance(data_type, StructType): + struct = pb2.DataType.Struct() for field in data_type.fields: struct_field = pb2.DataType.StructField() struct_field.name = field.name @@ -177,7 +178,8 @@ def pyspark_types_to_proto_types(data_type: DataType) -> pb2.DataType: struct_field.nullable = field.nullable if field.metadata is not None and len(field.metadata) > 0: struct_field.metadata = json.dumps(field.metadata) - ret.struct.fields.append(struct_field) + struct.fields.append(struct_field) + ret.struct.CopyFrom(struct) elif isinstance(data_type, MapType): ret.map.key_type.CopyFrom(pyspark_types_to_proto_types(data_type.keyType)) ret.map.value_type.CopyFrom(pyspark_types_to_proto_types(data_type.valueType)) diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py index a6380b5d44cac..7a80547b3fc8b 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py @@ -60,7 +60,7 @@ pandas_requirement_message, pyarrow_requirement_message, ) -from pyspark.testing.utils import QuietTest +from pyspark.testing.utils import QuietTest, assertDataFrameEqual if have_pandas: import pandas as pd @@ -529,7 +529,7 @@ def iter_f(it): self.assertListEqual([i, i + 1], f[1]) def test_vectorized_udf_struct_empty(self): - df = self.spark.range(10) + df = self.spark.range(3) return_type = StructType() def _scalar_f(id): @@ -542,12 +542,12 @@ def iter_f(it): for id in it: yield _scalar_f(id) - for f, udf_type in [(scalar_f, PandasUDFType.SCALAR), (iter_f, PandasUDFType.SCALAR_ITER)]: - actual = df.withColumn("f", f(col("id"))).collect() - for i, row in enumerate(actual): - id, f = row - self.assertEqual(i, id) - self.assertEqual(Row(), f) + for f, udf_type in [(scalar_f, "SCALAR"), (iter_f, "SCALAR_ITER")]: + with self.subTest(udf_type=udf_type): + assertDataFrameEqual( + df.withColumn("f", f(col("id"))), + [Row(id=0, f=Row()), Row(id=1, f=Row()), Row(id=2, f=Row())], + ) def test_vectorized_udf_nested_struct(self): with QuietTest(self.sc):