Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…ding into sparkml-target-encoding
  • Loading branch information
Enrique Rebollo committed Oct 16, 2024
2 parents 2d19f33 + 8d1cb76 commit 3db4002
Show file tree
Hide file tree
Showing 219 changed files with 3,382 additions and 1,730 deletions.
22 changes: 13 additions & 9 deletions .github/workflows/build_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -264,20 +264,20 @@ jobs:
with:
distribution: zulu
java-version: ${{ matrix.java }}
- name: Install Python 3.9
- name: Install Python 3.11
uses: actions/setup-python@v5
# We should install one Python that is higher than 3+ for SQL and Yarn because:
# - SQL component also has Python related tests, for example, IntegratedUDFTestUtils.
# - Yarn has a Python specific test too, for example, YarnClusterSuite.
if: contains(matrix.modules, 'yarn') || (contains(matrix.modules, 'sql') && !contains(matrix.modules, 'sql-')) || contains(matrix.modules, 'connect')
with:
python-version: '3.9'
python-version: '3.11'
architecture: x64
- name: Install Python packages (Python 3.9)
- name: Install Python packages (Python 3.11)
if: (contains(matrix.modules, 'sql') && !contains(matrix.modules, 'sql-')) || contains(matrix.modules, 'connect')
run: |
python3.9 -m pip install 'numpy>=1.20.0' pyarrow pandas scipy unittest-xml-reporting 'lxml==4.9.4' 'grpcio==1.62.0' 'grpcio-status==1.62.0' 'protobuf==4.25.1'
python3.9 -m pip list
python3.11 -m pip install 'numpy>=1.20.0' pyarrow pandas scipy unittest-xml-reporting 'lxml==4.9.4' 'grpcio==1.62.0' 'grpcio-status==1.62.0' 'protobuf==4.25.1'
python3.11 -m pip list
# Run the tests.
- name: Run tests
env: ${{ fromJSON(inputs.envs) }}
Expand Down Expand Up @@ -608,14 +608,14 @@ jobs:
with:
input: sql/connect/common/src/main
against: 'https://github.com/apache/spark.git#branch=branch-3.5,subdir=connector/connect/common/src/main'
- name: Install Python 3.9
- name: Install Python 3.11
uses: actions/setup-python@v5
with:
python-version: '3.9'
python-version: '3.11'
- name: Install dependencies for Python CodeGen check
run: |
python3.9 -m pip install 'black==23.9.1' 'protobuf==4.25.1' 'mypy==1.8.0' 'mypy-protobuf==3.3.0'
python3.9 -m pip list
python3.11 -m pip install 'black==23.9.1' 'protobuf==4.25.1' 'mypy==1.8.0' 'mypy-protobuf==3.3.0'
python3.11 -m pip list
- name: Python CodeGen check
run: ./dev/connect-check-protos.py

