Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
mpilquist committed Jan 16, 2025
1 parent f6a00e0 commit c582579
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 54 deletions.
63 changes: 36 additions & 27 deletions modules/core/shared/src/main/scala/data/Cache.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,57 +4,66 @@

package skunk.data


import scala.collection.immutable.SortedMap

final case class Cache[K, V] private (
/**
* Immutable, least recently used cache.
*
* Entries are stored in a hash map. Upon insertion and upon each access of an entry,
* a numeric stamp is assigned to each entry. Stamps start at 0 and increase for each
* insertion/access. The `accesses` field stores a sorted map of stamp to entry key.
* Resultantly, the head of `accesses` is the key of the least recently used entry.
*/
sealed abstract case class Cache[K, V] private (
max: Int,
entries: Map[K, V]
)(accesses: SortedMap[Long, K],
accessesInverted: Map[K, Long],
counter: Long
)(accesses: SortedBiMap[Long, K],
stamp: Long
) {
assert(entries.size == accesses.size)

private def accessesWithoutKey(k: K): SortedMap[Long, K] =
accessesInverted.get(k).fold(accesses)(oldCounter => accesses - oldCounter)
def size: Int = entries.size

def contains(k: K): Boolean = entries.contains(k)

def get(k: K): Option[V] =
lookup(k).map(_._2)

def lookup(k: K): Option[(Cache[K, V], V)] =
def get(k: K): Option[(Cache[K, V], V)] =
entries.get(k) match {
case Some(v) =>
val newAccesses = accessesWithoutKey(k) + (counter -> k)
val newCache = new Cache(max, entries)(newAccesses, accessesInverted + (k -> counter), counter + 1)
val newAccesses = accesses + (stamp -> k)
val newCache = Cache(max, entries, newAccesses, stamp + 1)
Some(newCache -> v)
case None =>
None
}

def put(k: K, v: V): Cache[K, V] =
insert(k, v)._1

def insert(k: K, v: V): (Cache[K, V], Option[(K, V)]) =
def put(k: K, v: V): (Cache[K, V], Option[(K, V)]) =
if (max <= 0) (this, Some((k, v)))
else if (entries.size >= max && !containsKey(k)) {
val (counterToEvict, keyToEvict) = accesses.head
val newCache = new Cache(max, entries - keyToEvict + (k -> v))(accessesWithoutKey(k) - counterToEvict + (counter -> k), accessesInverted + (k -> counter), counter + 1)
else if (entries.size >= max && !contains(k)) {
val (stampToEvict, keyToEvict) = accesses.head
val newEntries = entries - keyToEvict + (k -> v)
val newAccesses = accesses - stampToEvict + (stamp -> k)
val newCache = Cache(max, newEntries, newAccesses, stamp + 1)
(newCache, Some((keyToEvict, entries(keyToEvict))))
} else {
val newCache = new Cache(max, entries + (k -> v))(accessesWithoutKey(k) + (counter -> k), accessesInverted + (k -> counter), counter + 1)
(newCache, entries.get(k).filter(_ != v).map(k -> _))
val newEntries = entries + (k -> v)
val newAccesses = accesses + (stamp -> k)
val newCache = Cache(max, newEntries, newAccesses, stamp + 1)
// If the new value is different than what was previously stored
// under this key, if anything, evict the old (k, v) pairing
val evicted = entries.get(k).filter(_ != v).map(k -> _)
(newCache, evicted)
}

def containsKey(k: K): Boolean = entries.contains(k)
def values: Iterable[V] = entries.values

override def toString: String =
accesses.map { case (_, k) => s"$k -> ${entries(k)}" }.mkString("Cache(", ", ", ")")
accesses.entries.iterator.map { case (_, k) => s"$k -> ${entries(k)}" }.mkString("Cache(", ", ", ")")
}

object Cache {
private def apply[K, V](max: Int, entries: Map[K, V], accesses: SortedBiMap[Long, K], stamp: Long): Cache[K, V] =
new Cache(max, entries)(accesses, stamp) {}

def empty[K, V](max: Int): Cache[K, V] =
new Cache(max max 0, Map.empty)(SortedMap.empty, Map.empty, 0L)
apply(max max 0, Map.empty, SortedBiMap.empty, 0L)
}


45 changes: 45 additions & 0 deletions modules/core/shared/src/main/scala/data/SortedBiMap.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Copyright (c) 2018-2024 by Rob Norris and Contributors
// This software is licensed under the MIT License (MIT).
// For more information see LICENSE or https://opensource.org/licenses/MIT

package skunk.data

import scala.collection.immutable.SortedMap
import scala.math.Ordering

/** Immutable bi-directional map that is sorted by key. */
sealed abstract case class SortedBiMap[K: Ordering, V](entries: SortedMap[K, V], inverse: Map[V, K]) {
assert(entries.size == inverse.size)

def size: Int = entries.size

def head: (K, V) = entries.head

def get(k: K): Option[V] = entries.get(k)

def put(k: K, v: V): SortedBiMap[K, V] =
SortedBiMap(
inverse.get(v).fold(entries)(entries - _) + (k -> v),
entries.get(k).fold(inverse)(inverse - _) + (v -> k))

def +(kv: (K, V)): SortedBiMap[K, V] = put(kv._1, kv._2)

def -(k: K): SortedBiMap[K, V] =
get(k) match {
case Some(v) => SortedBiMap(entries - k, inverse - v)
case None => this
}

def inverseGet(v: V): Option[K] = inverse.get(v)

override def toString: String =
entries.iterator.map { case (k, v) => s"$k <-> $v" }.mkString("SortedBiMap(", ", ", ")")
}

object SortedBiMap {
private def apply[K: Ordering, V](entries: SortedMap[K, V], inverse: Map[V, K]): SortedBiMap[K, V] =
new SortedBiMap[K, V](entries, inverse) {}

def empty[K: Ordering, V]: SortedBiMap[K, V] = apply(SortedMap.empty, Map.empty)
}

6 changes: 3 additions & 3 deletions modules/core/shared/src/main/scala/util/StatementCache.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,21 @@ object StatementCache {

def get(k: Statement[_]): F[Option[V]] =
ref.modify { case (c, evicted) =>
c.lookup(k.cacheKey) match {
c.get(k.cacheKey) match {
case Some((cʹ, v)) => (cʹ -> evicted, Some(v))
case None => (c -> evicted, None)
}
}

def put(k: Statement[_], v: V): F[Unit] =
ref.update { case (c, evicted) =>
val (c2, e) = c.insert(k.cacheKey, v)
val (c2, e) = c.put(k.cacheKey, v)
val evicted2 = e.filter(_ => trackEviction).fold(evicted) { case (_, v) => evicted + v }
(c2, evicted2)
}

def containsKey(k: Statement[_]): F[Boolean] =
ref.get.map(_._1.containsKey(k.cacheKey))
ref.get.map(_._1.contains(k.cacheKey))

def clear: F[Unit] =
ref.update { case (c, evicted) =>
Expand Down
48 changes: 24 additions & 24 deletions modules/tests/shared/src/test/scala/data/CacheTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@ package skunk.data

import munit.ScalaCheckSuite
import org.scalacheck.Gen
import org.scalacheck.Prop._
import org.scalacheck.Prop.forAll

class CacheTest extends ScalaCheckSuite {

val genEmpty: Gen[Cache[Int, String]] =
Gen.choose(-1, 10).map(Cache.empty)

test("insert on empty cache results in eviction") {
val cache = Cache.empty(0).put("one", 1)
val cache = Cache.empty(0).put("one", 1)._1
assertEquals(cache.values.toList, Nil)
assert(!cache.containsKey("one"))
assert(!cache.contains("one"))
}

test("max is never negative") {
Expand All @@ -25,10 +25,10 @@ class CacheTest extends ScalaCheckSuite {
}
}

test("insert should allow lookup") {
test("insert should allow get") {
forAll(genEmpty) { c =>
val cʹ = c.put(1, "x")
assertEquals(cʹ.lookup(1), if (c.max == 0) None else Some((cʹ, "x")))
val cʹ = c.put(1, "x")._1
assertEquals(cʹ.get(1), if (c.max == 0) None else Some((cʹ, "x")))
}
}

Expand All @@ -37,43 +37,43 @@ class CacheTest extends ScalaCheckSuite {
val max = c.max

// Load up the cache such that one more insert will cause it to overflow
val cʹ = (0 until max).foldLeft(c) { case (c, n) => c.insert(n, "x")._1 }
assertEquals(.values.size, max)
val c1 = (0 until max).foldLeft(c) { case (c, n) => c.put(n, "x")._1 }
assertEquals(c1.values.size, max)

// Overflow the cache
val (cʹʹ, evicted) = cʹ.insert(max, "x")
assertEquals(cʹʹ.values.size, max)
val (c2, evicted) = c1.put(max, "x")
assertEquals(c2.values.size, max)
assertEquals(evicted, Some(0 -> "x"))

if (max > 2) {
// Access oldest element
val cʹʹʹ = cʹʹ.lookup(1).get._1
val c3 = c2.get(1).get._1

// Insert another element and make sure oldest element is not the element evicted
val (cʹʹʹʹ, evictedʹ) = cʹʹʹ.insert(max + 1, "x")
assertEquals(evictedʹ, Some(2 -> "x"))
val (c4, evicted1) = c3.put(max + 1, "x")
assertEquals(evicted1, Some(2 -> "x"))
}
}
}

test("eviction 2") {
val c1 = Cache.empty(2).put("one", 1)
val c1 = Cache.empty(2).put("one", 1)._1
assertEquals(c1.values.toList, List(1))
assertEquals(c1.get("one"), Some(1))
assertEquals(c1.get("one").map(_._2), Some(1))

val (c2, evicted2) = c1.insert("two", 2)
assert(c2.containsKey("one"))
assert(c2.containsKey("two"))
val (c2, evicted2) = c1.put("two", 2)
assert(c2.contains("one"))
assert(c2.contains("two"))
assertEquals(evicted2, None)

val (c3, evicted3) = c2.insert("one", 1)
assert(c3.containsKey("one"))
assert(c3.containsKey("two"))
val (c3, evicted3) = c2.put("one", 1)
assert(c3.contains("one"))
assert(c3.contains("two"))
assertEquals(evicted3, None)

val (c4, evicted4) = c2.insert("one", 0)
assert(c4.containsKey("one"))
assert(c4.containsKey("two"))
val (c4, evicted4) = c2.put("one", 0)
assert(c4.contains("one"))
assert(c4.contains("two"))
assertEquals(evicted4, Some("one" -> 1))
}
}
38 changes: 38 additions & 0 deletions modules/tests/shared/src/test/scala/data/SortedBiMapTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright (c) 2018-2024 by Rob Norris and Contributors
// This software is licensed under the MIT License (MIT).
// For more information see LICENSE or https://opensource.org/licenses/MIT

package skunk.data

import munit.ScalaCheckSuite
import org.scalacheck.Prop

class SortedBiMapTest extends ScalaCheckSuite {

test("put handles overwrites") {
val m = SortedBiMap.empty[Int, Int].put(1, 2)
assertEquals(m.size, 1)
assertEquals(m.get(1), Some(2))
assertEquals(m.inverseGet(2), Some(1))

val m2 = m.put(3, 2)
assertEquals(m2.size, 1)
assertEquals(m2.get(3), Some(2))
assertEquals(m2.inverseGet(2), Some(3))
assertEquals(m2.get(1), None)

val m3 = m2.put(3, 4)
assertEquals(m3.size, 1)
assertEquals(m3.get(3), Some(4))
assertEquals(m3.inverseGet(4), Some(3))
assertEquals(m3.inverseGet(2), None)
}

test("entries are sorted") {
Prop.forAll { (s: Set[Int]) =>
val m = s.foldLeft(SortedBiMap.empty[Int, String])((acc, i) => acc.put(i, i.toString))
assertEquals(m.size, s.size)
assertEquals(m.entries.keySet.toList, s.toList.sorted)
}
}
}

0 comments on commit c582579

Please sign in to comment.