diff --git a/python/pyspark/pandas/sql_formatter.py b/python/pyspark/pandas/sql_formatter.py index 350152a2cdb7a..18bbb9dd033d7 100644 --- a/python/pyspark/pandas/sql_formatter.py +++ b/python/pyspark/pandas/sql_formatter.py @@ -31,6 +31,7 @@ from pyspark.pandas.utils import default_session from pyspark.pandas.frame import DataFrame from pyspark.pandas.series import Series +from pyspark.sql.utils import is_remote __all__ = ["sql"] @@ -265,7 +266,10 @@ def _convert_value(self, val: Any, name: str) -> Optional[str]: val._to_spark().createOrReplaceTempView(df_name) return df_name elif isinstance(val, str): - return lit(val)._jc.expr().sql() # for escaped characters. + if is_remote(): + return f"'{val}'" + else: + return lit(val)._jc.expr().sql() # for escaped characters. else: return val diff --git a/python/pyspark/pandas/tests/connect/test_parity_sql.py b/python/pyspark/pandas/tests/connect/test_parity_sql.py index 6c2979f785a44..c042de6b90073 100644 --- a/python/pyspark/pandas/tests/connect/test_parity_sql.py +++ b/python/pyspark/pandas/tests/connect/test_parity_sql.py @@ -30,12 +30,6 @@ def test_sql_with_index_col(self): def test_sql_with_pandas_on_spark_objects(self): super().test_sql_with_pandas_on_spark_objects() - @unittest.skip( - "TODO(SPARK-43665): Enable PandasSQLStringFormatter.vformat to work with Spark Connect." - ) - def test_sql_with_python_objects(self): - super().test_sql_with_python_objects() - if __name__ == "__main__": from pyspark.pandas.tests.connect.test_parity_sql import * # noqa: F401