Skip to content

Commit

Permalink
Reimplement SortedSet for JS/native to improve performance (#1167)
Browse files Browse the repository at this point in the history
  • Loading branch information
m-sasha committed Mar 6, 2024
1 parent 6f43e83 commit 2d8a99e
Show file tree
Hide file tree
Showing 2 changed files with 251 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,37 +16,199 @@

package androidx.compose.ui.node


/**
* Implements [SortedSet] via a min-heap (implemented via an array) and a hash-map mapping the
* elements to their indices in the heap.
*
* The performance of this implementation is:
* - [add], [remove]: O(logN), due to the heap.
* - [first], [contains]: O(1), due to the hash map.
*/
internal actual class SortedSet<E> actual constructor(
private val comparator: Comparator<in E>
) {
private val list = mutableListOf<E>()

actual fun first(): E = list.first()
/**
* Compares two elements using the [comparator].
*/
private inline operator fun E.compareTo(value: E): Int = comparator.compare(this, value)

/**
* The heap array.
*/
private val itemTree = arrayListOf<E>()

/**
* Returns whether the index is the root of the tree.
*/
private val Int.isRootIndex get() = this == 0

/**
* Returns the index of the parent node.
*/
private val Int.parentIndex get() = (this - 1) shr 1

/**
* Returns the index of the left child node.
*/
private val Int.leftChildIndex get() = (this shl 1) + 1

/**
* Returns the index of the right child node.
*/
private val Int.rightChildIndex get() = (this shl 1) + 2

/**
* Maps each element to its index in [itemTree].
*/
private val indexByElement = mutableMapOf<E, Int>()

/**
* Inserts [element], if it's not already in the set.
*
* @returns whether actually inserted.
*/
actual fun add(element: E): Boolean {
var index = list.binarySearch(element, comparator)
if (index < 0) {
index = -index - 1
} else {
if (element in indexByElement) {
return false
}
list.add(index, element)

// Insert the item at the rightmost leaf
val index = itemTree.size
itemTree.add(element)
indexByElement[element] = index // This is the initial index; heapifyUp will update it

// Fix the heap
heapifyUp(index)

return true
}

/**
* Removes [element], if it's in the set.
*
* @return whether actually removed.
*/
actual fun remove(element: E): Boolean {
val index = list.binarySearch(element, comparator)
val found = index in list.indices
if (found) {
list.removeAt(index)
// Get the index in the tree and remove it
val index = indexByElement.remove(element) ?: return false

// Remove the rightmost leaf (to move it in place of the remove element)
val rightMostLeafElement = itemTree.removeLast()

// If the removed element is the rightmost leaf, then there's no need to move it, or to fix
// the heap. This also takes care of the case when the set is empty after removal.
if (index < itemTree.size) {
itemTree[index] = rightMostLeafElement
indexByElement[rightMostLeafElement] = index

// Restore min-heap invariant
if (!index.isRootIndex && (itemTree[index.parentIndex] >= itemTree[index])) {
heapifyUp(index)
} else {
heapifyDown(index)
}
}
return found

return true
}

actual fun contains(element: E): Boolean {
val index = list.binarySearch(element, comparator)
return index in list.indices && list[index] == element
/**
* Returns the smallest item in the set, according to [comparator].
*/
actual fun first() = itemTree[0]

/**
* Returns whether the set is empty.
*/
actual fun isEmpty(): Boolean = itemTree.isEmpty()

/**
* Returns whether the set contains the given element.
*/
actual fun contains(element: E) = element in indexByElement

/**
* Bubbles up the element at the given index until the min-heap invariant is restored.
*/
private fun heapifyUp(index: Int) {
val element = itemTree[index] // The element being bubbled up
var currentIndex = index // The index we're currently comparing to its parent

while (!currentIndex.isRootIndex) {
val parentIndex = currentIndex.parentIndex

// If the order is correct, stop
if (itemTree[parentIndex] <= element) {
break
}

// Swap
swap(currentIndex, parentIndex)

// Continue with parent
currentIndex = parentIndex
}
}

actual fun isEmpty(): Boolean = list.isEmpty()
/**
* Sinks down the element at the given index until the min-heap invariant is restored.
*/
private fun heapifyDown(index: Int) {
val element = itemTree[index] // The element being sunk down
var currentIndex = index // The index we're currently comparing to its children

while (true) {
val leftChildIndex = currentIndex.leftChildIndex
if (leftChildIndex >= itemTree.size) {
break
}
val leftChildElement = itemTree[leftChildIndex]
val rightChildIndex = currentIndex.rightChildIndex

val indexOfSmallerElement: Int
val smallerElement: E
if (rightChildIndex >= itemTree.size) { // There's no right child
// Look at left child
indexOfSmallerElement = leftChildIndex
smallerElement = leftChildElement
} else {
val rightChildElement = itemTree[rightChildIndex]
// Look at the smaller child
if (leftChildElement < rightChildElement) {
indexOfSmallerElement = leftChildIndex
smallerElement = leftChildElement
} else {
indexOfSmallerElement = rightChildIndex
smallerElement = rightChildElement
}
}

if (element <= smallerElement) {
break
}

swap(currentIndex, indexOfSmallerElement)
currentIndex = indexOfSmallerElement
}
}

/**
* Swaps the elements at the given indices in [itemTree], and updates the indices in
* [indexByElement].
*/
private fun swap(index1: Int, index2: Int) {
// Get the items
val item1 = itemTree[index1]
val item2 = itemTree[index2]

// Swap the items
itemTree[index1] = item2
itemTree[index2] = item1

// Update the indices
indexByElement[item1] = index2
indexByElement[item2] = index1
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ import kotlin.test.assertFalse
import kotlin.test.assertTrue

class SortedSetTest {
private fun <E: Comparable<*>> sortedSetOf(vararg elements: E): SortedSet<E> =
sortedSetOf(compareBy { it }, *elements)
private fun <E: Comparable<E>> sortedSetOf(vararg elements: E): SortedSet<E> =
sortedSetOf(naturalOrder(), *elements)

private fun <E> sortedSetOf(comparator: Comparator<in E>, vararg elements: E): SortedSet<E> {
val set = SortedSet(comparator)
Expand All @@ -35,23 +35,38 @@ class SortedSetTest {
return set
}

private fun <E> assertOrderEquals(expect: Iterable<E>, actual: SortedSet<E>) {
private fun <E> assertOrderEquals(
expect: Iterable<E>,
actual: SortedSet<E>,
message: String? = null
) {
for (e in expect) {
assertEquals(e, actual.first())
assertTrue(actual.contains(e))
assertTrue(actual.remove(e))
assertFalse(actual.contains(e))
}
assertTrue(actual.isEmpty())
assertTrue(actual.isEmpty(), message)
}

@Test
fun correctOrder() {
assertOrderEquals(listOf(1), sortedSetOf(1))
assertOrderEquals(listOf(1, 2, 5, 6), sortedSetOf(1, 2, 5, 6))
assertOrderEquals(listOf(1, 2, 5, 6), sortedSetOf(2, 6, 1, 5))
val numbers = (1..1000).map { Random.nextInt(10_000_000) }.distinct()
val set = sortedSetOf(*numbers.toTypedArray())
assertOrderEquals(numbers.sorted(), set)
val (seed, _, numbers) = generateRandomInts(amount = 1000)
logSeedOnFailure(seed) {
val set = sortedSetOf(*numbers.toTypedArray())
assertOrderEquals(numbers.sorted(), set, "Wrong order with seed $seed")
}
}

@Test
fun hashmapTest() {
val map = mutableMapOf<Int, Int>()
map[1] = 0
map.remove(1)
assertFalse(map.keys.contains(1))
}

@Test
Expand All @@ -69,4 +84,55 @@ class SortedSetTest {
val set = sortedSetOf(compareBy { it.length }, "B", "AAA", "DD")
assertOrderEquals(listOf("B", "DD", "AAA"), set)
}

@Test
fun removeNonMember() {
val set = sortedSetOf(1, 2, 3, 4, 5)
assertFalse(set.remove(0))
}

@Test
fun removeRandom() {
val (seed, random, numbers) = generateRandomInts(amount = 1000, seed = -1290005190)
logSeedOnFailure(seed) {
val set = sortedSetOf(*numbers.toTypedArray())
numbers.sort()
val countToRemove = random.nextInt(numbers.size)
repeat(countToRemove) {
val index = random.nextInt(until = numbers.size)
val number = numbers.removeAt(index)
set.remove(number)
}
assertOrderEquals(numbers, set, "Wrong order after removing with seed $seed")
}
}

@Test
fun removeLastAdded() {
val set = sortedSetOf(1, 2, 3)
set.remove(3)
assertOrderEquals(listOf(1, 2), set)
}

private fun generateRandomInts(
amount: Int,
seed: Int? = null
): Triple<Int, Random, MutableList<Int>> {
val actualSeed = seed ?: Random.nextInt()
val random = Random(actualSeed)
val numbers = (1..amount)
.mapTo(mutableSetOf()) { random.nextInt(10_000_000) }
.toMutableList()
.apply { shuffle(random) }
return Triple(actualSeed, random, numbers)
}
}

private inline fun logSeedOnFailure(seed: Int, block: () -> Unit) {
try {
block()
} catch (e: Throwable) {
println("Test failed with seed $seed")
throw e
}
}

0 comments on commit 2d8a99e

Please sign in to comment.