Skip to content

Commit

Permalink
[SPARK-45597][PYTHON][SQL] Support creating table using a Python data…
Browse files Browse the repository at this point in the history
… source in SQL (DSv2 exec)

### What changes were proposed in this pull request?

This PR is same as apache#44233 but does not use `V1Table` but the original DSv2 interface by reusing UDTF execution code.

### Why are the changes needed?

In order for Python Data Source to be able to be used in all other place including SparkR, Scala together.

### Does this PR introduce _any_ user-facing change?

Yes. Users can register their Python Data Source, and use them in SQL, SparkR, etc.

### How was this patch tested?

Unittests were added, and manually tested.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#44269
Closes apache#44233
Closes apache#43784

Closes apache#44305 from HyukjinKwon/SPARK-45597-3.

Authored-by: Hyukjin Kwon <gurwls223@apache.org>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
HyukjinKwon committed Dec 15, 2023
1 parent 69237c9 commit a1b0da2
Show file tree
Hide file tree
Showing 26 changed files with 290 additions and 359 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,11 @@

package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.api.python.PythonFunction
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, PythonUDF, PythonUDTF}
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
import org.apache.spark.sql.types.{BinaryType, StructType}
import org.apache.spark.sql.types.StructType

/**
* FlatMap groups using a udf: pandas.Dataframe -> pandas.DataFrame.
Expand Down Expand Up @@ -103,42 +101,6 @@ case class PythonMapInArrow(
copy(child = newChild)
}

/**
* Represents a Python data source.
*/
case class PythonDataSource(
dataSource: PythonFunction,
outputSchema: StructType,
override val output: Seq[Attribute]) extends LeafNode {
require(output.forall(_.resolved),
"Unresolved attributes found when constructing PythonDataSource.")
override protected def stringArgs: Iterator[Any] = {
Iterator(output)
}
final override val nodePatterns: Seq[TreePattern] = Seq(PYTHON_DATA_SOURCE)
}

/**
* Represents a list of Python data source partitions.
*/
case class PythonDataSourcePartitions(
output: Seq[Attribute],
partitions: Seq[Array[Byte]]) extends LeafNode {
override protected def stringArgs: Iterator[Any] = {
if (partitions.isEmpty) {
Iterator("<empty>", output)
} else {
Iterator(output)
}
}
}

object PythonDataSourcePartitions {
def schema: StructType = new StructType().add("partition", BinaryType)

def getOutputAttrs: Seq[Attribute] = toAttributes(schema)
}