Expand Down Expand Up @@ -1112,6 +1112,10 @@ jobs:
with:
distribution: zulu
java-version: ${{ inputs.java }}
- name: Install R
run: |
sudo apt update
sudo apt-get install r-base
- name: Start Minikube
uses: medyagh/setup-minikube@v0.0.18
with:
Expand Down
6 changes: 6 additions & 0 deletions assembly/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,12 @@
<groupId>org.apache.spark</groupId>
<artifactId>spark-connect-client-jvm_${scala.binary.version}</artifactId>
<version>${project.version}</version>
<exclusions>
<exclusion>
<groupId>org.apache.spark</groupId>
<artifactId>spark-connect-shims_${scala.binary.version}</artifactId>
</exclusion>
</exclusions>
<scope>provided</scope>
</dependency>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1363,9 +1363,9 @@ public static UTF8String trimRight(

public static UTF8String[] splitSQL(final UTF8String input, final UTF8String delim,
final int limit, final int collationId) {
if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) {
if (CollationFactory.fetchCollation(collationId).isUtf8BinaryType) {
return input.split(delim, limit);
} else if (CollationFactory.fetchCollation(collationId).supportsLowercaseEquality) {
} else if (CollationFactory.fetchCollation(collationId).isUtf8LcaseType) {
return lowercaseSplitSQL(input, delim, limit);
} else {
return icuSplitSQL(input, delim, limit, collationId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,38 +154,52 @@ public static class Collation {
*/
public final boolean supportsLowercaseEquality;

/**
* Support for Space Trimming implies that that based on specifier (for now only right trim)
* leading, trailing or both spaces are removed from the input string before comparison.
*/
public final boolean supportsSpaceTrimming;

/**
* Is Utf8 binary type as indicator if collation base type is UTF8 binary. Note currently only
* collations Utf8_Binary and Utf8_Binary_RTRIM are considered as Utf8 binary type.
*/
public final boolean isUtf8BinaryType;

/**
* Is Utf8 lcase type as indicator if collation base type is UTF8 lcase. Note currently only
* collations Utf8_Lcase and Utf8_Lcase_RTRIM are considered as Utf8 Lcase type.
*/
public final boolean isUtf8LcaseType;

public Collation(
String collationName,
String provider,
Collator collator,
Comparator<UTF8String> comparator,
String version,
ToLongFunction<UTF8String> hashFunction,
boolean supportsBinaryEquality,
boolean supportsBinaryOrdering,
boolean supportsLowercaseEquality) {
BiFunction<UTF8String, UTF8String, Boolean> equalsFunction,
boolean isUtf8BinaryType,
boolean isUtf8LcaseType,
boolean supportsSpaceTrimming) {
this.collationName = collationName;
this.provider = provider;
this.collator = collator;
this.comparator = comparator;
this.version = version;
this.hashFunction = hashFunction;
this.supportsBinaryEquality = supportsBinaryEquality;
this.supportsBinaryOrdering = supportsBinaryOrdering;
this.supportsLowercaseEquality = supportsLowercaseEquality;

// De Morgan's Law to check supportsBinaryOrdering => supportsBinaryEquality
assert(!supportsBinaryOrdering || supportsBinaryEquality);
this.isUtf8BinaryType = isUtf8BinaryType;
this.isUtf8LcaseType = isUtf8LcaseType;
this.equalsFunction = equalsFunction;
this.supportsSpaceTrimming = supportsSpaceTrimming;
this.supportsBinaryEquality = !supportsSpaceTrimming && isUtf8BinaryType;
this.supportsBinaryOrdering = !supportsSpaceTrimming && isUtf8BinaryType;
this.supportsLowercaseEquality = !supportsSpaceTrimming && isUtf8LcaseType;
// No Collation can simultaneously support binary equality and lowercase equality
assert(!supportsBinaryEquality || !supportsLowercaseEquality);

assert(SUPPORTED_PROVIDERS.contains(provider));

if (supportsBinaryEquality) {
this.equalsFunction = UTF8String::equals;
} else {
this.equalsFunction = (s1, s2) -> this.comparator.compare(s1, s2) == 0;
}
}

/**
Expand Down Expand Up @@ -538,27 +552,61 @@ private static boolean isValidCollationId(int collationId) {
@Override
protected Collation buildCollation() {
if (caseSensitivity == CaseSensitivity.UNSPECIFIED) {
Comparator<UTF8String> comparator;
ToLongFunction<UTF8String> hashFunction;
BiFunction<UTF8String, UTF8String, Boolean> equalsFunction;
boolean supportsSpaceTrimming = spaceTrimming != SpaceTrimming.NONE;

if (spaceTrimming == SpaceTrimming.NONE) {
comparator = UTF8String::binaryCompare;
hashFunction = s -> (long) s.hashCode();
equalsFunction = UTF8String::equals;
} else {
comparator = (s1, s2) -> applyTrimmingPolicy(s1, spaceTrimming).binaryCompare(
applyTrimmingPolicy(s2, spaceTrimming));
hashFunction = s -> (long) applyTrimmingPolicy(s, spaceTrimming).hashCode();
equalsFunction = (s1, s2) -> applyTrimmingPolicy(s1, spaceTrimming).equals(
applyTrimmingPolicy(s2, spaceTrimming));
}

return new Collation(
normalizedCollationName(),
PROVIDER_SPARK,
null,
UTF8String::binaryCompare,
comparator,
"1.0",
s -> (long) s.hashCode(),
/* supportsBinaryEquality = */ true,
/* supportsBinaryOrdering = */ true,
/* supportsLowercaseEquality = */ false);
hashFunction,
equalsFunction,
/* isUtf8BinaryType = */ true,
/* isUtf8LcaseType = */ false,
spaceTrimming != SpaceTrimming.NONE);
} else {
Comparator<UTF8String> comparator;
ToLongFunction<UTF8String> hashFunction;

if (spaceTrimming == SpaceTrimming.NONE) {
comparator = CollationAwareUTF8String::compareLowerCase;
hashFunction = s ->
(long) CollationAwareUTF8String.lowerCaseCodePoints(s).hashCode();
} else {
comparator = (s1, s2) -> CollationAwareUTF8String.compareLowerCase(
applyTrimmingPolicy(s1, spaceTrimming),
applyTrimmingPolicy(s2, spaceTrimming));
hashFunction = s -> (long) CollationAwareUTF8String.lowerCaseCodePoints(
applyTrimmingPolicy(s, spaceTrimming)).hashCode();
}

return new Collation(
normalizedCollationName(),
PROVIDER_SPARK,
null,
CollationAwareUTF8String::compareLowerCase,
comparator,
"1.0",
s -> (long) CollationAwareUTF8String.lowerCaseCodePoints(s).hashCode(),
/* supportsBinaryEquality = */ false,
/* supportsBinaryOrdering = */ false,
/* supportsLowercaseEquality = */ true);
hashFunction,
(s1, s2) -> comparator.compare(s1, s2) == 0,
/* isUtf8BinaryType = */ false,
/* isUtf8LcaseType = */ true,
spaceTrimming != SpaceTrimming.NONE);
}
}

