diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py index b1b1c04c7d..aaad495dfd 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py @@ -338,6 +338,10 @@ def to_spark_df(self) -> pyspark.sql.DataFrame: def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame: """Return dataset as Pandas DataFrame synchronously""" + spark_session = get_spark_session_or_start_new_with_repoconfig( + self._config.offline_store + ) + spark_session.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true") return self.to_spark_df().toPandas() def _to_arrow_internal(self, timeout: Optional[int] = None) -> pyarrow.Table: @@ -457,7 +461,7 @@ def get_spark_session_or_start_new_with_repoconfig( def _get_entity_df_event_timestamp_range( - entity_df: Union[pd.DataFrame, str], + entity_df: Union[pd.DataFrame, str, pyspark.sql.DataFrame], entity_df_event_timestamp_col: str, spark_session: SparkSession, ) -> Tuple[datetime, datetime]: @@ -496,7 +500,8 @@ def _get_entity_df_event_timestamp_range( def _get_entity_schema( - spark_session: SparkSession, entity_df: Union[pandas.DataFrame, str] + spark_session: SparkSession, + entity_df: Union[pandas.DataFrame, str, pyspark.sql.DataFrame], ) -> Dict[str, np.dtype]: if isinstance(entity_df, pd.DataFrame): return dict(zip(entity_df.columns, entity_df.dtypes)) @@ -518,7 +523,7 @@ def _get_entity_schema( def _upload_entity_df( spark_session: SparkSession, table_name: str, - entity_df: Union[pandas.DataFrame, str], + entity_df: Union[pandas.DataFrame, str, pyspark.sql.DataFrame], event_timestamp_col: str, ) -> None: if isinstance(entity_df, pd.DataFrame):