Skip to content

Commit

Permalink
Merge pull request #1685 from apache/master
Browse files Browse the repository at this point in the history
Create a new pull request by comparing changes across two branches
  • Loading branch information
GulajavaMinistudio authored Oct 14, 2024
2 parents 410da04 + 1aae160 commit 8e494b6
Show file tree
Hide file tree
Showing 25 changed files with 613 additions and 181 deletions.
6 changes: 6 additions & 0 deletions assembly/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,12 @@
<groupId>org.apache.spark</groupId>
<artifactId>spark-connect-client-jvm_${scala.binary.version}</artifactId>
<version>${project.version}</version>
<exclusions>
<exclusion>
<groupId>org.apache.spark</groupId>
<artifactId>spark-connect-shims_${scala.binary.version}</artifactId>
</exclusion>
</exclusions>
<scope>provided</scope>
</dependency>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7063,7 +7063,7 @@
},
"_LEGACY_ERROR_TEMP_2097" : {
"message" : [
"Could not execute broadcast in <timeout> secs. You can increase the timeout for broadcasts via <broadcastTimeout> or disable broadcast join by setting <autoBroadcastJoinThreshold> to -1."
"Could not execute broadcast in <timeout> secs. You can increase the timeout for broadcasts via <broadcastTimeout> or disable broadcast join by setting <autoBroadcastJoinThreshold> to -1 or remove the broadcast hint if it exists in your code."
]
},
"_LEGACY_ERROR_TEMP_2098" : {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,20 @@ import com.fasterxml.jackson.core.{JsonEncoding, JsonGenerator}
import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper}
import com.fasterxml.jackson.module.scala.DefaultScalaModule

import org.apache.spark.util.SparkErrorUtils.tryWithResource

private[spark] trait JsonUtils {

protected val mapper: ObjectMapper = new ObjectMapper().registerModule(DefaultScalaModule)
.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)

def toJsonString(block: JsonGenerator => Unit): String = {
val baos = new ByteArrayOutputStream()
val generator = mapper.createGenerator(baos, JsonEncoding.UTF8)
block(generator)
generator.close()
baos.close()
new String(baos.toByteArray, StandardCharsets.UTF_8)
tryWithResource(new ByteArrayOutputStream()) { baos =>
tryWithResource(mapper.createGenerator(baos, JsonEncoding.UTF8)) { generator =>
block(generator)
}
new String(baos.toByteArray, StandardCharsets.UTF_8)
}
}
}

Expand Down
22 changes: 16 additions & 6 deletions mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@

package org.apache.spark.ml.util