Expand Down Expand Up @@ -917,16 +965,34 @@ protected Collation buildCollation() {
Collator collator = Collator.getInstance(resultLocale);
// Freeze ICU collator to ensure thread safety.
collator.freeze();

Comparator<UTF8String> comparator;
ToLongFunction<UTF8String> hashFunction;

if (spaceTrimming == SpaceTrimming.NONE) {
hashFunction = s -> (long) collator.getCollationKey(
s.toValidString()).hashCode();
comparator = (s1, s2) ->
collator.compare(s1.toValidString(), s2.toValidString());
} else {
comparator = (s1, s2) -> collator.compare(
applyTrimmingPolicy(s1, spaceTrimming).toValidString(),
applyTrimmingPolicy(s2, spaceTrimming).toValidString());
hashFunction = s -> (long) collator.getCollationKey(
applyTrimmingPolicy(s, spaceTrimming).toValidString()).hashCode();
}

return new Collation(
normalizedCollationName(),
PROVIDER_ICU,
collator,
(s1, s2) -> collator.compare(s1.toValidString(), s2.toValidString()),
comparator,
ICU_COLLATOR_VERSION,
s -> (long) collator.getCollationKey(s.toValidString()).hashCode(),
/* supportsBinaryEquality = */ false,
/* supportsBinaryOrdering = */ false,
/* supportsLowercaseEquality = */ false);
hashFunction,
(s1, s2) -> comparator.compare(s1, s2) == 0,
/* isUtf8BinaryType = */ false,
/* isUtf8LcaseType = */ false,
spaceTrimming != SpaceTrimming.NONE);
}

@Override
Expand Down Expand Up @@ -1103,14 +1169,6 @@ public static boolean isCaseSensitiveAndAccentInsensitive(int collationId) {
Collation.CollationSpecICU.AccentSensitivity.AI;
}

/**
* Returns whether the collation uses trim collation for the given collation id.
*/
public static boolean usesTrimCollation(int collationId) {
return Collation.CollationSpec.getSpaceTrimming(collationId) !=
Collation.CollationSpec.SpaceTrimming.NONE;
}

public static void assertValidProvider(String provider) throws SparkException {
if (!SUPPORTED_PROVIDERS.contains(provider.toLowerCase())) {
Map<String, String> params = Map.of(
Expand All @@ -1137,12 +1195,12 @@ public static String[] getICULocaleNames() {

public static UTF8String getCollationKey(UTF8String input, int collationId) {
Collation collation = fetchCollation(collationId);
if (usesTrimCollation(collationId)) {
if (collation.supportsSpaceTrimming) {
input = Collation.CollationSpec.applyTrimmingPolicy(input, collationId);
}
if (collation.supportsBinaryEquality) {
if (collation.isUtf8BinaryType) {
return input;
} else if (collation.supportsLowercaseEquality) {
} else if (collation.isUtf8LcaseType) {
return CollationAwareUTF8String.lowerCaseCodePoints(input);
} else {
CollationKey collationKey = collation.collator.getCollationKey(
Expand All @@ -1153,12 +1211,12 @@ public static UTF8String getCollationKey(UTF8String input, int collationId) {

public static byte[] getCollationKeyBytes(UTF8String input, int collationId) {
Collation collation = fetchCollation(collationId);
if (usesTrimCollation(collationId)) {
if (collation.supportsSpaceTrimming) {
input = Collation.CollationSpec.applyTrimmingPolicy(input, collationId);
}
if (collation.supportsBinaryEquality) {
if (collation.isUtf8BinaryType) {
return input.getBytes();
} else if (collation.supportsLowercaseEquality) {
} else if (collation.isUtf8LcaseType) {
return CollationAwareUTF8String.lowerCaseCodePoints(input).getBytes();
} else {
return collation.collator.getCollationKey(
Expand Down
Loading

0 comments on commit 3db4002

Please sign in to comment.