Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reimplement SortedSet for JS/native to improve performance #1167

Merged
merged 2 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,37 +16,199 @@

package androidx.compose.ui.node


/**
* Implements [SortedSet] via a min-heap (implemented via an array) and a hash-map mapping the
* elements to their indices in the heap.
*
* The performance of this implementation is:
* - [add], [remove]: O(logN), due to the heap.
* - [first], [contains]: O(1), due to the hash map.
*/
internal actual class SortedSet<E> actual constructor(
private val comparator: Comparator<in E>
) {
private val list = mutableListOf<E>()

actual fun first(): E = list.first()
/**
* Compares two elements using the [comparator].
*/
private inline operator fun E.compareTo(value: E): Int = comparator.compare(this, value)

/**
* The heap array.
*/
private val itemTree = arrayListOf<E>()

/**
* Returns whether the index is the root of the tree.
*/
private val Int.isRootIndex get() = this == 0

/**
* Returns the index of the parent node.
*/
private val Int.parentIndex get() = (this - 1) shr 1

/**
* Returns the index of the left child node.
*/
private val Int.leftChildIndex get() = (this shl 1) + 1

/**
* Returns the index of the right child node.
*/
private val Int.rightChildIndex get() = (this shl 1) + 2

/**
* Maps each element to its index in [itemTree].
*/
private val indexByElement = mutableMapOf<E, Int>()

/**
* Inserts [element], if it's not already in the set.
*
* @returns whether actually inserted.
*/
actual fun add(element: E): Boolean {
var index = list.binarySearch(element, comparator)
if (index < 0) {
index = -index - 1
} else {
if (element in indexByElement) {
return false
}
list.add(index, element)

// Insert the item at the rightmost leaf
val index = itemTree.size
itemTree.add(element)
indexByElement[element] = index // This is the initial index; heapifyUp will update it

// Fix the heap
heapifyUp(index)

return true
}

/**
* Removes [element], if it's in the set.
*
* @return whether actually removed.
*/
actual fun remove(element: E): Boolean {
val index = list.binarySearch(element, comparator)
val found = index in list.indices
if (found) {
list.removeAt(index)
// Get the index in the tree and remove it
val index = indexByElement.remove(element) ?: return false

// Remove the rightmost leaf (to move it in place of the remove element)
val rightMostLeafElement = itemTree.removeLast()

// If the removed element is the rightmost leaf, then there's no need to move it, or to fix
// the heap. This also takes care of the case when the set is empty after removal.
if (index < itemTree.size) {
itemTree[index] = rightMostLeafElement
indexByElement[rightMostLeafElement] = index

// Restore min-heap invariant
if (!index.isRootIndex && (itemTree[index.parentIndex] >= itemTree[index])) {
heapifyUp(index)
} else {
heapifyDown(index)
}
}
return found

return true
}

actual fun contains(element: E): Boolean {
val index = list.binarySearch(element, comparator)
return index in list.indices && list[index] == element
/**
* Returns the smallest item in the set, according to [comparator].
*/
actual fun first() = itemTree[0]

/**
* Returns whether the set is empty.
*/
actual fun isEmpty(): Boolean = itemTree.isEmpty()

/**
* Returns whether the set contains the given element.
*/
actual fun contains(element: E) = element in indexByElement

/**
* Bubbles up the element at the given index until the min-heap invariant is restored.
*/
private fun heapifyUp(index: Int) {
val element = itemTree[index] // The element being bubbled up
var currentIndex = index // The index we're currently comparing to its parent

while (!currentIndex.isRootIndex) {
val parentIndex = currentIndex.parentIndex

// If the order is correct, stop
if (itemTree[parentIndex] <= element) {
break
}

// Swap
swap(currentIndex, parentIndex)

// Continue with parent
currentIndex = parentIndex
}
}

actual fun isEmpty(): Boolean = list.isEmpty()
/**
* Sinks down the element at the given index until the min-heap invariant is restored.
*/
private fun heapifyDown(index: Int) {
val element = itemTree[index] // The element being sunk down
var currentIndex = index // The index we're currently comparing to its children

while (true) {
val leftChildIndex = currentIndex.leftChildIndex
if (leftChildIndex >= itemTree.size) {
break
}
val leftChildElement = itemTree[leftChildIndex]
val rightChildIndex = currentIndex.rightChildIndex

val indexOfSmallerElement: Int
val smallerElement: E
if (rightChildIndex >= itemTree.size) { // There's no right child
// Look at left child
indexOfSmallerElement = leftChildIndex
smallerElement = leftChildElement
} else {
val rightChildElement = itemTree[rightChildIndex]
// Look at the smaller child
if (leftChildElement < rightChildElement) {
indexOfSmallerElement = leftChildIndex
smallerElement = leftChildElement
} else {
indexOfSmallerElement = rightChildIndex
smallerElement = rightChildElement
}
}

if (element <= smallerElement) {
break
}

swap(currentIndex, indexOfSmallerElement)
currentIndex = indexOfSmallerElement
}
}

/**
* Swaps the elements at the given indices in [itemTree], and updates the indices in
* [indexByElement].
*/
private fun swap(index1: Int, index2: Int) {
// Get the items
val item1 = itemTree[index1]
val item2 = itemTree[index2]

// Swap the items
itemTree[index1] = item2
itemTree[index2] = item1

// Update the indices
indexByElement[item1] = index2
indexByElement[item2] = index1
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ import kotlin.test.assertFalse
import kotlin.test.assertTrue

class SortedSetTest {
private fun <E: Comparable<*>> sortedSetOf(vararg elements: E): SortedSet<E> =
sortedSetOf(compareBy { it }, *elements)
private fun <E: Comparable<E>> sortedSetOf(vararg elements: E): SortedSet<E> =
sortedSetOf(naturalOrder(), *elements)

private fun <E> sortedSetOf(comparator: Comparator<in E>, vararg elements: E): SortedSet<E> {
val set = SortedSet(comparator)
Expand All @@ -35,23 +35,38 @@ class SortedSetTest {
return set
}

private fun <E> assertOrderEquals(expect: Iterable<E>, actual: SortedSet<E>) {
private fun <E> assertOrderEquals(
expect: Iterable<E>,
actual: SortedSet<E>,
message: String? = null
) {
for (e in expect) {
assertEquals(e, actual.first())
assertTrue(actual.contains(e))
assertTrue(actual.remove(e))
assertFalse(actual.contains(e))
}
assertTrue(actual.isEmpty())
assertTrue(actual.isEmpty(), message)
}

@Test
fun correctOrder() {
assertOrderEquals(listOf(1), sortedSetOf(1))
assertOrderEquals(listOf(1, 2, 5, 6), sortedSetOf(1, 2, 5, 6))
assertOrderEquals(listOf(1, 2, 5, 6), sortedSetOf(2, 6, 1, 5))
val numbers = (1..1000).map { Random.nextInt(10_000_000) }.distinct()
val set = sortedSetOf(*numbers.toTypedArray())
assertOrderEquals(numbers.sorted(), set)
val (seed, _, numbers) = generateRandomInts(amount = 1000)
logSeedOnFailure(seed) {
val set = sortedSetOf(*numbers.toTypedArray())
assertOrderEquals(numbers.sorted(), set, "Wrong order with seed $seed")
}
}

@Test
fun hashmapTest() {
val map = mutableMapOf<Int, Int>()
map[1] = 0
map.remove(1)
assertFalse(map.keys.contains(1))
}

@Test
Expand All @@ -69,4 +84,55 @@ class SortedSetTest {
val set = sortedSetOf(compareBy { it.length }, "B", "AAA", "DD")
assertOrderEquals(listOf("B", "DD", "AAA"), set)
}

@Test
fun removeNonMember() {
val set = sortedSetOf(1, 2, 3, 4, 5)
assertFalse(set.remove(0))
}

@Test
fun removeRandom() {
val (seed, random, numbers) = generateRandomInts(amount = 1000, seed = -1290005190)
logSeedOnFailure(seed) {
val set = sortedSetOf(*numbers.toTypedArray())
numbers.sort()
val countToRemove = random.nextInt(numbers.size)
repeat(countToRemove) {
val index = random.nextInt(until = numbers.size)
val number = numbers.removeAt(index)
set.remove(number)
}
assertOrderEquals(numbers, set, "Wrong order after removing with seed $seed")
}
}

@Test
fun removeLastAdded() {
val set = sortedSetOf(1, 2, 3)
set.remove(3)
assertOrderEquals(listOf(1, 2), set)
}

private fun generateRandomInts(
amount: Int,
seed: Int? = null
): Triple<Int, Random, MutableList<Int>> {
val actualSeed = seed ?: Random.nextInt()
val random = Random(actualSeed)
val numbers = (1..amount)
.mapTo(mutableSetOf()) { random.nextInt(10_000_000) }
.toMutableList()
.apply { shuffle(random) }
return Triple(actualSeed, random, numbers)
}
}

private inline fun logSeedOnFailure(seed: Int, block: () -> Unit) {
try {
block()
} catch (e: Throwable) {
println("Test failed with seed $seed")
throw e
}
}
Loading