import org.apache.spark.SparkIllegalArgumentException
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.linalg.VectorUDT
import org.apache.spark.sql.catalyst.util.AttributeNameParser
import org.apache.spark.sql.catalyst.util.{AttributeNameParser, QuotingUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._


/**
* Utils for handling schemas.
*/
Expand Down Expand Up @@ -206,18 +207,27 @@ private[spark] object SchemaUtils {
checkColumnTypes(schema, colName, typeCandidates)
}

def toSQLId(parts: String): String = {
AttributeNameParser.parseAttributeName(parts).map(QuotingUtils.quoteIdentifier).mkString(".")
}

/**
* Get schema field.
* @param schema input schema
* @param colName column name, nested column name is supported.
*/
def getSchemaField(schema: StructType, colName: String): StructField = {
val colSplits = AttributeNameParser.parseAttributeName(colName)
var field = schema(colSplits(0))
for (colSplit <- colSplits.slice(1, colSplits.length)) {
field = field.dataType.asInstanceOf[StructType](colSplit)
val fieldOpt = schema.findNestedField(colSplits, resolver = SQLConf.get.resolver)
if (fieldOpt.isEmpty) {
throw new SparkIllegalArgumentException(
errorClass = "FIELD_NOT_FOUND",
messageParameters = Map(
"fieldName" -> toSQLId(colName),
"fields" -> schema.fields.map(f => toSQLId(f.name)).mkString(", "))
)
}
field
fieldOpt.get._2
}

/**
Expand Down
4 changes: 4 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -3075,6 +3075,10 @@
reduce the cost of migration in subsequent versions.
-->
<arg>-Wconf:cat=deprecation&amp;msg=it will become a keyword in Scala 3:e</arg>
<!--
SPARK-49937 ban call the method `SparkThrowable#getErrorClass`
-->
<arg>-Wconf:cat=deprecation&amp;msg=method getErrorClass in trait SparkThrowable is deprecated:e</arg>
</args>
<jvmArgs>
<jvmArg>-Xss128m</jvmArg>
Expand Down
4 changes: 3 additions & 1 deletion project/SparkBuild.scala
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,9 @@ object SparkBuild extends PomBuild {
// reduce the cost of migration in subsequent versions.
"-Wconf:cat=deprecation&msg=it will become a keyword in Scala 3:e",
// SPARK-46938 to prevent enum scan on pmml-model, under spark-mllib module.
"-Wconf:cat=other&site=org.dmg.pmml.*:w"
"-Wconf:cat=other&site=org.dmg.pmml.*:w",
// SPARK-49937 ban call the method `SparkThrowable#getErrorClass`
"-Wconf:cat=deprecation&msg=method getErrorClass in trait SparkThrowable is deprecated:e"
)
}
)
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/errors/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@
"Could not get batch id from <obj_name>."
]
},
"CANNOT_INFER_ARRAY_TYPE": {
"CANNOT_INFER_ARRAY_ELEMENT_TYPE": {
"message": [
"Can not infer Array Type from a list with None as the first element."
"Can not infer the element data type, an non-empty list starting with an non-None value is required."
]
},
"CANNOT_INFER_EMPTY_SCHEMA": {
Expand Down
6 changes: 2 additions & 4 deletions python/pyspark/pandas/data_type_ops/datetime_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
)
from pyspark.sql.utils import pyspark_column_op
from pyspark.pandas._typing import Dtype, IndexOpsLike, SeriesOrIndex
from pyspark.pandas.spark import functions as SF
from pyspark.pandas.base import IndexOpsMixin
from pyspark.pandas.data_type_ops.base import (
DataTypeOps,
Expand Down Expand Up @@ -150,10 +151,7 @@ class DatetimeNTZOps(DatetimeOps):
"""

def _cast_spark_column_timestamp_to_long(self, scol: Column) -> Column:
from pyspark import SparkContext

jvm = SparkContext._active_spark_context._jvm
return Column(jvm.PythonSQLUtils.castTimestampNTZToLong(scol._jc))
return SF.timestamp_ntz_to_long(scol)

def astype(self, index_ops: IndexOpsLike, dtype: Union[str, type, Dtype]) -> IndexOpsLike:
dtype, spark_type = pandas_on_spark_type(dtype)
Expand Down
4 changes: 4 additions & 0 deletions python/pyspark/pandas/spark/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ def _invoke_internal_function_over_columns(name: str, *cols: "ColumnOrName") ->
return Column(sc._jvm.PythonSQLUtils.internalFn(name, _to_seq(sc, cols, _to_java_column)))


def timestamp_ntz_to_long(col: Column) -> Column:
return _invoke_internal_function_over_columns("timestamp_ntz_to_long", col)


def product(col: Column, dropna: bool) -> Column:
return _invoke_internal_function_over_columns("pandas_product", col, F.lit(dropna))

Expand Down
20 changes: 5 additions & 15 deletions python/pyspark/sql/connect/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def _infer_type(cls, value: Any) -> DataType:
return NullType()
elif isinstance(value, (bytes, bytearray)):
return BinaryType()
elif isinstance(value, bool):
elif isinstance(value, (bool, np.bool_)):
return BooleanType()
elif isinstance(value, int):
if JVM_INT_MIN <= value <= JVM_INT_MAX:
Expand All @@ -323,10 +323,8 @@ def _infer_type(cls, value: Any) -> DataType:
return StringType()
elif isinstance(value, decimal.Decimal):
return DecimalType()
elif isinstance(value, datetime.datetime) and is_timestamp_ntz_preferred():
return TimestampNTZType()
elif isinstance(value, datetime.datetime):
return TimestampType()
return TimestampNTZType() if is_timestamp_ntz_preferred() else TimestampType()
elif isinstance(value, datetime.date):
return DateType()
elif isinstance(value, datetime.timedelta):
Expand All @@ -335,23 +333,15 @@ def _infer_type(cls, value: Any) -> DataType:
dt = _from_numpy_type(value.dtype)
if dt is not None:
return dt
elif isinstance(value, np.bool_):
return BooleanType()
elif isinstance(value, list):
# follow the 'infer_array_from_first_element' strategy in 'sql.types._infer_type'
# right now, it's dedicated for pyspark.ml params like array<...>, array<array<...>>
if len(value) == 0:
raise PySparkValueError(
errorClass="CANNOT_BE_EMPTY",
messageParameters={"item": "value"},
)
first = value[0]
if first is None:
if len(value) == 0 or value[0] is None:
raise PySparkTypeError(
errorClass="CANNOT_INFER_ARRAY_TYPE",
errorClass="CANNOT_INFER_ARRAY_ELEMENT_TYPE",
messageParameters={},
)
return ArrayType(LiteralExpression._infer_type(first), True)
return ArrayType(LiteralExpression._infer_type(value[0]), True)

raise PySparkTypeError(
errorClass="UNSUPPORTED_DATA_TYPE",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,8 @@ def main(infile: IO, outfile: IO) -> None:
reader.stop()
except BaseException as e:
handle_worker_exception(e, outfile)
# ensure that the updates to the socket are flushed
outfile.flush()
sys.exit(-1)
send_accumulator_updates(outfile)

Expand Down
Loading

0 comments on commit 8e494b6

Please sign in to comment.