Skip to content

Commit

Permalink
Merge branch 'main' into topic/lru-cache
Browse files Browse the repository at this point in the history
  • Loading branch information
mpilquist committed Jan 15, 2025
2 parents 3a201b6 + b314c69 commit 9df2da7
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 5 deletions.
60 changes: 60 additions & 0 deletions modules/core/shared/src/main/scala/data/Cache.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Copyright (c) 2018-2024 by Rob Norris and Contributors
// This software is licensed under the MIT License (MIT).
// For more information see LICENSE or https://opensource.org/licenses/MIT

package skunk.data


import scala.collection.immutable.SortedMap

final case class Cache[K, V] private (
max: Int,
entries: Map[K, V]
)(accesses: SortedMap[Long, K],
accessesInverted: Map[K, Long],
counter: Long
) {

private def accessesWithoutKey(k: K): SortedMap[Long, K] =
accessesInverted.get(k).fold(accesses)(oldCounter => accesses - oldCounter)

def get(k: K): Option[V] =
lookup(k).map(_._2)

def lookup(k: K): Option[(Cache[K, V], V)] =
entries.get(k) match {
case Some(v) =>
val newAccesses = accessesWithoutKey(k) + (counter -> k)
val newCache = new Cache(max, entries)(newAccesses, accessesInverted + (k -> counter), counter + 1)
Some(newCache -> v)
case None =>
None
}

def put(k: K, v: V): Cache[K, V] =
insert(k, v)._1

def insert(k: K, v: V): (Cache[K, V], Option[(K, V)]) =
if (max <= 0) (this, Some((k, v)))
else if (entries.size >= max && !containsKey(k)) {
val (counterToEvict, keyToEvict) = accesses.head
val newCache = new Cache(max, entries - keyToEvict + (k -> v))(accessesWithoutKey(k) - counterToEvict + (counter -> k), accessesInverted + (k -> counter), counter + 1)
(newCache, Some((keyToEvict, entries(keyToEvict))))
} else {
val newCache = new Cache(max, entries + (k -> v))(accessesWithoutKey(k) + (counter -> k), accessesInverted + (k -> counter), counter + 1)
(newCache, entries.get(k).filter(_ != v).map(k -> _))
}

def containsKey(k: K): Boolean = entries.contains(k)
def values: Iterable[V] = entries.values

override def toString: String =
accesses.map { case (_, k) => s"$k -> ${entries(k)}" }.mkString("Cache(", ", ", ")")
}

object Cache {
def empty[K, V](max: Int): Cache[K, V] =
new Cache(max max 0, Map.empty)(SortedMap.empty, Map.empty, 0L)
}