/**
* Flatmap cogroups using a udf: pandas.Dataframe, pandas.Dataframe -> pandas.Dataframe
* This is used by DataFrame.groupby().cogroup().apply().
Expand Down
48 changes: 6 additions & 42 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@

package org.apache.spark.sql

import java.util.{Locale, Properties, ServiceConfigurationError}
import java.util.{Locale, Properties}

import scala.jdk.CollectionConverters._
import scala.util.{Failure, Success, Try}

import org.apache.spark.{Partition, SparkClassNotFoundException, SparkThrowable}
import org.apache.spark.Partition
import org.apache.spark.annotation.Stable
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -209,45 +208,10 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
throw QueryCompilationErrors.pathOptionNotSetCorrectlyWhenReadingError()
}

val isUserDefinedDataSource =
sparkSession.sessionState.dataSourceManager.dataSourceExists(source)

Try(DataSource.lookupDataSourceV2(source, sparkSession.sessionState.conf)) match {
case Success(providerOpt) =>
// The source can be successfully loaded as either a V1 or a V2 data source.
// Check if it is also a user-defined data source.
if (isUserDefinedDataSource) {
throw QueryCompilationErrors.foundMultipleDataSources(source)
}
providerOpt.flatMap { provider =>
DataSourceV2Utils.loadV2Source(
sparkSession, provider, userSpecifiedSchema, extraOptions, source, paths: _*)
}.getOrElse(loadV1Source(paths: _*))
case Failure(exception) =>
// Exceptions are thrown while trying to load the data source as a V1 or V2 data source.
// For the following not found exceptions, if the user-defined data source is defined,
// we can instead return the user-defined data source.
val isNotFoundError = exception match {
case _: NoClassDefFoundError | _: SparkClassNotFoundException => true
case e: SparkThrowable => e.getErrorClass == "DATA_SOURCE_NOT_FOUND"
case e: ServiceConfigurationError => e.getCause.isInstanceOf[NoClassDefFoundError]
case _ => false
}
if (isNotFoundError && isUserDefinedDataSource) {
loadUserDefinedDataSource(paths)
} else {
// Throw the original exception.
throw exception
}
}
}

private def loadUserDefinedDataSource(paths: Seq[String]): DataFrame = {
val builder = sparkSession.sessionState.dataSourceManager.lookupDataSource(source)
// Add `path` and `paths` options to the extra options if specified.
val optionsWithPath = DataSourceV2Utils.getOptionsWithPaths(extraOptions, paths: _*)
val plan = builder(sparkSession, source, userSpecifiedSchema, optionsWithPath)
Dataset.ofRows(sparkSession, plan)
DataSource.lookupDataSourceV2(source, sparkSession.sessionState.conf).flatMap { provider =>
DataSourceV2Utils.loadV2Source(sparkSession, provider, userSpecifiedSchema, extraOptions,
source, paths: _*)
}.getOrElse(loadV1Source(paths: _*))
}

private def loadV1Source(paths: String*) = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,6 @@ private[sql] class DataSourceRegistration private[sql] (dataSourceManager: DataS
| pythonExec: ${dataSource.dataSourceCls.pythonExec}
""".stripMargin)

dataSourceManager.registerDataSource(name, dataSource.builder)
dataSourceManager.registerDataSource(name, dataSource)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,7 @@ class SparkSession private(
DataSource.lookupDataSource(runner, sessionState.conf) match {
case source if classOf[ExternalCommandRunner].isAssignableFrom(source) =>
Dataset.ofRows(self, ExternalCommandExecutor(
source.getDeclaredConstructor().newInstance()
DataSource.newDataSourceInstance(runner, source)
.asInstanceOf[ExternalCommandRunner], command, options))

case _ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.optimizer._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.execution.datasources.{PlanPythonDataSourceScan, PruneFileSourcePartitions, SchemaPruning, V1Writes}
import org.apache.spark.sql.execution.datasources.{PruneFileSourcePartitions, SchemaPruning, V1Writes}
import org.apache.spark.sql.execution.datasources.v2.{GroupBasedRowLevelOperationScanPlanning, OptimizeMetadataOnlyDeleteFromTable, V2ScanPartitioningAndOrdering, V2ScanRelationPushDown, V2Writes}
import org.apache.spark.sql.execution.dynamicpruning.{CleanupDynamicPruningFilters, PartitionPruning, RowLevelOperationRuntimeGroupFiltering}
import org.apache.spark.sql.execution.python.{ExtractGroupingPythonUDFFromAggregate, ExtractPythonUDFFromAggregate, ExtractPythonUDFs, ExtractPythonUDTFs}
Expand All @@ -42,8 +42,7 @@ class SparkOptimizer(
V2ScanRelationPushDown :+
V2ScanPartitioningAndOrdering :+
V2Writes :+
PruneFileSourcePartitions :+
PlanPythonDataSourceScan
PruneFileSourcePartitions

override def preCBORules: Seq[Rule[LogicalPlan]] =
OptimizeMetadataOnlyDeleteFromTable :: Nil
Expand Down Expand Up @@ -102,8 +101,7 @@ class SparkOptimizer(
V2ScanRelationPushDown.ruleName :+
V2ScanPartitioningAndOrdering.ruleName :+
V2Writes.ruleName :+
ReplaceCTERefWithRepartition.ruleName :+
PlanPythonDataSourceScan.ruleName
ReplaceCTERefWithRepartition.ruleName

/**
* Optimization batches that are executed before the regular optimization batches (also before
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -753,8 +753,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case ArrowEvalPythonUDTF(udtf, requiredChildOutput, resultAttrs, child, evalType) =>
ArrowEvalPythonUDTFExec(
udtf, requiredChildOutput, resultAttrs, planLater(child), evalType) :: Nil
case PythonDataSourcePartitions(output, partitions) =>
PythonDataSourcePartitionsExec(output, partitions) :: Nil
case _ =>
Nil
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAM
import org.apache.spark.sql.connector.catalog.SupportsNamespaces._
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.errors.QueryExecutionErrors.hiveTableWithAnsiIntervalsError
import org.apache.spark.sql.execution.datasources.{DataSource, DataSourceUtils, FileFormat, HadoopFsRelation, LogicalRelation}
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2
import org.apache.spark.sql.internal.{HiveSerDe, SQLConf}
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -1025,7 +1025,9 @@ object DDLUtils extends Logging {

def checkDataColNames(provider: String, schema: StructType): Unit = {
val source = try {
DataSource.lookupDataSource(provider, SQLConf.get).getConstructor().newInstance()
DataSource.newDataSourceInstance(
provider,
DataSource.lookupDataSource(provider, SQLConf.get))
} catch {
case e: Throwable =>
logError(s"Failed to find data source: $provider when check data column names.", e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.DescribeCommandSchema
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIfNeeded, CaseInsensitiveMap, CharVarcharUtils, DateTimeUtils, ResolveDefaultColumns}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.CURRENT_DEFAULT_COLUMN_METADATA_KEY
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.TableIdentifierHelper
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
Expand Down Expand Up @@ -264,8 +264,9 @@ case class AlterTableAddColumnsCommand(
}

if (DDLUtils.isDatasourceTable(catalogTable)) {
DataSource.lookupDataSource(catalogTable.provider.get, conf).
getConstructor().newInstance() match {
DataSource.newDataSourceInstance(
catalogTable.provider.get,
DataSource.lookupDataSource(catalogTable.provider.get, conf)) match {
// For datasource table, this command can only support the following File format.
// TextFileFormat only default to one column "value"
// Hive type is already considered as hive serde table, so the logic will not
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2
import org.apache.spark.sql.execution.datasources.v2.orc.OrcDataSourceV2
import org.apache.spark.sql.execution.datasources.xml.XmlFileFormat
import org.apache.spark.sql.execution.python.PythonTableProvider
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.sources.{RateStreamProvider, TextSocketSourceProvider}
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -105,13 +106,14 @@ case class DataSource(
// [[FileDataSourceV2]] will still be used if we call the load()/save() method in
// [[DataFrameReader]]/[[DataFrameWriter]], since they use method `lookupDataSource`
// instead of `providingClass`.
cls.getDeclaredConstructor().newInstance() match {
DataSource.newDataSourceInstance(className, cls) match {
case f: FileDataSourceV2 => f.fallbackFileFormat
case _ => cls
}
}

private[sql] def providingInstance(): Any = providingClass.getConstructor().newInstance()
private[sql] def providingInstance(): Any =
DataSource.newDataSourceInstance(className, providingClass)

private def newHadoopConfiguration(): Configuration =
sparkSession.sessionState.newHadoopConfWithOptions(options)
Expand Down Expand Up @@ -622,6 +624,15 @@ object DataSource extends Logging {
"org.apache.spark.sql.sources.HadoopFsRelationProvider",
"org.apache.spark.Logging")

/** Create the instance of the datasource */
def newDataSourceInstance(provider: String, providingClass: Class[_]): Any = {
providingClass match {
case cls if classOf[PythonTableProvider].isAssignableFrom(cls) =>
cls.getDeclaredConstructor(classOf[String]).newInstance(provider)
case cls => cls.getDeclaredConstructor().newInstance()
}
}

