Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UDFs for Mapping Feed Ranges to Buckets #43092

Merged
merged 6 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/cosmos/azure-cosmos-spark_3-1_2-12/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
### 4.36.0-beta.1 (Unreleased)

#### Features Added
* Added the udfs `GetFeedRangesForBuckets` and `GetBucketForPartitionKey` to ease mapping of cosmos partition key to databricks table partition key. - See [PR 43092](https://github.com/Azure/azure-sdk-for-java/pull/43092)
tvaron3 marked this conversation as resolved.
Show resolved Hide resolved

#### Breaking Changes

Expand Down
1 change: 1 addition & 0 deletions sdk/cosmos/azure-cosmos-spark_3-2_2-12/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
### 4.36.0-beta.1 (Unreleased)

#### Features Added
* Added the udfs `GetFeedRangesForBuckets` and `GetBucketForPartitionKey` to ease mapping of cosmos partition key to databricks table partition key. - See [PR 43092](https://github.com/Azure/azure-sdk-for-java/pull/43092)

#### Breaking Changes

Expand Down
1 change: 1 addition & 0 deletions sdk/cosmos/azure-cosmos-spark_3-3_2-12/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
### 4.36.0-beta.1 (Unreleased)

#### Features Added
* Added the udfs `GetFeedRangesForBuckets` and `GetBucketForPartitionKey` to ease mapping of cosmos partition key to databricks table partition key. - See [PR 43092](https://github.com/Azure/azure-sdk-for-java/pull/43092)

#### Breaking Changes

Expand Down
1 change: 1 addition & 0 deletions sdk/cosmos/azure-cosmos-spark_3-4_2-12/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
### 4.36.0-beta.1 (Unreleased)

#### Features Added
* Added the udfs `GetFeedRangesForBuckets` and `GetBucketForPartitionKey` to ease mapping of cosmos partition key to databricks table partition key. - See [PR 43092](https://github.com/Azure/azure-sdk-for-java/pull/43092)

#### Breaking Changes

Expand Down
1 change: 1 addition & 0 deletions sdk/cosmos/azure-cosmos-spark_3-5_2-12/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
### 4.36.0-beta.1 (Unreleased)

#### Features Added
* Added the udfs `GetFeedRangesForBuckets` and `GetBucketForPartitionKey` to ease mapping of cosmos partition key to databricks table partition key. - See [PR 43092](https://github.com/Azure/azure-sdk-for-java/pull/43092)

#### Breaking Changes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,6 @@ private[cosmos] object SparkBridgeInternal {
s"${database.getClient.getServiceEndpoint}|${database.getId}|${container.getId}"
}

private[cosmos] def getNormalizedEffectiveRange
(
container: CosmosAsyncContainer,
feedRange: FeedRange
) : NormalizedRange = {

SparkBridgeImplementationInternal
.rangeToNormalizedRange(
container.getNormalizedEffectiveRange(feedRange).block)
}

private[cosmos] def getPartitionKeyRanges
(
container: CosmosAsyncContainer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@ import com.azure.cosmos.implementation.guava25.base.MoreObjects.firstNonNull
import com.azure.cosmos.implementation.guava25.base.Strings.emptyToNull
import com.azure.cosmos.implementation.query.CompositeContinuationToken
import com.azure.cosmos.implementation.routing.Range
import com.azure.cosmos.models.{FeedRange, PartitionKey, PartitionKeyBuilder, PartitionKeyDefinition, SparkModelBridgeInternal}
import com.azure.cosmos.models.{FeedRange, PartitionKey, PartitionKeyBuilder, PartitionKeyDefinition, PartitionKind, SparkModelBridgeInternal}
import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait
import com.azure.cosmos.spark.{ChangeFeedOffset, CosmosConstants, NormalizedRange}
import com.azure.cosmos.{CosmosAsyncClient, CosmosClientBuilder, DirectConnectionConfig, SparkBridgeInternal}
import com.fasterxml.jackson.databind.ObjectMapper

import scala.collection.convert.ImplicitConversions.`list asScalaBuffer`
import scala.collection.mutable

// scalastyle:off underscore.import
Expand Down Expand Up @@ -189,6 +190,11 @@ private[cosmos] object SparkBridgeImplementationInternal extends BasicLoggingTra
new Range[String](range.min, range.max, true, false)
}

private[cosmos] def toCosmosRange(range: String): Range[String] = {
val parts = range.split("-")
new Range[String](parts(0), parts(1), true, false)
}

def doRangesOverlap(left: NormalizedRange, right: NormalizedRange): Boolean = {
Range.checkOverlapping(toCosmosRange(left), toCosmosRange(right))
}
Expand All @@ -204,7 +210,7 @@ private[cosmos] object SparkBridgeImplementationInternal extends BasicLoggingTra
partitionKeyDefinitionJson: String
): NormalizedRange = {
val pkDefinition = SparkModelBridgeInternal.createPartitionKeyDefinitionFromJson(partitionKeyDefinitionJson)
partitionKeyToNormalizedRange(new PartitionKey(partitionKeyValue), pkDefinition)
partitionKeyToNormalizedRange(getPartitionKeyValue(pkDefinition, partitionKeyValue), pkDefinition)
}

private[cosmos] def partitionKeyToNormalizedRange(
Expand All @@ -220,28 +226,67 @@ private[cosmos] object SparkBridgeImplementationInternal extends BasicLoggingTra
partitionKeyValueJsonArray: Object,
partitionKeyDefinitionJson: String
): NormalizedRange = {
val pkDefinition = SparkModelBridgeInternal.createPartitionKeyDefinitionFromJson(partitionKeyDefinitionJson)
val partitionKey = getPartitionKeyValue(pkDefinition, partitionKeyValueJsonArray)
val feedRange = FeedRange
.forLogicalPartition(partitionKey)
.asInstanceOf[FeedRangePartitionKeyImpl]

val effectiveRange = feedRange.getEffectiveRange(pkDefinition)
rangeToNormalizedRange(effectiveRange)
}

private[cosmos] def trySplitFeedRanges
(
partitionKeyDefinitionJson: String,
feedRange: FeedRangeEpkImpl,
bucketCount: Int
): Array[String] = {

val pkDefinition = SparkModelBridgeInternal.createPartitionKeyDefinitionFromJson(partitionKeyDefinitionJson)
val feedRanges = FeedRangeInternal.trySplitCore(pkDefinition, feedRange.getRange, bucketCount)
val normalizedRanges = new Array[String](feedRanges.size())
for (i <- feedRanges.indices) {
val normalizedRange = rangeToNormalizedRange(feedRanges(i).getRange)
normalizedRanges(i) = s"${normalizedRange.min}-${normalizedRange.max}"
}
normalizedRanges
}

def findBucket(feedRanges: Array[String], pkValue: Object, pkDefinition: PartitionKeyDefinition):Int = {
val pk = getPartitionKeyValue(pkDefinition, pkValue)
val feedRangeFromPk = FeedRange.forLogicalPartition(pk).asInstanceOf[FeedRangePartitionKeyImpl]
val effectiveRangeFromPk = feedRangeFromPk.getEffectiveRange(pkDefinition)

for (i <- feedRanges.indices) {
val range = SparkBridgeImplementationInternal.toCosmosRange(feedRanges(i))
if (range.contains(effectiveRangeFromPk.getMin)) {
return i
}
}
throw new IllegalArgumentException("The partition key value does not belong to any of the feed ranges")
}

val partitionKey = new PartitionKeyBuilder()
private def getPartitionKeyValue(pkDefinition: PartitionKeyDefinition, pkValue: Object): PartitionKey = {
val partitionKey = new PartitionKeyBuilder()
var pk: PartitionKey = null
if (pkDefinition.getKind.equals(PartitionKind.MULTI_HASH)) {
val objectMapper = new ObjectMapper()
val json = partitionKeyValueJsonArray.toString
val json = pkValue.toString
try {
val partitionKeyValues = objectMapper.readValue(json, classOf[Array[String]])
for (value <- partitionKeyValues) {
partitionKey.add(value.trim)
}
partitionKey.build()
val partitionKeyValues = objectMapper.readValue(json, classOf[Array[String]])
for (value <- partitionKeyValues) {
partitionKey.add(value.trim)
}
pk = partitionKey.build()
} catch {
case e: Exception =>
logInfo("Invalid partition key paths: " + json, e)
case e: Exception =>
logInfo("Invalid partition key paths: " + json, e)
}

val feedRange = FeedRange
.forLogicalPartition(partitionKey.build())
.asInstanceOf[FeedRangePartitionKeyImpl]

val pkDefinition = SparkModelBridgeInternal.createPartitionKeyDefinitionFromJson(partitionKeyDefinitionJson)
val effectiveRange = feedRange.getEffectiveRange(pkDefinition)
rangeToNormalizedRange(effectiveRange)
} else if (pkDefinition.getKind.equals(PartitionKind.HASH)) {
pk = new PartitionKey(pkValue)
}
pk
}

def setIoThreadCountPerCoreFactor
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package com.azure.cosmos.spark.udf

import com.azure.cosmos.implementation.SparkBridgeImplementationInternal
import com.azure.cosmos.models.SparkModelBridgeInternal
import com.azure.cosmos.spark.CosmosPredicates.requireNotNullOrEmpty
import org.apache.spark.sql.api.java.UDF3

@SerialVersionUID(1L)
class GetBucketForPartitionKey extends UDF3[String, Object, Array[String], Int] {
tvaron3 marked this conversation as resolved.
Show resolved Hide resolved
override def call
(
partitionKeyDefinitionJson: String,
partitionKeyValue: Object,
feedRangesForBuckets: Array[String]
): Int = {
requireNotNullOrEmpty(partitionKeyDefinitionJson, "partitionKeyDefinitionJson")

val pkDefinition = SparkModelBridgeInternal.createPartitionKeyDefinitionFromJson(partitionKeyDefinitionJson)
SparkBridgeImplementationInternal.findBucket(feedRangesForBuckets, partitionKeyValue, pkDefinition)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package com.azure.cosmos.spark.udf

import com.azure.cosmos.implementation.SparkBridgeImplementationInternal
import com.azure.cosmos.implementation.feedranges.FeedRangeEpkImpl
import com.azure.cosmos.models.FeedRange
import com.azure.cosmos.spark.CosmosPredicates.requireNotNullOrEmpty
import org.apache.spark.sql.api.java.UDF2

@SerialVersionUID(1L)
class GetFeedRangesForBuckets extends UDF2[String, Int, Array[String]] {
tvaron3 marked this conversation as resolved.
Show resolved Hide resolved
override def call
(
partitionKeyDefinitionJson: String,
bucketCount: Int
): Array[String] = {

requireNotNullOrEmpty(partitionKeyDefinitionJson, "partitionKeyDefinitionJson")

SparkBridgeImplementationInternal.trySplitFeedRanges(partitionKeyDefinitionJson,
FeedRange.forFullRange().asInstanceOf[FeedRangeEpkImpl],
bucketCount)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package com.azure.cosmos.spark

import com.azure.cosmos.CosmosAsyncContainer
import com.azure.cosmos.implementation.{SparkBridgeImplementationInternal, Utils}
import com.azure.cosmos.models.CosmosQueryRequestOptions
import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait
import com.azure.cosmos.spark.udf.{GetBucketForPartitionKey, GetFeedRangesForBuckets}
import com.fasterxml.jackson.databind.node.ObjectNode
import org.apache.spark.sql.types.{ArrayType, IntegerType, StringType}

import java.util.UUID
import scala.collection.mutable

class FeedRangesForBucketsITest
extends IntegrationSpec
with SparkWithDropwizardAndSlf4jMetrics
with CosmosClient
with CosmosContainer
with BasicLoggingTrait
with MetricAssertions {

//scalastyle:off multiple.string.literals
//scalastyle:off magic.number

override def afterEach(): Unit = {
this.reinitializeContainer()
}

"feed ranges" can "be split into different buckets" in {
spark.udf.register("GetFeedRangesForBuckets", new GetFeedRangesForBuckets(), ArrayType(StringType))
var pkDefinition = "{\"paths\":[\"/id\"],\"kind\":\"Hash\"}"
val dummyDf = spark.sql(s"SELECT GetFeedRangesForBuckets('$pkDefinition', 5)")
val expectedFeedRanges = Array("-05C1C9CD673398", "05C1C9CD673398-05C1D9CD673398",
"05C1D9CD673398-05C1E399CD6732", "05C1E399CD6732-05C1E9CD673398", "05C1E9CD673398-FF")
val feedRange = dummyDf
.collect()(0)
.getList[String](0)
.toArray

assert(feedRange.sameElements(expectedFeedRanges), "Feed ranges do not match the expected values")
val lastId = "45170a78-eac0-4d3a-be5e-9b00bb5f4649"

var bucket = new GetBucketForPartitionKey().call(pkDefinition, lastId, expectedFeedRanges)
assert(bucket == 0, "Bucket does not match the expected value")

// test with hpk partition key definition
pkDefinition = "{\"paths\":[\"/tenantId\",\"/userId\",\"/sessionId\"],\"kind\":\"MultiHash\"}"
val pkValues = "[\"" + lastId + "\"]"

bucket = new GetBucketForPartitionKey().call(pkDefinition, pkValues, expectedFeedRanges)
assert(bucket == 4, "Bucket does not match the expected value")

}

"feed ranges" can "be converted into buckets for new partition key" in {
feedRangesForBuckets(false)
}

"feed ranges" can "be converted into buckets for new hierarchical partition key" in {
feedRangesForBuckets(true)
}

def feedRangesForBuckets(hpk: Boolean): Unit = {
val container = cosmosClient.getDatabase(cosmosDatabase).getContainer(cosmosContainer)
val docs = createItems(container, 50, hpk)

spark.udf.register("GetFeedRangesForBuckets", new GetFeedRangesForBuckets(), ArrayType(StringType))
val pkDefinition = if (hpk) {"{\"paths\":[\"/tenantId\",\"/userId\",\"/sessionId\"],\"kind\":\"MultiHash\"}"}
else {"{\"paths\":[\"/id\"],\"kind\":\"Hash\"}"}

val dummyDf = spark.sql(s"SELECT GetFeedRangesForBuckets('$pkDefinition', 5)")
val feedRanges = dummyDf
.collect()(0)
.getList[String](0)
.toArray(new Array[String](0))

spark.udf.register("GetBucketForPartitionKey", new GetBucketForPartitionKey(), IntegerType)
val bucketToDocsMap = mutable.Map[Int, List[ObjectNode]]().withDefaultValue(List())

for (doc <- docs) {
val lastId = if (!hpk) doc.get("id").asText() else "[\"" + doc.get("tenantId").asText() + "\"]"
val bucket = new GetBucketForPartitionKey().call(pkDefinition, lastId, feedRanges)
// Add the document to the corresponding bucket in the map
bucketToDocsMap(bucket) = doc :: bucketToDocsMap(bucket)
}

for (i <- feedRanges.indices) {
val range = SparkBridgeImplementationInternal.toCosmosRange(feedRanges(i))
val feedRange = SparkBridgeImplementationInternal.toFeedRange(SparkBridgeImplementationInternal.rangeToNormalizedRange(range))
val requestOptions = new CosmosQueryRequestOptions().setFeedRange(feedRange)
container.queryItems("SELECT * FROM c", requestOptions, classOf[ObjectNode]).byPage().collectList().block().forEach { rsp =>
val results = rsp.getResults
var numDocs = 0
val expectedResults = bucketToDocsMap(i)
results.forEach(doc => {
assert(expectedResults.collect({
case expectedDoc if expectedDoc.get("id").asText() == doc.get("id").asText() => expectedDoc
}).size >= 0, "Document not found in the expected bucket")
numDocs += 1
})
assert(numDocs == results.size(), "Number of documents in the bucket does not match the number of docs for that feed range")
}
}
}

def createItems(container: CosmosAsyncContainer, numOfItems: Int, hpk: Boolean): Array[ObjectNode] = {
val docs = new Array[ObjectNode](numOfItems)
for (sequenceNumber <- 1 to numOfItems) {
val lastId = UUID.randomUUID().toString
val objectNode = Utils.getSimpleObjectMapper.createObjectNode()
objectNode.put("name", "Shrodigner's cat")
objectNode.put("type", "cat")
objectNode.put("age", 20)
objectNode.put("sequenceNumber", sequenceNumber)
objectNode.put("id", lastId)
if (hpk) {
objectNode.put("tenantId", lastId)
objectNode.put("userId", "userId1")
objectNode.put("sessionId", "sessionId1")
}
docs(sequenceNumber - 1) = objectNode
container.createItem(objectNode).block()
}
docs
}

//scalastyle:on magic.number
//scalastyle:on multiple.string.literals
}

Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@ package com.azure.cosmos.spark
import com.azure.cosmos.SparkBridgeInternal
import com.azure.cosmos.implementation.changefeed.common.ChangeFeedState
import com.azure.cosmos.implementation.{TestConfigurations, Utils}
import com.azure.cosmos.models.PartitionKey
import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait
import com.azure.cosmos.spark.udf.{CreateChangeFeedOffsetFromSpark2, CreateSpark2ContinuationsFromChangeFeedOffset, GetFeedRangeForPartitionKeyValue}
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.sql.functions
import org.apache.spark.sql.types._

import java.io.{BufferedReader, InputStreamReader}
Expand Down
Loading
Loading