diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 4810cf40e2315..420f2a887e3b7 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -817,6 +817,38 @@ def test_schema_with_collations_on_non_string_types(self): PySparkTypeError, lambda: _parse_datatype_json_string(collations_in_nested_map_json) ) + def test_array_type_from_json(self): + arrayWithoutCollations = ArrayType(StringType(), True) + arrayWithCollations = ArrayType(StringType("UNICODE"), True) + array_json = {"type": "array", "elementType": "string", "containsNull": True} + collationsMap = {"element": "UNICODE"} + + self.assertEqual(arrayWithoutCollations, ArrayType.fromJson(array_json)) + self.assertEqual( + arrayWithCollations, + ArrayType.fromJson(array_json, fieldPath="", collationsMap=collationsMap), + ) + self.assertEqual( + arrayWithCollations, ArrayType.fromJson(array_json, collationsMap=collationsMap) + ) + + def test_map_type_from_json(self): + mapWithoutCollations = MapType(StringType(), StringType(), True) + mapWithCollations = MapType(StringType("UNICODE"), StringType("UNICODE"), True) + map_json = { + "type": "map", + "keyType": "string", + "valueType": "string", + "valueContainsNull": True, + } + collationsMap = {"key": "UNICODE", "value": "UNICODE"} + + self.assertEqual(mapWithoutCollations, MapType.fromJson(map_json)) + self.assertEqual( + mapWithCollations, MapType.fromJson(map_json, fieldPath="", collationsMap=collationsMap) + ) + self.assertEqual(mapWithCollations, MapType.fromJson(map_json, collationsMap=collationsMap)) + def test_schema_with_bad_collations_provider(self): from pyspark.sql.types import _parse_datatype_json_string, _COLLATIONS_METADATA_KEY diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index d2adc53a3618f..d4286afc1b03c 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -771,11 +771,13 @@ def jsonValue(self) -> Dict[str, Any]: def fromJson( cls, json: Dict[str, Any], - fieldPath: str, - collationsMap: Optional[Dict[str, str]], + fieldPath: str = "", + collationsMap: Optional[Dict[str, str]] = None, ) -> "ArrayType": elementType = _parse_datatype_json_value( - json["elementType"], fieldPath + ".element", collationsMap + json["elementType"], + "element" if fieldPath == "" else fieldPath + ".element", + collationsMap, ) return ArrayType(elementType, json["containsNull"]) @@ -911,12 +913,14 @@ def jsonValue(self) -> Dict[str, Any]: def fromJson( cls, json: Dict[str, Any], - fieldPath: str, - collationsMap: Optional[Dict[str, str]], + fieldPath: str = "", + collationsMap: Optional[Dict[str, str]] = None, ) -> "MapType": - keyType = _parse_datatype_json_value(json["keyType"], fieldPath + ".key", collationsMap) + keyType = _parse_datatype_json_value( + json["keyType"], "key" if fieldPath == "" else fieldPath + ".key", collationsMap + ) valueType = _parse_datatype_json_value( - json["valueType"], fieldPath + ".value", collationsMap + json["valueType"], "value" if fieldPath == "" else fieldPath + ".value", collationsMap ) return MapType( keyType,