/** Given a provider name, look up the data source class definition. */
def lookupDataSource(provider: String, conf: SQLConf): Class[_] = {
val provider1 = backwardCompatibilityMap.getOrElse(provider, provider) match {
Expand Down Expand Up @@ -649,6 +660,9 @@ object DataSource extends Logging {
// Found the data source using fully qualified path
dataSource
case Failure(error) =>
// TODO(SPARK-45600): should be session-based.
val isUserDefinedDataSource = SparkSession.getActiveSession.exists(
_.sessionState.dataSourceManager.dataSourceExists(provider))
if (provider1.startsWith("org.apache.spark.sql.hive.orc")) {
throw QueryCompilationErrors.orcNotUsedWithHiveEnabledError()
} else if (provider1.toLowerCase(Locale.ROOT) == "avro" ||
Expand All @@ -657,6 +671,8 @@ object DataSource extends Logging {
throw QueryCompilationErrors.failedToFindAvroDataSourceError(provider1)
} else if (provider1.toLowerCase(Locale.ROOT) == "kafka") {
throw QueryCompilationErrors.failedToFindKafkaDataSourceError(provider1)
} else if (isUserDefinedDataSource) {
classOf[PythonTableProvider]
} else {
throw QueryExecutionErrors.dataSourceNotFoundError(provider1, error)
}
Expand All @@ -673,6 +689,14 @@ object DataSource extends Logging {
}
case head :: Nil =>
// there is exactly one registered alias
// TODO(SPARK-45600): should be session-based.
val isUserDefinedDataSource = SparkSession.getActiveSession.exists(
_.sessionState.dataSourceManager.dataSourceExists(provider))
// The source can be successfully loaded as either a V1 or a V2 data source.
// Check if it is also a user-defined data source.
if (isUserDefinedDataSource) {
throw QueryCompilationErrors.foundMultipleDataSources(provider)
}
head.getClass
case sources =>
// There are multiple registered aliases for the input. If there is single datasource
Expand Down Expand Up @@ -708,17 +732,18 @@ object DataSource extends Logging {
def lookupDataSourceV2(provider: String, conf: SQLConf): Option[TableProvider] = {
val useV1Sources = conf.getConf(SQLConf.USE_V1_SOURCE_LIST).toLowerCase(Locale.ROOT)
.split(",").map(_.trim)
val cls = lookupDataSource(provider, conf)
val providingClass = lookupDataSource(provider, conf)
val instance = try {
cls.getDeclaredConstructor().newInstance()
newDataSourceInstance(provider, providingClass)
} catch {
// Throw the original error from the data source implementation.
case e: java.lang.reflect.InvocationTargetException => throw e.getCause
}
instance match {
case d: DataSourceRegister if useV1Sources.contains(d.shortName()) => None
case t: TableProvider
if !useV1Sources.contains(cls.getCanonicalName.toLowerCase(Locale.ROOT)) =>
if !useV1Sources.contains(
providingClass.getCanonicalName.toLowerCase(Locale.ROOT)) =>
Some(t)
case _ => None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,36 +21,28 @@ import java.util.Locale
import java.util.concurrent.ConcurrentHashMap

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.execution.python.UserDefinedPythonDataSource


/**
* A manager for user-defined data sources. It is used to register and lookup data sources by
* their short names or fully qualified names.
*/
class DataSourceManager extends Logging {

private type DataSourceBuilder = (
SparkSession, // Spark session
String, // provider name
Option[StructType], // user specified schema
CaseInsensitiveMap[String] // options
) => LogicalPlan

private val dataSourceBuilders = new ConcurrentHashMap[String, DataSourceBuilder]()
// TODO(SPARK-45917): Statically load Python Data Source so idempotently Python
// Data Sources can be loaded even when the Driver is restarted.
private val dataSourceBuilders = new ConcurrentHashMap[String, UserDefinedPythonDataSource]()

private def normalize(name: String): String = name.toLowerCase(Locale.ROOT)

/**
* Register a data source builder for the given provider.
* Note that the provider name is case-insensitive.
*/
def registerDataSource(name: String, builder: DataSourceBuilder): Unit = {
def registerDataSource(name: String, source: UserDefinedPythonDataSource): Unit = {
val normalizedName = normalize(name)
val previousValue = dataSourceBuilders.put(normalizedName, builder)
val previousValue = dataSourceBuilders.put(normalizedName, source)
if (previousValue != null) {
logWarning(f"The data source $name replaced a previously registered data source.")
}
Expand All @@ -60,7 +52,7 @@ class DataSourceManager extends Logging {
* Returns a data source builder for the given provider and throw an exception if
* it does not exist.
*/
def lookupDataSource(name: String): DataSourceBuilder = {
def lookupDataSource(name: String): UserDefinedPythonDataSource = {
if (dataSourceExists(name)) {
dataSourceBuilders.get(normalize(name))
} else {
Expand Down
Loading

0 comments on commit a1b0da2

Please sign in to comment.