12 changes: 9 additions & 3 deletions modules/core/shared/src/main/scala/data/SemispaceCache.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ sealed abstract case class SemispaceCache[K, V](gen0: Map[K, V], gen1: Map[K, V]
else if (gen0.size < max) SemispaceCache(gen0 + (k -> v), gen1 - k, max, evicted - v) // room in gen0, done!
else SemispaceCache(Map(k -> v), gen0 - k, max, evicted ++ gen1.values - v)// no room in gen0, slide it down

def lookup(k: K): Option[(SemispaceCache[K, V], V)] =
def lookup(k: K): Option[(SemispaceCache[K, V], V)] =
gen0.get(k).tupleLeft(this) orElse // key is in gen0, done!
gen1.get(k).map(v => (insert(k, v), v)) // key is in gen1, copy to gen0

Expand All @@ -40,8 +40,14 @@ sealed abstract case class SemispaceCache[K, V](gen0: Map[K, V], gen1: Map[K, V]

object SemispaceCache {

private def apply[K, V](gen0: Map[K, V], gen1: Map[K, V], max: Int, evicted: EvictionSet[V]): SemispaceCache[K, V] =
new SemispaceCache[K, V](gen0, gen1, max, evicted) {}
private def apply[K, V](gen0: Map[K, V], gen1: Map[K, V], max: Int, evicted: EvictionSet[V]): SemispaceCache[K, V] = {
val r = new SemispaceCache[K, V](gen0, gen1, max, evicted) {}
val gen0Intersection: Set[V] = (gen0.values.toSet intersect evicted.toList.toSet)
val gen1Intersection: Set[V] = (gen1.values.toSet intersect evicted.toList.toSet)
assert(gen0Intersection.isEmpty, s"gen0 has overlapping values in evicted: ${gen0Intersection}")
assert(gen1Intersection.isEmpty, s"gen1 has overlapping values in evicted: ${gen1Intersection}")
r
}

def empty[K, V](max: Int, trackEviction: Boolean): SemispaceCache[K, V] =
SemispaceCache[K, V](Map.empty, Map.empty, max max 0, if (trackEviction) EvictionSet.empty else new EvictionSet.ZeroEvictionSet)
Expand Down
79 changes: 79 additions & 0 deletions modules/tests/shared/src/test/scala/data/CacheTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Copyright (c) 2018-2024 by Rob Norris and Contributors
// This software is licensed under the MIT License (MIT).
// For more information see LICENSE or https://opensource.org/licenses/MIT

package skunk.data

import munit.ScalaCheckSuite
import org.scalacheck.Gen
import org.scalacheck.Prop._

class CacheTest extends ScalaCheckSuite {

val genEmpty: Gen[Cache[Int, String]] =
Gen.choose(-1, 10).map(Cache.empty)

test("insert on empty cache results in eviction") {
val cache = Cache.empty(0).put("one", 1)
assertEquals(cache.values.toList, Nil)
assert(!cache.containsKey("one"))
}

test("max is never negative") {
forAll(genEmpty) { c =>
assert(c.max >= 0)
}
}

test("insert should allow lookup") {
forAll(genEmpty) { c =>
val cʹ = c.put(1, "x")
assertEquals(cʹ.lookup(1), if (c.max == 0) None else Some((cʹ, "x")))
}
}

test("eviction") {
forAll(genEmpty) { c =>
val max = c.max

// Load up the cache such that one more insert will cause it to overflow
val cʹ = (0 until max).foldLeft(c) { case (c, n) => c.insert(n, "x")._1 }
assertEquals(cʹ.values.size, max)

// Overflow the cache
val (cʹʹ, evicted) = cʹ.insert(max, "x")
assertEquals(cʹʹ.values.size, max)
assertEquals(evicted, Some(0 -> "x"))

if (max > 2) {
// Access oldest element
val cʹʹʹ = cʹʹ.lookup(1).get._1

// Insert another element and make sure oldest element is not the element evicted
val (cʹʹʹʹ, evictedʹ) = cʹʹʹ.insert(max + 1, "x")
assertEquals(evictedʹ, Some(2 -> "x"))
}
}
}

test("eviction 2") {
val c1 = Cache.empty(2).put("one", 1)
assertEquals(c1.values.toList, List(1))
assertEquals(c1.get("one"), Some(1))

val (c2, evicted2) = c1.insert("two", 2)
assert(c2.containsKey("one"))
assert(c2.containsKey("two"))
assertEquals(evicted2, None)

val (c3, evicted3) = c2.insert("one", 1)
assert(c3.containsKey("one"))
assert(c3.containsKey("two"))
assertEquals(evicted3, None)

val (c4, evicted4) = c2.insert("one", 0)
assert(c4.containsKey("one"))
assert(c4.containsKey("two"))
assertEquals(evicted4, Some("one" -> 1))
}
}
36 changes: 34 additions & 2 deletions modules/tests/shared/src/test/scala/data/SemispaceCacheTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,36 @@ class SemispaceCacheTest extends ScalaCheckSuite {

val genEmpty: Gen[SemispaceCache[Int, String]] =
Gen.choose(-1, 10).map(SemispaceCache.empty(_, true))

test("eviction should never contain values in gen0/gen1") {
val cache = SemispaceCache.empty(2, true).insert("one", 1)

val i1 = cache.insert("one", 1)
// Two doesn't exist; space in gen0, insert
val i2 = i1.lookup("two").map(_._1).getOrElse(i1.insert("two", 2))
assertEquals(i2.gen0, Map("one" -> 1, "two" -> 2))
assertEquals(i2.gen1, Map.empty[String, Int])
assertEquals(i2.evicted.toList, Nil)

// Three doesn't exist, hit max; slide gen0 -> gen1 and add to gen0
val i3 = i2.lookup("three").map(_._1).getOrElse(i2.insert("three", 3))
assertEquals(i3.gen0, Map("three" -> 3))
assertEquals(i3.gen1, Map("one" -> 1, "two" -> 2))
assertEquals(i3.evicted.toList, Nil)

// One exists in gen1; pull up to gen0 and REMOVE from gen1
val i4 = i3.lookup("one").map(_._1).getOrElse(i3.insert("one", 1))
assertEquals(i4.gen0, Map("one" -> 1, "three" -> 3))
assertEquals(i4.gen1, Map("two" -> 2))
assertEquals(i4.evicted.toList, Nil)

// Four doesn't exist; gen0 is full so push to gen1
// insert four to gen0 and evict gen1
val i5 = i4.lookup("four").map(_._1).getOrElse(i4.insert("four", 4))
assertEquals(i5.gen0, Map("four" -> 4))
assertEquals(i5.gen1, Map("one" -> 1, "three" -> 3))
assertEquals(i5.evicted.toList, List(2))
}

test("insert on empty cache results in eviction") {
val cache = SemispaceCache.empty(0, true).insert("one", 1)
Expand Down Expand Up @@ -73,7 +103,7 @@ class SemispaceCacheTest extends ScalaCheckSuite {
val max = c.max

// Load up the cache such that it overflows by 1
val cʹ = (0 to max).foldLeft(c) { case (c, n) => c.insert(n, "x") }
val cʹ = (0 to max).foldLeft(c) { case (c, n) => c.insert(n, n.toString) }
assertEquals(cʹ.gen0.size, 1 min max)
assertEquals(cʹ.gen1.size, max)

Expand All @@ -82,7 +112,9 @@ class SemispaceCacheTest extends ScalaCheckSuite {
case None => assertEquals(max, 0)
case Some((cʹʹ, _)) =>
assertEquals(cʹʹ.gen0.size, 2 min max)
assertEquals(cʹʹ.gen1.size, max)
// When we promote 0 to gen0, we remove it from gen1
assertEquals(cʹʹ.gen1.size, max-1 max 1)
assertEquals(cʹʹ.evicted.toList, Nil)
}

}
Expand Down

0 comments on commit 9df2da7

Please sign in to comment.