Skip to content

Commit

Permalink
[SPARK-48970][PYTHON][ML] Avoid using SparkSession.getActiveSession i…
Browse files Browse the repository at this point in the history
…n spark ML reader/writer

### What changes were proposed in this pull request?

`SparkSession.getActiveSession` is thread-local session, but spark ML reader / writer might be executed in different threads which causes `SparkSession.getActiveSession` returning None.

### Why are the changes needed?

It fixes the bug like:
```
        spark = SparkSession.getActiveSession()
>       spark.createDataFrame(  # type: ignore[union-attr]
            [(metadataJson,)], schema=["value"]
        ).coalesce(1).write.text(metadataPath)
E       AttributeError: 'NoneType' object has no attribute 'createDataFrame'
```

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Manually.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#47453 from WeichenXu123/SPARK-48970.

Authored-by: Weichen Xu <weichen.xu@databricks.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
  • Loading branch information
WeichenXu123 committed Jul 23, 2024
1 parent 285489b commit fba4c8c
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ private[ml] object DefaultParamsReader {
*/
def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = {
val metadataPath = new Path(path, "metadata").toString
val spark = SparkSession.getActiveSession.get
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
val metadataStr = spark.read.text(metadataPath).first().getString(0)
parseMetadata(metadataStr, expectedClassName)
}
Expand Down
12 changes: 6 additions & 6 deletions python/pyspark/ml/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,10 +464,10 @@ def saveMetadata(
metadataJson = DefaultParamsWriter._get_metadata_to_save(
instance, sc, extraMetadata, paramMap
)
spark = SparkSession.getActiveSession()
spark.createDataFrame( # type: ignore[union-attr]
[(metadataJson,)], schema=["value"]
).coalesce(1).write.text(metadataPath)
spark = SparkSession._getActiveSessionOrCreate()
spark.createDataFrame([(metadataJson,)], schema=["value"]).coalesce(1).write.text(
metadataPath
)

@staticmethod
def _get_metadata_to_save(
Expand Down Expand Up @@ -580,8 +580,8 @@ def loadMetadata(path: str, sc: "SparkContext", expectedClassName: str = "") ->
If non empty, this is checked against the loaded metadata.
"""
metadataPath = os.path.join(path, "metadata")
spark = SparkSession.getActiveSession()
metadataStr = spark.read.text(metadataPath).first()[0] # type: ignore[union-attr,index]
spark = SparkSession._getActiveSessionOrCreate()
metadataStr = spark.read.text(metadataPath).first()[0] # type: ignore[index]
loadedVals = DefaultParamsReader._parseMetaData(metadataStr, expectedClassName)
return loadedVals

Expand Down

0 comments on commit fba4c8c

Please sign in to comment.