diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index b9a2829a1ca0b..5e7965554d825 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -464,7 +464,10 @@ def saveMetadata( metadataJson = DefaultParamsWriter._get_metadata_to_save( instance, sc, extraMetadata, paramMap ) - sc.parallelize([metadataJson], 1).saveAsTextFile(metadataPath) + spark = SparkSession.getActiveSession() + spark.createDataFrame( # type: ignore[union-attr] + [(metadataJson,)], schema=["value"] + ).coalesce(1).write.text(metadataPath) @staticmethod def _get_metadata_to_save( @@ -577,7 +580,8 @@ def loadMetadata(path: str, sc: "SparkContext", expectedClassName: str = "") -> If non empty, this is checked against the loaded metadata. """ metadataPath = os.path.join(path, "metadata") - metadataStr = sc.textFile(metadataPath, 1).first() + spark = SparkSession.getActiveSession() + metadataStr = spark.read.text(metadataPath).first()[0] # type: ignore[union-attr,index] loadedVals = DefaultParamsReader._parseMetaData(metadataStr, expectedClassName) return loadedVals